"""Feedforward correction functions.
This module provides:
- `dag_from_flow`: Construct a directed acyclic graph (DAG) from a flowlike object.
- `inverse_dag_from_dag`: Construct an inverse DAG (node -> dependencies).
- `topo_order_from_inv_dag`: Construct a topological order from an inverse DAG.
- `check_dag`: Check if a directed acyclic graph (DAG) does not contain a cycle.
- `check_flow`: Check if the flowlike object is causal with respect to the graph state.
- `signal_shifting`: Convert the correction maps into more parallel-friendly forms using signal shifting.
- `propagate_correction_map`: Propagate the correction map through a measurement at the target node.
"""
from __future__ import annotations
from collections.abc import Iterable, Mapping
from collections.abc import Set as AbstractSet
from graphlib import CycleError, TopologicalSorter
from typing import Any, TypeGuard
import typing_extensions
from graphqomb.common import Axis, Plane, determine_pauli_axis
from graphqomb.graphstate import BaseGraphState, odd_neighbors
TOPO_ORDER_CYCLE_ERROR_MSG = "No nodes can be measured; possible cyclic dependency or incomplete preparation."
def _is_flow(flowlike: Mapping[int, Any]) -> TypeGuard[Mapping[int, int]]:
r"""Check if the flowlike object is a flow.
Parameters
----------
flowlike : `collections.abc.Mapping`\[`int`, `typing.Any`\]
A flowlike object to check
Returns
-------
`bool`
True if the flowlike object is a flow, False otherwise
"""
return all(isinstance(v, int) for v in flowlike.values())
def _is_gflow(flowlike: Mapping[int, Any]) -> TypeGuard[Mapping[int, AbstractSet[int]]]:
r"""Check if the flowlike object is a GFlow.
Parameters
----------
flowlike : `collections.abc.Mapping`\[`int`, `typing.Any`\]
A flowlike object to check
Returns
-------
`bool`
True if the flowlike object is a GFlow, False otherwise
"""
return all(isinstance(v, AbstractSet) for v in flowlike.values())
[docs]
def dag_from_flow(
graph: BaseGraphState,
xflow: Mapping[int, int] | Mapping[int, AbstractSet[int]],
zflow: Mapping[int, int] | Mapping[int, AbstractSet[int]] | None = None,
) -> dict[int, set[int]]:
r"""Construct a directed acyclic graph (DAG) from a flowlike object.
Parameters
----------
graph : `BaseGraphState`
The graph state
xflow : `collections.abc.Mapping`\[`int`, `int`\] | `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\]
The X correction flow (flow and gflow are included)
zflow : `collections.abc.Mapping`\[`int`, `int`\] | `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\] | `None`
The Z correction flow. If `None`, it is generated from xflow by odd neighbors.
Returns
-------
`dict`\[`int`, `set`\[`int`\]\]
The directed acyclic graph
Raises
------
TypeError
If the flowlike object is not a Flow or GFlow
""" # noqa: E501
dag: dict[int, set[int]] = {}
output_nodes = set(graph.output_node_indices)
non_output_nodes = graph.physical_nodes - output_nodes
if _is_flow(xflow):
xflow = {node: {xflow[node]} for node in xflow}
elif _is_gflow(xflow):
xflow = {node: set(xflow[node]) for node in xflow}
else:
msg = "Invalid flowlike object"
raise TypeError(msg)
if zflow is None:
zflow = {node: odd_neighbors(xflow[node], graph) for node in xflow}
elif _is_flow(zflow):
zflow = {node: {zflow[node]} for node in zflow}
elif _is_gflow(zflow):
zflow = {node: set(zflow[node]) for node in zflow}
else:
msg = "Invalid zflow object"
raise TypeError(msg)
for node in non_output_nodes:
target_nodes = (xflow.get(node, set()) | zflow.get(node, set())) - {node} # remove self-loops
dag[node] = target_nodes
for output in output_nodes:
dag[output] = set()
return dag
[docs]
def check_dag(dag: Mapping[int, Iterable[int]]) -> None:
r"""Check if a directed acyclic graph (DAG) does not contain a cycle.
Parameters
----------
dag : `collections.abc.Mapping`\[`int`, `collections.abc.Iterable`\[`int`\]\]
directed acyclic graph
Raises
------
ValueError
If the flowlike object is not causal with respect to the graph state
"""
for node, children in dag.items():
for child in children:
if node in dag[child]:
msg = f"Cycle detected in the graph: {node} -> {child}"
raise ValueError(msg)
[docs]
def inverse_dag_from_dag(
dag: Mapping[int, Iterable[int]],
all_nodes: Iterable[int] | None = None,
) -> dict[int, set[int]]:
r"""Build inverse DAG (node -> dependencies) from parent->children DAG.
Parameters
----------
dag : `collections.abc.Mapping`\[`int`, `collections.abc.Iterable`\[`int`\]\]
DAG represented as parent node -> children.
all_nodes : `collections.abc.Iterable`\[`int`\] | `None`, optional
Optional full node set to include isolated nodes.
Returns
-------
`dict`\[`int`, `set`\[`int`\]\]
Inverse DAG represented as node -> dependencies.
"""
nodes = set(all_nodes) if all_nodes is not None else set(dag)
for children in dag.values():
nodes.update(children)
inv_dag: dict[int, set[int]] = {node: set() for node in nodes}
for parent, children in dag.items():
for child in children:
inv_dag[child].add(parent)
return inv_dag
[docs]
def topo_order_from_inv_dag(inv_dag: Mapping[int, Iterable[int]]) -> list[int]:
r"""Build topological order from an inverse DAG (node -> dependencies).
Parameters
----------
inv_dag : `collections.abc.Mapping`\[`int`, `collections.abc.Iterable`\[`int`\]\]
Inverse DAG where each node maps to the nodes it depends on.
Returns
-------
`list`\[`int`\]
Topological order from dependencies to dependents.
Raises
------
RuntimeError
If topological ordering is not possible due to a cycle.
"""
try:
return list(TopologicalSorter(inv_dag).static_order())
except CycleError as exc:
raise RuntimeError(TOPO_ORDER_CYCLE_ERROR_MSG) from exc
[docs]
def check_flow(
graph: BaseGraphState,
xflow: Mapping[int, int] | Mapping[int, AbstractSet[int]],
zflow: Mapping[int, int] | Mapping[int, AbstractSet[int]] | None = None,
) -> None:
r"""Check if the flowlike object is causal with respect to the graph state.
Parameters
----------
graph : `BaseGraphState`
The graph state
xflow : `collections.abc.Mapping`\[`int`, `int`\] | `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\]
The X correction flow (flow and gflow are included)
zflow : `collections.abc.Mapping`\[`int`, `int`\] | `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\] | `None`
The Z correction flow. If `None`, it is generated from xflow by odd neighbors.
""" # noqa: E501
dag = dag_from_flow(graph, xflow, zflow)
check_dag(dag)
[docs]
def signal_shifting(
graph: BaseGraphState, xflow: Mapping[int, AbstractSet[int]], zflow: Mapping[int, AbstractSet[int]] | None = None
) -> tuple[dict[int, set[int]], dict[int, set[int]]]:
r"""Convert the correction maps into more parallel-friendly forms using signal shifting.
Parameters
----------
graph : `BaseGraphState`
Underlying graph state.
xflow : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\]
Correction map for X.
zflow : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\] | `None`
Correction map for Z. If `None`, it is generated from xflow by odd neighbors.
Returns
-------
`tuple`\[`dict`\[`int`, `set`\[`int`\]\], `dict`\[`int`, `set`\[`int`\]\]]
Updated correction maps for X and Z after signal shifting.
"""
if zflow is None:
zflow = {node: odd_neighbors(xflow[node], graph) - {node} for node in xflow}
dag = dag_from_flow(graph, xflow, zflow)
topo_order = list(TopologicalSorter(dag).static_order())
topo_order.reverse() # from parents to children
for output in graph.output_node_indices:
topo_order.remove(output)
new_xflow = {k: set(vs) for k, vs in xflow.items()}
new_zflow = {k: set(vs) for k, vs in zflow.items()}
for target_node in topo_order:
new_xflow, new_zflow = propagate_correction_map(target_node, graph, new_xflow, new_zflow)
return new_xflow, new_zflow
[docs]
def propagate_correction_map( # noqa: C901, PLR0912
target_node: int,
graph: BaseGraphState,
xflow: Mapping[int, AbstractSet[int]],
zflow: Mapping[int, AbstractSet[int]] | None = None,
) -> tuple[dict[int, set[int]], dict[int, set[int]]]:
r"""Propagate the correction map through a measurement at the target node.
Parameters
----------
target_node : `int`
Node at which the measurement is performed.
graph : `BaseGraphState`
Underlying graph state.
xflow : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\]
Correction map for X.
zflow : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\] | `None`
Correction map for Z. If `None`, it is generated from xflow by odd neighbors.
Returns
-------
`tuple`\[`dict`\[`int`, `set`\[`int`\]\], `dict`\[`int`, `set`\[`int`\]\]]
Updated correction maps for X and Z after measurement at the target node.
Raises
------
ValueError
If the target node is an output node.
ValueError
If the measurement plane is unsupported.
Notes
-----
This function converts the correction maps into more parallel-friendly forms.
It is equivalent to the signal shifting technique in the measurement calculus.
"""
if target_node in graph.output_node_indices:
msg = "Cannot propagate flow for output nodes."
raise ValueError(msg)
if zflow is None:
zflow = {node: odd_neighbors(xflow[node], graph) - {node} for node in xflow}
inv_xflow: dict[int, set[int]] = {}
inv_zflow: dict[int, set[int]] = {}
for k, vs in xflow.items():
for v in vs:
inv_xflow.setdefault(v, set()).add(k)
for k, vs in zflow.items():
for v in vs:
inv_zflow.setdefault(v, set()).add(k)
new_xflow = {k: set(vs) for k, vs in xflow.items()}
new_zflow = {k: set(vs) for k, vs in zflow.items()}
meas_basis = graph.meas_bases[target_node]
if meas_basis.plane == Plane.XY:
target_parents = inv_zflow.get(target_node, set())
for parent in target_parents:
new_zflow[parent] -= {target_node}
elif meas_basis.plane == Plane.YZ:
target_parents = inv_xflow.get(target_node, set())
for parent in target_parents:
new_xflow[parent] -= {target_node}
elif meas_basis.plane == Plane.XZ:
target_parents = inv_xflow.get(target_node, set()) & inv_zflow.get(target_node, set())
for parent in target_parents:
new_xflow[parent] -= {target_node}
new_zflow[parent] -= {target_node}
else:
typing_extensions.assert_never(meas_basis.plane)
msg = f"Unsupported measurement plane: {meas_basis.plane}"
raise ValueError(msg)
for child_x in xflow.get(target_node, set()):
for parent in target_parents:
new_xflow[parent] ^= {child_x}
for child_z in zflow.get(target_node, set()):
for parent in target_parents:
new_zflow[parent] ^= {child_z}
return new_xflow, new_zflow
def pauli_simplification( # noqa: C901, PLR0912
graph: BaseGraphState,
xflow: Mapping[int, AbstractSet[int]],
zflow: Mapping[int, AbstractSet[int]] | None = None,
) -> tuple[dict[int, set[int]], dict[int, set[int]]]:
r"""Simplify the correction maps by removing redundant Pauli corrections.
Parameters
----------
graph : `BaseGraphState`
Underlying graph state.
xflow : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\]
Correction map for X.
zflow : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\] | `None`
Correction map for Z. If `None`, it is generated from xflow by odd neighbors.
Returns
-------
`tuple`\[`dict`\[`int`, `set`\[`int`\]\], `dict`\[`int`, `set`\[`int`\]\]]
Updated correction maps for X and Z after simplification.
"""
if zflow is None:
zflow = {node: odd_neighbors(xflow[node], graph) - {node} for node in xflow}
new_xflow = {k: set(vs) for k, vs in xflow.items()}
new_zflow = {k: set(vs) for k, vs in zflow.items()}
inv_xflow: dict[int, set[int]] = {}
inv_zflow: dict[int, set[int]] = {}
for k, vs in xflow.items():
for v in vs:
inv_xflow.setdefault(v, set()).add(k)
for k, vs in zflow.items():
for v in vs:
inv_zflow.setdefault(v, set()).add(k)
for node in graph.physical_nodes - graph.output_node_indices.keys():
meas_basis = graph.meas_bases.get(node)
if meas_basis is None:
continue
meas_axis = determine_pauli_axis(meas_basis)
if meas_axis is None:
continue
if meas_axis == Axis.X:
for parent in inv_xflow.get(node, set()):
new_xflow[parent] -= {node}
elif meas_axis == Axis.Z:
for parent in inv_zflow.get(node, set()):
new_zflow[parent] -= {node}
elif meas_axis == Axis.Y:
for parent in inv_xflow.get(node, set()) & inv_zflow.get(node, set()):
new_xflow[parent] -= {node}
new_zflow[parent] -= {node}
return new_xflow, new_zflow