Skip to content

Commit

Permalink
Apply as_sweep for SqrtCZGauge
Browse files Browse the repository at this point in the history
  • Loading branch information
babacry committed Jan 9, 2025
1 parent ecf9ed7 commit 033d607
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 30 deletions.
80 changes: 59 additions & 21 deletions cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,16 @@ def __call__(self, prng: np.random.Generator) -> Gauge:

@transformer_api.transformer
class GaugeTransformer:

def __init__(
self,
# target can be either a specific gate, gatefamily or gateset
# which allows matching parametric gates.
target: Union[ops.Gate, ops.Gateset, ops.GateFamily],
gauge_selector: Callable[[np.random.Generator], Gauge],
symbolize_2_qubit_gate_fn: Optional[
Callable[[ConstantGauge, sympy.Symbol], Tuple[ops.Gate, float]]
] = None,
) -> None:
"""Constructs a GaugeTransformer.
Expand All @@ -170,6 +174,7 @@ def __init__(
"""
self.target = ops.GateFamily(target) if isinstance(target, ops.Gate) else target
self.gauge_selector = gauge_selector
self.symbolize_2_qubit_gate_fn = symbolize_2_qubit_gate_fn

def __call__(
self,
Expand Down Expand Up @@ -235,15 +240,23 @@ def as_sweep(
if context.deep:
raise ValueError('GaugeTransformer cannot be used with deep=True')
new_moments: List[List[ops.Operation]] = [] # Store parameterized circuits.
values_by_params: Dict[str, List[float]] = {} # map from symbol name to N values.
symbol_count = count()
phxz_symbol_count = count()
two_qubit_gate_symbol_count = count()
# Map from "((pre|post),$qid,$moment_id)" to gate parameters.
# E.g. {(post,q1,2): {"x_exponent": "x1", "z_exponent": "z1", "axis_phase": "a1"}}
symbols_by_loc: Dict[Tuple[str, ops.Qid, int], Dict[str, sympy.Symbol]] = {}
# E.g., {(post,q1,2): {"x_exponent": "x1", "z_exponent": "z1", "axis_phase": "a1"}}
phxz_symbols_by_locs: Dict[Tuple[str, ops.Qid, int], Dict[str, sympy.Symbol]] = {}
# Map from "($q0,$q1,$moment_id)" to gate parameters.
# E.g., {(q0,q1,0): "s0"}.
two_qubit_gate_symbols_by_locs: Dict[Tuple[ops.Qid, ops.Qid, int], sympy.Symbol] = {}

def single_qubit_next_symbol() -> Dict[str, sympy.Symbol]:
sid = next(symbol_count)
return _parameterize(1, sid)
sid = next(phxz_symbol_count)
return _parameterize_to_phxz(sid)

# Returns a single symbol for 2 qubit gate parameterization.
def two_qubit_gate_next_symbol() -> sympy.Symbol:
sid = next(two_qubit_gate_symbol_count)
return sympy.Symbol(f"s{sid}")

# Build parameterized circuit.
for moment_id, moment in enumerate(circuit):
Expand All @@ -257,17 +270,25 @@ def single_qubit_next_symbol() -> Dict[str, sympy.Symbol]:
center_moment.append(op)
continue
if op.gate is not None and op in self.target:
random_gauge = self.gauge_selector(rng).sample(op.gate, rng)
# Build symbols for 2-qubit-gates if the transformer might transform it,
# otherwise, keep it as it is.
if self.symbolize_2_qubit_gate_fn is not None:
symbol: sympy.Symbol = two_qubit_gate_next_symbol()
two_qubit_gate_symbols_by_locs[(*op.qubits, moment_id)] = symbol
parameterized_2_qubit_gate, _ = self.symbolize_2_qubit_gate_fn(
random_gauge, symbol
)
center_moment.append(parameterized_2_qubit_gate.on(*op.qubits))
else:
center_moment.append(op)
# Build symbols for the gauge, for a 2-qubit gauge, symbols will be built for
# pre/post q0/q1 and the new 2-qubit gate if the 2-qubit gate is updated in
# the gauge compiling.
center_moment.append(op)
for prefix, q in itertools.product(["pre", "post"], op.qubits):
xza_by_symbols = single_qubit_next_symbol() # xza in phased xz gate.
loc = (prefix, q, moment_id)
symbols_by_loc[loc] = xza_by_symbols
phxz_symbols_by_locs[(prefix, q, moment_id)] = xza_by_symbols
new_op = ops.PhasedXZGate(**xza_by_symbols).on(q)
for symbol in xza_by_symbols.values():
values_by_params.update({str(symbol): []})
if prefix == "pre":
left_moment.append(new_op)
else:
Expand All @@ -278,6 +299,16 @@ def single_qubit_next_symbol() -> Dict[str, sympy.Symbol]:
[moment for moment in [left_moment, center_moment, right_moment] if moment]
)

# Initialize the map from symbol names to their N values.
values_by_params: Dict[str, List[float]] = {
**{
str(symbol): []
for symbols_by_names in phxz_symbols_by_locs.values()
for symbol in symbols_by_names.values()
},
**{str(symbol): [] for symbol in two_qubit_gate_symbols_by_locs.values()},
}

# Assign values for parameters via randomly chosen GaugeSelector.
for _ in range(N):
for moment_id, moment in enumerate(circuit):
Expand All @@ -292,13 +323,21 @@ def single_qubit_next_symbol() -> Dict[str, sympy.Symbol]:
raise NotImplementedError(
f"as_sweep isn't supported for {gauge.two_qubit_gate} gauge"
)

# Get the params for 2 qubit gates.
if self.symbolize_2_qubit_gate_fn is not None:
symbol = two_qubit_gate_symbols_by_locs[(*op.qubits, moment_id)]
_, val = self.symbolize_2_qubit_gate_fn(gauge, symbol)
values_by_params[str(symbol)].append(val)

# Get the params of pre/post q0/q1 gates.
for pre_or_post, idx in itertools.product(["pre", "post"], [0, 1]):
symbols = symbols_by_loc[(pre_or_post, op.qubits[idx], moment_id)]
symbols = phxz_symbols_by_locs[(pre_or_post, op.qubits[idx], moment_id)]
gates = getattr(gauge, f"{pre_or_post}_q{idx}")
phxz_params = _gate_sequence_to_phxz_params(gates, symbols)
for key, value in phxz_params.items():
values_by_params[key].append(value)

sweeps: List[Points] = [
Points(key=key, points=values) for key, values in values_by_params.items()
]
Expand All @@ -318,17 +357,16 @@ def _build_moments(operation_by_qubits: List[List[ops.Operation]]) -> List[List[
return moments


def _parameterize(num_qubits: int, symbol_id: int) -> Dict[str, sympy.Symbol]:
def _parameterize_to_phxz(symbol_id: int) -> Dict[str, sympy.Symbol]:
"""Returns symbolized parameters for the gate."""

if num_qubits == 1: # Convert single qubit gate to parameterized PhasedXZGate.
phased_xz_params = {
"x_exponent": sympy.Symbol(f"x{symbol_id}"),
"z_exponent": sympy.Symbol(f"z{symbol_id}"),
"axis_phase_exponent": sympy.Symbol(f"a{symbol_id}"),
}
return phased_xz_params
raise NotImplementedError("parameterization for non single qubit gates is not supported yet")
# Parameterize single qubit gate to parameterized PhasedXZGate.
phased_xz_params = {
"x_exponent": sympy.Symbol(f"x{symbol_id}"),
"z_exponent": sympy.Symbol(f"z{symbol_id}"),
"axis_phase_exponent": sympy.Symbol(f"a{symbol_id}"),
}
return phased_xz_params


def _gate_sequence_to_phxz_params(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
from cirq.transformers.gauge_compiling import (
GaugeTransformer,
CZGaugeTransformer,
SqrtCZGaugeTransformer,
ConstantGauge,
GaugeSelector,
)
from cirq.transformers.gauge_compiling.sqrt_cz_gauge import SqrtCZGauge
from cirq.transformers.gauge_compiling.iswap_gauge import RZRotation
from cirq.transformers.analytical_decompositions import single_qubit_decompositions


Expand Down Expand Up @@ -110,16 +113,27 @@ def test_as_sweep_convert_to_phxz_failed():
qs = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.CZ(*qs))

def mock_single_qubit_matrix_to_phxz(*args, **kwargs):
# Return an non PhasedXZ gate, so we expect errors from as_sweep().
return cirq.X

with unittest.mock.patch.object(
single_qubit_decompositions,
"single_qubit_matrix_to_phxz",
new=mock_single_qubit_matrix_to_phxz,
# Return an non PhasedXZ gate, so we expect errors from as_sweep().
return_value=cirq.X,
):
with pytest.raises(
ValueError, match="Failed to convert the gate sequence to a PhasedXZ gate."
):
_ = CZGaugeTransformer.as_sweep(c, context=cirq.TransformerContext(), N=1)


def test_symbolize_2_qubits_gate_failed():
qs = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.CZPowGate(exponent=0.5).on(*qs))

with unittest.mock.patch.object(
SqrtCZGauge,
"sample",
# ISWAP gate is not a CZPowGate; errors are expected when symbolizing the 2-qubit gate.
return_value=ConstantGauge(two_qubit_gate=cirq.ISWAP),
):
with pytest.raises(ValueError, match="Can't symbolize non-CZPowGate as CZ\\*\\*symbol."):
_ = SqrtCZGaugeTransformer.as_sweep(c, N=1)
21 changes: 17 additions & 4 deletions cirq-core/cirq/transformers/gauge_compiling/sqrt_cz_gauge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@
"""A Gauge transformer for CZ**0.5 and CZ**-0.5 gates."""


from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple
import numpy as np
import sympy

from cirq.transformers.gauge_compiling.gauge_compiling import (
GaugeTransformer,
GaugeSelector,
ConstantGauge,
Gauge,
)
from cirq.ops import CZ, S, X, Gateset
from cirq.ops import CZ, S, X, Gateset, Gate, CZPowGate


if TYPE_CHECKING:
import cirq
Expand All @@ -40,7 +42,7 @@ def weight(self) -> float:

def sample(self, gate: 'cirq.Gate', prng: np.random.Generator) -> ConstantGauge:
if prng.choice([True, False]):
return ConstantGauge(two_qubit_gate=gate)
return ConstantGauge(two_qubit_gate=gate, support_sweep=True)
swap_qubits = prng.choice([True, False])
if swap_qubits:
return ConstantGauge(
Expand All @@ -49,16 +51,27 @@ def sample(self, gate: 'cirq.Gate', prng: np.random.Generator) -> ConstantGauge:
post_q0=S if gate == _SQRT_CZ else _ADJ_S,
two_qubit_gate=gate**-1,
swap_qubits=True,
support_sweep=True,
)
else:
return ConstantGauge(
pre_q0=X,
post_q0=X,
post_q1=S if gate == _SQRT_CZ else _ADJ_S,
two_qubit_gate=gate**-1,
support_sweep=True,
)


def _symbolize_as_cz_pow(gauge: ConstantGauge, symbol: sympy.Symbol) -> Tuple[Gate, float]:
if not isinstance(gauge.two_qubit_gate, CZPowGate):
raise ValueError("Can't symbolize non-CZPowGate as CZ**symbol.")
gate: CZPowGate = gauge.two_qubit_gate
return CZ**symbol, gate.exponent


SqrtCZGaugeTransformer = GaugeTransformer(
target=Gateset(_SQRT_CZ, _SQRT_CZ**-1), gauge_selector=GaugeSelector(gauges=[SqrtCZGauge()])
target=Gateset(_SQRT_CZ, _SQRT_CZ**-1),
gauge_selector=GaugeSelector(gauges=[SqrtCZGauge()]),
symbolize_2_qubit_gate_fn=_symbolize_as_cz_pow,
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
class TestSqrtCZGauge(GaugeTester):
two_qubit_gate = cirq.CZ**0.5
gauge_transformer = SqrtCZGaugeTransformer
sweep_must_pass = True


class TestAdjointSqrtCZGauge(GaugeTester):
two_qubit_gate = cirq.CZ**-0.5
gauge_transformer = SqrtCZGaugeTransformer
sweep_must_pass = True

0 comments on commit 033d607

Please sign in to comment.