diff --git a/src/litserve/_win_shutdown_fix/__init__.py b/src/litserve/_win_shutdown_fix/__init__.py new file mode 100644 index 000000000..8b4635c51 --- /dev/null +++ b/src/litserve/_win_shutdown_fix/__init__.py @@ -0,0 +1,92 @@ +"""Windows/PyCharm debugger shutdown sentinel.""" + +import contextlib +import os +import sys + + +def _create_process_no_window(cmd_line: str) -> bool: + # ctypes CreateProcessW bypasses pydevd's subprocess.Popen patch entirely. + import ctypes + from ctypes import wintypes + + class STARTUPINFOW(ctypes.Structure): + _fields_ = ( + ("cb", wintypes.DWORD), + ("lpReserved", wintypes.LPWSTR), + ("lpDesktop", wintypes.LPWSTR), + ("lpTitle", wintypes.LPWSTR), + ("dwX", wintypes.DWORD), + ("dwY", wintypes.DWORD), + ("dwXSize", wintypes.DWORD), + ("dwYSize", wintypes.DWORD), + ("dwXCountChars", wintypes.DWORD), + ("dwYCountChars", wintypes.DWORD), + ("dwFillAttribute", wintypes.DWORD), + ("dwFlags", wintypes.DWORD), + ("wShowWindow", wintypes.WORD), + ("cbReserved2", wintypes.WORD), + ("lpReserved2", ctypes.c_void_p), + ("hStdInput", wintypes.HANDLE), + ("hStdOutput", wintypes.HANDLE), + ("hStdError", wintypes.HANDLE), + ) + + class ProcessInformation(ctypes.Structure): + _fields_ = ( + ("hProcess", wintypes.HANDLE), + ("hThread", wintypes.HANDLE), + ("dwProcessId", wintypes.DWORD), + ("dwThreadId", wintypes.DWORD), + ) + + si = STARTUPINFOW() + si.cb = ctypes.sizeof(STARTUPINFOW) + pi = ProcessInformation() + cmd_buf = ctypes.create_unicode_buffer(cmd_line) + ok = ctypes.windll.kernel32.CreateProcessW( + None, + cmd_buf, + None, + None, + False, + 0x08000000, + None, + None, + ctypes.byref(si), + ctypes.byref(pi), + ) + if ok: + ctypes.windll.kernel32.CloseHandle(pi.hProcess) + ctypes.windll.kernel32.CloseHandle(pi.hThread) + return bool(ok) + + +def start_heartbeat_sentinel(pid: int, heartbeat_path: str, kill_delay: float) -> None: + """Spawn an out-of-job sentinel that kills the process tree when the heartbeat goes stale. + + Spawn chain: ctypes.CreateProcessW -> powershell.exe -File -> WMI Win32_Process.Create + -> python.exe _child.py. The final process runs under WmiPrvSE, outside PyCharm's Job + Object, with no pydevd injected. + + """ + if sys.platform != "win32": + return + import tempfile + + child_py = os.path.join(os.path.dirname(__file__), "_child.py") + cmd = f'"{sys.executable}" "{child_py}" {pid} "{heartbeat_path}" {kill_delay}' + spawn_ps1 = os.path.join(tempfile.gettempdir(), f"litserve_spawn_sentinel_{pid}.ps1") + ps1_arg = cmd.replace("'", "''") # escape single quotes for PS single-quoted string + with contextlib.suppress(Exception), open(spawn_ps1, "w") as f: + f.write( + "try {\n" + " $r = Invoke-WmiMethod -Class Win32_Process -Name Create" + f" -ArgumentList '{ps1_arg}'\n" + " exit [int]$r.ReturnValue\n" + "} catch { exit 1 }\n" + ) + + ps_cmd = f'powershell.exe -NoProfile -NonInteractive -ExecutionPolicy Bypass -File "{spawn_ps1}"' + with contextlib.suppress(Exception): + _create_process_no_window(ps_cmd) diff --git a/src/litserve/_win_shutdown_fix/_child.py b/src/litserve/_win_shutdown_fix/_child.py new file mode 100644 index 000000000..73cf0187c --- /dev/null +++ b/src/litserve/_win_shutdown_fix/_child.py @@ -0,0 +1,92 @@ +"""Sentinel child process. + +Invoked as: python _child.py + +""" + +import ctypes +import os +import sys +import time + + +def _alive(pid): + h = ctypes.windll.kernel32.OpenProcess(0x1000, False, pid) + if h: + ctypes.windll.kernel32.CloseHandle(h) + return True + return False + + +def _kill_subtree(root_pid): + # Walk CreateToolhelp32Snapshot to find ALL descendants of root_pid. + # th32ParentProcessID is fixed at creation time, so orphaned children + # (whose parent already exited) are still found and killed. + class PROCESSENTRY32(ctypes.Structure): + _fields_ = [ + ("dwSize", ctypes.c_ulong), + ("cntUsage", ctypes.c_ulong), + ("th32ProcessID", ctypes.c_ulong), + ("th32DefaultHeapID", ctypes.c_size_t), + ("th32ModuleID", ctypes.c_ulong), + ("cntThreads", ctypes.c_ulong), + ("th32ParentProcessID", ctypes.c_ulong), + ("pcPriClassBase", ctypes.c_long), + ("dwFlags", ctypes.c_ulong), + ("szExeFile", ctypes.c_char * 260), + ] + + k32 = ctypes.windll.kernel32 + snap = k32.CreateToolhelp32Snapshot(0x00000002, 0) + if snap == ctypes.c_void_p(-1).value or snap is None: + return + + parent_map = {} + pe = PROCESSENTRY32() + pe.dwSize = ctypes.sizeof(PROCESSENTRY32) + if k32.Process32First(snap, ctypes.byref(pe)): + while True: + parent_map.setdefault(pe.th32ParentProcessID, []).append(pe.th32ProcessID) + if not k32.Process32Next(snap, ctypes.byref(pe)): + break + k32.CloseHandle(snap) + + descendants = [] + queue = [root_pid] + while queue: + cur = queue.pop() + for child in parent_map.get(cur, []): + if child != root_pid: + descendants.append(child) + queue.append(child) + + for cpid in descendants + [root_pid]: + h = k32.OpenProcess(0x0001, False, cpid) + if h: + k32.TerminateProcess(h, 1) + k32.CloseHandle(h) + + +def main(): + pid, hb, delay = int(sys.argv[1]), sys.argv[2], float(sys.argv[3]) + time.sleep(2) # grace period: let server finish startup file I/O + while True: + time.sleep(0.5) + if not _alive(pid): + # Main process exited cleanly; kill any orphaned children. + _kill_subtree(pid) + return + try: + age = time.time() - os.path.getmtime(hb) + except OSError: + # Heartbeat file gone; assume fatal and tree-kill. + _kill_subtree(pid) + return + if age > delay: + # Heartbeat stale: main thread likely suspended by pydevd. + _kill_subtree(pid) + return + + +if __name__ == "__main__": + main() diff --git a/src/litserve/api.py b/src/litserve/api.py index 73e8e80b4..f251a5dfa 100644 --- a/src/litserve/api.py +++ b/src/litserve/api.py @@ -148,7 +148,6 @@ def __init__( enable_async: bool = False, ): """Initialize LitAPI with configuration options.""" - if max_batch_size <= 0: raise ValueError("max_batch_size must be greater than 0") @@ -381,7 +380,6 @@ def pre_setup(self, spec: Optional[LitSpec] = None): def set_logger_queue(self, queue: Queue): """Set the queue for logging events.""" - self._logger_queue = queue def log(self, key, value): diff --git a/src/litserve/loggers.py b/src/litserve/loggers.py index acf2b18a0..88d217413 100644 --- a/src/litserve/loggers.py +++ b/src/litserve/loggers.py @@ -31,8 +31,10 @@ def __init__(self): self._config = {} def mount(self, path: str, app: ASGIApp) -> None: - """Mount an ASGI app endpoint to LitServer. Use this method when you want to add an additional endpoint to the - server such as /metrics endpoint for prometheus metrics. + """Mount an ASGI app endpoint to LitServer. + + Use this method when you want to add an additional endpoint to the server such as /metrics endpoint for + prometheus metrics. Args: path (str): The path to mount the app to. diff --git a/src/litserve/loops/base.py b/src/litserve/loops/base.py index 68d89c55b..9cc679c92 100644 --- a/src/litserve/loops/base.py +++ b/src/litserve/loops/base.py @@ -394,8 +394,8 @@ def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec] = None): ] ): raise ValueError( - """When `stream=False`, `lit_api.predict`, `lit_api.encode_response` must not be - generator or async generator functions. + """When `stream=False`, `lit_api.predict`, `lit_api.encode_response` must not be generator or async + generator functions. Correct usage: @@ -420,6 +420,7 @@ async def predict(self, inputs): ... for i in range(max_token_length): yield prediction + """ ) if ( diff --git a/src/litserve/loops/continuous_batching_loop.py b/src/litserve/loops/continuous_batching_loop.py index 2058e456b..8b8412066 100644 --- a/src/litserve/loops/continuous_batching_loop.py +++ b/src/litserve/loops/continuous_batching_loop.py @@ -55,8 +55,9 @@ class Output: class ContinuousBatchingLoop(LitLoop): def __init__(self, max_sequence_length: int = 2048, no_pending_requests: bool = False, sleep_delay: float = 0.001): - """Runs continuous batching loop. This loop handles adding new requests, processing them in batches, and - managing the state of active sequences. + """Runs continuous batching loop. + + This loop handles adding new requests, processing them in batches, and managing the state of active sequences. The loop requires the following methods to be implemented in the LitAPI: - setup: sets up the model on the device @@ -84,28 +85,26 @@ def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec] = None): ) if not hasattr(lit_api, "step") and not hasattr(lit_api, "predict"): - raise ValueError("""Using the default step method with Continuous batching loop requires the lit_api to -have a `predict` method which accepts decoded request inputs and a list of generated_sequence. -Please implement the has_finished method in the lit_api. - - class ExampleAPI(LitAPI): - ... - def predict(self, inputs, generated_sequence): - # implement predict logic - # return list of new tokens - ... - """) + raise ValueError( + """Using the default step method with Continuous batching loop requires the lit_api to have a `predict` + method which accepts decoded request inputs and a list of generated_sequence. Please implement the + has_finished method in the lit_api. + + class ExampleAPI(LitAPI): ... def predict(self, inputs, generated_sequence): # implement predict + logic # return list of new tokens ... + + """ + ) if not hasattr(lit_api, "step") and not hasattr(lit_api, "has_finished"): - raise ValueError("""Using the default step method with Continuous batching loop -requires the lit_api to have a has_finished method. Please implement the has_finished method in the lit_api. + raise ValueError("""Using the default step method with Continuous batching loop requires the lit_api to have + a has_finished method. Please implement the has_finished method in the lit_api. + + class ExampleAPI(LitAPI): ... def has_finished(self, uid: str, token: str, + max_sequence_length: int) -> bool: # implement has_finished logic return + False - class ExampleAPI(LitAPI): - ... - def has_finished(self, uid: str, token: str, max_sequence_length: int) -> bool: - # implement has_finished logic - return False - """) + """) def add_request( self, @@ -223,7 +222,6 @@ async def run( callback_runner: CallbackRunner, ): """Main loop that processes batches of requests.""" - warning_counter = 0 lit_spec = lit_api.spec try: diff --git a/src/litserve/mcp.py b/src/litserve/mcp.py index 22ea62e5a..e02a2cb12 100644 --- a/src/litserve/mcp.py +++ b/src/litserve/mcp.py @@ -52,8 +52,9 @@ def extract_input_schema(func) -> dict[str, Any]: - """Extract JSON schema for function input parameters from a Python function. Supports regular type annotations, - Pydantic Fields, and Pydantic BaseModel classes. + """Extract JSON schema for function input parameters from a Python function. + + Supports regular type annotations, Pydantic Fields, and Pydantic BaseModel classes. Args: func: Python function to analyze @@ -561,19 +562,15 @@ def connect_mcp_server(self, mcp_tools: list[ToolType], app: FastAPI): app: LitServer's FastAPI app to mount the MCP server to. """ - if len(mcp_tools) == 0: return for tool in mcp_tools: self.add_tool(tool) - logger.warning( "MCP support is in beta and APIs are subject to change. Please report any issues to https://github.com/Lightning-AI/litserve/issues" ) - self._mount_with_fastapi(app) - logger.info( "================================================" "\nšŸŽ‰ Enabled MCP server ⚔\n" diff --git a/src/litserve/server.py b/src/litserve/server.py index ccefb5c23..43412f67b 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -23,6 +23,7 @@ import secrets import socket import sys +import tempfile import threading import time import uuid @@ -42,7 +43,7 @@ from starlette.formparsers import MultiPartParser from starlette.middleware.gzip import GZipMiddleware -from litserve import LitAPI +from litserve import LitAPI, _win_shutdown_fix from litserve.callbacks.base import Callback, CallbackRunner, EventTypes from litserve.connector import _Connector from litserve.loggers import Logger, _LoggerConnector @@ -87,6 +88,10 @@ def no_auth(): pass +def _is_pydevd_active() -> bool: + return sys.platform == "win32" and "pydevd" in sys.modules + + def api_key_auth(x_api_key: str = Depends(APIKeyHeader(name="X-API-Key"))): if x_api_key != LIT_SERVER_API_KEY: raise HTTPException( @@ -1098,7 +1103,6 @@ def _get_request_queue(self, api_path: str): def _register_api_endpoints(self, lit_api: LitAPI, request_type, response_type): """Register endpoint routes for the FastAPI app.""" - self._callback_runner.trigger_event(EventTypes.ON_SERVER_START.value, litserver=self) # Create handlers @@ -1208,11 +1212,14 @@ def _perform_graceful_shutdown( logger.warning(f"{log_prefix}: Already not alive.") continue try: - uw.terminate() - uw.join(timeout=self.uvicorn_graceful_timeout) - if uw.is_alive(): - logger.warning(f"{log_prefix}: Did not terminate gracefully. Forcibly killing.") - uw.kill() + if isinstance(uw, threading.Thread): + uw.join(timeout=self.uvicorn_graceful_timeout) + else: + uw.terminate() + uw.join(timeout=self.uvicorn_graceful_timeout) + if uw.is_alive(): + logger.warning(f"{log_prefix}: Did not terminate gracefully. Forcibly killing.") + uw.kill() except Exception as e: logger.error(f"Error during termination of {log_prefix}: {e}") @@ -1502,6 +1509,14 @@ def run( self.verify_worker_status() + _pydevd_active = _is_pydevd_active() + _heartbeat_path = None + if _pydevd_active: + _heartbeat_path = os.path.join(tempfile.gettempdir(), f"litserve_hb_{os.getpid()}.tmp") + with contextlib.suppress(Exception): + open(_heartbeat_path, "w").close() + _win_shutdown_fix.start_heartbeat_sentinel(os.getpid(), _heartbeat_path, kill_delay=3.0) + shutdown_reason = "normal" uvicorn_workers = {} try: @@ -1515,7 +1530,13 @@ def run( if self._monitor_workers: self._start_worker_monitoring(manager, uvicorn_workers) - self._shutdown_event.wait() + if _pydevd_active and _heartbeat_path: + while not self._shutdown_event.is_set(): + with contextlib.suppress(Exception): + os.utime(_heartbeat_path, None) + time.sleep(0.5) + else: + self._shutdown_event.wait() except KeyboardInterrupt: logger.info("KeyboardInterrupt received. Initiating graceful shutdown.") diff --git a/src/litserve/utils.py b/src/litserve/utils.py index 03de04cad..9d11d9e9e 100644 --- a/src/litserve/utils.py +++ b/src/litserve/utils.py @@ -322,7 +322,6 @@ def add_ssl_context_from_env(kwargs: dict[str, Any]) -> dict[str, Any]: returns an empty dictionary. """ - if "ssl_keyfile" in kwargs and "ssl_certfile" in kwargs: return kwargs diff --git a/tests/conftest.py b/tests/conftest.py index a3a499caf..4b691fef6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -372,6 +372,9 @@ def set(self): def wait(self): pass + def is_set(self): + return True + class MockQueue: def put(self, item): diff --git a/tests/parity_fastapi/benchmark.py b/tests/parity_fastapi/benchmark.py index 4f1ab94fb..77ffe288c 100644 --- a/tests/parity_fastapi/benchmark.py +++ b/tests/parity_fastapi/benchmark.py @@ -52,7 +52,6 @@ def send_request(args): def benchmark(num_requests=100, concurrency_level=100, port=8000): """Benchmark the ML server.""" - # Create a session with appropriate pool size session = create_session(pool_connections=min(concurrency_level, 100), pool_maxsize=min(concurrency_level, 100)) diff --git a/tests/unit/test_lit_server.py b/tests/unit/test_lit_server.py index d319ae192..89102f41f 100644 --- a/tests/unit/test_lit_server.py +++ b/tests/unit/test_lit_server.py @@ -16,7 +16,9 @@ import json import os import sys +import threading import time +import types from time import sleep from unittest.mock import MagicMock, patch @@ -907,6 +909,125 @@ def Process(self, target, args, name): # noqa: N802 assert total_by_path == expected_total_by_path +# ── _is_pydevd_active ───────────────────────────────────────────────────────── + + +def test_is_pydevd_active_false_no_pydevd(monkeypatch): + from litserve.server import _is_pydevd_active + + monkeypatch.delitem(sys.modules, "pydevd", raising=False) + monkeypatch.setattr(sys, "platform", "win32") + assert _is_pydevd_active() is False + + +def test_is_pydevd_active_false_on_non_windows(monkeypatch): + from litserve.server import _is_pydevd_active + + monkeypatch.setitem(sys.modules, "pydevd", types.ModuleType("pydevd")) + monkeypatch.setattr(sys, "platform", "linux") + assert _is_pydevd_active() is False + + +def test_is_pydevd_active_true_when_both(monkeypatch): + from litserve.server import _is_pydevd_active + + monkeypatch.setitem(sys.modules, "pydevd", types.ModuleType("pydevd")) + monkeypatch.setattr(sys, "platform", "win32") + assert _is_pydevd_active() is True + + +# ── _perform_graceful_shutdown: Thread vs Process branch ────────────────────── + + +class _FakeUvicornThread(threading.Thread): + def __init__(self, alive=True): + super().__init__() + self.join = MagicMock() + self.is_alive = MagicMock(return_value=alive) + + +def test_perform_graceful_shutdown_joins_thread_without_terminate(): + server = LitServer(SimpleLitAPI()) + server._transport = MagicMock() + server.inference_workers = [] + t = _FakeUvicornThread(alive=True) + + server._perform_graceful_shutdown(MagicMock(), {0: t}) + + t.join.assert_called_once_with(timeout=server.uvicorn_graceful_timeout) + + +def test_perform_graceful_shutdown_terminates_process_worker(): + mock_proc = MagicMock() + mock_proc.is_alive.side_effect = [True, True] # alive pre-check passes; alive after join → kill + + server = LitServer(SimpleLitAPI()) + server._transport = MagicMock() + server.inference_workers = [] + + server._perform_graceful_shutdown(MagicMock(), {0: mock_proc}) + + mock_proc.terminate.assert_called_once() + mock_proc.join.assert_called_once_with(timeout=server.uvicorn_graceful_timeout) + mock_proc.kill.assert_called_once() + + +def test_perform_graceful_shutdown_skips_dead_workers(): + server = LitServer(SimpleLitAPI()) + server._transport = MagicMock() + server.inference_workers = [] + t = _FakeUvicornThread(alive=False) + + server._perform_graceful_shutdown(MagicMock(), {0: t}) + + t.join.assert_not_called() + + +# ── LitServer.run(): pydevd heartbeat branch ────────────────────────────────── + + +@patch("litserve.server.uvicorn") +@patch("litserve.server._is_pydevd_active", return_value=True) +def test_run_starts_heartbeat_sentinel_when_pydevd_active(mock_pydevd, mock_uvicorn, mock_manager): + server = LitServer(SimpleLitAPI()) + server.verify_worker_status = MagicMock() + server.launch_inference_worker = MagicMock(return_value=[MagicMock()]) + server._start_server = MagicMock(return_value={}) + server._perform_graceful_shutdown = MagicMock() + server._monitor_workers = False + + with ( + patch("litserve._win_shutdown_fix.start_heartbeat_sentinel") as mock_sentinel, + patch("litserve.server.mp.Manager", return_value=mock_manager), + ): + server.run(port=8000) + + mock_sentinel.assert_called_once() + pid_arg, path_arg = mock_sentinel.call_args[0][:2] + assert pid_arg == os.getpid() + assert f"litserve_hb_{os.getpid()}.tmp" in path_arg + assert mock_sentinel.call_args[1].get("kill_delay") == 3.0 + + +@patch("litserve.server.uvicorn") +@patch("litserve.server._is_pydevd_active", return_value=False) +def test_run_no_sentinel_when_pydevd_inactive(mock_pydevd, mock_uvicorn, mock_manager): + server = LitServer(SimpleLitAPI()) + server.verify_worker_status = MagicMock() + server.launch_inference_worker = MagicMock(return_value=[MagicMock()]) + server._start_server = MagicMock(return_value={}) + server._perform_graceful_shutdown = MagicMock() + server._monitor_workers = False + + with ( + patch("litserve._win_shutdown_fix.start_heartbeat_sentinel") as mock_sentinel, + patch("litserve.server.mp.Manager", return_value=mock_manager), + ): + server.run(port=8000) + + mock_sentinel.assert_not_called() + + @pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix") def test_workers_per_device_per_route_raises_on_unknown_route(): sentiment = MultiRouteAPI(api_path="/sentiment") diff --git a/tests/unit/test_middlewares.py b/tests/unit/test_middlewares.py index c4a5346b0..c63befad7 100644 --- a/tests/unit/test_middlewares.py +++ b/tests/unit/test_middlewares.py @@ -92,7 +92,6 @@ def test_middleware_multiple_initialization(): def test_track_requests_middleware_isolation(): """Test that _prepare_app_run doesn't modify the original app's middleware list.""" - lit_api = ls.test_examples.SimpleLitAPI() server = ls.LitServer(lit_api, track_requests=True) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 046102caf..8fc5002d6 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -77,7 +77,6 @@ async def test_azip_propagates_stream_errors(): @pytest.mark.skipif(sys.platform == "win32", reason="This test is for non-Windows platforms only.") def test_generate_random_zmq_address_non_windows(tmpdir): """Test generate_random_zmq_address on non-Windows platforms.""" - temp_dir = str(tmpdir) address1 = generate_random_zmq_address(temp_dir=temp_dir) address2 = generate_random_zmq_address(temp_dir=temp_dir) diff --git a/tests/unit/test_win_shutdown_fix.py b/tests/unit/test_win_shutdown_fix.py new file mode 100644 index 000000000..06f5af2b8 --- /dev/null +++ b/tests/unit/test_win_shutdown_fix.py @@ -0,0 +1,223 @@ +import ctypes +import sys +from unittest.mock import MagicMock, patch + +import pytest + +# ── start_heartbeat_sentinel ────────────────────────────────────────────────── + + +def test_start_heartbeat_sentinel_noop_on_non_windows(monkeypatch): + import litserve._win_shutdown_fix as mod + + monkeypatch.setattr(mod.sys, "platform", "linux") + with patch.object(mod, "_create_process_no_window") as mock_spawn: + mod.start_heartbeat_sentinel(1234, "/tmp/hb.tmp", 3.0) + mock_spawn.assert_not_called() + + +def test_start_heartbeat_sentinel_writes_ps1_and_spawns(monkeypatch, tmp_path): + import litserve._win_shutdown_fix as mod + + monkeypatch.setattr(mod.sys, "platform", "win32") + + with ( + patch.object(mod, "_create_process_no_window", return_value=True) as mock_spawn, + patch("tempfile.gettempdir", return_value=str(tmp_path)), + ): + mod.start_heartbeat_sentinel(99999, r"C:\hb.tmp", 3.0) + + ps1 = tmp_path / "litserve_spawn_sentinel_99999.ps1" + assert ps1.exists() + content = ps1.read_text() + assert "Invoke-WmiMethod" in content + assert "Win32_Process" in content + assert "99999" in content + assert "3.0" in content + assert "_child.py" in content + + mock_spawn.assert_called_once() + spawn_arg = mock_spawn.call_args[0][0] + assert spawn_arg.startswith("powershell.exe -NoProfile") + assert str(ps1) in spawn_arg + + +def test_start_heartbeat_sentinel_escapes_single_quotes(monkeypatch, tmp_path): + import litserve._win_shutdown_fix as mod + + monkeypatch.setattr(mod.sys, "platform", "win32") + + with ( + patch.object(mod, "_create_process_no_window", return_value=True), + patch("tempfile.gettempdir", return_value=str(tmp_path)), + ): + mod.start_heartbeat_sentinel(12345, r"C:\o'malley\hb.tmp", 3.0) + + content = (tmp_path / "litserve_spawn_sentinel_12345.ps1").read_text() + assert "''" in content # single quotes doubled for PS single-quoted string + + +def test_start_heartbeat_sentinel_swallows_spawn_errors(monkeypatch, tmp_path): + import litserve._win_shutdown_fix as mod + + monkeypatch.setattr(mod.sys, "platform", "win32") + + with ( + patch.object(mod, "_create_process_no_window", side_effect=RuntimeError("kaboom")), + patch("tempfile.gettempdir", return_value=str(tmp_path)), + ): + mod.start_heartbeat_sentinel(99999, r"C:\hb.tmp", 3.0) # must not raise + + +# ── _create_process_no_window ───────────────────────────────────────────────── + + +@pytest.mark.skipif(sys.platform != "win32", reason="Windows-only: tests ctypes.windll directly") +def test_create_process_no_window_returns_true_on_success(monkeypatch): + k32 = MagicMock() + k32.CreateProcessW.return_value = 1 + monkeypatch.setattr(ctypes, "windll", MagicMock(kernel32=k32), raising=False) + + from litserve._win_shutdown_fix import _create_process_no_window + + assert _create_process_no_window("cmd.exe") is True + assert k32.CloseHandle.call_count == 2 + + +@pytest.mark.skipif(sys.platform != "win32", reason="Windows-only: tests ctypes.windll directly") +def test_create_process_no_window_returns_false_on_failure(monkeypatch): + k32 = MagicMock() + k32.CreateProcessW.return_value = 0 + monkeypatch.setattr(ctypes, "windll", MagicMock(kernel32=k32), raising=False) + + from litserve._win_shutdown_fix import _create_process_no_window + + assert _create_process_no_window("cmd.exe") is False + k32.CloseHandle.assert_not_called() + + +# ── _child._alive ───────────────────────────────────────────────────────────── + + +def test_alive_true_when_handle_returned(monkeypatch): + k32 = MagicMock() + k32.OpenProcess.return_value = 0x1234 + monkeypatch.setattr(ctypes, "windll", MagicMock(kernel32=k32), raising=False) + + from litserve._win_shutdown_fix import _child + + assert _child._alive(123) is True + k32.CloseHandle.assert_called_once_with(0x1234) + + +def test_alive_false_when_no_handle(monkeypatch): + k32 = MagicMock() + k32.OpenProcess.return_value = 0 + monkeypatch.setattr(ctypes, "windll", MagicMock(kernel32=k32), raising=False) + + from litserve._win_shutdown_fix import _child + + assert _child._alive(123) is False + k32.CloseHandle.assert_not_called() + + +# ── _child._kill_subtree ────────────────────────────────────────────────────── + + +def test_kill_subtree_invalid_snapshot_is_noop(monkeypatch): + k32 = MagicMock() + k32.CreateToolhelp32Snapshot.return_value = ctypes.c_void_p(-1).value + monkeypatch.setattr(ctypes, "windll", MagicMock(kernel32=k32), raising=False) + + from litserve._win_shutdown_fix import _child + + _child._kill_subtree(100) + + k32.Process32First.assert_not_called() + k32.OpenProcess.assert_not_called() + k32.TerminateProcess.assert_not_called() + + +def test_kill_subtree_empty_snapshot_kills_only_root(monkeypatch): + k32 = MagicMock() + k32.CreateToolhelp32Snapshot.return_value = 42 + k32.Process32First.return_value = 0 # empty snapshot: no entries + k32.OpenProcess.return_value = 0xAB + monkeypatch.setattr(ctypes, "windll", MagicMock(kernel32=k32), raising=False) + + from litserve._win_shutdown_fix import _child + + _child._kill_subtree(100) + + k32.CloseHandle.assert_any_call(42) # snapshot handle closed + k32.OpenProcess.assert_called_once() + assert k32.OpenProcess.call_args[0][2] == 100 # root pid targeted + k32.TerminateProcess.assert_called_once() + + +# ── _child.main ─────────────────────────────────────────────────────────────── + + +def test_main_kills_when_pid_dies(monkeypatch): + from litserve._win_shutdown_fix import _child + + mock_kill = MagicMock() + monkeypatch.setattr(_child, "_alive", MagicMock(return_value=False)) + monkeypatch.setattr(_child, "_kill_subtree", mock_kill) + monkeypatch.setattr(_child.time, "sleep", MagicMock()) + monkeypatch.setattr(sys, "argv", ["_child.py", "9999", "/tmp/hb.tmp", "3.0"]) + + _child.main() + + mock_kill.assert_called_once_with(9999) + + +def test_main_kills_when_heartbeat_file_missing(monkeypatch): + from litserve._win_shutdown_fix import _child + + mock_kill = MagicMock() + monkeypatch.setattr(_child, "_alive", MagicMock(return_value=True)) + monkeypatch.setattr(_child, "_kill_subtree", mock_kill) + monkeypatch.setattr(_child.time, "sleep", MagicMock()) + monkeypatch.setattr(_child.time, "time", MagicMock(return_value=1000.0)) + monkeypatch.setattr(_child.os.path, "getmtime", MagicMock(side_effect=OSError("gone"))) + monkeypatch.setattr(sys, "argv", ["_child.py", "9999", "/tmp/hb.tmp", "3.0"]) + + _child.main() + + mock_kill.assert_called_once_with(9999) + + +def test_main_kills_when_heartbeat_stale(monkeypatch): + from litserve._win_shutdown_fix import _child + + mock_kill = MagicMock() + monkeypatch.setattr(_child, "_alive", MagicMock(return_value=True)) + monkeypatch.setattr(_child, "_kill_subtree", mock_kill) + monkeypatch.setattr(_child.time, "sleep", MagicMock()) + monkeypatch.setattr(_child.time, "time", MagicMock(return_value=1000.0)) + monkeypatch.setattr(_child.os.path, "getmtime", MagicMock(return_value=990.0)) # age=10 > delay=3 + monkeypatch.setattr(sys, "argv", ["_child.py", "9999", "/tmp/hb.tmp", "3.0"]) + + _child.main() + + mock_kill.assert_called_once_with(9999) + + +def test_main_loops_until_stale(monkeypatch): + from litserve._win_shutdown_fix import _child + + mock_sleep = MagicMock() + mock_kill = MagicMock() + monkeypatch.setattr(_child, "_alive", MagicMock(return_value=True)) + monkeypatch.setattr(_child, "_kill_subtree", mock_kill) + monkeypatch.setattr(_child.time, "sleep", mock_sleep) + # ages: 1.0, 2.0, 6.0 — kill triggered on third iteration when age > delay=3 + monkeypatch.setattr(_child.time, "time", MagicMock(side_effect=[1000.0, 1001.0, 1005.0])) + monkeypatch.setattr(_child.os.path, "getmtime", MagicMock(return_value=999.0)) + monkeypatch.setattr(sys, "argv", ["_child.py", "9999", "/tmp/hb.tmp", "3.0"]) + + _child.main() + + mock_kill.assert_called_once_with(9999) + assert mock_sleep.call_count == 4 # initial sleep(2) + 3 Ɨ sleep(0.5)