Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions src/litserve/_win_shutdown_fix/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""Windows/PyCharm debugger shutdown sentinel."""

import contextlib
import os
import sys


def _create_process_no_window(cmd_line: str) -> bool:
# ctypes CreateProcessW bypasses pydevd's subprocess.Popen patch entirely.
import ctypes
from ctypes import wintypes

class STARTUPINFOW(ctypes.Structure):
_fields_ = (
("cb", wintypes.DWORD),
("lpReserved", wintypes.LPWSTR),
("lpDesktop", wintypes.LPWSTR),
("lpTitle", wintypes.LPWSTR),
("dwX", wintypes.DWORD),
("dwY", wintypes.DWORD),
("dwXSize", wintypes.DWORD),
("dwYSize", wintypes.DWORD),
("dwXCountChars", wintypes.DWORD),
("dwYCountChars", wintypes.DWORD),
("dwFillAttribute", wintypes.DWORD),
("dwFlags", wintypes.DWORD),
("wShowWindow", wintypes.WORD),
("cbReserved2", wintypes.WORD),
("lpReserved2", ctypes.c_void_p),
("hStdInput", wintypes.HANDLE),
("hStdOutput", wintypes.HANDLE),
("hStdError", wintypes.HANDLE),
)

class ProcessInformation(ctypes.Structure):
_fields_ = (
("hProcess", wintypes.HANDLE),
("hThread", wintypes.HANDLE),
("dwProcessId", wintypes.DWORD),
("dwThreadId", wintypes.DWORD),
)

si = STARTUPINFOW()
si.cb = ctypes.sizeof(STARTUPINFOW)
pi = ProcessInformation()
cmd_buf = ctypes.create_unicode_buffer(cmd_line)
ok = ctypes.windll.kernel32.CreateProcessW(
None,
cmd_buf,
None,
None,
False,
0x08000000,
None,
None,
ctypes.byref(si),
ctypes.byref(pi),
)
if ok:
ctypes.windll.kernel32.CloseHandle(pi.hProcess)
ctypes.windll.kernel32.CloseHandle(pi.hThread)
return bool(ok)


def start_heartbeat_sentinel(pid: int, heartbeat_path: str, kill_delay: float) -> None:
"""Spawn an out-of-job sentinel that kills the process tree when the heartbeat goes stale.

Spawn chain: ctypes.CreateProcessW -> powershell.exe -File <ps1> -> WMI Win32_Process.Create
-> python.exe _child.py. The final process runs under WmiPrvSE, outside PyCharm's Job
Object, with no pydevd injected.

"""
if sys.platform != "win32":
return
import tempfile

child_py = os.path.join(os.path.dirname(__file__), "_child.py")
cmd = f'"{sys.executable}" "{child_py}" {pid} "{heartbeat_path}" {kill_delay}'
spawn_ps1 = os.path.join(tempfile.gettempdir(), f"litserve_spawn_sentinel_{pid}.ps1")
ps1_arg = cmd.replace("'", "''") # escape single quotes for PS single-quoted string
with contextlib.suppress(Exception), open(spawn_ps1, "w") as f:
f.write(
"try {\n"
" $r = Invoke-WmiMethod -Class Win32_Process -Name Create"
f" -ArgumentList '{ps1_arg}'\n"
" exit [int]$r.ReturnValue\n"
"} catch { exit 1 }\n"
)

ps_cmd = f'powershell.exe -NoProfile -NonInteractive -ExecutionPolicy Bypass -File "{spawn_ps1}"'
with contextlib.suppress(Exception):
_create_process_no_window(ps_cmd)
92 changes: 92 additions & 0 deletions src/litserve/_win_shutdown_fix/_child.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""Sentinel child process.

Invoked as: python _child.py <pid> <heartbeat_path> <kill_delay>

"""

import ctypes
import os
import sys
import time


def _alive(pid):
h = ctypes.windll.kernel32.OpenProcess(0x1000, False, pid)
if h:
ctypes.windll.kernel32.CloseHandle(h)
return True
return False


def _kill_subtree(root_pid):
# Walk CreateToolhelp32Snapshot to find ALL descendants of root_pid.
# th32ParentProcessID is fixed at creation time, so orphaned children
# (whose parent already exited) are still found and killed.
class PROCESSENTRY32(ctypes.Structure):
_fields_ = [
("dwSize", ctypes.c_ulong),
("cntUsage", ctypes.c_ulong),
("th32ProcessID", ctypes.c_ulong),
("th32DefaultHeapID", ctypes.c_size_t),
("th32ModuleID", ctypes.c_ulong),
("cntThreads", ctypes.c_ulong),
("th32ParentProcessID", ctypes.c_ulong),
("pcPriClassBase", ctypes.c_long),
("dwFlags", ctypes.c_ulong),
("szExeFile", ctypes.c_char * 260),
]

k32 = ctypes.windll.kernel32
snap = k32.CreateToolhelp32Snapshot(0x00000002, 0)
if snap == ctypes.c_void_p(-1).value or snap is None:
return

parent_map = {}
pe = PROCESSENTRY32()
pe.dwSize = ctypes.sizeof(PROCESSENTRY32)
if k32.Process32First(snap, ctypes.byref(pe)):
while True:
parent_map.setdefault(pe.th32ParentProcessID, []).append(pe.th32ProcessID)
if not k32.Process32Next(snap, ctypes.byref(pe)):
break
k32.CloseHandle(snap)

descendants = []
queue = [root_pid]
while queue:
cur = queue.pop()
for child in parent_map.get(cur, []):
if child != root_pid:
descendants.append(child)
queue.append(child)

for cpid in descendants + [root_pid]:
h = k32.OpenProcess(0x0001, False, cpid)
if h:
k32.TerminateProcess(h, 1)
k32.CloseHandle(h)


def main():
pid, hb, delay = int(sys.argv[1]), sys.argv[2], float(sys.argv[3])
time.sleep(2) # grace period: let server finish startup file I/O
while True:
time.sleep(0.5)
if not _alive(pid):
# Main process exited cleanly; kill any orphaned children.
_kill_subtree(pid)
return
try:
age = time.time() - os.path.getmtime(hb)
except OSError:
# Heartbeat file gone; assume fatal and tree-kill.
_kill_subtree(pid)
return
if age > delay:
# Heartbeat stale: main thread likely suspended by pydevd.
_kill_subtree(pid)
return


if __name__ == "__main__":
main()
2 changes: 0 additions & 2 deletions src/litserve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def __init__(
enable_async: bool = False,
):
"""Initialize LitAPI with configuration options."""

if max_batch_size <= 0:
raise ValueError("max_batch_size must be greater than 0")

Expand Down Expand Up @@ -381,7 +380,6 @@ def pre_setup(self, spec: Optional[LitSpec] = None):

def set_logger_queue(self, queue: Queue):
"""Set the queue for logging events."""

self._logger_queue = queue

def log(self, key, value):
Expand Down
6 changes: 4 additions & 2 deletions src/litserve/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ def __init__(self):
self._config = {}

def mount(self, path: str, app: ASGIApp) -> None:
"""Mount an ASGI app endpoint to LitServer. Use this method when you want to add an additional endpoint to the
server such as /metrics endpoint for prometheus metrics.
"""Mount an ASGI app endpoint to LitServer.

Use this method when you want to add an additional endpoint to the server such as /metrics endpoint for
prometheus metrics.

Args:
path (str): The path to mount the app to.
Expand Down
5 changes: 3 additions & 2 deletions src/litserve/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,8 @@ def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec] = None):
]
):
raise ValueError(
"""When `stream=False`, `lit_api.predict`, `lit_api.encode_response` must not be
generator or async generator functions.
"""When `stream=False`, `lit_api.predict`, `lit_api.encode_response` must not be generator or async
generator functions.

Correct usage:

Expand All @@ -420,6 +420,7 @@ async def predict(self, inputs):
...
for i in range(max_token_length):
yield prediction

"""
)
if (
Expand Down
42 changes: 20 additions & 22 deletions src/litserve/loops/continuous_batching_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ class Output:

class ContinuousBatchingLoop(LitLoop):
def __init__(self, max_sequence_length: int = 2048, no_pending_requests: bool = False, sleep_delay: float = 0.001):
"""Runs continuous batching loop. This loop handles adding new requests, processing them in batches, and
managing the state of active sequences.
"""Runs continuous batching loop.

This loop handles adding new requests, processing them in batches, and managing the state of active sequences.

The loop requires the following methods to be implemented in the LitAPI:
- setup: sets up the model on the device
Expand Down Expand Up @@ -84,28 +85,26 @@ def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec] = None):
)

if not hasattr(lit_api, "step") and not hasattr(lit_api, "predict"):
raise ValueError("""Using the default step method with Continuous batching loop requires the lit_api to
have a `predict` method which accepts decoded request inputs and a list of generated_sequence.
Please implement the has_finished method in the lit_api.

class ExampleAPI(LitAPI):
...
def predict(self, inputs, generated_sequence):
# implement predict logic
# return list of new tokens
...
""")
raise ValueError(
"""Using the default step method with Continuous batching loop requires the lit_api to have a `predict`
method which accepts decoded request inputs and a list of generated_sequence. Please implement the
has_finished method in the lit_api.

class ExampleAPI(LitAPI): ... def predict(self, inputs, generated_sequence): # implement predict
logic # return list of new tokens ...

"""
)

if not hasattr(lit_api, "step") and not hasattr(lit_api, "has_finished"):
raise ValueError("""Using the default step method with Continuous batching loop
requires the lit_api to have a has_finished method. Please implement the has_finished method in the lit_api.
raise ValueError("""Using the default step method with Continuous batching loop requires the lit_api to have
a has_finished method. Please implement the has_finished method in the lit_api.

class ExampleAPI(LitAPI): ... def has_finished(self, uid: str, token: str,
max_sequence_length: int) -> bool: # implement has_finished logic return
False

class ExampleAPI(LitAPI):
...
def has_finished(self, uid: str, token: str, max_sequence_length: int) -> bool:
# implement has_finished logic
return False
""")
""")

def add_request(
self,
Expand Down Expand Up @@ -223,7 +222,6 @@ async def run(
callback_runner: CallbackRunner,
):
"""Main loop that processes batches of requests."""

warning_counter = 0
lit_spec = lit_api.spec
try:
Expand Down
9 changes: 3 additions & 6 deletions src/litserve/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@


def extract_input_schema(func) -> dict[str, Any]:
"""Extract JSON schema for function input parameters from a Python function. Supports regular type annotations,
Pydantic Fields, and Pydantic BaseModel classes.
"""Extract JSON schema for function input parameters from a Python function.

Supports regular type annotations, Pydantic Fields, and Pydantic BaseModel classes.

Args:
func: Python function to analyze
Expand Down Expand Up @@ -561,19 +562,15 @@ def connect_mcp_server(self, mcp_tools: list[ToolType], app: FastAPI):
app: LitServer's FastAPI app to mount the MCP server to.

"""

if len(mcp_tools) == 0:
return

for tool in mcp_tools:
self.add_tool(tool)

logger.warning(
"MCP support is in beta and APIs are subject to change. Please report any issues to https://github.com/Lightning-AI/litserve/issues"
)

self._mount_with_fastapi(app)

logger.info(
"================================================"
"\n🎉 Enabled MCP server ⚡\n"
Expand Down
Loading