"""Graph scheduler for measurement and preparation timing in MBQC patterns.
This module provides:
- `compress_schedule`: Compress preparation and measurement times by removing gaps.
- `Scheduler`: Schedule graph node preparation and measurement operations
"""
from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING, NamedTuple
from graphqomb.feedforward import dag_from_flow
from graphqomb.greedy_scheduler import greedy_minimize_space, greedy_minimize_time
from graphqomb.schedule_solver import ScheduleConfig, Strategy, solve_schedule
if TYPE_CHECKING:
from collections.abc import Mapping
from collections.abc import Set as AbstractSet
from graphqomb.graphstate import BaseGraphState
[docs]
class ScheduleTimings(NamedTuple):
"""Scheduling timings for preparation, entanglement, and measurement."""
prepare_time: dict[int, int | None] #: Mapping from node indices to their preparation time.
measure_time: dict[int, int | None] #: Mapping from node indices to their measurement time.
entangle_time: dict[tuple[int, int], int | None] #: Mapping from edges to their entanglement time.
[docs]
class TimeSlice(NamedTuple):
"""Operations for a single time slice in the schedule."""
prepare_nodes: set[int] #: Set of node indices to prepare in this time slice.
entangle_edges: set[tuple[int, int]] #: Set of edges to entangle in this time slice.
measure_nodes: set[int] #: Set of node indices to measure in this time slice.
[docs]
def compress_schedule( # noqa: C901, PLR0912
prepare_time: Mapping[int, int | None],
measure_time: Mapping[int, int | None],
entangle_time: Mapping[tuple[int, int], int | None] | None = None,
) -> ScheduleTimings:
r"""Compress a schedule by removing gaps in time indices.
This function shifts all time indices forward to remove unused time slots,
reducing the total number of slices without changing the relative ordering.
Parameters
----------
prepare_time : `collections.abc.Mapping`\[`int`, `int` | `None`\]
A mapping from node indices to their preparation time.
measure_time : `collections.abc.Mapping`\[`int`, `int` | `None`\]
A mapping from node indices to their measurement time.
entangle_time : `collections.abc.Mapping`\[`tuple`\[`int`, `int`\], `int` | `None`\] | `None`, optional
A mapping from edges (as tuples) to their entanglement time.
Returns
-------
ScheduleTimings
A NamedTuple containing compressed timing information:
- prepare_time: `dict`\[`int`, `int` | `None`\]
- measure_time: `dict`\[`int`, `int` | `None`\]
- entangle_time: `dict`\[`tuple`\[`int`, `int`\], `int` | `None`\]
"""
# Collect all used time indices
all_times: set[int] = set()
for time in prepare_time.values():
if time is not None:
all_times.add(time)
for time in measure_time.values():
if time is not None:
all_times.add(time)
if entangle_time is not None:
for time in entangle_time.values():
if time is not None:
all_times.add(time)
if not all_times:
compressed_entangle_time: dict[tuple[int, int], int | None] = (
dict(entangle_time) if entangle_time is not None else {}
)
return ScheduleTimings(dict(prepare_time), dict(measure_time), compressed_entangle_time)
# Create mapping from old time to new compressed time
sorted_times = sorted(all_times)
time_mapping = {old_time: new_time for new_time, old_time in enumerate(sorted_times)}
# Apply compression to preparation times
compressed_prepare_time: dict[int, int | None] = {}
for node, old_time in prepare_time.items():
if old_time is not None:
compressed_prepare_time[node] = time_mapping[old_time]
else:
compressed_prepare_time[node] = None
# Apply compression to measurement times
compressed_measure_time: dict[int, int | None] = {}
for node, old_time in measure_time.items():
if old_time is not None:
compressed_measure_time[node] = time_mapping[old_time]
else:
compressed_measure_time[node] = None
# Apply compression to entanglement times
compressed_entangle_time = {}
if entangle_time is not None:
for edge, old_time in entangle_time.items():
if old_time is not None:
compressed_entangle_time[edge] = time_mapping[old_time]
else:
compressed_entangle_time[edge] = None
return ScheduleTimings(compressed_prepare_time, compressed_measure_time, compressed_entangle_time)
[docs]
class Scheduler:
r"""Schedule graph preparation and measurements.
Attributes
----------
graph : `BaseGraphState`
The graph state to be scheduled.
dag : `dict`\[`int`, `set`\[`int`\]\]
The directed acyclic graph representing dependencies.
prepare_time : `dict`\[`int`, `int` | `None`\]
A mapping from node indices to their preparation time.
measure_time : `dict`\[`int`, `int` | `None`\]
A mapping from node indices to their measurement time.
entangle_time : `dict`\[`tuple`\[`int`, `int`\], `int` | `None`\]
A mapping from edge (as tuple of two node indices) to their entanglement time.
"""
graph: BaseGraphState
dag: dict[int, set[int]]
prepare_time: dict[int, int | None]
measure_time: dict[int, int | None]
entangle_time: dict[tuple[int, int], int | None]
[docs]
def __init__(
self,
graph: BaseGraphState,
xflow: Mapping[int, AbstractSet[int]],
zflow: Mapping[int, AbstractSet[int]] | None = None,
) -> None:
self.graph = graph
self.dag = dag_from_flow(graph, xflow, zflow)
self.prepare_time = dict.fromkeys(graph.physical_nodes - graph.input_node_indices.keys())
self.measure_time = dict.fromkeys(graph.physical_nodes - graph.output_node_indices.keys())
# Initialize entangle_time for all physical edges
self.entangle_time = dict.fromkeys(graph.physical_edges)
[docs]
def num_slices(self) -> int:
r"""Return the number of slices in the schedule.
Returns
-------
`int`
The number of slices, which is the maximum time across all nodes and edges plus one.
"""
return (
max(
max((t for t in self.prepare_time.values() if t is not None), default=0),
max((t for t in self.measure_time.values() if t is not None), default=0),
max((t for t in self.entangle_time.values() if t is not None), default=0),
)
+ 1
)
@property
def timeline(self) -> list[TimeSlice]:
r"""Get the per-slice operations for preparation, entanglement, and measurement.
Returns
-------
`list`\[`TimeSlice`\]
Each element is a `TimeSlice` containing three sets for each time slice:
- prepare_nodes: Nodes to prepare
- entangle_edges: Edges to entangle
- measure_nodes: Nodes to measure
"""
prep_time: defaultdict[int, set[int]] = defaultdict(set)
for node, time in self.prepare_time.items():
if time is not None:
prep_time[time].add(node)
ent_time: defaultdict[int, set[tuple[int, int]]] = defaultdict(set)
for edge, time in self.entangle_time.items():
if time is not None:
ent_time[time].add(edge)
meas_time: defaultdict[int, set[int]] = defaultdict(set)
for node, time in self.measure_time.items():
if time is not None:
meas_time[time].add(node)
return [TimeSlice(prep_time[time], ent_time[time], meas_time[time]) for time in range(self.num_slices())]
[docs]
def manual_schedule(
self,
prepare_time: Mapping[int, int | None],
measure_time: Mapping[int, int | None],
entangle_time: Mapping[tuple[int, int], int | None] | None = None,
) -> None:
r"""Set the schedule manually.
Parameters
----------
prepare_time : `collections.abc.Mapping`\[`int`, `int` | `None`\]
A mapping from node indices to their preparation time.
measure_time : `collections.abc.Mapping`\[`int`, `int` | `None`\]
A mapping from node indices to their measurement time.
entangle_time : `collections.abc.Mapping`\[`tuple`\[`int`, `int`\], `int` | `None`\] | `None`, optional
A mapping from edges (as tuples) to their entanglement time.
If None, unscheduled entanglement times are auto-scheduled based on preparation times.
Notes
-----
After setting preparation and measurement times, any unscheduled entanglement times
(with `None` value) are automatically scheduled using `auto_schedule_entanglement()`.
The graph is treated as undirected. For convenience, `entangle_time` accepts edges
in either order: both ``(u, v)`` and ``(v, u)`` are recognized. If both keys are
provided, the canonical order (as returned by :attr:`BaseGraphState.physical_edges`)
takes precedence, even when the value is ``None``.
"""
self.prepare_time = {
node: prepare_time.get(node, None)
for node in self.graph.physical_nodes - self.graph.input_node_indices.keys()
}
self.measure_time = {
node: measure_time.get(node, None)
for node in self.graph.physical_nodes - self.graph.output_node_indices.keys()
}
if entangle_time is not None:
resolved_entangle_time: dict[tuple[int, int], int | None] = {}
for edge in self.entangle_time:
if edge in entangle_time:
resolved_entangle_time[edge] = entangle_time[edge]
else:
u, v = edge
resolved_entangle_time[edge] = entangle_time.get((v, u), None)
self.entangle_time = resolved_entangle_time
# Auto-schedule unscheduled entanglement times
if any(time is None for time in self.entangle_time.values()):
self.auto_schedule_entanglement()
def _validate_node_sets(self) -> None:
"""Validate that node sets are correctly configured.
Raises
------
ValueError
If input/output nodes are incorrectly included in prepare/measure times,
or if node sets do not match expected sets.
"""
input_nodes = self.graph.input_node_indices.keys()
output_nodes = self.graph.output_node_indices.keys()
physical_nodes = self.graph.physical_nodes
# Input nodes should not be in prepare_time
invalid_prep = input_nodes & self.prepare_time.keys()
if invalid_prep:
msg = f"Input nodes {sorted(invalid_prep)} should not be in prepare_time"
raise ValueError(msg)
# Output nodes should not be in measure_time
invalid_meas = output_nodes & self.measure_time.keys()
if invalid_meas:
msg = f"Output nodes {sorted(invalid_meas)} should not be in measure_time"
raise ValueError(msg)
# Check expected node sets
expected_prep_nodes = physical_nodes - input_nodes
expected_meas_nodes = physical_nodes - output_nodes
if self.prepare_time.keys() != expected_prep_nodes:
missing = expected_prep_nodes - self.prepare_time.keys()
unexpected = self.prepare_time.keys() - expected_prep_nodes
msg_parts: list[str] = []
if missing:
msg_parts.append(f"missing nodes {sorted(missing)}")
if unexpected:
msg_parts.append(f"unexpected nodes {sorted(unexpected)}")
msg = f"prepare_time has incorrect node set: {', '.join(msg_parts)}"
raise ValueError(msg)
if self.measure_time.keys() != expected_meas_nodes:
missing = expected_meas_nodes - self.measure_time.keys()
unexpected = self.measure_time.keys() - expected_meas_nodes
msg_parts = []
if missing:
msg_parts.append(f"missing nodes {sorted(missing)}")
if unexpected:
msg_parts.append(f"unexpected nodes {sorted(unexpected)}")
msg = f"measure_time has incorrect node set: {', '.join(msg_parts)}"
raise ValueError(msg)
def _validate_all_nodes_scheduled(self) -> None:
"""Validate that all required nodes are scheduled.
Raises
------
ValueError
If any node in prepare_time or measure_time has None as its time value.
"""
# All nodes in prepare_time must have non-None values
unscheduled_prep = [node for node, time in self.prepare_time.items() if time is None]
if unscheduled_prep:
msg = f"Nodes {sorted(unscheduled_prep)} have no preparation time scheduled (time is None)"
raise ValueError(msg)
# All nodes in measure_time must have non-None values
unscheduled_meas = [node for node, time in self.measure_time.items() if time is None]
if unscheduled_meas:
msg = f"Nodes {sorted(unscheduled_meas)} have no measurement time scheduled (time is None)"
raise ValueError(msg)
def _validate_executable_times_are_nonnegative(self) -> None:
"""Validate that all executable schedule times are non-negative.
Raises
------
ValueError
If any executable preparation, measurement, or entanglement time is negative.
"""
for schedule_name, schedule in (
("preparation", self.prepare_time),
("measurement", self.measure_time),
("entanglement", self.entangle_time),
):
negative_times = {item: time for item, time in schedule.items() if time is not None and time < 0}
if negative_times:
msg = f"{schedule_name.capitalize()} schedule contains negative executable times: {negative_times}"
raise ValueError(msg)
def _validate_dag_constraints(self) -> None:
"""Validate that measurement order respects DAG dependencies.
Raises
------
ValueError
If measurement times violate DAG ordering constraints
(a node must be measured before all its successors in the DAG).
"""
for u, successors in self.dag.items():
u_time = self.measure_time.get(u)
if u_time is None:
continue
for v in successors:
v_time = self.measure_time.get(v)
if v_time is not None and u_time >= v_time:
msg = (
f"DAG violation: node {u} (measure_time={u_time}) "
f"must be measured before node {v} (measure_time={v_time})"
)
raise ValueError(msg)
[docs]
def auto_schedule_entanglement(self) -> None:
r"""Automatically schedule entanglement operations based on preparation times.
Each edge is scheduled at the time when both of its endpoints are prepared.
For edges involving input nodes, they are scheduled when the non-input node is prepared.
Input nodes are considered to be prepared at time -1 (before the first time slice).
Note
----
Only schedules entanglement for edges with `None` time. Preserves manually set times.
Validation of measurement causality is performed by `validate_schedule()`.
"""
for edge in self.graph.physical_edges:
node1, node2 = edge
# Get preparation times (input nodes are considered prepared at time -1)
time1 = self.prepare_time.get(node1)
if time1 is None and node1 in self.graph.input_node_indices:
time1 = -1
time2 = self.prepare_time.get(node2)
if time2 is None and node2 in self.graph.input_node_indices:
time2 = -1
# Edge can be created when both nodes are prepared
# Only schedule if not already scheduled (preserve manual settings)
if time1 is not None and time2 is not None and self.entangle_time[edge] is None:
# Keep executable slices non-negative even when both endpoints are inputs.
self.entangle_time[edge] = max(time1, time2, 0)
def _validate_entangle_time_constraints(self) -> None:
"""Validate that entanglement times respect preparation and measurement constraints.
Checks that:
- Entanglement happens AFTER both nodes are prepared
- Entanglement happens BEFORE either node is measured
Raises
------
ValueError
If entanglement times violate preparation or measurement causality constraints.
"""
for edge, ent_time in self.entangle_time.items():
if ent_time is None:
# Entanglement not scheduled yet is okay (might be auto-scheduled later)
continue
node1, node2 = edge
# Get preparation times (input nodes are considered prepared at time -1)
time1 = self.prepare_time.get(node1)
if time1 is None and node1 in self.graph.input_node_indices:
time1 = -1
time2 = self.prepare_time.get(node2)
if time2 is None and node2 in self.graph.input_node_indices:
time2 = -1
# Both nodes must be prepared before or at entanglement time
if time1 is None or time2 is None:
# Cannot validate if preparation times are not set
msg = f"Edge {edge} entanglement validation failed: preparation times not set"
raise ValueError(msg)
if ent_time < time1:
msg = f"Edge {edge} entanglement at time {ent_time} is before node {node1} preparation at time {time1}"
raise ValueError(msg)
if ent_time < time2:
msg = f"Edge {edge} entanglement at time {ent_time} is before node {node2} preparation at time {time2}"
raise ValueError(msg)
# Entanglement must happen BEFORE measurement of either node
# Get measurement times (output nodes are not measured)
meas_time1 = self.measure_time.get(node1)
meas_time2 = self.measure_time.get(node2)
# If node is measured, entanglement must be strictly before measurement
if meas_time1 is not None and ent_time >= meas_time1:
msg = (
f"Edge {edge} entanglement at time {ent_time} "
f"is not before node {node1} measurement at time {meas_time1}"
)
raise ValueError(msg)
if meas_time2 is not None and ent_time >= meas_time2:
msg = (
f"Edge {edge} entanglement at time {ent_time} "
f"is not before node {node2} measurement at time {meas_time2}"
)
raise ValueError(msg)
def _validate_time_ordering(self) -> None:
"""Validate ordering within same time slice.
Raises
------
ValueError
If any node is both prepared and measured at the same time.
"""
# Group nodes by time
time_to_prep_nodes: defaultdict[int, set[int]] = defaultdict(set)
time_to_meas_nodes: defaultdict[int, set[int]] = defaultdict(set)
for node, time in self.prepare_time.items():
if time is not None:
time_to_prep_nodes[time].add(node)
for node, time in self.measure_time.items():
if time is not None:
time_to_meas_nodes[time].add(node)
# Check that no node is both prepared and measured at the same time
all_times = time_to_prep_nodes.keys() | time_to_meas_nodes.keys()
for time in all_times:
prep_nodes = time_to_prep_nodes[time]
meas_nodes = time_to_meas_nodes[time]
conflicting_nodes = prep_nodes & meas_nodes
if conflicting_nodes:
msg = f"Nodes {sorted(conflicting_nodes)} cannot be both prepared and measured at time {time}"
raise ValueError(msg)
[docs]
def validate_schedule(self) -> None:
r"""Validate that the schedule is consistent with the graph state and DAG.
Checks:
- Input nodes are not prepared (assumed to be prepared before time 0)
- Output nodes are not measured
- All non-input nodes have a preparation time
- All non-output nodes have a measurement time
- Measurement order respects DAG dependencies
- Within same time slice, measurements happen before preparations
- Entanglement times respect causality constraints (if entanglement is scheduled):
- Entanglement happens AFTER both nodes are prepared
- Entanglement happens BEFORE either node is measured
"""
self._validate_node_sets()
self._validate_all_nodes_scheduled()
self._validate_executable_times_are_nonnegative()
self._validate_dag_constraints()
self._validate_time_ordering()
# Validate entanglement times only if at least one edge has a scheduled time
if any(time is not None for time in self.entangle_time.values()):
self._validate_entangle_time_constraints()
[docs]
def solve_schedule(
self,
config: ScheduleConfig | None = None,
timeout: int = 60,
) -> bool:
r"""Compute the schedule using constraint programming or greedy heuristics.
Parameters
----------
config : `ScheduleConfig` | `None`, optional
The scheduling configuration. If None, defaults to MINIMIZE_TIME strategy.
timeout : `int`, optional
Maximum solve time in seconds for CP-SAT solver, by default 60.
Ignored when use_greedy=True.
Returns
-------
`bool`
True if a solution was found and applied, False otherwise.
Note
----
After solving, any unscheduled entanglement times (with `None` value) are
automatically scheduled using `auto_schedule_entanglement()`.
"""
if config is None:
config = ScheduleConfig(Strategy.MINIMIZE_TIME)
result: tuple[dict[int, int], dict[int, int]] | None
if config.use_greedy:
# Use fast greedy heuristics
if config.strategy == Strategy.MINIMIZE_TIME:
result = greedy_minimize_time(self.graph, self.dag, max_qubit_count=config.max_qubit_count)
else: # Strategy.MINIMIZE_SPACE
result = greedy_minimize_space(self.graph, self.dag)
else:
# Use CP-SAT solver for optimal solution
result = solve_schedule(self.graph, self.dag, config, timeout)
if result is None:
return False
prepare_time, measure_time = result
prep_time = {
node: prepare_time.get(node, None)
for node in self.graph.physical_nodes - self.graph.input_node_indices.keys()
}
meas_time = {
node: measure_time.get(node, None)
for node in self.graph.physical_nodes - self.graph.output_node_indices.keys()
}
self.prepare_time = prep_time
self.measure_time = meas_time
# Auto-schedule unscheduled entanglement times before compression so
# slice-0 entanglement remains part of the executable time set.
if any(time is None for time in self.entangle_time.values()):
self.auto_schedule_entanglement()
# Compress the schedule to minimize time indices
timings = compress_schedule(self.prepare_time, self.measure_time, self.entangle_time)
self.prepare_time = timings.prepare_time
self.measure_time = timings.measure_time
self.entangle_time = timings.entangle_time
return True