Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CircuitOperation: change use_repetition_ids default to False #6910

Merged
merged 5 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
repetition_ids: Optional[Sequence[str]] = None,
parent_path: Tuple[str, ...] = (),
extern_keys: FrozenSet['cirq.MeasurementKey'] = frozenset(),
use_repetition_ids: bool = True,
use_repetition_ids: Optional[bool] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit - consider updating the use_repetition_ids docstring below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

repeat_until: Optional['cirq.Condition'] = None,
):
"""Initializes a CircuitOperation.
Expand Down Expand Up @@ -120,7 +120,8 @@ def __init__(
use_repetition_ids: When True, any measurement key in the subcircuit
will have its path prepended with the repetition id for each
repetition. When False, this will not happen and the measurement
key will be repeated.
key will be repeated. When None, default to False unless the caller
passes `repetition_ids` explicitly.
repeat_until: A condition that will be tested after each iteration of
the subcircuit. The subcircuit will repeat until condition returns
True, but will always run at least once, and the measurement key
Expand Down Expand Up @@ -156,6 +157,8 @@ def __init__(
# Ensure that the circuit is invertible if the repetitions are negative.
self._repetitions = repetitions
self._repetition_ids = None if repetition_ids is None else list(repetition_ids)
if use_repetition_ids is None:
use_repetition_ids = repetition_ids is not None
self._use_repetition_ids = use_repetition_ids
if isinstance(self._repetitions, float):
if math.isclose(self._repetitions, round(self._repetitions)):
Expand Down Expand Up @@ -263,7 +266,7 @@ def replace(self, **changes) -> 'cirq.CircuitOperation':
'repetition_ids': self.repetition_ids,
'parent_path': self.parent_path,
'extern_keys': self._extern_keys,
'use_repetition_ids': self.use_repetition_ids,
'use_repetition_ids': True if 'repetition_ids' in changes else self.use_repetition_ids,
'repeat_until': self.repeat_until,
**changes,
}
Expand Down Expand Up @@ -448,11 +451,9 @@ def __repr__(self):
args += f'param_resolver={proper_repr(self.param_resolver)},\n'
if self.parent_path:
args += f'parent_path={proper_repr(self.parent_path)},\n'
if self.repetition_ids != self._default_repetition_ids():
if self.use_repetition_ids:
# Default repetition_ids need not be specified.
args += f'repetition_ids={proper_repr(self.repetition_ids)},\n'
if not self.use_repetition_ids:
args += 'use_repetition_ids=False,\n'
if self.repeat_until:
args += f'repeat_until={self.repeat_until!r},\n'
indented_args = args.replace('\n', '\n ')
Expand All @@ -477,14 +478,15 @@ def dict_str(d: Mapping) -> str:
args.append(f'params={self.param_resolver.param_dict}')
if self.parent_path:
args.append(f'parent_path={self.parent_path}')
if self.repetition_ids != self._default_repetition_ids():
# Default repetition_ids need not be specified.
args.append(f'repetition_ids={self.repetition_ids}')
if self.use_repetition_ids:
if self.repetition_ids != self._default_repetition_ids():
args.append(f'repetition_ids={self.repetition_ids}')
else:
# Default repetition_ids need not be specified.
args.append(f'loops={self.repetitions}, use_repetition_ids=True')
elif self.repetitions != 1:
# Only add loops if we haven't added repetition_ids.
# Add loops if not using repetition_ids.
args.append(f'loops={self.repetitions}')
if not self.use_repetition_ids:
args.append('no_rep_ids')
if self.repeat_until:
args.append(f'until={self.repeat_until}')
if not args:
Expand Down Expand Up @@ -529,10 +531,9 @@ def _json_dict_(self):
'measurement_key_map': self.measurement_key_map,
'param_resolver': self.param_resolver,
'repetition_ids': self.repetition_ids,
'use_repetition_ids': self.use_repetition_ids,
'parent_path': self.parent_path,
}
if not self.use_repetition_ids:
resp['use_repetition_ids'] = False
if self.repeat_until:
resp['repeat_until'] = self.repeat_until
return resp
Expand Down Expand Up @@ -566,7 +567,10 @@ def _from_json_dict_(
# Methods for constructing a similar object with one field modified.

def repeat(
self, repetitions: Optional[IntParam] = None, repetition_ids: Optional[Sequence[str]] = None
self,
repetitions: Optional[IntParam] = None,
repetition_ids: Optional[Sequence[str]] = None,
use_repetition_ids: Optional[bool] = None,
) -> 'CircuitOperation':
"""Returns a copy of this operation repeated 'repetitions' times.
Each repetition instance will be identified by a single repetition_id.
Expand All @@ -577,6 +581,10 @@ def repeat(
defaults to the length of `repetition_ids`.
repetition_ids: List of IDs, one for each repetition. If unset,
defaults to `default_repetition_ids(repetitions)`.
use_repetition_ids: If given, this specifies the value for `use_repetition_ids`
of the resulting circuit operation. If not given, we enable ids if
`repetition_ids` is not None, and otherwise fall back to
`self.use_repetition_ids`.

Returns:
A copy of this operation repeated `repetitions` times with the
Expand All @@ -591,6 +599,9 @@ def repeat(
ValueError: Unexpected length of `repetition_ids`.
ValueError: Both `repetitions` and `repetition_ids` are None.
"""
if use_repetition_ids is None:
use_repetition_ids = True if repetition_ids is not None else self.use_repetition_ids

if repetitions is None:
if repetition_ids is None:
raise ValueError('At least one of repetitions and repetition_ids must be set')
Expand All @@ -604,7 +615,7 @@ def repeat(
expected_repetition_id_length: int = np.abs(repetitions)

if repetition_ids is None:
if self.use_repetition_ids:
if use_repetition_ids:
repetition_ids = default_repetition_ids(expected_repetition_id_length)
elif len(repetition_ids) != expected_repetition_id_length:
raise ValueError(
Expand All @@ -617,7 +628,11 @@ def repeat(

# The eventual number of repetitions of the returned CircuitOperation.
final_repetitions = protocols.mul(self.repetitions, repetitions)
return self.replace(repetitions=final_repetitions, repetition_ids=repetition_ids)
return self.replace(
repetitions=final_repetitions,
repetition_ids=repetition_ids,
use_repetition_ids=use_repetition_ids,
)

def __pow__(self, power: IntParam) -> 'cirq.CircuitOperation':
return self.repeat(power)
Expand Down
63 changes: 38 additions & 25 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,15 @@ def test_repeat(add_measurements: bool, use_default_ids_for_initial_rep: bool) -
op_with_reps: Optional[cirq.CircuitOperation] = None
rep_ids = []
if use_default_ids_for_initial_rep:
op_with_reps = op_base.repeat(initial_repetitions)
rep_ids = ['0', '1', '2']
assert op_base**initial_repetitions == op_with_reps
op_with_reps = op_base.repeat(initial_repetitions, use_repetition_ids=True)
else:
rep_ids = ['a', 'b', 'c']
op_with_reps = op_base.repeat(initial_repetitions, rep_ids)
assert op_base**initial_repetitions != op_with_reps
assert (op_base**initial_repetitions).replace(repetition_ids=rep_ids) == op_with_reps
assert op_base**initial_repetitions != op_with_reps
assert (op_base**initial_repetitions).replace(repetition_ids=rep_ids) == op_with_reps
assert op_with_reps.repetitions == initial_repetitions
assert op_with_reps.use_repetition_ids
assert op_with_reps.repetition_ids == rep_ids
assert op_with_reps.repeat(1) is op_with_reps

Expand Down Expand Up @@ -436,6 +436,7 @@ def test_parameterized_repeat_side_effects():
op = cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.X(q).with_classical_controls('c'), cirq.measure(q, key='m')),
repetitions=sympy.Symbol('a'),
use_repetition_ids=True,
)

# Control keys can be calculated because they only "lift" if there's a matching
Expand Down Expand Up @@ -689,7 +690,6 @@ def test_string_format():
),
),
]),
use_repetition_ids=False,
)"""
)
op7 = cirq.CircuitOperation(
Expand All @@ -706,7 +706,6 @@ def test_string_format():
cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='a')),
),
]),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey(name='a')),
)"""
)
Expand Down Expand Up @@ -737,6 +736,7 @@ def test_json_dict():
'param_resolver': op.param_resolver,
'parent_path': op.parent_path,
'repetition_ids': None,
'use_repetition_ids': False,
}


Expand Down Expand Up @@ -843,6 +843,26 @@ def test_decompose_loops_with_measurements():
circuit = cirq.FrozenCircuit(cirq.H(a), cirq.CX(a, b), cirq.measure(a, b, key='m'))
base_op = cirq.CircuitOperation(circuit)

op = base_op.with_qubits(b, a).repeat(3)
expected_circuit = cirq.Circuit(
cirq.H(b),
cirq.CX(b, a),
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
cirq.H(b),
cirq.CX(b, a),
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
cirq.H(b),
cirq.CX(b, a),
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
)
assert cirq.Circuit(cirq.decompose_once(op)) == expected_circuit


def test_decompose_loops_with_measurements_use_rep_ids():
a, b = cirq.LineQubit.range(2)
circuit = cirq.FrozenCircuit(cirq.H(a), cirq.CX(a, b), cirq.measure(a, b, key='m'))
base_op = cirq.CircuitOperation(circuit, use_repetition_ids=True)

op = base_op.with_qubits(b, a).repeat(3)
expected_circuit = cirq.Circuit(
cirq.H(b),
Expand Down Expand Up @@ -999,7 +1019,9 @@ def test_keys_under_parent_path():
op3 = cirq.with_key_path_prefix(op2, ('C',))
assert cirq.measurement_key_names(op3) == {'C:B:A'}
op4 = op3.repeat(2)
assert cirq.measurement_key_names(op4) == {'C:B:0:A', 'C:B:1:A'}
assert cirq.measurement_key_names(op4) == {'C:B:A'}
op4_rep = op3.repeat(2).replace(use_repetition_ids=True)
assert cirq.measurement_key_names(op4_rep) == {'C:B:0:A', 'C:B:1:A'}


def test_mapped_circuit_preserves_moments():
Expand Down Expand Up @@ -1077,12 +1099,8 @@ def test_mapped_circuit_allows_repeated_keys():
def test_simulate_no_repetition_ids_both_levels(sim):
q = cirq.LineQubit(0)
inner = cirq.Circuit(cirq.measure(q, key='a'))
middle = cirq.Circuit(
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False)
)
outer_subcircuit = cirq.CircuitOperation(
middle.freeze(), repetitions=2, use_repetition_ids=False
)
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
circuit = cirq.Circuit(outer_subcircuit)
result = sim.run(circuit)
assert result.records['a'].shape == (1, 4, 1)
Expand All @@ -1092,10 +1110,10 @@ def test_simulate_no_repetition_ids_both_levels(sim):
def test_simulate_no_repetition_ids_outer(sim):
q = cirq.LineQubit(0)
inner = cirq.Circuit(cirq.measure(q, key='a'))
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
outer_subcircuit = cirq.CircuitOperation(
middle.freeze(), repetitions=2, use_repetition_ids=False
middle = cirq.Circuit(
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=True)
)
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
circuit = cirq.Circuit(outer_subcircuit)
result = sim.run(circuit)
assert result.records['0:a'].shape == (1, 2, 1)
Expand All @@ -1106,10 +1124,10 @@ def test_simulate_no_repetition_ids_outer(sim):
def test_simulate_no_repetition_ids_inner(sim):
q = cirq.LineQubit(0)
inner = cirq.Circuit(cirq.measure(q, key='a'))
middle = cirq.Circuit(
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False)
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
outer_subcircuit = cirq.CircuitOperation(
middle.freeze(), repetitions=2, use_repetition_ids=True
)
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
circuit = cirq.Circuit(outer_subcircuit)
result = sim.run(circuit)
assert result.records['0:a'].shape == (1, 2, 1)
Expand All @@ -1124,7 +1142,6 @@ def test_repeat_until(sim):
cirq.X(q),
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.X(q), cirq.measure(q, key=key)),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
),
)
Expand All @@ -1139,7 +1156,6 @@ def test_repeat_until_sympy(sim):
q1, q2 = cirq.LineQubit.range(2)
circuitop = cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.X(q2), cirq.measure(q2, key='b')),
use_repetition_ids=False,
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), sympy.Symbol('b'))),
)
c = cirq.Circuit(cirq.measure(q1, key='a'), circuitop)
Expand All @@ -1159,7 +1175,6 @@ def test_post_selection(sim):
c = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.X(q) ** 0.2, cirq.measure(q, key=key)),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
)
)
Expand All @@ -1175,14 +1190,13 @@ def test_repeat_until_diagram():
c = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.X(q) ** 0.2, cirq.measure(q, key=key)),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
)
)
cirq.testing.assert_has_diagram(
c,
"""
0: ───[ 0: ───X^0.2───M('m')─── ](no_rep_ids, until=m)───
0: ───[ 0: ───X^0.2───M('m')─── ](until=m)───
""",
use_unicode_characters=True,
)
Expand All @@ -1199,7 +1213,6 @@ def test_repeat_until_error():
with pytest.raises(ValueError, match='Infinite loop'):
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.measure(q, key='m')),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
)

Expand Down
Loading
Loading