diff --git a/src/litserve/server.py b/src/litserve/server.py index ccefb5c2..ea56241e 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -77,6 +77,12 @@ SHUTDOWN_API_KEY = os.environ.get("LIT_SHUTDOWN_API_KEY") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + +def _display_host_for_url(host: str) -> str: + """Return a browser-usable host for user-facing URLs.""" + return "localhost" if host in {"0.0.0.0", "::"} else host + + # FastAPI writes form files to disk over 1MB by default, which prevents serialization by multiprocessing MultiPartParser.max_file_size = sys.maxsize # renamed in PR: https://github.com/encode/starlette/pull/2780 @@ -1510,7 +1516,8 @@ def run( ) if not self._disable_openapi_url: - print(f"Swagger UI is available at http://0.0.0.0:{port}/docs") + display_host = _display_host_for_url(host) + print(f"Swagger UI is available at http://{display_host}:{port}/docs") if self._monitor_workers: self._start_worker_monitoring(manager, uvicorn_workers) diff --git a/tests/unit/test_lit_server.py b/tests/unit/test_lit_server.py index d319ae19..604a4f45 100644 --- a/tests/unit/test_lit_server.py +++ b/tests/unit/test_lit_server.py @@ -340,6 +340,14 @@ def test_server_terminate(): server._transport.close.assert_called() +def test_display_host_for_url_uses_browser_address(): + from litserve.server import _display_host_for_url + + assert _display_host_for_url("0.0.0.0") == "localhost" + assert _display_host_for_url("::") == "localhost" + assert _display_host_for_url("127.0.0.1") == "127.0.0.1" + + @pytest.mark.parametrize(("disable_openapi_url", "should_print"), [(False, True), (True, False)]) @patch("builtins.print") @patch("litserve.server.uvicorn") @@ -356,7 +364,7 @@ def test_disable_openapi_url_print_message(mock_uvicorn, mock_print, mock_manage server.run(port=8000) if should_print: - mock_print.assert_called_with("Swagger UI is available at http://0.0.0.0:8000/docs") + mock_print.assert_called_with("Swagger UI is available at http://localhost:8000/docs") else: mock_print.assert_not_called()