Source code for discrete_optimization.generic_tools.ortools_cpsat_tools

#  Copyright (c) 2024 AIRBUS and its affiliates.
#  This source code is licensed under the MIT license found in the
#  LICENSE file in the root directory of this source tree.
import logging
from abc import abstractmethod
from typing import Any, List, Optional

from ortools.sat.python.cp_model import (
    FEASIBLE,
    INFEASIBLE,
    OPTIMAL,
    UNKNOWN,
    CpModel,
    CpSolver,
    CpSolverSolutionCallback,
)

from discrete_optimization.generic_tools.callbacks.callback import (
    Callback,
    CallbackList,
)
from discrete_optimization.generic_tools.cp_tools import (
    CPSolver,
    ParametersCP,
    StatusSolver,
)
from discrete_optimization.generic_tools.do_problem import Solution
from discrete_optimization.generic_tools.exceptions import SolveEarlyStop
from discrete_optimization.generic_tools.result_storage.result_storage import (
    ResultStorage,
)

logger = logging.getLogger(__name__)


[docs] class OrtoolsCPSatSolver(CPSolver): """Generic ortools cp-sat solver.""" cp_model: Optional[CpModel] = None early_stopping_exception: Optional[Exception] = None
[docs] @abstractmethod def retrieve_solution(self, cpsolvercb: CpSolverSolutionCallback) -> Solution: """Construct a do solution from the cpsat solver internal solution. It will be called each time the cpsat solver find a new solution. At that point, value of internal variables are accessible via `cpsolvercb.Value(VARIABLE_NAME)`. Args: cpsolvercb: the ortools callback called when the cpsat solver finds a new solution. Returns: the intermediate solution, at do format. """ ...
[docs] def solve( self, callbacks: Optional[List[Callback]] = None, parameters_cp: Optional[ParametersCP] = None, **kwargs: Any, ) -> ResultStorage: """Solve the problem with a CPSat solver drom ortools library. Args: callbacks: list of callbacks used to hook into the various stage of the solve parameters_cp: parameters specific to cp solvers. We use here only `parameters_cp.time_limit` and `parameters_cp.nb_process`. **kwargs: keyword arguments passed to `self.init_model()` Returns: A dedicated ortools callback is used to: - update a resultstorage each time a new solution is found by the cpsat solver. - call the user (do) callbacks at each new solution, with the possibility of early stopping if the callback return True. This ortools callback use the method `self.retrieve_solution()` to reconstruct a do Solution from the cpsat solve internal state. """ self.early_stopping_exception = None callbacks_list = CallbackList(callbacks=callbacks) callbacks_list.on_solve_start(solver=self) if self.cp_model is None: self.init_model(**kwargs) if parameters_cp is None: parameters_cp = ParametersCP.default_cpsat() solver = CpSolver() solver.parameters.max_time_in_seconds = parameters_cp.time_limit solver.parameters.num_workers = parameters_cp.nb_process ortools_callback = OrtoolsCallback(do_solver=self, callback=callbacks_list) status = solver.Solve(self.cp_model, ortools_callback) self.status_solver = cpstatus_to_dostatus(status_from_cpsat=status) if self.early_stopping_exception: if isinstance(self.early_stopping_exception, SolveEarlyStop): logger.info(self.early_stopping_exception) else: raise self.early_stopping_exception res = ortools_callback.res callbacks_list.on_solve_end(res=res, solver=self) return res
[docs] class OrtoolsCallback(CpSolverSolutionCallback): def __init__(self, do_solver: OrtoolsCPSatSolver, callback: Callback): super().__init__() self.do_solver = do_solver self.callback = callback self.res = ResultStorage( [], mode_optim=self.do_solver.params_objective_function.sense_function, limit_store=False, ) self.nb_solutions = 0
[docs] def on_solution_callback(self) -> None: self.store_current_solution() self.nb_solutions += 1 # end of step callback: stopping? try: stopping = self.callback.on_step_end( step=self.nb_solutions, res=self.res, solver=self.do_solver ) except Exception as e: self.do_solver.early_stopping_exception = e stopping = True else: if stopping: self.do_solver.early_stopping_exception = SolveEarlyStop( f"{self.do_solver.__class__.__name__}.solve() stopped by user callback." ) if stopping: self.StopSearch()
[docs] def store_current_solution(self): sol = self.do_solver.retrieve_solution(cpsolvercb=self) fit = self.do_solver.aggreg_from_sol(sol) self.res.add_solution(solution=sol, fitness=fit)
[docs] def cpstatus_to_dostatus(status_from_cpsat) -> StatusSolver: """ :param status_from_cpsat: either [UNKNOWN,INFEASIBLE,OPTIMAL,FEASIBLE] from ortools.cp api. :return: Status """ if status_from_cpsat == UNKNOWN: return StatusSolver.UNKNOWN if status_from_cpsat == INFEASIBLE: return StatusSolver.UNSATISFIABLE if status_from_cpsat == OPTIMAL: return StatusSolver.OPTIMAL if status_from_cpsat == FEASIBLE: return StatusSolver.SATISFIED