Source code for graphqomb.schedule_solver

"""Scheduling solver for optimizing MBQC pattern execution.

This module provides:

- `Strategy`: Enumeration of scheduling optimization strategies
- `ScheduleConfig`: Configuration for scheduling parameters and constraints
- `solve_schedule`: Solve scheduling optimization using constraint programming
"""

from __future__ import annotations

import enum
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING

from ortools.sat.python import cp_model

if TYPE_CHECKING:
    from collections.abc import Mapping
    from collections.abc import Set as AbstractSet

    from graphqomb.graphstate import BaseGraphState


[docs] class Strategy(Enum): """Enumeration for scheduling strategies.""" MINIMIZE_SPACE = enum.auto() MINIMIZE_TIME = enum.auto()
[docs] @dataclass class ScheduleConfig: """Configuration for scheduling strategy, constraints, and parameters.""" strategy: Strategy max_qubit_count: int | None = None max_time: int | None = None use_greedy: bool = False
@dataclass class _ModelContext: """Internal context for model and graph data.""" model: cp_model.CpModel graph: BaseGraphState def _add_constraints( model: cp_model.CpModel, graph: BaseGraphState, dag: Mapping[int, AbstractSet[int]], node2prep: Mapping[int, cp_model.IntVar], node2meas: Mapping[int, cp_model.IntVar], ) -> None: """Add constraints to the scheduling model.""" # Measurement order constraints for node, children in dag.items(): for child in children: if node in node2meas and child in node2meas: model.add(node2meas[node] < node2meas[child]) # A non-input, non-output node must be prepared before it is measured. for node in graph.physical_nodes: if node in node2prep and node in node2meas: model.add(node2prep[node] < node2meas[node]) # Edge constraints for node in graph.physical_nodes - set(graph.output_node_indices): for neighbor in graph.neighbors(node): if neighbor in graph.input_node_indices: if node in graph.input_node_indices: model.add(node2meas[node] > 0) continue model.add(node2prep[neighbor] < node2meas[node]) def _set_objective( ctx: _ModelContext, node2prep: Mapping[int, cp_model.IntVar], node2meas: Mapping[int, cp_model.IntVar], config: ScheduleConfig, max_time: int, ) -> None: """Set the objective function for the scheduling model. Raises ------ ValueError If the scheduling strategy is unknown. """ if config.strategy == Strategy.MINIMIZE_SPACE: _set_minimize_space_objective(ctx, node2prep, node2meas, max_time) elif config.strategy == Strategy.MINIMIZE_TIME: _set_minimize_time_objective(ctx, node2prep, node2meas, max_time, config.max_qubit_count) else: msg = f"Unknown scheduling strategy: {config.strategy}" raise ValueError(msg) def _compute_alive_nodes_at_time( ctx: _ModelContext, node2prep: Mapping[int, cp_model.IntVar], node2meas: Mapping[int, cp_model.IntVar], t: int, ) -> list[cp_model.IntVar]: """Compute the list of alive nodes at time t. Returns ------- list[cp_model.IntVar] Boolean variables indicating whether each node is alive at time t. """ alive_at_t: list[cp_model.IntVar] = [] for node in ctx.graph.physical_nodes: a_pre = ctx.model.new_bool_var(f"alive_pre_{node}_{t}") if node in ctx.graph.input_node_indices: ctx.model.add(a_pre == 1) else: p = node2prep[node] ctx.model.add(p <= t).only_enforce_if(a_pre) ctx.model.add(p > t).only_enforce_if(a_pre.negated()) a_meas = ctx.model.new_bool_var(f"alive_meas_{node}_{t}") if node in ctx.graph.output_node_indices: ctx.model.add(a_meas == 0) else: q = node2meas[node] ctx.model.add(q <= t).only_enforce_if(a_meas) ctx.model.add(q > t).only_enforce_if(a_meas.negated()) alive = ctx.model.new_bool_var(f"alive_{node}_{t}") ctx.model.add_implication(alive, a_pre) ctx.model.add_implication(alive, a_meas.negated()) ctx.model.add(a_pre - a_meas <= alive) alive_at_t.append(alive) return alive_at_t def _set_minimize_space_objective( ctx: _ModelContext, node2prep: Mapping[int, cp_model.IntVar], node2meas: Mapping[int, cp_model.IntVar], max_time: int, ) -> None: """Set objective to minimize the maximum number of qubits used at any time.""" max_space = ctx.model.new_int_var(0, len(ctx.graph.physical_nodes), "max_space") for t in range(max_time): alive_at_t = _compute_alive_nodes_at_time(ctx, node2prep, node2meas, t) ctx.model.add(max_space >= sum(alive_at_t)) ctx.model.minimize(max_space) def _set_minimize_time_objective( ctx: _ModelContext, node2prep: Mapping[int, cp_model.IntVar], node2meas: Mapping[int, cp_model.IntVar], max_time: int, max_qubit_count: int | None = None, ) -> None: """Set objective to minimize the total execution time.""" # Add space constraint if max_qubit_count is specified if max_qubit_count is not None: for t in range(max_time): alive_at_t = _compute_alive_nodes_at_time(ctx, node2prep, node2meas, t) ctx.model.add(sum(alive_at_t) <= max_qubit_count) # Time objective: minimize makespan. Pure input-output graphs have no # measurement/preparation variables, but can still have edge constraints. meas_vars = list(node2meas.values()) or list(node2prep.values()) if not meas_vars: ctx.model.minimize(0) return makespan = ctx.model.new_int_var(0, max_time, "makespan") ctx.model.add_max_equality(makespan, meas_vars) ctx.model.minimize(makespan)
[docs] def solve_schedule( graph: BaseGraphState, dag: Mapping[int, AbstractSet[int]], config: ScheduleConfig, timeout: int = 60, ) -> tuple[dict[int, int], dict[int, int]] | None: r"""Solve the scheduling problem for the given graph. Parameters ---------- graph : `BaseGraphState` The graph state to optimize. dag : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\] The directed acyclic graph representing dependencies. config : `ScheduleConfig` The scheduling configuration including strategy and constraints. timeout : `int`, optional Maximum solve time in seconds, by default 60 Returns ------- `tuple`\[`dict`\[`int`, `int`\], `dict`\[`int`, `int`\]\] | `None` A tuple of (prepare_time, measure_time) dictionaries if solved, None if no solution found. """ # Construct model model = cp_model.CpModel() # Determine max_time from config or calculate default max_time = config.max_time if config.max_time is not None else 2 * len(graph.physical_nodes) # Create variables node2prep: dict[int, cp_model.IntVar] = {} node2meas: dict[int, cp_model.IntVar] = {} for node in graph.physical_nodes: if node not in graph.input_node_indices: node2prep[node] = model.new_int_var(0, max_time, f"prep_{node}") if node not in graph.output_node_indices: node2meas[node] = model.new_int_var(0, max_time, f"meas_{node}") # Add constraints _add_constraints(model, graph, dag, node2prep, node2meas) # Set objective ctx = _ModelContext(model, graph) _set_objective(ctx, node2prep, node2meas, config, max_time) # Solve solver = cp_model.CpSolver() solver.parameters.max_time_in_seconds = timeout status = solver.Solve(model) if status in {cp_model.OPTIMAL, cp_model.FEASIBLE}: prepare_time: dict[int, int] = {node: int(solver.Value(var)) for node, var in node2prep.items()} measure_time: dict[int, int] = {node: int(solver.Value(var)) for node, var in node2meas.items()} return prepare_time, measure_time return None