diff --git a/linopy/solvers.py b/linopy/solvers.py index 2783e7b8..597d49de 100644 --- a/linopy/solvers.py +++ b/linopy/solvers.py @@ -13,6 +13,7 @@ import re import subprocess as sub import sys +import threading import warnings from abc import ABC, abstractmethod from collections import namedtuple @@ -56,6 +57,73 @@ which = "where" if os.name == "nt" else "which" + +def _run_highs_with_keyboard_interrupt(h: Any) -> None: + """ + Run `highspy.Highs.run()` while ensuring Ctrl-C cancels the solve. + + HiGHS can run for a long time inside a C-extension call. Running it in a + worker thread allows the main thread to reliably receive KeyboardInterrupt + and signal HiGHS to stop via `cancelSolve()`. + """ + + handle_keyboard_interrupt = getattr(h, "HandleKeyboardInterrupt", None) + handle_user_interrupt = getattr(h, "HandleUserInterrupt", None) + + old_handle_keyboard_interrupt = ( + handle_keyboard_interrupt if not callable(handle_keyboard_interrupt) else None + ) + old_handle_user_interrupt = ( + handle_user_interrupt if not callable(handle_user_interrupt) else None + ) + + try: + if callable(handle_keyboard_interrupt): + handle_keyboard_interrupt(True) + elif handle_keyboard_interrupt is not None: + h.HandleKeyboardInterrupt = True + + if callable(handle_user_interrupt): + handle_user_interrupt(True) + elif handle_user_interrupt is not None: + h.HandleUserInterrupt = True + + finished = threading.Event() + run_error: BaseException | None = None + + def _target() -> None: + nonlocal run_error + try: + h.run() + except BaseException as exc: # pragma: no cover + run_error = exc + finally: + finished.set() + + thread = threading.Thread(target=_target, name="linopy-highs-run", daemon=True) + thread.start() + + try: + while not finished.wait(0.1): + pass + except KeyboardInterrupt: + cancel_solve = getattr(h, "cancelSolve", None) + if callable(cancel_solve): + with contextlib.suppress(Exception): + cancel_solve() + while not finished.wait(0.1): + pass + raise + + if run_error is not None: + raise run_error + finally: + if old_handle_keyboard_interrupt is not None: + h.HandleKeyboardInterrupt = old_handle_keyboard_interrupt + if old_handle_user_interrupt is not None: + h.HandleUserInterrupt = old_handle_user_interrupt + + # the first available solver will be the default solver with contextlib.suppress(ModuleNotFoundError): import gurobipy @@ -912,7 +980,7 @@ def _solve( elif warmstart_fn: h.readBasis(path_to_string(warmstart_fn)) - h.run() + _run_highs_with_keyboard_interrupt(h) condition = h.getModelStatus() termination_condition = CONDITION_MAP.get( diff --git a/test/test_highs_keyboard_interrupt.py b/test/test_highs_keyboard_interrupt.py new file mode 100644 index 00000000..71d5ce80 --- /dev/null +++ b/test/test_highs_keyboard_interrupt.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import _thread +import threading +import time + +import pytest + +from linopy.solvers import _run_highs_with_keyboard_interrupt + + +class DummyHighs: + def __init__(self) -> None: + self.HandleKeyboardInterrupt = False + self.HandleUserInterrupt = False + self._cancel_event = threading.Event() + self.started = threading.Event() + self.finished = threading.Event() + self.cancel_calls = 0 + + def run(self) -> None: + self.started.set() + self._cancel_event.wait(timeout=5) + self.finished.set() + + def cancelSolve(self) -> None: + self.cancel_calls += 1 + self._cancel_event.set() + + +def test_run_highs_cancels_on_keyboard_interrupt() -> None: + dummy = DummyHighs() + + def interrupter() -> None: + assert dummy.started.wait(timeout=1) + time.sleep(0.05) + _thread.interrupt_main() + + threading.Thread(target=interrupter, daemon=True).start() + + with pytest.raises(KeyboardInterrupt): + _run_highs_with_keyboard_interrupt(dummy) + + assert dummy.cancel_calls >= 1 + assert dummy.finished.wait(timeout=1) + assert dummy.HandleKeyboardInterrupt is False + assert dummy.HandleUserInterrupt is False