diff --git a/qadence/measurements/shadow.py b/qadence/measurements/shadow.py index 29978d22..9c4fded6 100644 --- a/qadence/measurements/shadow.py +++ b/qadence/measurements/shadow.py @@ -6,7 +6,7 @@ from qadence.backend import Backend from qadence.backends.pyqtorch import Backend as PyQBackend -from qadence.blocks import AbstractBlock, KronBlock, chain, kron +from qadence.blocks import AbstractBlock, KronBlock, kron from qadence.blocks.block_to_tensor import HMAT, IMAT, SDAGMAT from qadence.blocks.composite import CompositeBlock from qadence.blocks.primitive import PrimitiveBlock @@ -16,7 +16,7 @@ from qadence.measurements.utils import get_qubit_indices_for_op from qadence.noise import NoiseHandler from qadence.operations import H, I, SDagger, X, Y, Z -from qadence.types import Endianness +from qadence.types import BackendName, Endianness pauli_gates = [X, Y, Z] pauli_rotations = [ @@ -134,20 +134,38 @@ def classical_shadow( shadow: list = list() all_rotations = extract_operators(unitary_ids, circuit.n_qubits) + initial_state = state + backend_name = backend.name if hasattr(backend, "name") else backend.backend.name + if backend_name == BackendName.PYQTORCH: + # run the initial circuit without rotations + # to save computation time + conv_circ = backend.circuit(circuit) + initial_state = backend.run( + circuit=conv_circ, + param_values=param_values, + state=state, + endianness=endianness, + ) + all_rotations = [ + QuantumCircuit(circuit.n_qubits, rots) if rots else QuantumCircuit(circuit.n_qubits) + for rots in all_rotations + ] + else: + all_rotations = [ + QuantumCircuit(circuit.n_qubits, circuit.block, rots) + if rots + else QuantumCircuit(circuit.n_qubits, circuit.block) + for rots in all_rotations + ] + for i in range(shadow_size): - if all_rotations[i]: - rotated_circuit = QuantumCircuit( - circuit.register, chain(circuit.block, all_rotations[i]) - ) - else: - rotated_circuit = circuit # Reverse endianness to get sample bitstrings in ILO. - conv_circ = backend.circuit(rotated_circuit) + conv_circ = backend.circuit(all_rotations[i]) batch_samples = backend.sample( circuit=conv_circ, param_values=param_values, n_shots=1, - state=state, + state=initial_state, noise=noise, endianness=endianness, )