Skip to content

Commit

Permalink
CircuitOperation: change use_repetition_ids default to False (#6910)
Browse files Browse the repository at this point in the history
Review: @pavoljuhas
  • Loading branch information
maffoo authored Jan 6, 2025
1 parent b840178 commit 5ffb3ad
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 107 deletions.
49 changes: 32 additions & 17 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
repetition_ids: Optional[Sequence[str]] = None,
parent_path: Tuple[str, ...] = (),
extern_keys: FrozenSet['cirq.MeasurementKey'] = frozenset(),
use_repetition_ids: bool = True,
use_repetition_ids: Optional[bool] = None,
repeat_until: Optional['cirq.Condition'] = None,
):
"""Initializes a CircuitOperation.
Expand Down Expand Up @@ -120,7 +120,8 @@ def __init__(
use_repetition_ids: When True, any measurement key in the subcircuit
will have its path prepended with the repetition id for each
repetition. When False, this will not happen and the measurement
key will be repeated.
key will be repeated. When None, default to False unless the caller
passes `repetition_ids` explicitly.
repeat_until: A condition that will be tested after each iteration of
the subcircuit. The subcircuit will repeat until condition returns
True, but will always run at least once, and the measurement key
Expand Down Expand Up @@ -156,6 +157,8 @@ def __init__(
# Ensure that the circuit is invertible if the repetitions are negative.
self._repetitions = repetitions
self._repetition_ids = None if repetition_ids is None else list(repetition_ids)
if use_repetition_ids is None:
use_repetition_ids = repetition_ids is not None
self._use_repetition_ids = use_repetition_ids
if isinstance(self._repetitions, float):
if math.isclose(self._repetitions, round(self._repetitions)):
Expand Down Expand Up @@ -263,7 +266,7 @@ def replace(self, **changes) -> 'cirq.CircuitOperation':
'repetition_ids': self.repetition_ids,
'parent_path': self.parent_path,
'extern_keys': self._extern_keys,
'use_repetition_ids': self.use_repetition_ids,
'use_repetition_ids': True if 'repetition_ids' in changes else self.use_repetition_ids,
'repeat_until': self.repeat_until,
**changes,
}
Expand Down Expand Up @@ -448,11 +451,9 @@ def __repr__(self):
args += f'param_resolver={proper_repr(self.param_resolver)},\n'
if self.parent_path:
args += f'parent_path={proper_repr(self.parent_path)},\n'
if self.repetition_ids != self._default_repetition_ids():
if self.use_repetition_ids:
# Default repetition_ids need not be specified.
args += f'repetition_ids={proper_repr(self.repetition_ids)},\n'
if not self.use_repetition_ids:
args += 'use_repetition_ids=False,\n'
if self.repeat_until:
args += f'repeat_until={self.repeat_until!r},\n'
indented_args = args.replace('\n', '\n ')
Expand All @@ -477,14 +478,15 @@ def dict_str(d: Mapping) -> str:
args.append(f'params={self.param_resolver.param_dict}')
if self.parent_path:
args.append(f'parent_path={self.parent_path}')
if self.repetition_ids != self._default_repetition_ids():
# Default repetition_ids need not be specified.
args.append(f'repetition_ids={self.repetition_ids}')
if self.use_repetition_ids:
if self.repetition_ids != self._default_repetition_ids():
args.append(f'repetition_ids={self.repetition_ids}')
else:
# Default repetition_ids need not be specified.
args.append(f'loops={self.repetitions}, use_repetition_ids=True')
elif self.repetitions != 1:
# Only add loops if we haven't added repetition_ids.
# Add loops if not using repetition_ids.
args.append(f'loops={self.repetitions}')
if not self.use_repetition_ids:
args.append('no_rep_ids')
if self.repeat_until:
args.append(f'until={self.repeat_until}')
if not args:
Expand Down Expand Up @@ -529,10 +531,9 @@ def _json_dict_(self):
'measurement_key_map': self.measurement_key_map,
'param_resolver': self.param_resolver,
'repetition_ids': self.repetition_ids,
'use_repetition_ids': self.use_repetition_ids,
'parent_path': self.parent_path,
}
if not self.use_repetition_ids:
resp['use_repetition_ids'] = False
if self.repeat_until:
resp['repeat_until'] = self.repeat_until
return resp
Expand Down Expand Up @@ -566,7 +567,10 @@ def _from_json_dict_(
# Methods for constructing a similar object with one field modified.

def repeat(
self, repetitions: Optional[IntParam] = None, repetition_ids: Optional[Sequence[str]] = None
self,
repetitions: Optional[IntParam] = None,
repetition_ids: Optional[Sequence[str]] = None,
use_repetition_ids: Optional[bool] = None,
) -> 'CircuitOperation':
"""Returns a copy of this operation repeated 'repetitions' times.
Each repetition instance will be identified by a single repetition_id.
Expand All @@ -577,6 +581,10 @@ def repeat(
defaults to the length of `repetition_ids`.
repetition_ids: List of IDs, one for each repetition. If unset,
defaults to `default_repetition_ids(repetitions)`.
use_repetition_ids: If given, this specifies the value for `use_repetition_ids`
of the resulting circuit operation. If not given, we enable ids if
`repetition_ids` is not None, and otherwise fall back to
`self.use_repetition_ids`.
Returns:
A copy of this operation repeated `repetitions` times with the
Expand All @@ -591,6 +599,9 @@ def repeat(
ValueError: Unexpected length of `repetition_ids`.
ValueError: Both `repetitions` and `repetition_ids` are None.
"""
if use_repetition_ids is None:
use_repetition_ids = True if repetition_ids is not None else self.use_repetition_ids

if repetitions is None:
if repetition_ids is None:
raise ValueError('At least one of repetitions and repetition_ids must be set')
Expand All @@ -604,7 +615,7 @@ def repeat(
expected_repetition_id_length: int = np.abs(repetitions)

if repetition_ids is None:
if self.use_repetition_ids:
if use_repetition_ids:
repetition_ids = default_repetition_ids(expected_repetition_id_length)
elif len(repetition_ids) != expected_repetition_id_length:
raise ValueError(
Expand All @@ -617,7 +628,11 @@ def repeat(

# The eventual number of repetitions of the returned CircuitOperation.
final_repetitions = protocols.mul(self.repetitions, repetitions)
return self.replace(repetitions=final_repetitions, repetition_ids=repetition_ids)
return self.replace(
repetitions=final_repetitions,
repetition_ids=repetition_ids,
use_repetition_ids=use_repetition_ids,
)

def __pow__(self, power: IntParam) -> 'cirq.CircuitOperation':
return self.repeat(power)
Expand Down
63 changes: 38 additions & 25 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,15 @@ def test_repeat(add_measurements: bool, use_default_ids_for_initial_rep: bool) -
op_with_reps: Optional[cirq.CircuitOperation] = None
rep_ids = []
if use_default_ids_for_initial_rep:
op_with_reps = op_base.repeat(initial_repetitions)
rep_ids = ['0', '1', '2']
assert op_base**initial_repetitions == op_with_reps
op_with_reps = op_base.repeat(initial_repetitions, use_repetition_ids=True)
else:
rep_ids = ['a', 'b', 'c']
op_with_reps = op_base.repeat(initial_repetitions, rep_ids)
assert op_base**initial_repetitions != op_with_reps
assert (op_base**initial_repetitions).replace(repetition_ids=rep_ids) == op_with_reps
assert op_base**initial_repetitions != op_with_reps
assert (op_base**initial_repetitions).replace(repetition_ids=rep_ids) == op_with_reps
assert op_with_reps.repetitions == initial_repetitions
assert op_with_reps.use_repetition_ids
assert op_with_reps.repetition_ids == rep_ids
assert op_with_reps.repeat(1) is op_with_reps

Expand Down Expand Up @@ -436,6 +436,7 @@ def test_parameterized_repeat_side_effects():
op = cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.X(q).with_classical_controls('c'), cirq.measure(q, key='m')),
repetitions=sympy.Symbol('a'),
use_repetition_ids=True,
)

# Control keys can be calculated because they only "lift" if there's a matching
Expand Down Expand Up @@ -689,7 +690,6 @@ def test_string_format():
),
),
]),
use_repetition_ids=False,
)"""
)
op7 = cirq.CircuitOperation(
Expand All @@ -706,7 +706,6 @@ def test_string_format():
cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='a')),
),
]),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey(name='a')),
)"""
)
Expand Down Expand Up @@ -737,6 +736,7 @@ def test_json_dict():
'param_resolver': op.param_resolver,
'parent_path': op.parent_path,
'repetition_ids': None,
'use_repetition_ids': False,
}


Expand Down Expand Up @@ -843,6 +843,26 @@ def test_decompose_loops_with_measurements():
circuit = cirq.FrozenCircuit(cirq.H(a), cirq.CX(a, b), cirq.measure(a, b, key='m'))
base_op = cirq.CircuitOperation(circuit)

op = base_op.with_qubits(b, a).repeat(3)
expected_circuit = cirq.Circuit(
cirq.H(b),
cirq.CX(b, a),
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
cirq.H(b),
cirq.CX(b, a),
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
cirq.H(b),
cirq.CX(b, a),
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
)
assert cirq.Circuit(cirq.decompose_once(op)) == expected_circuit


def test_decompose_loops_with_measurements_use_rep_ids():
a, b = cirq.LineQubit.range(2)
circuit = cirq.FrozenCircuit(cirq.H(a), cirq.CX(a, b), cirq.measure(a, b, key='m'))
base_op = cirq.CircuitOperation(circuit, use_repetition_ids=True)

op = base_op.with_qubits(b, a).repeat(3)
expected_circuit = cirq.Circuit(
cirq.H(b),
Expand Down Expand Up @@ -999,7 +1019,9 @@ def test_keys_under_parent_path():
op3 = cirq.with_key_path_prefix(op2, ('C',))
assert cirq.measurement_key_names(op3) == {'C:B:A'}
op4 = op3.repeat(2)
assert cirq.measurement_key_names(op4) == {'C:B:0:A', 'C:B:1:A'}
assert cirq.measurement_key_names(op4) == {'C:B:A'}
op4_rep = op3.repeat(2).replace(use_repetition_ids=True)
assert cirq.measurement_key_names(op4_rep) == {'C:B:0:A', 'C:B:1:A'}


def test_mapped_circuit_preserves_moments():
Expand Down Expand Up @@ -1077,12 +1099,8 @@ def test_mapped_circuit_allows_repeated_keys():
def test_simulate_no_repetition_ids_both_levels(sim):
q = cirq.LineQubit(0)
inner = cirq.Circuit(cirq.measure(q, key='a'))
middle = cirq.Circuit(
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False)
)
outer_subcircuit = cirq.CircuitOperation(
middle.freeze(), repetitions=2, use_repetition_ids=False
)
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
circuit = cirq.Circuit(outer_subcircuit)
result = sim.run(circuit)
assert result.records['a'].shape == (1, 4, 1)
Expand All @@ -1092,10 +1110,10 @@ def test_simulate_no_repetition_ids_both_levels(sim):
def test_simulate_no_repetition_ids_outer(sim):
q = cirq.LineQubit(0)
inner = cirq.Circuit(cirq.measure(q, key='a'))
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
outer_subcircuit = cirq.CircuitOperation(
middle.freeze(), repetitions=2, use_repetition_ids=False
middle = cirq.Circuit(
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=True)
)
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
circuit = cirq.Circuit(outer_subcircuit)
result = sim.run(circuit)
assert result.records['0:a'].shape == (1, 2, 1)
Expand All @@ -1106,10 +1124,10 @@ def test_simulate_no_repetition_ids_outer(sim):
def test_simulate_no_repetition_ids_inner(sim):
q = cirq.LineQubit(0)
inner = cirq.Circuit(cirq.measure(q, key='a'))
middle = cirq.Circuit(
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False)
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
outer_subcircuit = cirq.CircuitOperation(
middle.freeze(), repetitions=2, use_repetition_ids=True
)
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
circuit = cirq.Circuit(outer_subcircuit)
result = sim.run(circuit)
assert result.records['0:a'].shape == (1, 2, 1)
Expand All @@ -1124,7 +1142,6 @@ def test_repeat_until(sim):
cirq.X(q),
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.X(q), cirq.measure(q, key=key)),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
),
)
Expand All @@ -1139,7 +1156,6 @@ def test_repeat_until_sympy(sim):
q1, q2 = cirq.LineQubit.range(2)
circuitop = cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.X(q2), cirq.measure(q2, key='b')),
use_repetition_ids=False,
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), sympy.Symbol('b'))),
)
c = cirq.Circuit(cirq.measure(q1, key='a'), circuitop)
Expand All @@ -1159,7 +1175,6 @@ def test_post_selection(sim):
c = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.X(q) ** 0.2, cirq.measure(q, key=key)),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
)
)
Expand All @@ -1175,14 +1190,13 @@ def test_repeat_until_diagram():
c = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.X(q) ** 0.2, cirq.measure(q, key=key)),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
)
)
cirq.testing.assert_has_diagram(
c,
"""
0: ───[ 0: ───X^0.2───M('m')─── ](no_rep_ids, until=m)───
0: ───[ 0: ───X^0.2───M('m')─── ](until=m)───
""",
use_unicode_characters=True,
)
Expand All @@ -1199,7 +1213,6 @@ def test_repeat_until_error():
with pytest.raises(ValueError, match='Infinite loop'):
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.measure(q, key='m')),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
)

Expand Down
Loading

0 comments on commit 5ffb3ad

Please sign in to comment.