diff --git a/rollout_refactor_requirements.md b/rollout_refactor_requirements.md new file mode 100644 index 0000000000..eb86bb8b35 --- /dev/null +++ b/rollout_refactor_requirements.md @@ -0,0 +1,275 @@ +# Rollout Generate Refactor Requirements + +## 背景 + +当前 rollout controller 和 rollout worker 同时承担控制面、生成数据面、健康检测与失败恢复等多类职责,导致调用路径复杂、序列化开销高,并且普通 agent 与 agentic/sandbox 场景存在不一致的生成入口。 + +本轮重构先聚焦职责拆分和生成路径解耦,不在第一版引入健康检测和失败恢复。 + +## 核心需求 + +### 1. 删除健康检测和失败恢复 + +第一版使用纯净模型,假设 rollout backend、server、worker 不会失败。 + +需要删除或避免引入以下逻辑: + +- worker/server 健康检测、探活、ready/liveness 轮询 +- worker 失败标记、降级、屏蔽 +- 自动恢复、重启 failed worker +- 健康检测状态在 controller、worker、router 之间同步 + +这类能力后续如果重新引入,需要作为独立设计处理,不应混入本轮 generate 解耦。 + +### 2. generate 从 controller 和 worker 中彻底移出 + +`RolloutController.generate` 和 `RolloutWorker.generate` 都不应该保留。 + +目标职责边界: + +- `RolloutController`:只负责中心化控制面,例如 runtime 初始化、worker metadata、权重同步、offload/onload、pause/continue、endpoint/router 启动。 +- `RolloutWorker`:只负责 backend server 生命周期和 backend 控制,例如启动服务、权重更新、KV cache 控制、pause/continue。 +- 独立生成模块:负责真正的数据生成调用、请求/响应转换、parser、partial rollout 等生成相关逻辑。 + +### 3. worker 侧生成类需要评估普通类与 Ray actor 形态 + +controller 侧拆出的生成入口相对轻量,可以优先按普通 async 类实现。 + +worker 侧生成模块可能包含 tokenizer、partial rollout handler、parser、状态转换等重操作,因此需要评估两种实现方式: + +#### 方案 A:生成类本身是 Ray actor + +每个 active worker 绑定一个 generation actor,数量与 active worker 一致,并尽量调度到对应 worker 所在节点或资源附近。 + +优点: + +- 重逻辑天然隔离,不阻塞调用方进程。 +- partial rollout handler、tokenizer 等状态可以常驻 actor 内。 +- 和 worker 一一绑定,路由关系清晰。 + +缺点: + +- 引入 Ray actor/grpc 序列化开销。 +- 需要管理 actor 生命周期和 placement。 +- 轻量单轮推理场景可能不划算。 + +#### 方案 B:生成类保持普通类,重操作外包 + +主 generate 类保持普通 async 类;当启用 tokenizer、partial rollout handler 等重操作时,将这些重操作单独外包给 Ray actor 或其他执行器。 + +优点: + +- 无重操作时路径最短,避免额外 Ray 序列化。 +- 可按需引入远程执行能力。 +- 更适合简单单轮推理场景。 + +缺点: + +- 接口拆分更复杂,需要明确哪些步骤属于主流程,哪些步骤可外包。 +- partial/tokenize actor 的生命周期和路由关系仍需设计。 +- 如果重操作很多,主流程可能变成多个远程调用,反而增加复杂度。 + +当前第一版选择方案 A。普通类/重操作外包方案先不进入实现,后续只有在明确需要进一步降低 Ray actor 调用开销时再单独评估。 + +当前实现原则: + +- 第一版统一使用 `RolloutWorkerGenerator`,每个 active rollout worker 绑定一个生成 actor。 +- agentloop 运行时只负责按 session 选择 worker,然后直接调用对应 `RolloutWorkerGenerator`,不经过 controller。 +- partial rollout 不再通过 controller 全局开关传播,而是作为单次 generate 调用参数进入生成模块,避免跨请求共享状态。 +- 对外入口集中在 `xtuner/v1/rl/rollout/rollout_generator.py`,内部实现集中在 `xtuner/v1/rl/rollout/_generation/`,避免 controller/worker 文件继续承载生成细节。 + +### 4. controller 拆分出的生成入口只是可选路径 + +generate 解耦后,生成路径需要支持至少三种运行模式: + +#### 模式 1:调用内部 generate 接口 + +agentloop 直接调用 xtuner 内部的独立 generate 类或接口。 + +适用场景: + +- 普通内联 agentloop +- 希望减少 controller 中转 +- 希望减少不必要序列化 + +#### 模式 2:调用 xtuner 内部 routed url + +xtuner 内部启动 routed HTTP endpoint,对外暴露一个 routed url。agentloop 通过该 url 生成,请求再路由到对应 worker 的生成 URL。 + +适用场景: + +- 需要 HTTP 接口兼容 +- agentloop 不适合直接调用 Ray/Python 接口 +- 希望由 xtuner 内部管理 router + +#### 模式 3:调用外部 routed 注册服务 + +xtuner 将每个 worker 的生成 URL 注册到外部 router,由外部服务提供 routed url。agentloop/sandbox 使用外部 routed url 生成。 + +适用场景: + +- agentic/sandbox 场景 +- 需要第三方 router、隔离部署或独立扩容 +- 不希望 xtuner 内部 router 成为中心化瓶颈 + +实现上,模式 2 和模式 3 对 agentloop 都是 HTTP 生成路径,因此统一为 `kind="http"`。 +二者区别只体现在 HTTP 入口来源:`http_entry="internal"` 表示 XTuner 启动内部 router, +`http_entry="external"` 表示 XTuner 将 worker 生成 URL 注册到外部 router。 + +代码上这两类入口对称命名为 `InternalRolloutHttpEntry` 和 `ExternalRolloutHttpEntry`。 +`InternalRolloutHttpEntry` 是 XTuner 内部 FastAPI 转发入口;`ExternalRolloutHttpEntry` +不转发请求,只复用 routedapiproxy 注册逻辑,将 worker 生成 URL 注册到外部 router。 + +worker 生成 URL 由 `http_worker_url_source` 控制: + +- `backend`:使用 worker backend url,适合评测或不需要 SessionServer 增强逻辑的简单场景。 +- `session`:使用每个 worker 外层的 SessionServer url,适合需要 session id、token cache、trace/replay 等增强逻辑的 agentic 场景。内部 router 会自动把选出的 `session_id` 写入请求体;外部 router 场景要求调用方或外部 router 保留/注入 `session_id`。 + +## 运行流程图 + +### 公共启动流程 + +所有模式共享同一套 worker 构建流程。`RolloutWorkerBuilder` 启动 active worker、收集 backend URL 和 SessionServer URL,并为每个 active worker 创建一个 `RolloutWorkerGenerator` actor。 + +```mermaid +flowchart LR + Cfg[RolloutConfig.build] --> Builder[RolloutWorkerBuilder] + Builder --> Worker[RolloutWorker] + Worker --> Backend[Backend server
lmdeploy/vLLM/SGLang] + Worker --> SessionServer[SessionServerActor
optional proxy URL] + Builder --> Generator[RolloutWorkerGenerator
one actor per active worker] + Builder --> Handle[RolloutWorkerHandle
rank, worker_actor, backend_url,
generator_actor, session_server_url] + Handle --> Controller[RolloutController metadata] +``` + +### 1. 本地生成:`kind="local"` + +`AgentLoop` 持有 `RolloutGenerateHandle`,其中包含 `LocalRolloutGenerator`。运行时按 session 选择 worker,然后直接调用该 worker 绑定的 `RolloutWorkerGenerator` actor。controller 不在数据生成路径上。 + +```mermaid +flowchart LR + AgentLoop[AgentLoop] --> Handle[RolloutGenerateHandle
kind=local] + Handle --> LocalGen[LocalRolloutGenerator] + LocalGen --> Selector[SessionWorkerSelector] + Selector --> WorkerHandle[RolloutWorkerHandle] + WorkerHandle --> GenActor[RolloutWorkerGenerator actor] + GenActor --> Backend[worker backend_url] +``` + +参与类: + +- `RolloutGenerateHandle` +- `LocalRolloutGenerator` +- `SessionWorkerSelector` +- `RolloutWorkerHandle` +- `RolloutWorkerGenerator` + +### 2. 内部 router,直接走 backend:`kind="http", http_entry="internal", http_worker_url_source="backend"` + +XTuner 启动 `InternalRolloutHttpEntry`。agentloop 只看到一个 internal router base URL。router 按 session 选择 worker,并把请求转发到 worker 的 `backend_url`。这条路径不经过 `SessionServerActor`。 + +```mermaid +flowchart LR + AgentLoop[AgentLoop] --> Handle[RolloutGenerateHandle
kind=http] + Handle --> RouterURL[internal_http_entry_url] + RouterURL --> InternalRouter[InternalRolloutHttpEntry] + InternalRouter --> Selector[SessionWorkerSelector] + Selector --> WorkerHandle[RolloutWorkerHandle] + WorkerHandle --> Backend[worker backend_url] +``` + +参与类: + +- `RolloutGenerateHandle` +- `InternalRolloutHttpEntry` +- `SessionWorkerSelector` +- `RolloutWorkerHandle` + +### 3. 内部 router,走 SessionServer:`kind="http", http_entry="internal", http_worker_url_source="session"` + +XTuner 仍然启动 `InternalRolloutHttpEntry`,但 router 选择 worker 后转发到 `session_server_url`。router 会把选出的 `session_id` 写入请求体,满足 `SessionServer` 的 session/cache/trace 逻辑。 + +```mermaid +flowchart LR + AgentLoop[AgentLoop] --> Handle[RolloutGenerateHandle
kind=http] + Handle --> RouterURL[internal_http_entry_url] + RouterURL --> InternalRouter[InternalRolloutHttpEntry] + InternalRouter --> Selector[SessionWorkerSelector] + Selector --> WorkerHandle[RolloutWorkerHandle] + WorkerHandle --> SessionServer[SessionServerActor
session_server_url] + SessionServer --> Backend[worker backend_url] +``` + +参与类: + +- `RolloutGenerateHandle` +- `InternalRolloutHttpEntry` +- `SessionWorkerSelector` +- `RolloutWorkerHandle` +- `SessionServerActor` + +### 4. 外部 router,注册 backend:`kind="http", http_entry="external", http_worker_url_source="backend"` + +XTuner 不转发请求,只通过 `ExternalRolloutHttpEntry` 复用 routedapiproxy 注册逻辑,将每个 worker 的 `backend_url` 注册到外部 router。agentloop/sandbox 访问外部 router,由外部 router 转发到 worker backend。 + +```mermaid +flowchart LR + Entry[ExternalRolloutHttpEntry] --> WorkerHandle[RolloutWorkerHandle] + WorkerHandle --> Register[register backend_url
to external router] + + AgentLoop[AgentLoop or Sandbox] --> Handle[RolloutGenerateHandle
kind=http] + Handle --> ExternalURL[external router base_url] + ExternalURL --> ExternalRouter[External router] + ExternalRouter --> Backend[worker backend_url] +``` + +参与类: + +- `RolloutGenerateHandle` +- `ExternalRolloutHttpEntry` +- `RolloutWorkerHandle` +- 外部 router 服务 + +### 5. 外部 router,注册 SessionServer:`kind="http", http_entry="external", http_worker_url_source="session"` + +XTuner 通过 `ExternalRolloutHttpEntry` 复用 routedapiproxy 注册逻辑,将每个 worker 的 `session_server_url` 注册到外部 router。agentloop/sandbox 访问外部 router,由外部 router 转发到 SessionServer,再到 worker backend。该模式要求调用方或外部 router 保留/注入 `session_id`。 + +```mermaid +flowchart LR + Entry[ExternalRolloutHttpEntry] --> WorkerHandle[RolloutWorkerHandle] + WorkerHandle --> Register[register session_server_url
to external router] + + AgentLoop[AgentLoop or Sandbox] --> Handle[RolloutGenerateHandle
kind=http] + Handle --> ExternalURL[external router base_url] + ExternalURL --> ExternalRouter[External router] + ExternalRouter --> SessionServer[SessionServerActor
session_server_url] + SessionServer --> Backend[worker backend_url] +``` + +参与类: + +- `RolloutGenerateHandle` +- `ExternalRolloutHttpEntry` +- `RolloutWorkerHandle` +- `SessionServerActor` +- 外部 router 服务 + +## 非目标 + +本轮暂不处理: + +- 健康检测、失败恢复、自动重启 +- worker/router 的复杂容错状态同步 +- CI 和旧单测兼容修复 +- SGLang 和 vLLM 分支的逻辑正确性 +- 外部 router 的完整协议细节 +- gateway 内部的所有东西都可以忽略 + +## 当前判断 + +本轮重构应先完成纯净职责拆分: + +1. 删除健康检测和失败恢复。 +2. 将 controller/worker 的 generate 完全外移。 +3. 明确独立生成模块的接口和运行模式。 +4. 再评估 worker 侧生成模块采用普通类、Ray actor,还是普通类加重操作外包的混合方案。 diff --git a/tests/rl/test_rollout_utils.py b/tests/rl/test_rollout_utils.py index 7a4995acc0..4812f1ecb7 100644 --- a/tests/rl/test_rollout_utils.py +++ b/tests/rl/test_rollout_utils.py @@ -1,26 +1,26 @@ import ray import torch import threading -import time import unittest import os import tempfile +from pathlib import Path from types import SimpleNamespace from unittest.mock import patch -from xtuner.v1.data_proto.rl_data import Status, RolloutState, SampleParams +from xtuner.v1.data_proto.rl_data import Status, RolloutState from xtuner.v1.rl.rollout.worker import RolloutConfig -from xtuner.v1.rl.rollout.controller import RolloutController, WorkerInfo +from xtuner.v1.rl.rollout.controller import RolloutController +from xtuner.v1.rl.rollout.health_manager import RolloutHealthManager, RolloutWorkerRouteInfo +from xtuner.v1.rl.rollout.runtime import build_rollout_runtime from xtuner.v1.rl.rollout.utils import ( PartialRolloutHandler, - RolloutHealthChecker, SessionRouter, ) -from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers, asyncio_run +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers MODEL_PATH = os.environ.get("ROLLOUT_MODEL_PATH", "") RESOURCE_MAP = {"npu": "NPU", "cuda": "GPU"} -TEST_TEXT_MESSAGES=[{"role": "user", "content": "Hello!"}] class _FakeRemoteMethod: @@ -40,55 +40,53 @@ def __init__(self): self.shutdown = _FakeRemoteMethod("shutdown", self.call_log) -class TestRolloutHealthChecker(unittest.TestCase): - def _build_checker(self, workers_info): - config = SimpleNamespace(health_check_interval_seconds=10, health_check_failure_threshold=1) - return RolloutHealthChecker(config, workers_info) - - def test_shutdown_runs_when_offload_fails(self): - worker = _FakeWorker() - workers_info = {0: SimpleNamespace(actor=worker, url="http://worker-0", is_active=True)} - checker = self._build_checker(workers_info) - - async def unhealthy_worker(*args, **kwargs): - return False - - def ray_get(ref, timeout=None): - worker.call_log.append((ref, "get")) - if ref == "offload": - raise RuntimeError("offload failed") - return None +class _FakeGenerator: + def __init__(self): + self.call_log = [] + self.ping = _FakeRemoteMethod("ping", self.call_log) - with ( - patch("xtuner.v1.rl.rollout.utils.check_worker_health", side_effect=unhealthy_worker), - patch("xtuner.v1.rl.rollout.utils.ray.get", side_effect=ray_get), - ): - checker.run_once() - self.assertFalse(workers_info[0].is_active) - self.assertEqual( - worker.call_log, +class TestRolloutHealthManager(unittest.TestCase): + def _build_manager(self, worker): + config = SimpleNamespace( + health_check_interval_seconds=10, + health_check_failure_threshold=1, + worker_log_dir=Path("."), + ) + manager = RolloutHealthManager( + config, [ - ("offload", "remote"), - ("offload", "get"), - ("shutdown", "remote"), - ("shutdown", "get"), + RolloutWorkerRouteInfo( + rank=0, + actor=worker, + url="http://worker-0", + generator=_FakeGenerator(), + is_active=True, + ) ], ) + manager.stop() + return manager - def test_inactive_worker_is_not_cleaned_up_again(self): + def test_run_once_does_not_probe_or_deactivate_workers(self): worker = _FakeWorker() - workers_info = {0: SimpleNamespace(actor=worker, url="http://worker-0", is_active=False)} - checker = self._build_checker(workers_info) + manager = self._build_manager(worker) - with ( - patch("xtuner.v1.rl.rollout.utils.check_worker_health") as check_worker_health_mock, - patch("xtuner.v1.rl.rollout.utils.ray.get") as ray_get_mock, - ): - checker.run_once() + with patch("xtuner.v1.rl.rollout.health_manager.ray.get", side_effect=ray_get): + manager.run_once() + + self.assertTrue(manager.rank2info[0].is_active) + self.assertEqual(worker.call_log, []) + + def test_report_worker_failure_is_registry_only_noop(self): + worker = _FakeWorker() + manager = self._build_manager(worker) + + with patch("xtuner.v1.rl.rollout.health_manager.ray.get") as ray_get_mock: + manager.report_worker_failure(0, "request failed") - check_worker_health_mock.assert_not_called() ray_get_mock.assert_not_called() + self.assertTrue(manager.rank2info[0].is_active) self.assertEqual(worker.call_log, []) @@ -179,39 +177,36 @@ def init_rollout_controller(self): health_check_interval_seconds=10, health_check_failure_threshold=1, ) - controller = RolloutController(rollout_cfg, pg) + runtime = build_rollout_runtime(rollout_cfg, pg) + controller = RolloutController(rollout_cfg, runtime) return controller @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_healthcheck_deactivate_and_recover(self): + def test_registry_does_not_deactivate_or_recover_workers(self): controller = self.init_rollout_controller() ranks = list(controller.rank2info.keys()) rank0 = ranks[0] actor0 = controller.rank2info[rank0].actor ray.get(actor0.shutdown.remote()) - time.sleep(3) # wait for the actor to be fully killed health_before_recover = ray.get(actor0.check_health.remote()) url = controller.rank2info[rank0].url self.assertFalse(health_before_recover) - controller.health_checker.run_once() + ray.get(controller.health_manager.run_once.remote()) - self.assertFalse(controller.rank2info[rank0].is_active) - rollout_state = RolloutState( - message=TEST_TEXT_MESSAGES, - sample_params=SampleParams(return_token_ids=True), - ) - out = asyncio_run(controller.generate(rollout_state)) - self.assertEqual(out.status, Status.FAILED) + ready, details = ray.get(controller.health_manager.get_ready_status.remote()) + self.assertTrue(details["worker_routes"][rank0]["is_active"]) + self.assertTrue(ready) controller.recover_failed_workers() - self.assertTrue(controller.rank2info[rank0].is_active) - self.assertEqual(url, controller.rank2info[rank0].url) + ready, details = ray.get(controller.health_manager.get_ready_status.remote()) + self.assertTrue(details["worker_routes"][rank0]["is_active"]) + self.assertTrue(ready) + route_infos = ray.get(controller.health_manager.get_worker_route_infos.remote()) + self.assertEqual(url, route_infos[0].url) health_after_recover = ray.get(actor0.check_health.remote()) - self.assertTrue(health_after_recover) - out = asyncio_run(controller.generate(rollout_state)) - self.assertNotEqual(out.status, Status.FAILED) + self.assertFalse(health_after_recover) if __name__ == "__main__": diff --git a/tests/rl/test_rollout_worker.py b/tests/rl/test_rollout_worker.py index 5057e09715..ab07822d4b 100644 --- a/tests/rl/test_rollout_worker.py +++ b/tests/rl/test_rollout_worker.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch from xtuner.v1.data_proto.rl_data import Status +from xtuner.v1.rl.rollout.generation import RolloutGenerationService from xtuner.v1.rl.rollout.lmdeploy import LMDeployWorker from xtuner.v1.rl.rollout.sglang import SGLangWorker from xtuner.v1.rl.rollout.worker import RolloutWorker @@ -37,19 +38,22 @@ def test_continue_generation_clears_abort_flag(self): worker._make_request.assert_called_once_with("continue_generation") -class TestRolloutWorker(unittest.IsolatedAsyncioTestCase): +class TestRolloutGenerationService(unittest.IsolatedAsyncioTestCase): async def test_generate_returns_aborted_when_abort_flag_is_set(self): - worker = RolloutWorker.__new__(RolloutWorker) - worker.receive_abort_request = threading.Event() - worker.receive_abort_request.set() + service = RolloutGenerationService.__new__(RolloutGenerationService) + service.receive_abort_request = threading.Event() + service.receive_abort_request.set() rollout_state = MagicMock() - result = await worker.generate(rollout_state) + result = await service.generate(rollout_state) self.assertIs(result, rollout_state) self.assertEqual(rollout_state.finish_reason, "abort") self.assertEqual(rollout_state.status, Status.ABORTED) + +class TestRolloutWorker(unittest.IsolatedAsyncioTestCase): + async def test_pause_generation_sets_abort_flag(self): worker = RolloutWorker.__new__(RolloutWorker) worker.receive_abort_request = threading.Event() @@ -115,9 +119,9 @@ async def test_lmdeploy_cleanup_after_pause_skips_without_routed_experts(self): get_actor.assert_not_called() async def test_safe_post_request_returns_aborted_on_cancellation(self): - worker = RolloutWorker.__new__(RolloutWorker) - worker.receive_abort_request = threading.Event() - worker.logger = MagicMock() + service = RolloutGenerationService.__new__(RolloutGenerationService) + service.receive_abort_request = threading.Event() + service.logger = MagicMock() send_started = asyncio.Event() send_cancelled = asyncio.Event() @@ -133,23 +137,23 @@ async def send(self, req): send_cancelled.set() raise - worker.client = _Client() + service.client = _Client() - task = asyncio.create_task(worker._safe_post_request("http://test", headers={}, payload={"input_ids": [1]})) + task = asyncio.create_task(service._safe_post_request("http://test", headers={}, payload={"input_ids": [1]})) await send_started.wait() task.cancel() result = await task self.assertEqual(result.error_type, HttpRequestErrorType.REQUEST_ABORTED) - self.assertTrue(worker.receive_abort_request.is_set()) + self.assertTrue(service.receive_abort_request.is_set()) self.assertTrue(send_cancelled.is_set()) async def test_safe_post_request_cancels_inflight_request_after_abort_timeout(self): - worker = RolloutWorker.__new__(RolloutWorker) - worker.receive_abort_request = threading.Event() - worker.abort_timeout = 0.01 - worker.logger = MagicMock() + service = RolloutGenerationService.__new__(RolloutGenerationService) + service.receive_abort_request = threading.Event() + service.abort_timeout = 0.01 + service.logger = MagicMock() send_started = asyncio.Event() send_cancelled = asyncio.Event() @@ -165,11 +169,11 @@ async def send(self, req): send_cancelled.set() raise - worker.client = _Client() + service.client = _Client() - task = asyncio.create_task(worker._safe_post_request("http://test", headers={}, payload={"input_ids": [1]})) + task = asyncio.create_task(service._safe_post_request("http://test", headers={}, payload={"input_ids": [1]})) await send_started.wait() - worker.receive_abort_request.set() + service.receive_abort_request.set() result = await task @@ -177,10 +181,10 @@ async def send(self, req): self.assertTrue(send_cancelled.is_set()) async def test_safe_post_request_keeps_abort_response_within_timeout(self): - worker = RolloutWorker.__new__(RolloutWorker) - worker.receive_abort_request = threading.Event() - worker.abort_timeout = 1.0 - worker.logger = MagicMock() + service = RolloutGenerationService.__new__(RolloutGenerationService) + service.receive_abort_request = threading.Event() + service.abort_timeout = 1.0 + service.logger = MagicMock() send_started = asyncio.Event() finish_send = asyncio.Event() @@ -199,11 +203,11 @@ async def send(self, req): await finish_send.wait() return response - worker.client = _Client() + service.client = _Client() - task = asyncio.create_task(worker._safe_post_request("http://test", headers={}, payload={"input_ids": [1]})) + task = asyncio.create_task(service._safe_post_request("http://test", headers={}, payload={"input_ids": [1]})) await send_started.wait() - worker.receive_abort_request.set() + service.receive_abort_request.set() finish_send.set() result = await task diff --git a/xtuner/v1/rl/agent_loop/agent_loop.py b/xtuner/v1/rl/agent_loop/agent_loop.py index 4ae6000e82..06ec1f8b8a 100644 --- a/xtuner/v1/rl/agent_loop/agent_loop.py +++ b/xtuner/v1/rl/agent_loop/agent_loop.py @@ -10,7 +10,7 @@ from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams from xtuner.v1.rl.judger import Judger -from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.rollout import RolloutController, RolloutGenerateHandle, RolloutGenerateHandleConfig from xtuner.v1.rl.utils import ( CPUActorLauncher, CPUResourcesConfig, @@ -27,10 +27,21 @@ class AgentLoopConfig(ABC, BaseModel): sample_params: SampleParams | None = None cpu_resources: CPUResourcesConfig | None = None - def build(self, rollout_controller, judger: Judger | None = None, logger=None) -> AgentLoopSpec: + def build( + self, + rollout_controller=None, + rollout_generator: RolloutGenerateHandle | None = None, + judger: Judger | None = None, + logger=None, + ) -> AgentLoopSpec: + if rollout_generator is None: + if rollout_controller is None: + raise ValueError("Either rollout_controller or rollout_generator must be provided.") + rollout_generator = RolloutGenerateHandleConfig().build(rollout_controller) if self.cpu_resources is None: return self.build_local( rollout_controller=rollout_controller, + rollout_generator=rollout_generator, judger=judger, logger=logger, ) @@ -43,12 +54,14 @@ def build(self, rollout_controller, judger: Judger | None = None, logger=None) - if self.cpu_resources.num_workers > 1: return self._build_router( rollout_controller=rollout_controller, + rollout_generator=rollout_generator, cpu_resources=self.cpu_resources, judger=judger, logger=logger, ) return self._build_ray_actor( rollout_controller=rollout_controller, + rollout_generator=rollout_generator, cpu_resources=self.cpu_resources, judger=judger, logger=logger, @@ -58,6 +71,7 @@ def build(self, rollout_controller, judger: Judger | None = None, logger=None) - def build_local( self, rollout_controller, + rollout_generator: RolloutGenerateHandle | None = None, judger: Judger | None = None, logger=None, ) -> AgentLoop: ... @@ -65,6 +79,7 @@ def build_local( def _build_ray_actor( self, rollout_controller: RolloutController, + rollout_generator: RolloutGenerateHandle, cpu_resources: CPUResourcesConfig, pg: PlacementGroup | None = None, judger: Judger | None = None, @@ -76,6 +91,7 @@ def _build_ray_actor( AgentLoopActor, self, rollout_controller, + rollout_generator, judger, pg=pg, bundle_idx=0, @@ -88,6 +104,7 @@ def _build_ray_actor( def _build_ray_actors( self, rollout_controller: RolloutController, + rollout_generator: RolloutGenerateHandle, cpu_resources: CPUResourcesConfig, pg: PlacementGroup | None = None, judger: Judger | None = None, @@ -100,6 +117,7 @@ def _build_ray_actors( AgentLoopActor, self, rollout_controller, + rollout_generator, judger, pg=pg, start_bundle_idx=start_bundle_idx, @@ -113,6 +131,7 @@ def _build_ray_actors( def _build_router( self, rollout_controller: RolloutController, + rollout_generator: RolloutGenerateHandle, cpu_resources: CPUResourcesConfig, pg: PlacementGroup | None = None, judger: Judger | None = None, @@ -122,6 +141,7 @@ def _build_router( return RouterAgentLoop( workers=self._build_ray_actors( rollout_controller=rollout_controller, + rollout_generator=rollout_generator, cpu_resources=cpu_resources, pg=pg, judger=judger, @@ -136,12 +156,14 @@ class AgentLoop(ABC): def __init__( self, rollout_ctl: RolloutController | None, + rollout_generator: RolloutGenerateHandle | None, sample_params: SampleParams | None, hf_checkpoint: str, judger: Judger | None = None, logger=None, - ) -> None: + ) -> None: self.rollout_ctl = rollout_ctl + self.rollout_generator = rollout_generator self.hf_checkpoint = hf_checkpoint self.tokenizer = load_tokenizer(hf_checkpoint, trust_remote_code=True) self.processor = load_processor(hf_checkpoint, trust_remote_code=True) @@ -155,6 +177,24 @@ def __init__( @abstractmethod async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: ... + async def rollout_generate(self, rollout_state: RolloutState, *, enable_partial_rollout: bool = False) -> RolloutState: + if self.rollout_generator is None: + raise RuntimeError("AgentLoop requires rollout_generator; RolloutController.generate has been removed.") + if self.rollout_generator.kind == "local": + return await self.rollout_generator.require_local_generator().generate( + rollout_state, + enable_partial_rollout=enable_partial_rollout, + ) + return await self.rollout_generate_from_url( + rollout_state=rollout_state, + base_url=self.rollout_generator.require_base_url(), + ) + + async def rollout_generate_from_url(self, rollout_state: RolloutState, base_url: str) -> RolloutState: + raise NotImplementedError( + f"{type(self).__name__} does not implement URL rollout generation for endpoint at {base_url!r}." + ) + async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: pending_tasks = [] for state in rollout_state: @@ -221,11 +261,13 @@ def __init__( self, agent_loop_config: AgentLoopConfig, rollout_controller: RolloutController, + rollout_generator: RolloutGenerateHandle, judger: Judger | None = None, logger=None, ): self.agent_loop = agent_loop_config.build_local( rollout_controller=rollout_controller, + rollout_generator=rollout_generator, judger=judger, logger=logger, ) diff --git a/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py b/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py index feb1a7c9ce..38c868f6df 100644 --- a/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py +++ b/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py @@ -8,7 +8,7 @@ from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams from xtuner.v1.rl.agent_loop import AgentLoop, AgentLoopConfig from xtuner.v1.rl.judger import Judger -from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.rollout import RolloutController, RolloutGenerateHandle from xtuner.v1.utils import get_logger @@ -18,10 +18,17 @@ class GSM8KToolAgentLoopConfig(AgentLoopConfig): max_turns: int - def build_local(self, rollout_controller, judger: Judger | None = None, logger=None) -> "GSM8KToolAgentLoop": + def build_local( + self, + rollout_controller, + rollout_generator: RolloutGenerateHandle | None = None, + judger: Judger | None = None, + logger=None, + ) -> "GSM8KToolAgentLoop": return GSM8KToolAgentLoop( max_turns=self.max_turns, rollout_ctl=rollout_controller, + rollout_generator=rollout_generator, hf_checkpoint=self.hf_checkpoint, sample_params=self.sample_params, judger=judger, @@ -40,12 +47,17 @@ def __init__( self, max_turns: int, rollout_ctl: RolloutController, + rollout_generator: RolloutGenerateHandle | None, hf_checkpoint: str, sample_params: SampleParams, judger: Judger | None = None, ): super().__init__( - rollout_ctl=rollout_ctl, hf_checkpoint=hf_checkpoint, sample_params=sample_params, judger=judger + rollout_ctl=rollout_ctl, + rollout_generator=rollout_generator, + hf_checkpoint=hf_checkpoint, + sample_params=sample_params, + judger=judger, ) self.max_turns = max_turns self.tool_call_pattern = re.compile(r"\n*(.*?)", re.DOTALL) @@ -99,7 +111,10 @@ async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> Rollou rollout_state.sample_params = copy.deepcopy(base_sample_params) rollout_state.sample_params.max_tokens = remaining_max_tokens - rollout_state = await self.rollout_ctl.generate.remote(rollout_state) # type: ignore[attr-defined] + rollout_state = await self.rollout_generate( + rollout_state, + enable_partial_rollout=kwargs.get("enable_partial_rollout", False), + ) cur_turn += 1 response_ids = cast(list[int], rollout_state.response_ids) cur_turn_tokens.extend(response_ids) diff --git a/xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py b/xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py index a4db6e80c2..d503217717 100644 --- a/xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py +++ b/xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py @@ -11,7 +11,7 @@ from xtuner.v1.data_proto.rl_data import RolloutState, Status from xtuner.v1.rl.judger import Judger -from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.rollout import RolloutController, RolloutGenerateHandle from xtuner.v1.rl.utils import create_task from ..agent_loop import AgentLoop, AgentLoopConfig @@ -54,9 +54,16 @@ class AgentInSandboxLoopConfig(AgentLoopConfig): """ max_concurrent_samples: int | None = None - def build_local(self, rollout_controller: RolloutController | None = None, judger: Judger | None = None, logger=None) -> "AgentInSandboxLoop": + def build_local( + self, + rollout_controller: RolloutController | None = None, + rollout_generator: RolloutGenerateHandle | None = None, + judger: Judger | None = None, + logger=None, + ) -> "AgentInSandboxLoop": return AgentInSandboxLoop( rollout_ctl=rollout_controller, + rollout_generator=rollout_generator, hf_checkpoint=self.hf_checkpoint, judger=judger, logger=logger, @@ -68,12 +75,13 @@ class AgentInSandboxLoop(AgentLoop): def __init__( self, rollout_ctl: RolloutController | None = None, + rollout_generator: RolloutGenerateHandle | None = None, hf_checkpoint: str = None, judger: Judger | None = None, logger=None, max_concurrent_samples: int | None = None, ): - super().__init__(rollout_ctl, None, hf_checkpoint, judger, logger) + super().__init__(rollout_ctl, rollout_generator, None, hf_checkpoint, judger, logger) self.max_concurrent_samples = max_concurrent_samples self._sample_semaphore = asyncio.Semaphore(max_concurrent_samples) if max_concurrent_samples else None diff --git a/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py index 40edc2dca7..9e550607a8 100644 --- a/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py +++ b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py @@ -2,7 +2,7 @@ from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status from xtuner.v1.rl.judger import Judger -from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.rollout import RolloutController, RolloutGenerateHandle from xtuner.v1.rl.utils import create_task from .agent_loop import AgentLoop, AgentLoopConfig @@ -39,9 +39,16 @@ class SingleTurnAgentLoopConfig(AgentLoopConfig): enable_batch_judge: bool = False - def build_local(self, rollout_controller, judger: Judger | None = None, logger=None) -> "SingleTurnAgentLoop": + def build_local( + self, + rollout_controller, + rollout_generator: RolloutGenerateHandle | None = None, + judger: Judger | None = None, + logger=None, + ) -> "SingleTurnAgentLoop": return SingleTurnAgentLoop( rollout_ctl=rollout_controller, + rollout_generator=rollout_generator, sample_params=self.sample_params, hf_checkpoint=self.hf_checkpoint, judger=judger, @@ -54,13 +61,14 @@ class SingleTurnAgentLoop(AgentLoop): def __init__( self, rollout_ctl: RolloutController, + rollout_generator: RolloutGenerateHandle | None, sample_params: SampleParams, hf_checkpoint: str, judger: Judger | None = None, logger=None, enable_batch_judge: bool = False, ): - super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger) + super().__init__(rollout_ctl, rollout_generator, sample_params, hf_checkpoint, judger, logger) self.enable_batch_judge = enable_batch_judge async def generate_sample( @@ -72,7 +80,10 @@ async def generate_sample( rollout_state.tokens = rollout_state.prompt_ids # 推理引擎generate, 生成的结果会覆盖到 rollout_state.response_ids 上 - rollout_state = await self.rollout_ctl.generate.remote(rollout_state) # type: ignore[attr-defined] + rollout_state = await self.rollout_generate( + rollout_state, + enable_partial_rollout=kwargs.get("enable_partial_rollout", False), + ) # 非 COMPLETED 状态(如被截断、放弃等)直接早退,不触发打分 if rollout_state.status != Status.COMPLETED: return rollout_state diff --git a/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py b/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py index 7663ecc85e..54b7caeeab 100644 --- a/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py +++ b/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py @@ -13,7 +13,7 @@ from xtuner.v1.rl.agent_loop import AgentLoopConfig, AgentLoopSpec, get_agent_loop_rollout_ctl from xtuner.v1.rl.judger import ComposedJudgerConfig, JudgerConfig, build_judger from xtuner.v1.rl.replay_buffer import ReplayBuffer -from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.rollout import RolloutController, RolloutGenerateHandle from xtuner.v1.rl.utils import asyncio_run from xtuner.v1.utils import get_logger @@ -300,6 +300,7 @@ def build( replay_buffer: ReplayBuffer, logger=None, sync_weights_interval: int = 1, + rollout_generator: RolloutGenerateHandle | None = None, ) -> "AgentLoopManager": tasks = self.tasks if isinstance(self.tasks, list) else [self.tasks] if not tasks: @@ -314,6 +315,7 @@ def build( agent_loop = task_cfg.agent_loop_config.build( rollout_controller=rollout_controller, + rollout_generator=rollout_generator, judger=build_judger(task_cfg.judger_config) if task_cfg.judger_config is not None else None, logger=logger, ) diff --git a/xtuner/v1/rl/agent_loop_manager/producer.py b/xtuner/v1/rl/agent_loop_manager/producer.py index 35db08f27d..7dc34d5aa7 100644 --- a/xtuner/v1/rl/agent_loop_manager/producer.py +++ b/xtuner/v1/rl/agent_loop_manager/producer.py @@ -502,10 +502,6 @@ def build( sync_weights_interval: int = 1, rollout_controller: "Optional[RolloutControllerProxy]" = None, ) -> "AsyncProduceStrategy": - if rollout_controller is not None: - import ray - - ray.get(rollout_controller.set_enable_partial_rollout.remote(self.enable_partial_rollout)) return AsyncProduceStrategy( over_sample_threshold=self.over_sample_threshold, enable_partial_rollout=self.enable_partial_rollout, diff --git a/xtuner/v1/rl/gateway/backend/local_backend.py b/xtuner/v1/rl/gateway/backend/local_backend.py index ae8e773cb9..af22badef3 100644 --- a/xtuner/v1/rl/gateway/backend/local_backend.py +++ b/xtuner/v1/rl/gateway/backend/local_backend.py @@ -11,6 +11,7 @@ from xtuner.v1.data_proto.rl_data import RolloutState, RolloutToolCall, SampleParams, Status from xtuner.v1.rl.rollout.parser.factory import build_tool_call_parser from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.rollout.rollout_generator import LocalRolloutGenerator, LocalRolloutGeneratorConfig from ..adapters.base import coerce_content_to_text from ..adapters.trace import normalize_trace_payload @@ -44,6 +45,7 @@ def __init__( ): self._controller = controller self._config = rollout_config or self._resolve_rollout_config(controller) + self._generator: LocalRolloutGenerator = LocalRolloutGeneratorConfig().build(controller) if isinstance(tokenizer, str): tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) resolved_tokenizer = tokenizer @@ -57,12 +59,12 @@ def __init__( async def generate(self, request: CanonicalGenerateRequest) -> CanonicalGenerateResponse: rollout_state = self._canonical_request_to_rollout_state(request) - rollout_state = await self._controller.generate.remote(rollout_state) + rollout_state = await self._generator.generate(rollout_state) self._raise_for_failed_rollout(rollout_state, request_id=str(rollout_state.uid)) return self._rollout_state_to_canonical_response(rollout_state, request) async def health(self) -> BackendHealth: - ready, details = await self._controller.get_ready_status.remote() + ready, details = await self._controller.get_runtime_status.remote() return BackendHealth( ready=ready, status="ready" if ready else "unavailable", diff --git a/xtuner/v1/rl/rollout/__init__.py b/xtuner/v1/rl/rollout/__init__.py index dd73324c55..b669bd7afe 100644 --- a/xtuner/v1/rl/rollout/__init__.py +++ b/xtuner/v1/rl/rollout/__init__.py @@ -1,6 +1,21 @@ import os +from ._generation.external_http_entry import ExternalRolloutHttpEntry, ExternalRolloutHttpEntryConfig +from ._generation.internal_http_entry import ( + InternalRolloutHttpEntry, + InternalRolloutHttpEntryConfig, + build_internal_rollout_http_entry_app, + serve_internal_rollout_http_entry_in_thread, +) +from ._generation.session_worker_selector import RolloutWorkerHandle, SessionWorkerSelector from .controller import RolloutController +from .rollout_generator import ( + LocalRolloutGenerator, + LocalRolloutGeneratorConfig, + RolloutGenerateHandle, + RolloutGenerateHandleConfig, +) +from .rollout_worker_build import RolloutRuntime, RolloutWorkerBuilder, RolloutWorkerRuntime, build_rollout_runtime from .worker import RolloutWorker diff --git a/xtuner/v1/rl/rollout/_generation/__init__.py b/xtuner/v1/rl/rollout/_generation/__init__.py new file mode 100644 index 0000000000..0e345c2f2d --- /dev/null +++ b/xtuner/v1/rl/rollout/_generation/__init__.py @@ -0,0 +1 @@ +"""Internal rollout generation helpers.""" diff --git a/xtuner/v1/rl/rollout/_generation/external_http_entry.py b/xtuner/v1/rl/rollout/_generation/external_http_entry.py new file mode 100644 index 0000000000..2d5170c5cd --- /dev/null +++ b/xtuner/v1/rl/rollout/_generation/external_http_entry.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from xtuner.v1.rl.utils.misc import check_chat_completions, delete_from_routedapiproxy, register_to_routedapiproxy +from xtuner.v1.utils import get_logger + +from .session_worker_selector import RolloutWorkerHandle, RolloutWorkerUrlSource +from ..worker import RolloutConfig + + +@dataclass +class ExternalRolloutHttpEntryConfig: + """Control-plane config for an external routedapiproxy HTTP entry. + + This class does not proxy user requests. It registers rollout worker + generation URLs into an externally managed router, then AgentLoop uses + ``base_url`` as the HTTP generation entry. + """ + + # The external router URL used by AgentLoop for generation. + base_url: str + + # Which per-worker URL should be registered into the external router. + worker_url_source: RolloutWorkerUrlSource = "backend" + + # Existing routedapiproxy semantics: delete old model registration before + # registering the current worker URLs. + delete_existing: bool = True + + # Optional smoke checks. They are useful for debug rollout, but can be + # disabled when registration happens before the external router is ready. + check_worker_urls: bool = True + check_base_url: bool = True + + +class ExternalRolloutHttpEntry: + def __init__( + self, + worker_handles: list[RolloutWorkerHandle], + rollout_config: RolloutConfig, + config: ExternalRolloutHttpEntryConfig, + *, + log_dir: str | None = None, + ) -> None: + self.worker_handles = worker_handles + self.rollout_config = rollout_config + self.config = config + self.logger = get_logger(log_dir=log_dir, tag="ExternalRolloutHttpEntry") + self._registered_urls: list[str] = [] + + def start(self) -> None: + model_name = self.rollout_config.model_name + if self.config.delete_existing: + delete_from_routedapiproxy(model_name) + self.logger.info(f"Deleted existing routedapiproxy registrations for model {model_name}.") + + self.logger.info("Registering rollout worker URLs to routedapiproxy.") + for worker in sorted(self.worker_handles, key=lambda item: item.rank): + worker_url = worker.get_generate_url(self.config.worker_url_source) + register_to_routedapiproxy(model_name, worker_url) + self._registered_urls.append(worker_url) + self.logger.info(f"Registered rollout worker {worker.rank} to routedapiproxy: {worker_url}") + + if self.config.check_worker_urls and not check_chat_completions(worker_url, model_name): + raise RuntimeError(f"check chat completions failed for rollout worker URL {worker_url}") + + if self.config.check_base_url and not check_chat_completions(self.config.base_url, model_name): + raise RuntimeError(f"check chat completions failed for external router URL {self.config.base_url}") + + self.logger.info("Registered rollout worker URLs to routedapiproxy.") + + def stop(self) -> None: + self._registered_urls.clear() diff --git a/xtuner/v1/rl/rollout/_generation/internal_http_entry.py b/xtuner/v1/rl/rollout/_generation/internal_http_entry.py new file mode 100644 index 0000000000..73b02d2d38 --- /dev/null +++ b/xtuner/v1/rl/rollout/_generation/internal_http_entry.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +import json +import hashlib +import socket +import threading +from dataclasses import dataclass +from typing import Any +from uuid import uuid4 + +import httpx +import uvicorn +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import Response, StreamingResponse + +from xtuner.v1.utils import get_logger + +from .session_worker_selector import RolloutWorkerHandle, RolloutWorkerUrlSource, SessionWorkerSelector +from ..worker import RolloutConfig + + +@dataclass +class InternalRolloutHttpEntryConfig: + port: int + host: str = "0.0.0.0" + title: str = "XTuner Internal Rollout Router" + version: str = "0.1.0" + log_level: str = "warning" + request_timeout: float | None = None + stream_timeout: float | None = None + worker_url_source: RolloutWorkerUrlSource = "backend" + + +class InternalRolloutHttpEntry: + def __init__( + self, + worker_handles: list[RolloutWorkerHandle], + rollout_config: RolloutConfig, + config: InternalRolloutHttpEntryConfig, + ) -> None: + self.worker_selector = SessionWorkerSelector(worker_handles) + self.rollout_config = rollout_config + self.config = config + timeout = config.request_timeout or rollout_config.rollout_timeout + self.client = httpx.AsyncClient(timeout=timeout) + self.stream_timeout = config.stream_timeout or rollout_config.rollout_timeout + self.logger = get_logger(log_dir=rollout_config.worker_log_dir, tag="InternalRolloutHttpEntry") + + async def models(self) -> dict[str, Any]: + model_id = self.rollout_config.model_name or "xtuner-rollout" + return { + "object": "list", + "data": [ + { + "id": model_id, + "object": "model", + "created": 0, + "owned_by": "xtuner", + } + ], + } + + async def chat_completions(self, request: Request) -> Response: + payload = await request.json() + session_id = self._extract_session_id(payload, request) + worker = await self.worker_selector.select(session_id) + if worker is None: + raise HTTPException(status_code=503, detail={"error": "No active rollout worker available."}) + + if self.config.worker_url_source == "session": + payload.setdefault("session_id", session_id) + try: + worker_base_url = worker.get_generate_url(self.config.worker_url_source) + except (RuntimeError, ValueError) as exc: + raise HTTPException(status_code=503, detail={"error": str(exc)}) from exc + url = f"{worker_base_url.rstrip('/')}/v1/chat/completions" + headers = self._forward_headers(request) + if payload.get("stream") is True: + return await self._stream_chat_completions(url, payload, headers, worker) + return await self._post_chat_completions(url, payload, headers, worker) + + def _extract_session_id(self, payload: dict[str, Any], request: Request) -> int: + for header_name in ("x-session-uid", "x-session-id", "x-request-id"): + header_value = request.headers.get(header_name) + if header_value: + return self._stable_int(header_value) + + metadata = payload.get("metadata") + if isinstance(metadata, dict): + for key in ("session_uid", "session_id", "conversation_id", "thread_id"): + if key in metadata and metadata[key] is not None: + return self._stable_int(metadata[key]) + + for key in ("session_uid", "session_id"): + if key in payload and payload[key] is not None: + return self._stable_int(payload[key]) + + return uuid4().int + + def _stable_int(self, value: Any) -> int: + if isinstance(value, int): + return value + if isinstance(value, str): + try: + return int(value) + except ValueError: + return uuid4().int if not value else self._hash_to_int(value) + return self._hash_to_int(json.dumps(value, sort_keys=True, default=str)) + + def _hash_to_int(self, value: str) -> int: + return int.from_bytes(hashlib.sha256(value.encode("utf-8")).digest()[:16], byteorder="big") + + def _forward_headers(self, request: Request) -> dict[str, str]: + ignored = { + "host", + "content-length", + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + } + headers = {key: value for key, value in request.headers.items() if key.lower() not in ignored} + if "content-type" not in {key.lower() for key in headers}: + headers["content-type"] = "application/json" + return headers + + async def _post_chat_completions( + self, + url: str, + payload: dict[str, Any], + headers: dict[str, str], + worker: RolloutWorkerHandle, + ) -> Response: + try: + response = await self.client.post(url, json=payload, headers=headers) + except Exception as exc: + raise HTTPException(status_code=502, detail={"error": str(exc)}) from exc + + return Response( + content=response.content, + status_code=response.status_code, + media_type=response.headers.get("content-type", "application/json"), + ) + + async def _stream_chat_completions( + self, + url: str, + payload: dict[str, Any], + headers: dict[str, str], + worker: RolloutWorkerHandle, + ) -> StreamingResponse: + async def stream_response(): + try: + async with self.client.stream("POST", url, json=payload, headers=headers, timeout=self.stream_timeout) as response: + if response.status_code >= 500: + body = await response.aread() + yield body + return + async for chunk in response.aiter_bytes(): + yield chunk + except Exception as exc: + self.logger.error(f"Streaming chat completion failed for worker {worker.rank}: {exc}") + yield f'data: {{"error": {json.dumps(str(exc))}}}\n\n'.encode() + + return StreamingResponse(stream_response(), media_type="text/event-stream") + + +def build_internal_rollout_http_entry_app( + worker_handles: list[RolloutWorkerHandle], + rollout_config: RolloutConfig, + config: InternalRolloutHttpEntryConfig, +) -> FastAPI: + entry = InternalRolloutHttpEntry(worker_handles=worker_handles, rollout_config=rollout_config, config=config) + app = FastAPI(title=config.title, version=config.version) + app.state.internal_rollout_http_entry = entry + + @app.get("/v1/models") + async def models(): + return await entry.models() + + @app.post("/v1/chat/completions") + async def chat_completions(request: Request): + return await entry.chat_completions(request) + + return app + + +def serve_internal_rollout_http_entry(app: FastAPI, config: InternalRolloutHttpEntryConfig) -> None: + _ensure_port_available(config) + uvicorn.run(app, host=config.host, port=config.port, log_level=config.log_level) + + +def serve_internal_rollout_http_entry_in_thread( + app: FastAPI, config: InternalRolloutHttpEntryConfig +) -> threading.Thread: + thread = threading.Thread( + target=serve_internal_rollout_http_entry, + args=(app, config), + daemon=True, + name="internal-rollout-http-entry", + ) + thread.start() + return thread + + +def _ensure_port_available(config: InternalRolloutHttpEntryConfig) -> None: + host = "127.0.0.1" if config.host in ("", "0.0.0.0") else config.host + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(1.0) + if sock.connect_ex((host, config.port)) == 0: + raise OSError(f"Internal rollout HTTP entry port already in use: {config.host}:{config.port}") diff --git a/xtuner/v1/rl/rollout/_generation/rollout_worker_generator.py b/xtuner/v1/rl/rollout/_generation/rollout_worker_generator.py new file mode 100644 index 0000000000..5588f95eb0 --- /dev/null +++ b/xtuner/v1/rl/rollout/_generation/rollout_worker_generator.py @@ -0,0 +1,658 @@ +from __future__ import annotations + +import asyncio +import base64 +import copy +import json +import threading +import traceback +from typing import Any, cast + +import httpx +import numpy as np +import ray +from transformers import AutoConfig, AutoTokenizer + +from xtuner.v1.data_proto.rl_data import ( + RolloutState, + SampleParams, + Status, + reset_rollout_response, + update_status_from_finish_reason, +) +from xtuner.v1.rl.utils import cancel_and_drain, get_eos_token +from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger +from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult + +from ..utils import PartialRolloutHandler +from ..worker import RolloutConfig + + +LMDEPLOY_SHARED_STORE = "shared_store" +LMDEPLOY_SHARED_STORE_NAMESPACE = "lmdeploy" + + +class RolloutWorkerGenerator: + """Generator bound to one rollout worker backend URL.""" + + def __init__(self, config: RolloutConfig, rank: int, server_url: str) -> None: + self.config = config + self.rank = rank + self.server_url = server_url + self.backend = config.rollout_backend + self.logger = get_logger(log_dir=config.worker_log_dir, tag=f"RolloutWorkerGenerator-{rank}") + self.endpoints = self._build_endpoints() + tokenizer_path = config.tokenizer_path or config.model_path + self.tokenizer_path = tokenizer_path + self._tokenizer = None + self.model_name = config.model_name + self.enable_return_routed_experts = config.enable_return_routed_experts + self._partial_rollout_handler: PartialRolloutHandler | None = None + self.receive_abort_request = threading.Event() + self.abort_timeout = 10.0 + http_concurrency = config.rollout_max_batch_size_per_instance * config.allow_over_concurrency_ratio + limits = httpx.Limits(max_connections=http_concurrency, max_keepalive_connections=100) + self.client = httpx.AsyncClient(limits=limits, timeout=config.rollout_timeout) + eos_token = get_eos_token(config.model_path) + self.eos_token: list[int] = [eos_token] if isinstance(eos_token, int) else eos_token + self.lmdeploy_actor = None + self.routed_experts_num_hidden_layers = None + self.routed_experts_num_experts_per_tok = None + if self.backend == "sglang": + model_config = AutoConfig.from_pretrained(config.model_path, trust_remote_code=True) + text_config = getattr(model_config, "text_config", model_config) + self.routed_experts_num_hidden_layers = getattr(text_config, "num_hidden_layers", None) + self.routed_experts_num_experts_per_tok = getattr(text_config, "num_experts_per_tok", None) + + @property + def tokenizer(self): + if self._tokenizer is None: + self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, trust_remote_code=True) + return self._tokenizer + + @property + def partial_rollout_handler(self) -> PartialRolloutHandler: + if self._partial_rollout_handler is None: + self._partial_rollout_handler = PartialRolloutHandler() + return self._partial_rollout_handler + + def _build_endpoints(self) -> dict[str, str]: + if self.backend == "lmdeploy": + return { + "generate": "generate", + "v1/chat/completions": "v1/chat/completions", + "abort_request": "abort_request", + } + if self.backend == "sglang": + return { + "generate": "generate", + "v1/chat/completions": "v1/chat/completions", + "abort_request": "abort_request", + } + if self.backend == "vllm": + return { + "generate": "v1/chat/completions", + "v1/chat/completions": "v1/chat/completions", + "abort_request": "abort_request", + } + raise ValueError(f"Unsupported rollout backend: {self.backend}") + + def update_server_url(self, server_url: str) -> None: + self.server_url = server_url + self.receive_abort_request.clear() + + async def pause_generation(self) -> bool: + self.receive_abort_request.set() + return await self._send_abort_request() + + def continue_generation(self) -> None: + self.receive_abort_request.clear() + + async def _send_abort_request(self) -> bool: + endpoint = self.endpoints.get("abort_request", "abort_request") + url = f"{self.server_url}/{endpoint}" + try: + response = await self.client.post(url) + response.raise_for_status() + return True + except Exception as exc: + self.logger.warning(f"Failed to send abort request to {url}: {exc}") + return False + + async def _wait_abort_request(self) -> None: + while not self.receive_abort_request.is_set(): + await asyncio.sleep(1) + + async def generate(self, rollout_state: RolloutState, *, enable_partial_rollout: bool = False) -> RolloutState: + if self.receive_abort_request.is_set(): + rollout_state.finish_reason = "abort" + rollout_state.status = Status.ABORTED + return rollout_state + + uid = rollout_state.uid + sample_params: SampleParams = rollout_state.sample_params + max_tokens = sample_params.max_tokens + if sample_params.return_token_ids: + endpoint_url = f"{self.server_url}/{self.endpoints['generate']}" + else: + endpoint_url = f"{self.server_url}/{self.endpoints['v1/chat/completions']}" + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.config.api_key}", + } + + if enable_partial_rollout: + rollout_state = self.partial_rollout_handler.preprocess(rollout_state, max_tokens) + elif rollout_state.status == Status.ABORTED: + rollout_state = reset_rollout_response(rollout_state) + rollout_state.sample_params = rollout_state.sample_params.model_copy(update={"max_tokens": max_tokens}) + rollout_state.status = Status.INIT + + payload = self._get_request_payload(rollout_state) + max_retries = self.config.max_retry_per_sample + + if rollout_state.status == Status.COMPLETED: + self.logger.debug(f"Request {uid} is already marked as COMPLETED, skipping generation.") + return rollout_state + + input_ids = payload.get("input_ids", []) + max_tokens = payload.get("max_tokens") or payload.get("max_new_tokens") + sampling_params = payload.get("sampling_params") + if max_tokens is None and isinstance(sampling_params, dict): + max_tokens = sampling_params.get("max_tokens") or sampling_params.get("max_new_tokens") + max_tokens = cast(int | None, max_tokens) + last_id = input_ids[-1] if len(input_ids) > 0 else "None" + is_max_tokens_zero = max_tokens is not None and max_tokens <= 0 + is_eos_reached = len(input_ids) > 0 and input_ids[-1] in self.eos_token + if is_max_tokens_zero or is_eos_reached: + self.logger.debug( + f"No generation needed for request {uid}: max_tokens={max_tokens} or last input_id={last_id} is in eos_token." + ) + rollout_state.finish_reason = "stop" if is_eos_reached else "length" + rollout_state.status = Status.COMPLETED + return rollout_state + + for attempt in range(max_retries + 1): + is_last_attempt = attempt == max_retries + http_result = await self._safe_post_request(endpoint_url, headers=headers, payload=payload) + if http_result.response: + rollout_state = await self._safe_handle_response(rollout_state, http_result.response) + if rollout_state.status in [Status.COMPLETED, Status.ABORTED]: + return rollout_state + if is_last_attempt: + self.logger.warning( + f"Invalid rollout response for request {uid} after {max_retries} attempts, marking as FAILED." + ) + rollout_state.status = Status.FAILED + rollout_state.error_msg = f"Invalid rollout response after {max_retries} attempts." + return rollout_state + self.logger.warning( + f"Invalid rollout response for request {uid}, retrying {attempt + 1}/{max_retries}." + ) + await asyncio.sleep(0.1) + continue + + if http_result.error_type == HttpRequestErrorType.REQUEST_ABORTED: + rollout_state.finish_reason = "abort" + rollout_state.status = update_status_from_finish_reason("abort") + return rollout_state + + if http_result.is_client_error: + self.logger.warning( + f"rollout request {uid} to {http_result.url} was skipped due to client error {http_result.error_type} with {http_result.error_msg}" + ) + rollout_state.error_msg = ( + f"Client error {http_result.error_type} with message: {http_result.error_msg}" + ) + rollout_state.status = Status.FAILED + return rollout_state + + if http_result.is_server_error: + self.logger.warning( + f"rollout request {uid} to {http_result.url} failed due to server error {http_result.error_type} with {http_result.error_msg}" + ) + rollout_state.error_msg = ( + f"Server error {http_result.error_type} with message: {http_result.error_msg}" + ) + rollout_state.status = Status.FAILED + return rollout_state + + if http_result.is_retryable: + if is_last_attempt: + self.logger.warning( + f"rollout request {uid} to {http_result.url} failed after {max_retries} attempts due to retryable error {http_result.error_type} with {http_result.error_msg}" + ) + rollout_state.error_msg = f"Request failed after {max_retries} attempts due to retryable error {http_result.error_type} with message: {http_result.error_msg}" + rollout_state.status = Status.FAILED + return rollout_state + self.logger.warning( + f"rollout request {uid} to {http_result.url} failed due to retryable error {http_result.error_type} with {http_result.error_msg}, retrying {attempt + 1}/{max_retries}." + ) + await asyncio.sleep(0.1) + continue + + if http_result.is_unknown_error: + raise RuntimeError( + f"Unexpected error during rollout request {uid} to {http_result.url}: {http_result.exception}" + ) + return rollout_state + + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + send_task = None + abort_task = None + try: + if self.receive_abort_request.is_set(): + self.logger.debug(f"Request to {url} was cancelled before sending due to an abort signal.") + return HttpRequestResult(error_type=HttpRequestErrorType.REQUEST_ABORTED, url=url, payload=payload) + req = self.client.build_request("POST", url, headers=headers, json=payload) + send_task = asyncio.create_task(self.client.send(req)) + abort_task = asyncio.create_task(self._wait_abort_request()) + done, _ = await asyncio.wait({send_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) + if send_task in done: + response = await send_task + else: + try: + response = await asyncio.wait_for(asyncio.shield(send_task), timeout=self.abort_timeout) + except asyncio.TimeoutError: + self.logger.debug( + f"Request to {url} did not return within {self.abort_timeout:.2f}s after abort signal." + ) + await cancel_and_drain([send_task]) + return HttpRequestResult(error_type=HttpRequestErrorType.REQUEST_ABORTED, url=url, payload=payload) + response.raise_for_status() + return HttpRequestResult(response=response) + except asyncio.CancelledError: + self.logger.debug(f"Request to {url} was cancelled while waiting for the response.") + await cancel_and_drain([send_task, abort_task]) + self.receive_abort_request.set() + return HttpRequestResult(error_type=HttpRequestErrorType.REQUEST_ABORTED, url=url, payload=payload) + except Exception as exc: + error_type = HttpRequestErrorType.from_exception(exc) + return HttpRequestResult(error_type=error_type, exception=exc, url=url, payload=payload) + finally: + await cancel_and_drain([abort_task]) + + def _get_request_payload(self, rollout_state: RolloutState) -> dict[str, Any]: + if self.backend == "lmdeploy": + return self._get_lmdeploy_request_payload(rollout_state) + if self.backend == "sglang": + return self._get_sglang_request_payload(rollout_state) + if self.backend == "vllm": + return self._get_vllm_request_payload(rollout_state) + raise ValueError(f"Unsupported rollout backend: {self.backend}") + + def _get_lmdeploy_request_payload(self, rollout_state: RolloutState) -> dict[str, Any]: + sample_params = rollout_state.sample_params + optional_fields: dict[str, object] = {} + if rollout_state.tools is not None: + optional_fields["tools"] = rollout_state.tools + if rollout_state.tool_choice is not None: + optional_fields["tool_choice"] = rollout_state.tool_choice + + if sample_params.return_token_ids: + payload: dict[str, Any] = {"model": self.model_name, **optional_fields} + if "image_data" in rollout_state.extra_fields: + assert rollout_state.tokens is not None, "input_tokens is required when image_data is provided." + payload["image_data"] = rollout_state.extra_fields["image_data"] + if rollout_state.tokens is not None: + payload["input_ids"] = rollout_state.tokens + else: + text_prompt = self.tokenizer.apply_chat_template( + rollout_state.message, tokenize=False, add_generation_prompt=True + ) + payload["input_ids"] = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"] + sample_params.return_routed_experts = True if self.enable_return_routed_experts else False + payload.update(sample_params.model_dump(exclude_none=True)) + return payload + + payload = {"model": self.model_name, "messages": rollout_state.message, **optional_fields} + lmdeploy_sample_params: dict[str, Any] = { + "temperature": sample_params.temperature, + "top_p": sample_params.top_p, + "n": sample_params.n, + "stream": sample_params.stream, + "max_tokens": sample_params.max_tokens, + "repetition_penalty": sample_params.repetition_penalty, + "top_k": sample_params.top_k, + "skip_special_tokens": sample_params.skip_special_tokens, + } + if sample_params.stops: + lmdeploy_sample_params["stop"] = sample_params.stops + if sample_params.min_tokens > 0: + lmdeploy_sample_params["min_new_tokens"] = sample_params.min_tokens + payload.update(lmdeploy_sample_params) + return payload + + def _get_sglang_request_payload(self, rollout_state: RolloutState) -> dict[str, Any]: + sample_params = rollout_state.sample_params + payload: dict[str, Any] = {"model": self.model_name} + if rollout_state.tools is not None: + payload["tools"] = rollout_state.tools + if rollout_state.tool_choice is not None: + payload["tool_choice"] = rollout_state.tool_choice + + sample_params_dict = sample_params.model_dump() + sglang_sample_params = self._transform_sglang_sample_params(sample_params_dict) + sglang_extra_params = self._transform_sglang_extra_params(sample_params_dict) + payload.update(sglang_extra_params) + if self.enable_return_routed_experts and not rollout_state.extra_fields.get("disable_routed_experts", False): + payload["return_routed_experts"] = True + + if sample_params.return_token_ids: + if "image_data" in rollout_state.extra_fields: + assert rollout_state.tokens is not None, "input_ids is required when image_data is provided." + payload["image_data"] = rollout_state.extra_fields["image_data"] + if rollout_state.tokens is not None: + payload["input_ids"] = rollout_state.tokens + else: + text_prompt = self.tokenizer.apply_chat_template( + rollout_state.message, tokenize=False, add_generation_prompt=True + ) + payload["input_ids"] = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"] + payload["sampling_params"] = sglang_sample_params + return payload + + payload["messages"] = rollout_state.message + payload.update(sglang_sample_params) + payload["max_tokens"] = sglang_sample_params["max_new_tokens"] + payload["min_tokens"] = sglang_sample_params["min_new_tokens"] + payload.pop("max_new_tokens", None) + payload.pop("min_new_tokens", None) + return payload + + def _get_vllm_request_payload(self, rollout_state: RolloutState) -> dict[str, Any]: + sample_params = rollout_state.sample_params + prompt = copy.deepcopy(rollout_state.message) + extra_fields = rollout_state.extra_fields + if "image_data" in extra_fields: + image_index = 0 + for message in prompt: + if not isinstance(message, dict) or message.get("role") != "user": + continue + new_content = [] + for content_part in message.get("content", []): + if not isinstance(content_part, dict): + new_content.append(content_part) + continue + if content_part.get("type") == "image_url": + content_part["image_url"]["url"] = f"file://{extra_fields['image_data'][image_index]}" + content_part["image_url"].pop("image_wh", None) + image_index += 1 + new_content.append(content_part) + message["content"] = new_content + assert image_index == len(extra_fields["image_data"]), ( + f"Expected {len(extra_fields['image_data'])} images, but processed {image_index}." + ) + + payload: dict[str, Any] = { + "model": self.config.model_path, + "messages": prompt, + "stream": sample_params.stream, + } + if rollout_state.tokens is not None: + payload["input_ids"] = rollout_state.tokens + elif "train_prompt_ids" in extra_fields: + payload["input_ids"] = extra_fields["train_prompt_ids"] + payload.update(self._transform_vllm_sample_params(sample_params.model_dump(), sample_params.model_dump())) + return payload + + def _transform_sglang_sample_params(self, sample_params: dict[str, Any]) -> dict[str, Any]: + if sample_params["top_p"] > 0: + sample_params["top_k"] = -1 + sglang_sample_params = { + "n": sample_params["n"], + "top_k": sample_params["top_k"], + "top_p": sample_params["top_p"], + "temperature": sample_params["temperature"], + "repetition_penalty": sample_params["repetition_penalty"], + "presence_penalty": sample_params["presence_penalty"], + "frequency_penalty": sample_params["frequency_penalty"], + "max_new_tokens": sample_params["max_tokens"], + "min_new_tokens": sample_params["min_tokens"], + "stop": sample_params["stops"], + "stop_token_ids": sample_params["stop_token_ids"], + "skip_special_tokens": sample_params["skip_special_tokens"], + } + sampling_seed = sample_params.get("sampling_seed") + if sampling_seed is None and XTUNER_DETERMINISTIC: + sampling_seed = self.config.random_seed + if sampling_seed is not None: + sglang_sample_params["sampling_seed"] = sampling_seed + return sglang_sample_params + + def _transform_sglang_extra_params(self, extra_params: dict[str, Any]) -> dict[str, Any]: + return { + "stream": extra_params["stream"], + "return_logprob": extra_params["return_logprob"], + "include_stop_str_in_output": extra_params["include_stop_str_in_output"], + "no_stop_trim": extra_params.get("no_stop_trim", False), + "spaces_between_special_tokens": extra_params.get("spaces_between_special_tokens", False), + } + + def _transform_vllm_sample_params( + self, sample_params: dict[str, Any], extra_params: dict[str, Any] | None = None + ) -> dict[str, Any]: + vllm_sample_params = copy.deepcopy(sample_params) + if extra_params: + vllm_sample_params.update(extra_params) + if "stops" in vllm_sample_params: + vllm_sample_params["stop"] = vllm_sample_params.pop("stops") + if "no_stop_trim" in vllm_sample_params: + vllm_sample_params["include_stop_str_in_output"] = vllm_sample_params.pop("no_stop_trim") + if "top_logprobs" in vllm_sample_params and "return_logprob" in vllm_sample_params: + vllm_sample_params["logprobs"] = vllm_sample_params.pop("return_logprob") + return vllm_sample_params + + async def _decode_routed_experts(self, routed_experts: Any) -> Any: + if self.backend == "lmdeploy": + if isinstance(routed_experts, str): + if self.lmdeploy_actor is None: + self.lmdeploy_actor = ray.get_actor(LMDEPLOY_SHARED_STORE, namespace=LMDEPLOY_SHARED_STORE_NAMESPACE) + routed_experts_data = await self.lmdeploy_actor.get.remote(routed_experts) + return ray.put(np.asarray(routed_experts_data)) + return np.asarray(routed_experts) + if self.backend == "sglang": + if isinstance(routed_experts, str): + routed_experts_flat = np.frombuffer(base64.b64decode(routed_experts), dtype=np.int32) + routed_experts_array = routed_experts_flat.reshape( + -1, + self.routed_experts_num_hidden_layers, + self.routed_experts_num_experts_per_tok, + ) + return routed_experts_array.copy() + return np.asarray(routed_experts) + if self.backend == "vllm": + if isinstance(routed_experts, str): + routed_experts = ray.cloudpickle.loads(base64.b64decode(routed_experts)) + return np.asarray(routed_experts) + return routed_experts + + async def _safe_handle_response(self, rollout_state: RolloutState, http_response: httpx.Response) -> RolloutState: + if self.backend == "vllm": + return await self._safe_handle_vllm_response(rollout_state, http_response) + return await self._safe_handle_openai_or_token_response(rollout_state, http_response) + + async def _safe_handle_openai_or_token_response( + self, rollout_state: RolloutState, http_response: httpx.Response + ) -> RolloutState: + uid = rollout_state.message_uid + sample_params = rollout_state.sample_params + response = http_response.json() + + if sample_params.return_token_ids: + response_ids: list[int] = [] + logprobs: list[float] = [] + routed_experts = None + returned_response = "" + try: + meta_info = response.get("meta_info") or {} + finish_reason_info = meta_info.get("finish_reason") or {} + finish_reason = finish_reason_info.get("type") + if finish_reason is None: + if self.receive_abort_request.is_set(): + rollout_state.finish_reason = "abort" + rollout_state.status = Status.ABORTED + else: + rollout_state.finish_reason = "error" + rollout_state.status = Status.FAILED + rollout_state.error_msg = "Missing finish_reason in response meta_info" + return rollout_state + returned_response = response.get("text", "") + if meta_info.get("output_token_logprobs") is not None: + response_ids = [item[1] for item in meta_info["output_token_logprobs"]] + logprobs = [item[0] for item in meta_info["output_token_logprobs"]] + else: + num_return_tokens = meta_info.get("completion_tokens", 0) + response_ids = response["output_ids"][-num_return_tokens:] if num_return_tokens > 0 else [] + + if self.enable_return_routed_experts: + assert "routed_experts" in meta_info, ( + "enable_return_routed_experts is True, but routed_experts is not in meta_info" + ) + routed_experts = meta_info["routed_experts"] + if routed_experts is not None: + routed_experts = await self._decode_routed_experts(routed_experts) + if not isinstance(routed_experts, ray.ObjectRef): + routed_experts = ray.put(routed_experts) + + rollout_status = update_status_from_finish_reason(finish_reason) + if rollout_status == Status.COMPLETED: + validation_errors = [] + if not response_ids: + validation_errors.append("empty response_ids") + if not returned_response: + validation_errors.append("empty response text") + if sample_params.return_logprob and not logprobs: + validation_errors.append("missing logprobs") + if self.enable_return_routed_experts and routed_experts is None: + validation_errors.append("missing routed_experts") + if validation_errors: + error_msg = f"Incomplete rollout data for msg {uid}: {', '.join(validation_errors)}" + self.logger.error(error_msg) + rollout_state.status = Status.FAILED + rollout_state.error_msg = error_msg + return rollout_state + elif rollout_status == Status.FAILED: + error_msg = f"Rollout failed for msg {uid} with finish_reason {finish_reason}" + self.logger.error(error_msg) + rollout_state.status = Status.FAILED + rollout_state.error_msg = error_msg + return rollout_state + + if enable_partial_rollout: + expect_len = meta_info.get("prompt_tokens", 0) + meta_info.get("completion_tokens", 0) - 1 + rollout_state = await self.partial_rollout_handler.postprocess( + rollout_state, + response=returned_response, + response_ids=response_ids, + logprobs=logprobs, + routed_experts=routed_experts, + finish_reason=finish_reason, + status=rollout_status, + routed_experts_expect_len=expect_len, + ) + else: + rollout_state.response = returned_response + rollout_state.response_ids = response_ids + rollout_state.logprobs = logprobs + rollout_state.routed_experts = routed_experts + rollout_state.finish_reason = finish_reason + rollout_state.status = rollout_status + return rollout_state + except Exception as exc: + raise self._response_error(exc, response, uid) + + try: + returned_response = response["choices"][0]["message"]["content"] + finish_reason = response["choices"][0]["finish_reason"] + rollout_status = update_status_from_finish_reason(finish_reason) + if rollout_status == Status.COMPLETED and not returned_response: + rollout_state.status = Status.FAILED + rollout_state.error_msg = "Empty response text" + return rollout_state + rollout_state.response = returned_response + rollout_state.finish_reason = finish_reason + rollout_state.status = rollout_status + return rollout_state + except Exception as exc: + raise self._response_error(exc, response, uid) + + async def _safe_handle_vllm_response(self, rollout_state: RolloutState, http_response) -> RolloutState: + uid = rollout_state.uid or rollout_state.message_uid + sample_params = rollout_state.sample_params + last_token_ids: list[int] = [] + last_logprobs: list[float] = [] + routed_experts = None + + response_json = http_response.json() + try: + response_choice = response_json["choices"][0] + if response_choice.get("logprobs") is not None: + last_token_ids = response_choice.get("token_ids", response_json.get("token_ids", [])) + last_logprobs = [ + item["logprob"] for item in response_choice["logprobs"].get("content", []) if "logprob" in item + ] + assert len(last_token_ids) == len(last_logprobs) + assert len(last_token_ids) <= sample_params.max_tokens, ( + f"Generation length exceeds limit: generated {len(last_token_ids)}, limit {sample_params.max_tokens}" + ) + + last_trajectory = response_choice["message"].get("content") or "" + finish_reason = response_choice.get("finish_reason") + if finish_reason == "abort" and not self.receive_abort_request.is_set(): + self.receive_abort_request.set() + self.logger.info(f"Setting receive_abort_request to True for rank {self.rank}") + + if self.enable_return_routed_experts: + routed_experts = response_choice.get("routed_experts", response_json.get("routed_experts")) + if routed_experts is not None: + routed_experts = await self._decode_routed_experts(routed_experts) + if not isinstance(routed_experts, ray.ObjectRef): + routed_experts = ray.put(routed_experts) + + rollout_status = update_status_from_finish_reason(finish_reason) + if rollout_status == Status.COMPLETED: + validation_errors = [] + if sample_params.return_token_ids and len(last_token_ids) == 0: + validation_errors.append("empty response_ids") + if sample_params.return_logprob and len(last_logprobs) == 0: + validation_errors.append("missing logprobs") + if not last_trajectory: + validation_errors.append("empty response text") + if self.enable_return_routed_experts and routed_experts is None: + validation_errors.append("missing routed_experts") + if validation_errors: + error_msg = f"Incomplete rollout data for request {uid}: {', '.join(validation_errors)}" + self.logger.error(f"{error_msg}. Raw response: {response_json}") + rollout_state.status = Status.FAILED + rollout_state.error_msg = error_msg + return rollout_state + + rollout_state.response = last_trajectory + rollout_state.response_ids = last_token_ids if len(last_token_ids) > 0 else None + rollout_state.logprobs = last_logprobs if len(last_logprobs) > 0 else None + rollout_state.routed_experts = routed_experts + rollout_state.finish_reason = finish_reason + rollout_state.status = rollout_status + return rollout_state + except Exception as exc: + raise self._response_error(exc, response_json, uid) + + def _response_error(self, exc: Exception, response: dict[str, Any], uid: Any) -> RuntimeError: + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + if isinstance(exc, KeyError): + return RuntimeError(f"Missing expected key {exc} in response {response_for_log} for {uid}") + if isinstance(exc, IndexError): + return RuntimeError(f"Index error {exc} while processing response {response_for_log} for {uid}") + if isinstance(exc, AssertionError): + return RuntimeError(f"AssertionError: {exc} when processing response {response_for_log} for {uid}") + if isinstance(exc, json.JSONDecodeError): + return RuntimeError(f"JSONDecodeError: {exc} when processing response {response} for {uid}") + if isinstance(exc, TypeError): + return RuntimeError(f"TypeError: {exc} when processing response {response_for_log} for {uid}") + return RuntimeError( + f"Unexpected error: {exc} when processing response {response_for_log} for {uid}\nTraceback: {traceback.format_exc()}" + ) + +RayRolloutWorkerGenerator = ray.remote(RolloutWorkerGenerator) diff --git a/xtuner/v1/rl/rollout/_generation/session_worker_selector.py b/xtuner/v1/rl/rollout/_generation/session_worker_selector.py new file mode 100644 index 0000000000..e563637305 --- /dev/null +++ b/xtuner/v1/rl/rollout/_generation/session_worker_selector.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import asyncio +import time +from collections import OrderedDict +from dataclasses import dataclass +from itertools import cycle +from typing import TYPE_CHECKING, Any, Literal + + +if TYPE_CHECKING: + from .rollout_worker_generator import RolloutWorkerGenerator + + +RolloutWorkerUrlSource = Literal["backend", "session"] + + +@dataclass +class RolloutWorkerHandle: + """Runtime handles for one active rollout worker. + + This object intentionally groups the control-plane worker actor with the + data-plane generation entries selected from that worker. + """ + + # Stable active-worker rank. Used for session affinity, logging, and + # registering worker entries into HTTP routers. + rank: int + + # Original RolloutWorker Ray actor. Controller control-plane operations + # such as offload/onload/pause/shutdown are sent to this actor. + worker_actor: Any + + # Raw backend server URL exposed by lmdeploy/vLLM/SGLang. + backend_url: str + + # Per-active-worker generation actor used by local Python/Ray generation. + generator_actor: RolloutWorkerGenerator + + # Optional SessionServer proxy URL wrapping backend_url. It is only used + # when HTTP generation is configured to require session/cache/trace logic. + session_server_url: str | None = None + + def require_session_server_url(self) -> str: + if self.session_server_url is None: + raise RuntimeError(f"Rollout worker {self.rank} does not have a SessionServer URL.") + return self.session_server_url + + def get_generate_url(self, source: RolloutWorkerUrlSource) -> str: + if source == "backend": + return self.backend_url + if source == "session": + return self.require_session_server_url() + raise ValueError(f"Unsupported rollout worker URL source: {source!r}") + + +class SessionWorkerSelector: + def __init__( + self, + workers: list[RolloutWorkerHandle], + *, + max_sessions: int = 10000, + max_idle_seconds: float | None = 3600.0, + ) -> None: + self._workers = {worker.rank: worker for worker in workers} + self._rank_cycle = cycle(self._workers) + self._max_sessions = max_sessions + self._max_idle_seconds = max_idle_seconds + self._sessions: OrderedDict[int, tuple[int, float]] = OrderedDict() + self._lock = asyncio.Lock() + + async def select(self, session_id: int) -> RolloutWorkerHandle | None: + async with self._lock: + self._evict_expired() + if session_id in self._sessions: + rank, _ = self._sessions.pop(session_id) + worker = self._workers.get(rank) + if worker is not None: + self._sessions[session_id] = (rank, self._now()) + return worker + + worker = self._next_worker() + if worker is None: + return None + self._sessions[session_id] = (worker.rank, self._now()) + self._evict_to_capacity() + return worker + + def _next_worker(self) -> RolloutWorkerHandle | None: + if not self._workers: + return None + return self._workers[next(self._rank_cycle)] + + def _evict_expired(self) -> None: + if self._max_idle_seconds is None: + return + now = self._now() + expired = [] + for session_id, (_, last_used_at) in self._sessions.items(): + if now - last_used_at > self._max_idle_seconds: + expired.append(session_id) + else: + break + for session_id in expired: + self._sessions.pop(session_id, None) + + def _evict_to_capacity(self) -> None: + while len(self._sessions) > self._max_sessions: + self._sessions.popitem(last=False) + + def _now(self) -> float: + return time.time() diff --git a/xtuner/v1/rl/rollout/controller.py b/xtuner/v1/rl/rollout/controller.py index 86137e9167..0ce82d7b7d 100644 --- a/xtuner/v1/rl/rollout/controller.py +++ b/xtuner/v1/rl/rollout/controller.py @@ -1,40 +1,21 @@ -import asyncio -import math -import os -import threading -from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeAlias, TypedDict -from uuid import uuid4 import ray from ray.actor import ActorProxy -from ray.util.placement_group import PlacementGroup -from transformers import AutoTokenizer -from xtuner.v1.data_proto.rl_data import RolloutState, Status -from xtuner.v1.rl.utils import AutoAcceleratorWorkers -from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger +from xtuner.v1.utils import get_logger -from .parser.factory import build_reasoning_parser, build_tool_call_parser -from .parser.reasoning_parser import ReasoningParser -from .parser.tool_parser import ToolCallParser -from .utils import ROLLOUT_RAY_GET_TIMEOUT, RolloutHealthChecker, SessionRouter -from .worker import ROLLOUT_CONCURRENCY_GROUP_GENERATE, RolloutConfig, RolloutWorker +from ._generation.session_worker_selector import RolloutWorkerHandle +from .rollout_worker_build import RolloutRuntime +from .utils import ROLLOUT_RAY_GET_TIMEOUT +from .worker import RolloutConfig if TYPE_CHECKING: from xtuner.v1.rl.gateway.config import GatewayConfig - - -@dataclass -class WorkerInfo: - """A data class to hold all state information for a single worker.""" - - actor: RolloutWorker - url: str - session_url: str | None = None - is_active: bool = True + from xtuner.v1.rl.rollout._generation.external_http_entry import ExternalRolloutHttpEntryConfig + from xtuner.v1.rl.rollout._generation.internal_http_entry import InternalRolloutHttpEntryConfig class RolloutWorkerMetadata(TypedDict): @@ -59,24 +40,28 @@ class RolloutWorkerMetadata(TypedDict): # 包括:并行策略(TP/EP)、超时设置、后端类型(LMDeploy/vLLM/SGLang)等 rollout_config: RolloutConfig - # 每个 worker 服务器 URL 的当前活跃状态 - # 键:服务器 URL 字符串 - # 值:布尔值,True 表示该 worker 处于活跃状态,False 表示已失效或停用 + # Worker server URL map used by trainer-side control paths. First-version + # rollout refactor assumes all workers remain available. worker_server_urls_status: Dict[str, bool] # Gateway HTTP server URL (e.g. "http://1.2.3.4:8080"). # Set after start_gateway() is called; None if the gateway has not been started. api_server_url: Optional[str] - # worker rank -> SessionServer proxy URL. These are the externally - # registered URLs for routedapiproxy; server_url_dict keeps the original - # worker URLs for trainer-side weight update / backend control paths. + # Internal rollout HTTP entry URL. Set after start_internal_http_entry() is called. + internal_http_entry_url: Optional[str] + + # worker rank -> SessionServer proxy URL. server_url_dict keeps the + # original worker URLs for trainer-side weight update / backend control paths. worker_session_url_dict: Dict[int, str] - # SessionServer URL -> active status. This mirrors worker_server_urls_status - # but is keyed by the proxy URL that external traffic uses. + # SessionServer URL -> availability status. First-version rollout refactor + # does not deactivate workers, so these values are always True. worker_session_urls_status: Dict[str, bool] + # Runtime worker handles consumed by local/http/external generation paths. + worker_handles: List[RolloutWorkerHandle] + # Keep this as a Ray actor because Ray AgentLoop actors need a shared, cross-process handle to the same controller # state; passing a normal Python object would serialize a separate copy into each actor. @@ -87,37 +72,30 @@ class RolloutController: def __init__( self, infer_config: RolloutConfig, - placement_group: PlacementGroup, + runtime: RolloutRuntime, ): """Initialize the RolloutController. Args: infer_config (RolloutConfig): The configuration for the rollout. - placement_group (PlacementGroup): The placement group for the - RolloutWorker actors. + runtime: Pre-built rollout runtime, including worker handles and + backend server URLs. """ self.config = infer_config self.num_gpus_per_engine = self.config.num_gpus_per_engine self.logger = get_logger(log_dir=infer_config.worker_log_dir, tag="RolloutController") - self.engine_rank_mesh_array: List[List[int]] = [] - self.worker_server_urls_map: dict[int, str] = {} - self.rank2info: dict[int, WorkerInfo] = {} - self.engine_rank_mesh_array, self.worker_server_urls_map, self.rank2info = self._init_workers(placement_group) - self.num_active_workers = len(self.rank2info) - self.worker_info_lock = threading.RLock() + self.engine_rank_mesh_array = runtime.engine_rank_mesh_array + self.worker_server_urls_map = runtime.worker_server_urls_map + self.rank2worker = runtime.rank2worker + self.worker_handles = runtime.worker_handles + self.num_rollout_workers = len(self.rank2worker) # The timeout for the environment to wait for the rollout controller's response. # This should be longer than the controller's internal timeout (`rollout_timeout`) # to account for potential queuing delays and other overheads. self.timeout_multiplier = 2.0 - self.router = SessionRouter(self.rank2info, worker_infos_lock=self.worker_info_lock) - self.health_checker = RolloutHealthChecker( - config=self.config, - workers_info=self.rank2info, - worker_infos_lock=self.worker_info_lock, - ) - self.health_checker.start() - self._tool_call_parser, self._reasoning_parser = self._build_output_parsers() self._gateway_url: str | None = None + self._internal_http_entry_url: str | None = None + self._external_http_entries = [] def start_gateway(self, config: "GatewayConfig") -> str | None: """Start the gateway HTTP server in a daemon thread and return its URL. @@ -152,6 +130,36 @@ def start_gateway(self, config: "GatewayConfig") -> str | None: self.logger.info(f"Gateway server started at {url}, capture_folder: {config.capture_folder}") return url + def start_internal_http_entry(self, config: "InternalRolloutHttpEntryConfig") -> str: + from xtuner.v1.rl.rollout._generation.internal_http_entry import ( + build_internal_rollout_http_entry_app, + serve_internal_rollout_http_entry_in_thread, + ) + + app = build_internal_rollout_http_entry_app( + worker_handles=self.worker_handles, + rollout_config=self.config, + config=config, + ) + serve_internal_rollout_http_entry_in_thread(app, config) + host = ray.util.get_node_ip_address() if config.host in ("", "0.0.0.0") else config.host + url = f"http://{host}:{config.port}" + self._internal_http_entry_url = url + self.logger.info(f"Internal rollout HTTP entry started at {url}") + return url + + def start_external_http_entry(self, config: "ExternalRolloutHttpEntryConfig") -> None: + from xtuner.v1.rl.rollout._generation.external_http_entry import ExternalRolloutHttpEntry + + entry = ExternalRolloutHttpEntry( + worker_handles=self.worker_handles, + rollout_config=self.config, + config=config, + log_dir=str(self.config.worker_log_dir), + ) + entry.start() + self._external_http_entries.append(entry) + def get_rollout_metadata(self) -> RolloutWorkerMetadata: """Get information about the current rollout setup. @@ -159,124 +167,67 @@ def get_rollout_metadata(self) -> RolloutWorkerMetadata: dict: A dictionary containing the engine mesh list, server URL dictionary, and the rollout configuration. """ - with self.worker_info_lock: - worker_server_urls_status = {info.url: info.is_active for info in self.rank2info.values()} - worker_session_url_dict = { - rank: info.session_url for rank, info in self.rank2info.items() if info.session_url is not None - } - worker_session_urls_status = { - info.session_url: info.is_active for info in self.rank2info.values() if info.session_url is not None - } + worker_server_urls_map = {worker.rank: worker.backend_url for worker in self.worker_handles} + worker_server_urls_status = {worker.backend_url: True for worker in self.worker_handles} + worker_session_url_dict = { + worker.rank: worker.session_server_url for worker in self.worker_handles if worker.session_server_url is not None + } + worker_session_urls_status = { + worker.session_server_url: True for worker in self.worker_handles if worker.session_server_url is not None + } rollout_metadata: RolloutWorkerMetadata = { "engine_rank_mesh_array": self.engine_rank_mesh_array, - "server_url_dict": self.worker_server_urls_map, + "server_url_dict": worker_server_urls_map, "rollout_config": self.config, "worker_server_urls_status": worker_server_urls_status, - "api_server_url": self._gateway_url, + "api_server_url": self._internal_http_entry_url or self._gateway_url, "worker_session_url_dict": worker_session_url_dict, "worker_session_urls_status": worker_session_urls_status, + "internal_http_entry_url": self._internal_http_entry_url, + "worker_handles": list(self.worker_handles), } return rollout_metadata - def _build_output_parsers(self) -> tuple[ToolCallParser | None, ReasoningParser | None]: - tool_call_parser = None - reasoning_parser = None - - if self.config.tool_call_parser != "none": - tool_call_parser = build_tool_call_parser(self.config.tool_call_parser) - - if self.config.reasoning_parser != "none": - tokenizer_path = self.config.tokenizer_path or self.config.model_path - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) - reasoning_parser = build_reasoning_parser(self.config.reasoning_parser, tokenizer) - - return tool_call_parser, reasoning_parser - - def get_ready_status(self) -> tuple[bool, dict[str, Any]]: - with self.worker_info_lock: - active_workers = sum(1 for info in self.rank2info.values() if info.is_active) - total_workers = len(self.rank2info) - return active_workers > 0, { - "active_workers": active_workers, - "total_workers": total_workers, + def get_worker_handles(self) -> list[RolloutWorkerHandle]: + return list(self.worker_handles) + + def get_runtime_status(self) -> tuple[bool, dict[str, Any]]: + return bool(self.worker_handles), { + "rollout_workers": len(self.worker_handles), + "total_workers": len(self.worker_handles), + "workers": { + worker.rank: { + "url": worker.backend_url, + "session_url": worker.session_server_url, + "has_rollout_worker_generator": True, + } + for worker in self.worker_handles + }, } - @ray.method(concurrency_group=ROLLOUT_CONCURRENCY_GROUP_GENERATE) - async def generate(self, rollout_state: RolloutState) -> RolloutState: - if XTUNER_DETERMINISTIC: - sample_params = rollout_state.sample_params.model_copy(deep=True) - sample_params.sampling_seed = self.config.random_seed + ( - (rollout_state.uid or 0) - (rollout_state.message_uid or 0) - ) - rollout_state.sample_params = sample_params - - session_id = rollout_state.session_uid if rollout_state.session_uid is not None else uuid4().int - worker = await self.router.get_worker(session_id) - if worker is None: - rollout_state.status = Status.FAILED - rollout_state.error_msg = "No active rollout worker available." - return rollout_state - - response_ref = worker.generate.remote(rollout_state=rollout_state) # type: ignore[attr-defined] - try: - response_rollout_state = await asyncio.wait_for( - response_ref, - timeout=self.config.rollout_timeout * self.timeout_multiplier, - ) - self._apply_output_parsers(response_rollout_state) - return response_rollout_state - except asyncio.TimeoutError: - self.logger.error(f"Rollout timeout for worker {worker}. Skipping sample.") - rollout_state.status = Status.FAILED - rollout_state.error_msg = ( - f"Rollout request timed out after {self.config.rollout_timeout * self.timeout_multiplier} seconds." - ) - return rollout_state - - def _apply_output_parsers(self, rollout_state: RolloutState) -> None: - """Apply tool-call and reasoning parsers to the rollout state in- - place.""" - if self._tool_call_parser is not None: - parsed = self._tool_call_parser.parse(rollout_state) - rollout_state.tool_calls = parsed.tool_calls - rollout_state.response = parsed.remaining_text or None - if self._reasoning_parser is not None: - parsed_reasoning = self._reasoning_parser.parse(rollout_state) - rollout_state.response = parsed_reasoning.remaining_text - if parsed_reasoning.reasoning_text: - rollout_state.extra_fields["reasoning_text"] = parsed_reasoning.reasoning_text - else: - rollout_state.extra_fields.pop("reasoning_text", None) - - def set_enable_partial_rollout(self, enable: bool) -> None: - """Propagate enable_partial_rollout flag to all active workers.""" - with self.worker_info_lock: - active_actors = [info.actor for info in self.rank2info.values() if info.is_active] - ray.get([actor.set_enable_partial_rollout.remote(enable) for actor in active_actors]) # type: ignore[attr-defined] - def pause_generation(self): - self.health_checker.pause() - self._broadcast_to_active_workers("pause_generation") + self._broadcast_to_rollout_worker_generators("pause_generation") + self._broadcast_to_workers("pause_generation") def cleanup_after_pause(self): - self._broadcast_to_active_workers("cleanup_after_pause") + self._broadcast_to_workers("cleanup_after_pause") def continue_generation(self): - self.health_checker.resume() - self._broadcast_to_active_workers("continue_generation") + self._broadcast_to_rollout_worker_generators("continue_generation") + self._broadcast_to_workers("continue_generation") def offload(self): - self._broadcast_to_active_workers("offload") + self._broadcast_to_workers("offload") def onload(self): - self._broadcast_to_active_workers("onload_weights") - self._broadcast_to_active_workers("onload_kvcache") + self._broadcast_to_workers("onload_weights") + self._broadcast_to_workers("onload_kvcache") def onload_weights(self): - self._broadcast_to_active_workers("onload_weights") + self._broadcast_to_workers("onload_weights") def onload_kvcache(self): - self._broadcast_to_active_workers("onload_kvcache") + self._broadcast_to_workers("onload_kvcache") def shutdown(self): """Shuts down all active rollout workers. @@ -284,80 +235,13 @@ def shutdown(self): Args: block (bool): Whether to block until the operation completes. """ - self.health_checker.stop() - self._broadcast_to_active_workers("shutdown") - - def recover_failed_workers(self): - """Recovers from worker failures by restarting failed workers and - reinitializing the rollout setup.""" - self.health_checker.pause() - with self.worker_info_lock: - failed_workers = [info for info in self.rank2info.values() if not info.is_active] - if not failed_workers: - self.logger.info("No failed workers detected during recovery.") - return - - self.logger.warning(f"Detected {len(failed_workers)} failed workers. Initiating recovery process.") - for worker in failed_workers: - if self._restart_failed_workers(worker.actor): - with self.worker_info_lock: - rank = self._get_rank_by_actor(worker.actor) - if rank is not None: - self.rank2info[rank].is_active = True - self.health_checker.resume() - - def _restart_failed_workers(self, worker: RolloutWorker) -> bool: - try: - dist_init_addr = ray.get(worker.init_dist_port.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - _, url = ray.get(worker.init.remote(dist_init_addr), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - _, session_url = ray.get(worker.get_session_server_info.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - is_healthy = ray.get(worker.check_health.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - - if is_healthy: - self.logger.info(f"Successfully restarted worker {worker} with URL {url}.") - with self.worker_info_lock: - rank = self._get_rank_by_actor(worker) - if rank is not None: - self.rank2info[rank].url = url - self.rank2info[rank].session_url = session_url - self.worker_server_urls_map[rank] = url - return True - else: - self.logger.error(f"Worker {worker} is still unhealthy after restart.") - return False - except Exception as e: - self.logger.error(f"Failed to restart worker: {e}") - return False - - def _update_dist_init_addr(self, nodes_per_engine, server_urls_per_engine, dist_init_addrs, tp_size): - """Update the distributed initialization addresses for workers. - - This is used to group workers that belong to the same inference engine. + for entry in self._external_http_entries: + entry.stop() + self._external_http_entries.clear() + self._broadcast_to_workers("shutdown") - Args: - nodes_per_engine (int): The number of nodes per inference engine. - server_urls_per_engine (int): The number of server urls per inference engine. - dist_init_addrs (list): The list of initial addresses. - tp_size (int): The tensor parallel size. - - Returns: - list: The updated list of distributed initialization addresses. - """ - # lmdeploy pytorch ep: server_urls_per_engine > 1 - # sglang cross node engine: nodes_per_engine > 1 - assert server_urls_per_engine == 1 or nodes_per_engine == 1 - if nodes_per_engine > 1: - index = list(range(0, self.num_active_workers + 1, tp_size)) + [self.num_active_workers] - for i in range(1, len(index)): - dist_init_addrs[index[i - 1] : index[i]] = [dist_init_addrs[index[i - 1]]] * (index[i] - index[i - 1]) - if server_urls_per_engine > 1: - activate_servers = len(dist_init_addrs) - for i in range(0, activate_servers, server_urls_per_engine): - dist_init_addrs[i : i + server_urls_per_engine] = [dist_init_addrs[i]] * server_urls_per_engine - return dist_init_addrs - - def _broadcast_to_active_workers(self, method_name: str): - """Helper function to call a method on all active workers. + def _broadcast_to_workers(self, method_name: str): + """Helper function to call a method on all rollout workers. Args: method_name (str): The name of the method to call. @@ -366,147 +250,17 @@ def _broadcast_to_active_workers(self, method_name: str): Returns: A list of futures if `block` is False, otherwise a list of results. """ - futures = [] - with self.worker_info_lock: - active_actors = [info.actor for info in self.rank2info.values() if info.is_active] - futures = [getattr(actor, method_name).remote() for actor in active_actors] + worker_actors = [worker.worker_actor for worker in self.worker_handles] + futures = [getattr(actor, method_name).remote() for actor in worker_actors] results = ray.get(futures, timeout=ROLLOUT_RAY_GET_TIMEOUT) return results - def _get_worker_cls(self): - if os.environ.get("XTUNER_USE_LMDEPLOY") == "1": - from .lmdeploy import LMDeployWorker - - worker_cls = LMDeployWorker - elif os.environ.get("XTUNER_USE_VLLM") == "1": - from .vllm import vLLMWorker - - worker_cls = vLLMWorker - elif os.environ.get("XTUNER_USE_SGLANG") == "1": - from .sglang import SGLangWorker - - worker_cls = SGLangWorker - else: - raise NotImplementedError( - "Rollout backend is not supported." - "Please set XTUNER_USE_LMDEPLOY or XTUNER_USE_VLLM" - " or XTUNER_USE_SGLANG environment variable." - ) - assert self.config.rollout_max_batch_size_per_instance is not None, ( - "rollout_max_batch_size_per_instance must be set before building RolloutWorker." - ) - worker_generate_max_concurrency = max( - 1000, # Ray async actor default max_concurrency. - math.ceil(self.config.rollout_max_batch_size_per_instance * self.config.allow_over_concurrency_ratio), - ) - return ray.remote( - concurrency_groups={ - ROLLOUT_CONCURRENCY_GROUP_GENERATE: worker_generate_max_concurrency, - }, - )(worker_cls) - - def _get_rank_by_actor(self, actor: RolloutWorker) -> Optional[int]: - """Get rank by actor object. - - Args: - actor: The RolloutWorker actor object. - - Returns: - The rank of the worker, or None if not found. - """ - for rank, info in self.rank2info.items(): - if info.actor == actor: - return rank - return None - - def _update_active_workers_and_urls_map(self, active_rollout_workers, worker_server_urls_map): - """Update the list of active rollout workers and their server URLs. - - When the inference engine is launched across nodes (rollout_cross_node_comm=True), only the worker with - tp_rank=0 in each engine is responsible for receiving input data. Other tp_ranks do not accept input. - Therefore, this function updates active_rollout_workers and worker_server_urls_map to keep only the tp_rank=0 - workers and their corresponding URLs. - """ - if self.config.rollout_cross_node_comm or self.num_gpus_per_engine < self.config.gpus_per_node: - return active_rollout_workers, worker_server_urls_map - else: - active_worker_interval = self.num_gpus_per_engine // self.config.gpus_per_node - active_rank = list(worker_server_urls_map.keys())[::active_worker_interval] - active_worker_server_urls = list(worker_server_urls_map.values())[::active_worker_interval] - return active_rollout_workers[::active_worker_interval], dict(zip(active_rank, active_worker_server_urls)) - - def _init_workers(self, placement_group: PlacementGroup): - """Initializes and configures the pool of RolloutWorker actors. - - This method creates workers from the placement group, configures distributed - inference engines by grouping workers, where each group forms a tensor-parallel - inference engine. It determines the `active_workers` to act as the head of each - engine, constructs the `engine_rank_mesh_array` to define engine topology, - acquires necessary distributed communication ports, and finally launches servers - on the `active_workers` to get their addresses. - - Returns: - Tuple[List, Dict]: A tuple where the first element is - `engine_rank_mesh_array`, a list of lists containing the ranks of workers - in each engine, and the second element is `worker_server_urls_map`, - a dictionary mapping the rank of each active worker to its - corresponding server URL. - """ - # Create workers from placement group - workers, rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group( - self._get_worker_cls(), self.config, placement_group - ) - active_servers_count, nodes_per_engine = self.config.get_active_servers_count(len(workers)) - interval = len(workers) // active_servers_count - active_rollout_workers = workers[::interval] - server_urls_per_engine = self.config.server_urls_per_engine - - set_bundle_idxs_objectref = [] - engine_rank_mesh_array = [] - activate_worker_idx = 0 - for active_worker in active_rollout_workers: - head_rank, _ = rank_bundle_idx_list[activate_worker_idx] - engine_workers_meta = rank_bundle_idx_list[head_rank : head_rank + interval] - engine_bundle_idxs = [meta[1] for meta in engine_workers_meta] # meta: (rank, bundle_idx) - set_bundle_idxs_objectref.append(active_worker._set_engine_bundle_idxs.remote(engine_bundle_idxs)) # type: ignore[attr-defined] - engine_rank_mesh_array.append([meta[0] for meta in engine_workers_meta]) - activate_worker_idx += interval - ray.get(set_bundle_idxs_objectref) - # set engine mesh list for each worker - ray.get( - [worker._set_engine_rank_mesh_array.remote(engine_rank_mesh_array) for worker in active_rollout_workers] - ) # type: ignore[attr-defined] - # init dist_init_addr for each worker according to parallel settings - init_dist_init_addrs = ray.get([worker.init_dist_port.remote() for worker in active_rollout_workers]) # type: ignore[attr-defined] - dist_init_addrs = self._update_dist_init_addr( - nodes_per_engine, server_urls_per_engine, init_dist_init_addrs, self.num_gpus_per_engine - ) - # launch rollout servers - init_results = ray.get( - [worker.init.remote(dist_init_addrs[i]) for i, worker in enumerate(active_rollout_workers)] - ) - worker_server_urls_map = dict(init_results) # rank -> url - worker_session_url_dict = dict( - ray.get([worker.get_session_server_info.remote() for worker in active_rollout_workers]) - ) - active_rollout_workers, worker_server_urls_map = self._update_active_workers_and_urls_map( - active_rollout_workers, worker_server_urls_map - ) - active_ranks = list(worker_server_urls_map.keys()) - worker_session_url_dict = {rank: worker_session_url_dict[rank] for rank in active_ranks} - workers_info = {} - for i in range(len(active_rollout_workers)): - rank = list(worker_server_urls_map.keys())[i] - url = worker_server_urls_map[rank] - workers_info[rank] = WorkerInfo( - actor=active_rollout_workers[i], - url=url, - session_url=worker_session_url_dict[rank], - ) - self.logger.info(f"Rollout worker server URLs: {[info.url for info in workers_info.values()]}") - self.logger.info(f"Rollout worker session server URLs: {[info.session_url for info in workers_info.values()]}") - return engine_rank_mesh_array, worker_server_urls_map, workers_info - + def _broadcast_to_rollout_worker_generators(self, method_name: str): + generators = [worker.generator_actor for worker in self.worker_handles] + futures = [getattr(actor, method_name).remote() for actor in generators] + if not futures: + return [] + return ray.get(futures, timeout=ROLLOUT_RAY_GET_TIMEOUT) RayRolloutController = ray.remote(RolloutController) RolloutControllerProxy: TypeAlias = ActorProxy[RayRolloutController] diff --git a/xtuner/v1/rl/rollout/lmdeploy.py b/xtuner/v1/rl/rollout/lmdeploy.py index e8eb241038..d3e2eb7037 100644 --- a/xtuner/v1/rl/rollout/lmdeploy.py +++ b/xtuner/v1/rl/rollout/lmdeploy.py @@ -3,14 +3,10 @@ from itertools import chain from typing import Any, Dict, List -import numpy as np import ray import requests from ray.util.placement_group import placement_group_table -from transformers import AutoTokenizer -from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams - from .worker import RolloutConfig, RolloutWorker @@ -67,14 +63,12 @@ def __init__( super().__init__(config, rank, master_addr, master_port, world_size, accelerator) self.server_func = run_lmdeploy_server_wrapper self.router_func_str = "lmdeploy.serve.proxy.proxy.proxy" - self.endpoints["health_generate"] = "health" self.endpoints["generate"] = "generate" self.endpoints["v1/chat/completions"] = "v1/chat/completions" self.endpoints["output_ids"] = "output_ids" self.endpoints["response"] = "text" self.endpoints["sleep"] = "sleep" self.endpoints["wake_up"] = "wakeup" - self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_path, trust_remote_code=True) self.api_keys = self.config.api_key self.model_name = self.config.model_name self.enable_return_routed_experts = self.config.enable_return_routed_experts @@ -92,58 +86,6 @@ def onload_kvcache(self): """Onloads the KV cache by waking up the model.""" return self._wake_up(tags=["kv_cache"]) - def _get_request_payload(self, rollout_state: RolloutState) -> dict: - tools = rollout_state.tools - tool_choice = rollout_state.tool_choice - sample_params = rollout_state.sample_params - message = rollout_state.message - input_tokens = rollout_state.tokens - - optional_fields: dict[str, object] = {} - if tools is not None: - optional_fields["tools"] = tools - if tool_choice is not None: - optional_fields["tool_choice"] = tool_choice - - if sample_params.return_token_ids: - payload = {"model": self.model_name, **optional_fields} - - if "image_data" in rollout_state.extra_fields: - assert input_tokens is not None, "input_tokens is required when image_data is provided." - payload["image_data"] = rollout_state.extra_fields["image_data"] - - if input_tokens is not None: - payload["input_ids"] = input_tokens - else: - text_prompt = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) - prompt_token_ids = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"] - payload["input_ids"] = prompt_token_ids - sample_params.return_routed_experts = True if self.enable_return_routed_experts else False - lmdeploy_sample_params = self._transform_sample_params(sample_params) - payload.update(lmdeploy_sample_params) - else: - payload = { - "model": self.model_name, - "messages": rollout_state.message, - **optional_fields, - } - lmdeploy_sample_params = { - "temperature": sample_params.temperature, - "top_p": sample_params.top_p, - "n": sample_params.n, - "stream": sample_params.stream, - "max_tokens": sample_params.max_tokens, - "repetition_penalty": sample_params.repetition_penalty, - "top_k": sample_params.top_k, - "skip_special_tokens": sample_params.skip_special_tokens, - } - if sample_params.stops: - lmdeploy_sample_params["stop"] = sample_params.stops - if sample_params.min_tokens > 0: - lmdeploy_sample_params["min_new_tokens"] = sample_params.min_tokens - payload.update(lmdeploy_sample_params) - return payload - def _sleep(self, level: int = 1): """Put the model into a sleep state to save resources. @@ -177,15 +119,6 @@ def _wake_up(self, tags: List[str] | None = None): assert response.status_code == 200, response.status_code return response.text - async def _decode_routed_experts(self, routed_experts: Any) -> Any: - if isinstance(routed_experts, str): - if self.lmdeploy_actor is None: - self.lmdeploy_actor = ray.get_actor(SHARED_STORE, namespace=SHARED_STORE_NAMESPACE) - assert self.lmdeploy_actor is not None, "LMDeploy actor should be available in the shared store." - routed_experts_data = await self.lmdeploy_actor.get.remote(routed_experts) - return ray.put(np.asarray(routed_experts_data)) - return np.asarray(routed_experts) - async def cleanup_after_pause(self) -> None: if not self.enable_return_routed_experts: return @@ -385,6 +318,3 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: speculative_config=speculative_config, **lmdeploy_config_kwargs, ) - - def _transform_sample_params(self, sample_params: SampleParams) -> dict: - return sample_params.model_dump(exclude_none=True) diff --git a/xtuner/v1/rl/rollout/rollout_generator.py b/xtuner/v1/rl/rollout/rollout_generator.py new file mode 100644 index 0000000000..39bd18243b --- /dev/null +++ b/xtuner/v1/rl/rollout/rollout_generator.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from typing import Any, Literal +from uuid import uuid4 + +import ray +from pydantic import BaseModel, ConfigDict, Field +from ray.exceptions import RayActorError +from transformers import AutoTokenizer + +from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger + +from ._generation.external_http_entry import ExternalRolloutHttpEntryConfig +from ._generation.internal_http_entry import InternalRolloutHttpEntryConfig +from ._generation.session_worker_selector import RolloutWorkerHandle, RolloutWorkerUrlSource, SessionWorkerSelector +from .parser.factory import build_reasoning_parser, build_tool_call_parser +from .parser.reasoning_parser import ReasoningParser +from .parser.tool_parser import ToolCallParser +from .worker import RolloutConfig + + +class LocalRolloutGeneratorConfig(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + timeout_multiplier: float = Field(default=2.0, gt=0) + + def build(self, rollout_controller) -> "LocalRolloutGenerator": + rollout_metadata = ray.get(rollout_controller.get_rollout_metadata.remote()) + return LocalRolloutGenerator( + worker_handles=rollout_metadata["worker_handles"], + rollout_config=rollout_metadata["rollout_config"], + timeout_multiplier=self.timeout_multiplier, + ) + + +class LocalRolloutGenerator: + """Local AgentLoop generation path. + + It chooses one active rollout worker by session id, then calls the worker's + bound RolloutWorkerGenerator actor directly. The RolloutController is not + on the runtime generation path. + """ + + def __init__( + self, + worker_handles: list[RolloutWorkerHandle], + rollout_config: RolloutConfig, + timeout_multiplier: float = 2.0, + ) -> None: + self.worker_handles = worker_handles + self.worker_selector = SessionWorkerSelector(worker_handles) + self.config = rollout_config + self.timeout_multiplier = timeout_multiplier + self.logger = get_logger(log_dir=rollout_config.worker_log_dir, tag="LocalRolloutGenerator") + self._tool_call_parser, self._reasoning_parser = self._build_output_parsers() + + def _build_output_parsers(self) -> tuple[ToolCallParser | None, ReasoningParser | None]: + tool_call_parser = None + reasoning_parser = None + + if self.config.tool_call_parser != "none": + tool_call_parser = build_tool_call_parser(self.config.tool_call_parser) + + if self.config.reasoning_parser != "none": + tokenizer_path = self.config.tokenizer_path or self.config.model_path + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) + reasoning_parser = build_reasoning_parser(self.config.reasoning_parser, tokenizer) + + return tool_call_parser, reasoning_parser + + def _apply_output_parsers(self, rollout_state: RolloutState) -> None: + if self._tool_call_parser is not None: + parsed = self._tool_call_parser.parse(rollout_state) + rollout_state.tool_calls = parsed.tool_calls + rollout_state.response = parsed.remaining_text or None + if self._reasoning_parser is not None: + parsed_reasoning = self._reasoning_parser.parse(rollout_state) + rollout_state.response = parsed_reasoning.remaining_text + if parsed_reasoning.reasoning_text: + rollout_state.extra_fields["reasoning_text"] = parsed_reasoning.reasoning_text + else: + rollout_state.extra_fields.pop("reasoning_text", None) + + async def generate(self, rollout_state: RolloutState, *, enable_partial_rollout: bool = False) -> RolloutState: + if XTUNER_DETERMINISTIC: + sample_params = rollout_state.sample_params.model_copy(deep=True) + sample_params.sampling_seed = self.config.random_seed + ( + (rollout_state.uid or 0) - (rollout_state.message_uid or 0) + ) + rollout_state.sample_params = sample_params + + session_id = rollout_state.session_uid if rollout_state.session_uid is not None else uuid4().int + worker = await self.worker_selector.select(session_id) + if worker is None: + rollout_state.status = Status.FAILED + rollout_state.error_msg = "No rollout worker available." + return rollout_state + + try: + response_rollout_state = await asyncio.wait_for( + worker.generator_actor.generate.remote( + rollout_state=rollout_state, + enable_partial_rollout=enable_partial_rollout, + ), + timeout=self.config.rollout_timeout * self.timeout_multiplier, + ) + self._apply_output_parsers(response_rollout_state) + return response_rollout_state + except (asyncio.TimeoutError, RayActorError) as exc: + self.logger.error(f"Rollout failed for worker {worker.rank}. Skipping sample. Error: {exc}") + rollout_state.status = Status.FAILED + if isinstance(exc, asyncio.TimeoutError): + rollout_state.error_msg = ( + f"Rollout request timed out after {self.config.rollout_timeout * self.timeout_multiplier} seconds." + ) + else: + rollout_state.error_msg = f"Rollout worker generator actor failed with error: {exc}" + return rollout_state + + +RolloutGenerateKind = Literal["local", "http"] +RolloutHttpEntryKind = Literal["internal", "external"] + + +@dataclass +class RolloutGenerateHandle: + kind: RolloutGenerateKind + local_generator: LocalRolloutGenerator | None = None + base_url: str | None = None + rollout_controller: Any | None = None + + def require_local_generator(self) -> LocalRolloutGenerator: + if self.local_generator is None: + raise RuntimeError(f"Rollout generate handle {self.kind!r} does not provide a local generator.") + return self.local_generator + + def require_base_url(self) -> str: + if self.base_url is None: + raise RuntimeError(f"Rollout generate handle {self.kind!r} does not provide a base URL.") + return self.base_url + + +class RolloutGenerateHandleConfig(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + kind: RolloutGenerateKind = "local" + http_entry: RolloutHttpEntryKind = "internal" + local_generator_config: LocalRolloutGeneratorConfig = Field(default_factory=LocalRolloutGeneratorConfig) + + base_url: str | None = None + + internal_http_entry_host: str = "0.0.0.0" + internal_http_entry_port: int = 8081 + internal_http_entry_title: str = "XTuner Internal Rollout Router" + internal_http_entry_version: str = "0.1.0" + internal_http_entry_log_level: str = "warning" + internal_http_entry_request_timeout: float | None = None + internal_http_entry_stream_timeout: float | None = None + http_worker_url_source: RolloutWorkerUrlSource = "backend" + + external_http_entry_delete_existing: bool = True + external_http_entry_check_worker_urls: bool = True + external_http_entry_check_base_url: bool = True + + def build(self, rollout_controller) -> RolloutGenerateHandle: + if self.kind == "local": + return RolloutGenerateHandle( + kind=self.kind, + local_generator=self.local_generator_config.build(rollout_controller), + rollout_controller=rollout_controller, + ) + + if self.kind != "http": + raise ValueError(f"Unsupported rollout generate kind: {self.kind!r}") + + base_url = self._resolve_http_base_url(rollout_controller) + if base_url is None: + raise ValueError(f"Rollout generate handle {self.kind!r} requires base_url.") + return RolloutGenerateHandle(kind=self.kind, base_url=base_url, rollout_controller=rollout_controller) + + def _resolve_http_base_url(self, rollout_controller) -> str | None: + if self.base_url is not None: + return self.base_url + + if self.http_entry == "internal": + rollout_metadata = ray.get(rollout_controller.get_rollout_metadata.remote()) + base_url = rollout_metadata.get("internal_http_entry_url") + if base_url is None: + raise ValueError( + "Rollout generate handle kind='http' and http_entry='internal' requires the internal HTTP " + "entry to be started before building AgentLoop, or base_url to be provided as a runtime override." + ) + return base_url + + if self.http_entry == "external": + raise ValueError("Rollout generate handle kind='http' and http_entry='external' requires base_url.") + + raise ValueError(f"Unsupported rollout HTTP entry: {self.http_entry!r}") + + def build_internal_http_entry_config(self) -> InternalRolloutHttpEntryConfig | None: + if self.kind != "http" or self.http_entry != "internal": + return None + return InternalRolloutHttpEntryConfig( + host=self.internal_http_entry_host, + port=self.internal_http_entry_port, + title=self.internal_http_entry_title, + version=self.internal_http_entry_version, + log_level=self.internal_http_entry_log_level, + request_timeout=self.internal_http_entry_request_timeout, + stream_timeout=self.internal_http_entry_stream_timeout, + worker_url_source=self.http_worker_url_source, + ) + + def build_external_http_entry_config(self) -> ExternalRolloutHttpEntryConfig | None: + if self.kind != "http" or self.http_entry != "external": + return None + if self.base_url is None: + raise ValueError("Rollout generate handle kind='http' and http_entry='external' requires base_url.") + return ExternalRolloutHttpEntryConfig( + base_url=self.base_url, + worker_url_source=self.http_worker_url_source, + delete_existing=self.external_http_entry_delete_existing, + check_worker_urls=self.external_http_entry_check_worker_urls, + check_base_url=self.external_http_entry_check_base_url, + ) diff --git a/xtuner/v1/rl/rollout/rollout_worker_build.py b/xtuner/v1/rl/rollout/rollout_worker_build.py new file mode 100644 index 0000000000..f7d34bea7e --- /dev/null +++ b/xtuner/v1/rl/rollout/rollout_worker_build.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import Any + +import ray +from ray.util.placement_group import PlacementGroup +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from xtuner.v1.rl.utils import AutoAcceleratorWorkers +from xtuner.v1.utils import get_logger + +from ._generation.rollout_worker_generator import RayRolloutWorkerGenerator +from ._generation.session_worker_selector import RolloutWorkerHandle +from .worker import RolloutConfig, RolloutWorker + + +@dataclass +class RolloutWorkerRuntime: + """Runtime handles for one rollout worker.""" + + worker_actor: RolloutWorker + backend_url: str + session_server_url: str | None = None + generator_actor: Any | None = None + bundle_idx: int | None = None + + +@dataclass +class RolloutRuntime: + engine_rank_mesh_array: list[list[int]] + worker_server_urls_map: dict[int, str] + rank2worker: dict[int, RolloutWorkerRuntime] + worker_handles: list[RolloutWorkerHandle] + + +class RolloutWorkerBuilder: + """Bootstrap rollout workers and worker generation actors.""" + + def __init__(self, config: RolloutConfig, placement_group: PlacementGroup) -> None: + self.config = config + self.placement_group = placement_group + self.num_gpus_per_engine = config.num_gpus_per_engine + self.logger = get_logger(log_dir=config.worker_log_dir, tag="RolloutWorkerBuilder") + self.num_active_workers = 0 + + def build(self) -> RolloutRuntime: + engine_rank_mesh_array, worker_server_urls_map, rank2worker = self._init_workers() + self._init_rollout_worker_generators(rank2worker) + worker_handles = self._build_worker_handles(rank2worker) + return RolloutRuntime( + engine_rank_mesh_array=engine_rank_mesh_array, + worker_server_urls_map=worker_server_urls_map, + rank2worker=rank2worker, + worker_handles=worker_handles, + ) + + def _build_worker_handles(self, rank2worker: dict[int, RolloutWorkerRuntime]) -> list[RolloutWorkerHandle]: + worker_handles = [] + for rank, worker in rank2worker.items(): + if worker.generator_actor is None: + raise RuntimeError(f"Missing RolloutWorkerGenerator for rollout worker rank {rank}.") + worker_handles.append( + RolloutWorkerHandle( + rank=rank, + worker_actor=worker.worker_actor, + backend_url=worker.backend_url, + generator_actor=worker.generator_actor, + session_server_url=worker.session_server_url, + ) + ) + return worker_handles + + def _init_rollout_worker_generators(self, rank2worker: dict[int, RolloutWorkerRuntime]) -> None: + for rank, worker in rank2worker.items(): + if worker.bundle_idx is None: + raise RuntimeError(f"Missing placement bundle index for active rollout worker rank {rank}.") + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=self.placement_group, + placement_group_capture_child_tasks=False, + placement_group_bundle_index=worker.bundle_idx, + ) + worker.generator_actor = RayRolloutWorkerGenerator.options( + scheduling_strategy=scheduling_strategy, + num_cpus=0, + ).remote(self.config, rank, worker.backend_url) + + def _get_worker_cls(self): + if os.environ.get("XTUNER_USE_LMDEPLOY") == "1": + from .lmdeploy import LMDeployWorker + + worker_cls = LMDeployWorker + elif os.environ.get("XTUNER_USE_VLLM") == "1": + from .vllm import vLLMWorker + + worker_cls = vLLMWorker + elif os.environ.get("XTUNER_USE_SGLANG") == "1": + from .sglang import SGLangWorker + + worker_cls = SGLangWorker + else: + raise NotImplementedError( + "Rollout backend is not supported." + "Please set XTUNER_USE_LMDEPLOY or XTUNER_USE_VLLM" + " or XTUNER_USE_SGLANG environment variable." + ) + return ray.remote(worker_cls) + + def _update_dist_init_addr(self, nodes_per_engine, server_urls_per_engine, dist_init_addrs, tp_size): + assert server_urls_per_engine == 1 or nodes_per_engine == 1 + if nodes_per_engine > 1: + index = list(range(0, self.num_active_workers + 1, tp_size)) + [self.num_active_workers] + for i in range(1, len(index)): + dist_init_addrs[index[i - 1] : index[i]] = [dist_init_addrs[index[i - 1]]] * (index[i] - index[i - 1]) + if server_urls_per_engine > 1: + activate_servers = len(dist_init_addrs) + for i in range(0, activate_servers, server_urls_per_engine): + dist_init_addrs[i : i + server_urls_per_engine] = [dist_init_addrs[i]] * server_urls_per_engine + return dist_init_addrs + + def _update_active_workers_and_urls_map(self, active_rollout_workers, worker_server_urls_map): + if self.config.rollout_cross_node_comm or self.num_gpus_per_engine < self.config.gpus_per_node: + return active_rollout_workers, worker_server_urls_map + + active_worker_interval = self.num_gpus_per_engine // self.config.gpus_per_node + active_rank = list(worker_server_urls_map.keys())[::active_worker_interval] + active_worker_server_urls = list(worker_server_urls_map.values())[::active_worker_interval] + return active_rollout_workers[::active_worker_interval], dict(zip(active_rank, active_worker_server_urls)) + + def _init_workers(self): + workers, rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group( + self._get_worker_cls(), self.config, self.placement_group + ) + active_servers_count, nodes_per_engine = self.config.get_active_servers_count(len(workers)) + interval = len(workers) // active_servers_count + active_rollout_workers = workers[::interval] + self.num_active_workers = len(active_rollout_workers) + server_urls_per_engine = self.config.server_urls_per_engine + + set_bundle_idxs_objectref = [] + engine_rank_mesh_array = [] + activate_worker_idx = 0 + for active_worker in active_rollout_workers: + head_rank, _ = rank_bundle_idx_list[activate_worker_idx] + engine_workers_meta = rank_bundle_idx_list[head_rank : head_rank + interval] + engine_bundle_idxs = [meta[1] for meta in engine_workers_meta] + set_bundle_idxs_objectref.append(active_worker._set_engine_bundle_idxs.remote(engine_bundle_idxs)) # type: ignore[attr-defined] + engine_rank_mesh_array.append([meta[0] for meta in engine_workers_meta]) + activate_worker_idx += interval + ray.get(set_bundle_idxs_objectref) + ray.get( + [worker._set_engine_rank_mesh_array.remote(engine_rank_mesh_array) for worker in active_rollout_workers] + ) # type: ignore[attr-defined] + + init_dist_init_addrs = ray.get([worker.init_dist_port.remote() for worker in active_rollout_workers]) # type: ignore[attr-defined] + dist_init_addrs = self._update_dist_init_addr( + nodes_per_engine, server_urls_per_engine, init_dist_init_addrs, self.num_gpus_per_engine + ) + init_results = ray.get( + [worker.init.remote(dist_init_addrs[i]) for i, worker in enumerate(active_rollout_workers)] + ) + worker_server_urls_map = dict(init_results) + worker_session_url_dict = dict( + ray.get([worker.get_session_server_info.remote() for worker in active_rollout_workers]) + ) + active_rollout_workers, worker_server_urls_map = self._update_active_workers_and_urls_map( + active_rollout_workers, worker_server_urls_map + ) + active_ranks = list(worker_server_urls_map.keys()) + worker_session_url_dict = {rank: worker_session_url_dict[rank] for rank in active_ranks} + + worker_runtimes = {} + rank_to_bundle_idx = {rank: bundle_idx for rank, bundle_idx in rank_bundle_idx_list} + for i, rank in enumerate(worker_server_urls_map.keys()): + worker_runtimes[rank] = RolloutWorkerRuntime( + worker_actor=active_rollout_workers[i], + backend_url=worker_server_urls_map[rank], + session_server_url=worker_session_url_dict[rank], + bundle_idx=rank_to_bundle_idx[rank], + ) + self.logger.info(f"Rollout worker server URLs: {[worker.backend_url for worker in worker_runtimes.values()]}") + self.logger.info( + f"Rollout worker session server URLs: {[worker.session_server_url for worker in worker_runtimes.values()]}" + ) + return engine_rank_mesh_array, worker_server_urls_map, worker_runtimes + + +def build_rollout_runtime(config: RolloutConfig, placement_group: PlacementGroup) -> RolloutRuntime: + return RolloutWorkerBuilder(config, placement_group).build() diff --git a/xtuner/v1/rl/rollout/sglang.py b/xtuner/v1/rl/rollout/sglang.py index aedd490e23..61763b2a82 100644 --- a/xtuner/v1/rl/rollout/sglang.py +++ b/xtuner/v1/rl/rollout/sglang.py @@ -28,7 +28,6 @@ def __init__( from sglang.srt.entrypoints.http_server import launch_server self.server_func = launch_server - self.endpoints["health_generate"] = "health" self.endpoints["generate"] = "generate" self.endpoints["v1/chat/completions"] = "v1/chat/completions" self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path, trust_remote_code=True) @@ -136,14 +135,6 @@ def _make_request(self, endpoint: str, payload=None): response.raise_for_status() return response.json() - def check_health(self) -> bool: - try: - response = requests.get(f"{self.server_url}/{self.endpoints['health_generate']}", timeout=5.0) - return response.status_code == 200 - except requests.RequestException as e: - self.logger.error(f"Health check failed for server {self.server_url}: {e}") - return False - def flush_cache(self): """Flush the cache of the server.""" # TODO: 支持 tp diff --git a/xtuner/v1/rl/rollout/utils.py b/xtuner/v1/rl/rollout/utils.py index 69fdf12f52..cf7039822c 100644 --- a/xtuner/v1/rl/rollout/utils.py +++ b/xtuner/v1/rl/rollout/utils.py @@ -1,254 +1,20 @@ -import asyncio import os -import threading import time -from collections import OrderedDict -from itertools import cycle -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import cast import numpy as np import ray from ray import ObjectRef as RayObjectRef from xtuner.v1.data_proto.rl_data import RolloutState, Status -from xtuner.v1.rl.utils import asyncio_run +from xtuner.v1.rl.utils import free_object_refs from xtuner.v1.utils import get_logger -if TYPE_CHECKING: - from .controller import WorkerInfo - from .worker import RolloutConfig, RolloutWorker - ROLLOUT_RAY_GET_TIMEOUT = int(os.getenv("XTUNER_ROLLOUT_RAY_GET_TIMEOUT", str(5 * 3600))) # default 5 hours logger = get_logger() -class SessionRouter: - def __init__( - self, - worker_infos: dict[int, "WorkerInfo"], # worker: worker_status - worker_infos_lock: Optional[threading.RLock] = None, - max_sessions: int = 10000, - max_idle_seconds: Optional[float] = 3600.0, - ): - self._worker_infos = worker_infos - self._worker_infos_lock = worker_infos_lock - self._max_sessions = max_sessions - self._max_idle = max_idle_seconds - - # OrderedDict: key=session_id -> value=(worker_rank, last_used_ts) - self._map: OrderedDict[int, tuple[int, float]] = OrderedDict() - - self._worker_cycler = cycle(worker_infos.keys()) - self._lock = asyncio.Lock() - self.logger = get_logger() - - def _now(self) -> float: - return time.time() - - def _evict_expired(self): - if self._max_idle is None: - return - now = self._now() - - to_delete = [] - for sid, (_, last_used) in self._map.items(): - if now - last_used > self._max_idle: - to_delete.append(sid) - else: - break - for sid in to_delete: - self._map.pop(sid, None) - - def _evict_lru_to_capacity(self): - while len(self._map) > self._max_sessions: - self._map.popitem(last=False) - - def _choose_next_active_worker(self) -> tuple[int, Any]: - n = len(self._worker_infos) - for _ in range(n): - rank = next(self._worker_cycler) - if self._worker_infos_lock is None: - info = self._worker_infos[rank] - if info and info.is_active: - return rank, info.actor - else: - with self._worker_infos_lock: - info = self._worker_infos[rank] - if info and info.is_active: - return rank, info.actor - return -1, None - - async def get_worker(self, session_id: int) -> Optional[Any]: - async with self._lock: - self._evict_expired() - - if session_id in self._map: - worker_rank, _ = self._map.pop(session_id) - if self._worker_infos_lock is None: - info = self._worker_infos.get(worker_rank) - else: - with self._worker_infos_lock: - info = self._worker_infos.get(worker_rank) - if info and info.is_active: - self._map[session_id] = (worker_rank, self._now()) - return info.actor - - rank, worker = self._choose_next_active_worker() - if rank == -1: - return None - self._map[session_id] = (rank, self._now()) - self._evict_lru_to_capacity() - return worker - - -class RolloutHealthChecker: - def __init__( - self, - config: "RolloutConfig", - workers_info: dict[int, "WorkerInfo"], - worker_infos_lock: Optional[threading.RLock] = None, - ): - self._workers_info = workers_info - self._worker_infos_lock = worker_infos_lock - self._check_interval = config.health_check_interval_seconds - self._check_failure_threshold = config.health_check_failure_threshold - self._stop_event: Optional[threading.Event] = None - self._pause_event: Optional[threading.Event] = None - self._thread: Optional[threading.Thread] = None - - def start(self) -> None: - if self._thread and self._thread.is_alive(): - return - - self._stop_event = threading.Event() - self._pause_event = threading.Event() - self._pause_event.set() # 启动时设置为暂停状态,开始generation后再调用restart方法恢复 - self._thread = threading.Thread(target=self._run_loop, daemon=True) - self._thread.start() - logger.info("RolloutHealthChecker started.") - - def stop(self) -> None: - if not self._thread: - return - - assert self._stop_event is not None - self._stop_event.set() - if self._pause_event: - self._pause_event.clear() - self._thread.join(timeout=5) - self._thread = None - self._stop_event = None - logger.info("RolloutHealthChecker stopped.") - - def pause(self) -> None: - if self._pause_event is None: - return - self._pause_event.set() - logger.info("RolloutHealthChecker paused.") - - def resume(self) -> None: - if self._pause_event is None: - return - self._pause_event.clear() - logger.info("RolloutHealthChecker restarted.") - - def run_once(self) -> None: - logger.debug("RolloutHealthChecker running health checks for all workers.") - if self._worker_infos_lock is None: - workers_snapshot = { - rank: (info.actor, info.url, info.is_active) for rank, info in self._workers_info.items() - } - else: - with self._worker_infos_lock: - workers_snapshot = { - rank: (info.actor, info.url, info.is_active) for rank, info in self._workers_info.items() - } - - workers_to_check = [ - (rank, actor, url, is_active) for rank, (actor, url, is_active) in workers_snapshot.items() if is_active - ] - if not workers_to_check: - return - - tasks = [ - check_worker_health(actor, rank, url, is_active, self._check_failure_threshold) - for rank, actor, url, is_active in workers_to_check - ] - - async def _run_checks() -> list[bool]: - return await asyncio.gather(*tasks) - - check_results = asyncio_run(_run_checks()) - inactive_workers = [] - for (rank, _, _, _), is_healthy in zip(workers_to_check, check_results): - if not is_healthy: - logger.warning(f"Worker {rank} failed health check. Marking as inactive.") - if self._worker_infos_lock is None: - self._workers_info[rank].is_active = False - inactive_worker = self._workers_info[rank].actor - else: - with self._worker_infos_lock: - self._workers_info[rank].is_active = False - inactive_worker = self._workers_info[rank].actor - if inactive_worker is None: - logger.error(f"[RolloutHealthChecker] Worker {rank} has no actor reference. Skipping shutdown.") - continue - inactive_workers.append((rank, inactive_worker)) - else: - logger.debug(f"[RolloutHealthChecker] Worker {rank} passed health check.") - - for rank, inactive_worker in inactive_workers: - try: - ray.get(inactive_worker.offload.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - except Exception as e: - logger.error(f"Exception while offloading worker {rank}: {e}") - - try: - ray.get(inactive_worker.shutdown.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - except Exception as e: - logger.error(f"Exception while shutting down worker {rank}: {e}") - - def _run_loop(self) -> None: - assert self._stop_event is not None and self._pause_event is not None - logger.info("RolloutHealthChecker loop started.") - - while not self._stop_event.is_set(): - while self._pause_event.is_set() and not self._stop_event.is_set(): - self._stop_event.wait(timeout=0.5) - - if self._stop_event.is_set(): - break - - if not self._pause_event.is_set() and not self._stop_event.is_set(): - self.run_once() - - if self._stop_event.wait(self._check_interval): - break - - -async def check_worker_health( - worker: "RolloutWorker", rank: int, url: str, is_active: bool, failure_threshold: int = 3 -) -> bool: - if worker is None or not is_active: - logger.warning("Worker has no actor reference or is marked inactive.") - return False - failing_count = 0 - while failing_count < failure_threshold: - try: - health_status = await worker.check_health.remote() # type: ignore[attr-defined] - if health_status: - return True - failing_count += 1 - logger.warning(f"Health check failed for worker {rank} at {url}. Failure count: {failing_count}") - except Exception as e: - failing_count += 1 - logger.error( - f"Exception during health check for worker {rank} at {url}: {e}. Failure count: {failing_count}" - ) - return False - - async def _resolve_routed_experts(routed_experts: np.ndarray | RayObjectRef) -> np.ndarray: if isinstance(routed_experts, RayObjectRef): routed_experts = await routed_experts @@ -294,7 +60,8 @@ async def postprocess( routed_experts: np.ndarray | RayObjectRef | None, finish_reason: str, status: Status, - routed_experts_expect_len: int, + routed_experts_expect_len: int | None = None, + enable_partial_rollout: bool | None = None, ) -> RolloutState: rollout_state.finish_reason = finish_reason rollout_state.status = status @@ -312,7 +79,9 @@ async def postprocess( if history_routed_experts is not None and routed_experts is not None: # case 1: 上一次 rolloutstate 有 response, 本次推理也有 response,需要对 routed experts 进行拼接 start_time = time.perf_counter() - history_routed_experts = await _resolve_routed_experts(history_routed_experts) # type: ignore[assignment] + history_routed_experts_ref = history_routed_experts + cur_routed_experts_ref = routed_experts + history_routed_experts = await _resolve_routed_experts(history_routed_experts_ref) # type: ignore[assignment] cur_routed_experts = await _resolve_routed_experts(routed_experts) # type: ignore[assignment] history_routed_experts_len = len(history_routed_experts) cur_routed_experts_len = len(cur_routed_experts) @@ -320,18 +89,26 @@ async def postprocess( f"Existing routed_experts len: {history_routed_experts_len}, current routed_experts len: {cur_routed_experts_len}, history_response_ids len: {len(history_response_ids)}, current response_ids len: {len(response_ids)}" ) cur_routed_experts = cur_routed_experts[history_routed_experts_len:] - concat_routed_experts = np.concatenate([history_routed_experts, cur_routed_experts], axis=0) + if isinstance(history_routed_experts, list) and isinstance(cur_routed_experts, list): + concat_routed_experts = history_routed_experts + cur_routed_experts + else: + concat_routed_experts = np.concatenate([history_routed_experts, cur_routed_experts], axis=0) rollout_state.routed_experts = ray.put(concat_routed_experts) - expected_len = len(cast(list[int], rollout_state.prompt_ids)) + len(rollout_state.response_ids) - 1 - assert expected_len == routed_experts_expect_len, ( - f"Expected routed_experts len: {expected_len}, routed_experts_expect_len: {routed_experts_expect_len}, prompt_ids_len: {len(cast(list[int], rollout_state.prompt_ids))}, response_ids_len: {len(rollout_state.response_ids)}" - ) - assert len(concat_routed_experts) == routed_experts_expect_len, ( - f"After concatenation, routed_experts len: {len(concat_routed_experts)}, expected len: {expected_len}, history_routed_experts_len: {history_routed_experts_len}, current_routed_experts_len: {len(cur_routed_experts)}, prompt_ids_len: {len(cast(list[int], rollout_state.prompt_ids))}, response_ids_len: {len(rollout_state.response_ids)}" + if routed_experts_expect_len is not None: + expected_len = len(cast(list[int], rollout_state.prompt_ids)) + len(rollout_state.response_ids) - 1 + assert expected_len == routed_experts_expect_len, ( + f"Expected routed_experts len: {expected_len}, routed_experts_expect_len: {routed_experts_expect_len}, prompt_ids_len: {len(cast(list[int], rollout_state.prompt_ids))}, response_ids_len: {len(rollout_state.response_ids)}" + ) + assert len(concat_routed_experts) == routed_experts_expect_len, ( + f"After concatenation, routed_experts len: {len(concat_routed_experts)}, expected len: {expected_len}, history_routed_experts_len: {history_routed_experts_len}, current_routed_experts_len: {len(cur_routed_experts)}, prompt_ids_len: {len(cast(list[int], rollout_state.prompt_ids))}, response_ids_len: {len(rollout_state.response_ids)}" + ) + free_object_refs( + [ + ref + for ref in (history_routed_experts_ref, cur_routed_experts_ref) + if isinstance(ref, RayObjectRef) + ] ) - # free_object_refs( - # [ref for ref in (history_routed_experts_ref, cur_routed_experts_ref) if isinstance(ref, ray.ObjectRef)] - # ) end_time = time.perf_counter() self.logger.debug( f"[PartialRolloutHandler] Postprocess routed_experts concatenation time: {end_time - start_time:.4f} seconds" diff --git a/xtuner/v1/rl/rollout/vllm.py b/xtuner/v1/rl/rollout/vllm.py index b1ff1794bb..6b5c211d63 100644 --- a/xtuner/v1/rl/rollout/vllm.py +++ b/xtuner/v1/rl/rollout/vllm.py @@ -121,13 +121,13 @@ def __init__(self, server_namespace: Namespace): print(stack_trace) raise # Re-raise the exception to prevent silent failure - def actor_health(self): - return "healthy" + def started(self): + return "started" # Add a dummy task. def run_lmdeploy_server_wrapper(server_namespace: Namespace): - return ray.get(VllmServerWrapper.remote(server_namespace).actor_health.remote()) # type: ignore + return ray.get(VllmServerWrapper.remote(server_namespace).started.remote()) # type: ignore class vLLMWorker(RolloutWorker): @@ -143,7 +143,6 @@ def __init__( super().__init__(config, rank, master_addr, master_port, world_size, accelerator) self.router_func = "" self.server_func = run_lmdeploy_server_wrapper - self.endpoints["health_generate"] = "health" self.endpoints["v1/chat/completions"] = "v1/chat/completions" self.endpoints["generate"] = "v1/chat/completions" self.endpoints["sleep"] = "sleep" diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py index 4dabb19f7c..7655a83022 100644 --- a/xtuner/v1/rl/rollout/worker.py +++ b/xtuner/v1/rl/rollout/worker.py @@ -1,48 +1,34 @@ import asyncio -import copy import json -import math import multiprocessing import os import socket import threading import time -import traceback from abc import abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Union import httpx import ray -import requests # type: ignore[import-untyped] from cyclopts import Group, Parameter -from packaging.version import Version from pydantic import BaseModel, ConfigDict, PrivateAttr from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from typing_extensions import Annotated from transformers import AutoTokenizer -from xtuner.v1.data_proto.rl_data import ( - RolloutState, - SampleParams, - Status, - reset_rollout_response, - update_status_from_finish_reason, -) from xtuner.v1.rl.utils import ( AutoAcceleratorWorkers, CPUResourcesConfig, SingleAcceleratorWorker, - cancel_and_drain, find_master_addr_and_port, get_eos_token, register_cpu_resources, ) from xtuner.v1.utils import get_logger -from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult from .session_server import SessionServerActor -from .utils import ROLLOUT_RAY_GET_TIMEOUT, PartialRolloutHandler +from .utils import ROLLOUT_RAY_GET_TIMEOUT if TYPE_CHECKING: @@ -50,7 +36,6 @@ infer_group = Group("inference", help="Inference worker configuration.") -ROLLOUT_CONCURRENCY_GROUP_GENERATE = "generate" class RolloutConfig(BaseModel): @@ -272,13 +257,6 @@ class RolloutConfig(BaseModel): help='Extra configuration for different rollout worker. vllm parameters will start with prefix "vllm", etc.', ), ] = {} - max_retry_per_worker: Annotated[ - Optional[int], - Parameter( - group=infer_group, - help="Maximum number of retries per rollout worker before deactivation.", - ), - ] = None max_retry_per_sample: Annotated[ int, Parameter( @@ -308,20 +286,6 @@ class RolloutConfig(BaseModel): ), ] = False worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" - health_check_interval_seconds: Annotated[ - float, - Parameter( - group=infer_group, - help="Interval in seconds between rollout worker health checks.", - ), - ] = 30.0 - health_check_failure_threshold: Annotated[ - int, - Parameter( - group=infer_group, - help="Number of consecutive health check failures required before marking a worker inactive.", - ), - ] = 3 _logged_server_urls_per_engine: bool = PrivateAttr(default=False) @property @@ -375,17 +339,6 @@ def get_active_servers_count(self, num_rollout_workers: int) -> tuple[int, int]: ) return active_servers_count, nodes_per_engine - def get_controller_generate_concurrency(self, placement_group: "PlacementGroup") -> int: - active_worker_count, _ = self.get_active_servers_count(len(placement_group.bundle_specs)) - assert self.rollout_max_batch_size_per_instance is not None, ( - "rollout_max_batch_size_per_instance must be set before building RolloutController." - ) - concurrency_per_worker = math.ceil( - self.rollout_max_batch_size_per_instance * self.allow_over_concurrency_ratio - ) - generate_max_concurrency = active_worker_count * concurrency_per_worker - return generate_max_concurrency - def model_post_init(self, __context: Any) -> None: if self.model_name is None: model_name_from_config = None @@ -433,9 +386,6 @@ def model_post_init(self, __context: Any) -> None: else: self.rollout_max_batch_size_per_instance = 128 - if self.max_retry_per_worker is None: - self.max_retry_per_worker = self.rollout_max_batch_size_per_instance - self.worker_log_dir.mkdir(parents=True, exist_ok=True) def build(self, placement_group: "PlacementGroup"): @@ -450,31 +400,22 @@ def build(self, placement_group: "PlacementGroup"): import ray from xtuner.v1.rl.rollout.controller import RolloutController + from xtuner.v1.rl.rollout.rollout_worker_build import build_rollout_runtime num_workers = 1 register_cpu_resources( name="rollout_controller", cpu_resources=CPUResourcesConfig(num_workers=num_workers), ) - generate_max_concurrency = self.get_controller_generate_concurrency(placement_group) - get_logger().info(f"Calculated RolloutController generate concurrency: {generate_max_concurrency}") - return ( - ray.remote( - concurrency_groups={ - ROLLOUT_CONCURRENCY_GROUP_GENERATE: generate_max_concurrency, - }, - )(RolloutController) - .options(num_cpus=num_workers) - .remote(self, placement_group) - ) + runtime = build_rollout_runtime(self, placement_group) + return ray.remote(RolloutController).options(num_cpus=num_workers).remote(self, runtime=runtime) class RolloutWorker(SingleAcceleratorWorker): """Base class for a rollout worker that runs an inference server. - This class manages the lifecycle of a distributed inference server, including initialization, launching, and - handling generation requests. It is designed to be subclassed for specific inference backends like LMDeploy, vLLM - or SGLang. + This class manages the lifecycle of a distributed inference server, including initialization, launching, weight + updates, and backend control. Runtime generation is handled by RolloutWorkerGenerator. """ def __init__( @@ -506,13 +447,9 @@ def __init__( self.server_func: Callable self.endpoints: dict[str, str] = dict() self.engine_rank_mesh_array: list[list[int]] - # http_concurrency is calculated based on the max batch size per engine and the total number of engines assert config.rollout_max_batch_size_per_instance, ( "rollout_max_batch_size_per_instance must be set in RolloutConfig" ) - http_concurrency = config.rollout_max_batch_size_per_instance * config.allow_over_concurrency_ratio - limits = httpx.Limits(max_connections=http_concurrency, max_keepalive_connections=100) - self.client = httpx.AsyncClient(limits=limits, timeout=self.config.rollout_timeout) self.server_task = None self.engine_bundle_idxs: list[int] = [] self.server_process: Optional[multiprocessing.Process] = None @@ -533,11 +470,6 @@ def __init__( self.abort_timeout = 10.0 self.dist_init_addr: str = "" self.serverl_url: str = "" - self.partial_rollout_handler = PartialRolloutHandler() - self.enable_partial_rollout: bool = False - - def set_enable_partial_rollout(self, enable: bool) -> None: - self.enable_partial_rollout = enable def init(self, dist_init_addr: str) -> tuple[int, str]: """Initialize the worker and launch the server. @@ -686,185 +618,9 @@ def continue_generation(self): """Resume the worker's generation process.""" self.receive_abort_request.clear() - def check_health(self) -> bool: - """Check the health of the worker's server. - - Returns: - bool: True if the server is healthy, False otherwise. - """ - try: - headers = { - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {self.config.api_key}", - } - response = requests.get( - f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers, timeout=5.0 - ) - return response.status_code == 200 - except requests.RequestException as e: - self.logger.error(f"Health check failed for server {self.server_url}: {e}") - return False - - async def _decode_routed_experts(self, routed_experts: Any) -> Any: - return routed_experts - - @ray.method(concurrency_group=ROLLOUT_CONCURRENCY_GROUP_GENERATE) - async def generate(self, rollout_state: RolloutState) -> RolloutState: - # TODO(@duanyanhui): - # 1. support claude format input - # 2. 需要看下新的输入输出(RolloutState)怎么适配PartialRollout的逻辑,先跑起来 - # 3. 对于流式返回的response先删掉,目前还用不上,等需要的时候再加上 - - if self.receive_abort_request.is_set(): - rollout_state.finish_reason = "abort" - rollout_state.status = Status.ABORTED - return rollout_state - - uid = rollout_state.uid - sample_params: SampleParams = rollout_state.sample_params - max_tokens = sample_params.max_tokens - enable_partial_rollout = self.enable_partial_rollout - if sample_params.return_token_ids: - endpoint_url = f"{self.server_url}/{self.endpoints['generate']}" - else: - endpoint_url = f"{self.server_url}/{self.endpoints['v1/chat/completions']}" - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.config.api_key}", - } - - if enable_partial_rollout: - rollout_state = self.partial_rollout_handler.preprocess(rollout_state, max_tokens) - elif rollout_state.status == Status.ABORTED: - # ABORTED samples can be replayed; without partial rollout, rerun from the original prompt. - rollout_state = reset_rollout_response(rollout_state) - rollout_state.sample_params = rollout_state.sample_params.model_copy(update={"max_tokens": max_tokens}) - rollout_state.status = Status.INIT - payload = self._get_request_payload(rollout_state) - max_retries = self.config.max_retry_per_sample - - # 早退逻辑 1:检查是否已被标记为完成 - if rollout_state.status == Status.COMPLETED: - self.logger.debug(f"Request {uid} is already marked as COMPLETED, skipping generation.") - return rollout_state - - # 早退逻辑 2:检测输入是否还需要 generation (安全获取变量) - input_ids = payload.get("input_ids", []) - max_tokens = cast(int, payload.get("max_tokens")) - - last_id = input_ids[-1] if len(input_ids) > 0 else "None" - is_max_tokens_zero = max_tokens is not None and max_tokens <= 0 - is_eos_reached = len(input_ids) > 0 and input_ids[-1] in self.eos_token - - if is_max_tokens_zero or is_eos_reached: - self.logger.debug( - f"No generation needed for request {uid}: max_tokens={max_tokens} or last input_id={last_id} is in eos_token." - ) - finish_reason = "stop" if is_eos_reached else "length" - # 对于是否开 partial rollout 的情况都直接标记为完成并返回,因为本轮 rollout 未开始,也不需要拼接 - rollout_state.finish_reason = finish_reason - rollout_state.status = Status.COMPLETED - return rollout_state - - for attempt in range(max_retries + 1): - is_last_attempt = attempt == max_retries - http_result = await self._safe_post_request(endpoint_url, headers=headers, payload=payload) - - # Case 1: HTTP Request is Successful - if http_result.response: - # Case 1.1: Valid rollout response - rollout_state = await self._safe_handle_response(rollout_state, http_result.response) - if rollout_state.status in [Status.COMPLETED, Status.ABORTED]: - return rollout_state - - if is_last_attempt: - # Case 1.2: Invalid rollout response and no retries left, so we return FAILED - self.logger.warning( - f"Invalid rollout response for request {uid} after {max_retries} attempts, marking as FAILED." - ) - rollout_state.status = Status.FAILED - rollout_state.error_msg = f"Invalid rollout response after {max_retries} attempts." - return rollout_state - - # Case 1.3: Invalid rollout response but we have retries left - self.logger.warning( - f"Invalid rollout response for request {uid}, retrying {attempt + 1}/{max_retries}." - ) - await asyncio.sleep(0.1) - continue - - # Case 2: Error occurred during HTTP Request - if http_result.error_type == HttpRequestErrorType.REQUEST_ABORTED: - # Case 2.1: The request was aborted due to an signal set by `receive_abort_request` - rollout_state.finish_reason = "abort" - rollout_state.status = update_status_from_finish_reason("abort") - return rollout_state - - if http_result.is_client_error: - # Case 2.2: A non-retryable client error occurred (such as 4xx HTTP status) - self.logger.warning( - f"rollout request {uid} to {http_result.url} was skipped due to client error {http_result.error_type} with {http_result.error_msg}" - ) - rollout_state.error_msg = ( - f"Client error {http_result.error_type} with message: {http_result.error_msg}" - ) - rollout_state.status = Status.FAILED - return rollout_state - - if http_result.is_server_error: - # Case 2.3: A non-retryable server error occurred (such as 5xx HTTP status) - self.logger.warning( - f"rollout request {uid} to {http_result.url} failed due to server error {http_result.error_type} with {http_result.error_msg}" - ) - rollout_state.error_msg = ( - f"Server error {http_result.error_type} with message: {http_result.error_msg}" - ) - rollout_state.status = Status.FAILED - return rollout_state - - # Case 3: Retryable error occurred during HTTP Request - if http_result.is_retryable: - if is_last_attempt: - self.logger.warning( - f"rollout request {uid} to {http_result.url} failed after {max_retries} attempts due to retryable error {http_result.error_type} with {http_result.error_msg}" - ) - rollout_state.error_msg = f"Request failed after {max_retries} attempts due to retryable error {http_result.error_type} with message: {http_result.error_msg}" - rollout_state.status = Status.FAILED - return rollout_state - - self.logger.warning( - f"rollout request {uid} to {http_result.url} failed due to retryable error {http_result.error_type} with {http_result.error_msg}, retrying {attempt + 1}/{max_retries}." - ) - await asyncio.sleep(0.1) - continue - - # Case 4: Unknown error occurred during HTTP Request and stop the rollout - if http_result.is_unknown_error: - raise RuntimeError( - f"Unexpected error during rollout request {uid} to {http_result.url}: {http_result.exception}" - ) - return rollout_state - def _launch_server(self): - """Launch the inference server as a separate process or Ray task. - - It waits for the server to become healthy before returning. - - Raises: - TimeoutError: If the server fails to start within the specified - timeout. - Exception: If the server task terminates unexpectedly. - """ + """Launch the inference server as a separate process or Ray task.""" server_configs = self._transform_rollout_config_to_server_configs() - timeout = 3600.0 # Increased timeout to 5 minutes for downloading large models - start_time = time.perf_counter() - last_log_time = start_time - headers = { - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {server_configs.api_key}", - } - self.logger.info(f"Launch server task on server_url: {self.server_url}") # note(@duanyanhui): launch server as multiprocessing for sglang temporarily @@ -873,30 +629,8 @@ def _launch_server(self): process = mp_ctx.Process(target=self.server_func, args=(server_configs,)) process.start() self.server_process = process - time.sleep(60) # Wait for the server to start - with requests.Session() as session: - while time.perf_counter() - start_time < timeout: - try: - response = session.get( - f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers - ) - if response.status_code == 200: - return - except requests.RequestException as e: - self.logger.error( - f"can't connect to server url {self.server_url}/{self.endpoints['health_generate']} because {e}" - ) - - current_time = time.perf_counter() - if current_time - last_log_time >= 15: - self.logger.info( - f"Waiting for server to start, Elapsed time: {current_time - start_time:.2f}s" - ) - last_log_time = current_time - - time.sleep(5) - process.terminate() - raise TimeoutError("Server failed to start within the timeout period.") + time.sleep(60) + return else: # launch the server as ray task # so that the lmdeploy backend could get externl pg @@ -919,312 +653,7 @@ def _launch_server(self): ) .remote(server_configs) ) - - with requests.Session() as session: - while time.perf_counter() - start_time < timeout: - try: - response = session.get( - f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers - ) - if response.status_code == 200: - return - except requests.RequestException: - pass - - try: - ray.get(self.server_task, timeout=0.1) - raise Exception("Server task terminated unexpectedly.") - except ray.exceptions.GetTimeoutError: - pass - except Exception as e: - raise e - - current_time = time.perf_counter() - if current_time - last_log_time >= 15: - self.logger.info( - f"Waiting for server to start... Elapsed time: {current_time - start_time:.2f}s" - ) - last_log_time = current_time - - ray.cancel(self.server_task) - raise TimeoutError("Server failed to start within the timeout period.") - - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - send_task = None - abort_task = None - - try: - if self.receive_abort_request.is_set(): - self.logger.debug(f"Request to {url} was cancelled before sending due to an abort signal.") - return HttpRequestResult(error_type=HttpRequestErrorType.REQUEST_ABORTED, url=url, payload=payload) - req = self.client.build_request( - "POST", - url, - headers=headers, - json=payload, - ) - send_task = asyncio.create_task(self.client.send(req)) - abort_task = asyncio.create_task(self._wait_abort_request()) - done, _ = await asyncio.wait( - {send_task, abort_task}, - return_when=asyncio.FIRST_COMPLETED, - ) - if send_task in done: - r = await send_task - else: - try: - r = await asyncio.wait_for(asyncio.shield(send_task), timeout=self.abort_timeout) - except asyncio.TimeoutError: - self.logger.debug( - f"Request to {url} did not return within {self.abort_timeout:.2f}s after abort signal." - ) - await cancel_and_drain([send_task]) - return HttpRequestResult( - error_type=HttpRequestErrorType.REQUEST_ABORTED, - url=url, - payload=payload, - ) - r.raise_for_status() - return HttpRequestResult(response=r) - - except asyncio.CancelledError: - self.logger.debug(f"Request to {url} was cancelled while waiting for the response.") - await cancel_and_drain([send_task, abort_task]) - self.receive_abort_request.set() - return HttpRequestResult(error_type=HttpRequestErrorType.REQUEST_ABORTED, url=url, payload=payload) - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - return result - finally: - await cancel_and_drain([abort_task]) - - async def _safe_handle_response(self, rollout_state: RolloutState, http_response: httpx.Response) -> RolloutState: - uid = rollout_state.message_uid - sample_params = rollout_state.sample_params - is_token_out = sample_params.return_token_ids - response = http_response.json() - - if is_token_out: - response_ids: list[int] = [] - logprobs: list[float] = [] - routed_experts = None - returned_response = "" - try: - meta_info = response.get("meta_info") or {} - finish_reason_info = meta_info.get("finish_reason") or {} - finish_reason = finish_reason_info.get("type") - if finish_reason is None: - if self.receive_abort_request.is_set(): - rollout_state.finish_reason = "abort" - rollout_state.status = Status.ABORTED - self.logger.warning( - f"finish_reason is missing in response meta_info when waiting for aborted message {uid}, defaulting to 'abort'. Response: {response}" - ) - else: - rollout_state.finish_reason = "error" - rollout_state.status = Status.FAILED - self.logger.warning( - f"finish_reason is missing in response meta_info for message {uid}, defaulting to 'error'. Response: {response}" - ) - rollout_state.error_msg = "Missing finish_reason in response meta_info" - return rollout_state - returned_response = response.get("text", "") - # 获取response_ids && respoonse_ids - if ( - "output_token_logprobs" in response["meta_info"] - and response["meta_info"]["output_token_logprobs"] is not None - ): - response_ids = [item[1] for item in response["meta_info"]["output_token_logprobs"]] - logprobs = [item[0] for item in response["meta_info"]["output_token_logprobs"]] - else: - num_return_tokens = response["meta_info"].get("completion_tokens", 0) - response_ids = response["output_ids"][-num_return_tokens:] if num_return_tokens > 0 else [] - - # 获取 routed_experts - if self.enable_return_routed_experts: - assert "routed_experts" in response["meta_info"], ( - "enable_return_routed_experts is True, but routed_experts is not in meta_info" - ) - routed_experts = response["meta_info"]["routed_experts"] # token[layer[expert]] - if routed_experts is not None: - routed_experts = await self._decode_routed_experts(routed_experts) - if not isinstance(routed_experts, ray.ObjectRef): - routed_experts = ray.put(routed_experts) - - # 获取 status - rollout_status = update_status_from_finish_reason(finish_reason) - - # 检查输出结果 - if rollout_status == Status.COMPLETED: - validation_errors = [] - - if not response_ids: - validation_errors.append("empty response_ids") - - if not response: - validation_errors.append("empty response text") - - if sample_params.return_logprob and not logprobs: - validation_errors.append("missing logprobs") - - if self.enable_return_routed_experts and routed_experts is None: - validation_errors.append("missing routed_experts") - - if validation_errors: - error_msg = f"Incomplete rollout data for msg {uid}: {', '.join(validation_errors)}" - self.logger.error(error_msg) - rollout_state.status = Status.FAILED - rollout_state.error_msg = error_msg - return rollout_state - elif rollout_status == Status.FAILED: - error_msg = f"Rollout failed for msg {uid} with finish_reason {finish_reason}" - self.logger.error(error_msg) - rollout_state.status = Status.FAILED - rollout_state.error_msg = error_msg - return rollout_state - - if self.enable_partial_rollout: - expect_len = ( - response["meta_info"]["prompt_tokens"] + response["meta_info"]["completion_tokens"] - 1 - ) - rollout_state = await self.partial_rollout_handler.postprocess( - rollout_state, - response=returned_response, - response_ids=response_ids, - logprobs=logprobs, - routed_experts=routed_experts, - finish_reason=finish_reason, - status=rollout_status, - routed_experts_expect_len=expect_len, - ) - else: - rollout_state.response = returned_response - rollout_state.response_ids = response_ids - rollout_state.logprobs = logprobs - rollout_state.routed_experts = routed_experts - rollout_state.finish_reason = finish_reason - rollout_state.status = rollout_status - return rollout_state - except KeyError as e: - response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} - error_msg = f"Missing expected key {e} in response {response_for_log} for {uid}" - raise RuntimeError(error_msg) - except IndexError as e: - response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} - error_msg = f"Index error {e} while processing response {response_for_log} for {uid}" - raise RuntimeError(error_msg) - except AssertionError as e: - response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} - error_msg = f"AssertionError: {e} when processing response {response_for_log} for {uid}" - raise RuntimeError(error_msg) - except json.JSONDecodeError as e: - error_msg = f"JSONDecodeError: {e} when processing response {response} for {uid}" - raise RuntimeError(error_msg) - except TypeError as e: - response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} - error_msg = f"TypeError: {e} when processing response {response_for_log} for {uid}" - raise RuntimeError(error_msg) - except Exception as e: - response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} - error_msg = f"Unexpected error: {e} when processing response {response_for_log} for {uid}\nTraceback: {traceback.format_exc()}" - raise RuntimeError(error_msg) - else: - # v1/chat/completions API response - try: - returned_response = response["choices"][0]["message"]["content"] - finish_reason = response["choices"][0]["finish_reason"] - rollout_status = update_status_from_finish_reason(finish_reason) - if rollout_status == Status.COMPLETED and not returned_response: - self.logger.error(f"Empty response text for msg {uid} with finish_reason {finish_reason}") - rollout_state.status = Status.FAILED - rollout_state.error_msg = "Empty response text" - return rollout_state - - rollout_state.response = returned_response - rollout_state.finish_reason = finish_reason - rollout_state.status = rollout_status - return rollout_state - except KeyError as e: - response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} - error_msg = f"Missing expected key {e} in response {response_for_log} for {uid}" - raise RuntimeError(error_msg) - except IndexError as e: - response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} - error_msg = f"Index error {e} while processing response {response_for_log} for {uid}" - raise RuntimeError(error_msg) - except AssertionError as e: - response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} - error_msg = f"AssertionError: {e} when processing response {response_for_log} for {uid}" - raise RuntimeError(error_msg) - except json.JSONDecodeError as e: - error_msg = f"JSONDecodeError: {e} when processing response {response} for {uid}" - raise RuntimeError(error_msg) - except TypeError as e: - response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} - error_msg = f"TypeError: {e} when processing response {response_for_log} for {uid}" - raise RuntimeError(error_msg) - except Exception as e: - response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} - error_msg = f"Unexpected error: {e} when processing response {response_for_log} for {uid}\nTraceback: {traceback.format_exc()}" - raise RuntimeError(error_msg) - - def _adapt_input_to_openai_spec(self, prompts, tools, tool_choice): - openai_prompts = [] - openai_tools = [] - # transform claude spec to openai spec - # 1. transform system prompt: concat provided system_prompt to input prompt - system_prompt = self.config.system_prompt - if system_prompt: - system_prompt_json = {"role": "system", "content": f"{system_prompt}"} - prompts.insert(0, system_prompt_json) - # 2. transform multi-modal usage - for prompt in prompts: - content = prompt["content"] - openai_content = [] - for item in content: - if item["type"] == "image": - if item["source"]["type"] == "base64": - openai_url = f"data:{item['source']['media_type']};base64,{item['source']['data']}" - if item["source"]["type"] == "url": - openai_url = item["source"]["url"] - new_prompt = {"type": "image_url", "image_url": {"url": openai_url}} - openai_content.append(new_prompt) - elif item["type"] == "text": - openai_content.append(item) - new_prompt = copy.deepcopy(prompt) - new_prompt["content"] = openai_content - openai_prompts.append(new_prompt) - # 3. transform tool use - for tool in tools: - openai_tool = { - "type": "function", - "function": { - "name": tool["name"], - "description": tool["description"], - "parameters": tool["input_schema"], - }, - } - openai_tools.append(openai_tool) - return openai_prompts, openai_tools - - def _check_infer_engine_version(self, return_token_ids: bool): - # TODO(@duanyanhui): remove this check when all backends support return_token_ids - if self.check_flag: - if os.environ.get("XTUNER_USE_VLLM", "0") == "1": - if return_token_ids: - self.logger.error( - "VLLM backend does not support return_token_ids or generate with input_ids as input in Xtuner now" - ) - elif os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "1": - import lmdeploy - - lmdeploy_version = lmdeploy.__version__ - if return_token_ids and Version(lmdeploy_version) < Version("0.10.2"): - self.logger.error( - f"You should use lmdeploy >= v0.10.2 to support return_token_ids, but current version is {lmdeploy_version}" - ) - self.check_flag = False + return def _set_engine_rank_mesh_array(self, engine_rank_mesh_array: list[list[int]]): self.engine_rank_mesh_array = engine_rank_mesh_array @@ -1240,14 +669,6 @@ def _set_engine_bundle_idxs(self, engine_bundle_idxs: list[int]): """ self.engine_bundle_idxs = engine_bundle_idxs - @abstractmethod - def _get_request_payload(self, rollout_state: RolloutState) -> dict: - """Abstract method to create a generation request. - - Must be implemented by subclasses. - """ - raise NotImplementedError("_create_request must be implemented in subclass") - @abstractmethod def _transform_rollout_config_to_server_configs(self): """Abstract method to transform rollout config to server configs. @@ -1256,14 +677,6 @@ def _transform_rollout_config_to_server_configs(self): """ raise NotImplementedError("_transform_rollout_config_to_server_configs must be implemented in subclass") - @abstractmethod - def _transform_sample_params(self, sample_params: SampleParams) -> dict: - """Abstract method to transform rollout config to server configs. - - Must be implemented by subclasses. - """ - raise NotImplementedError("_transform_rollout_config_to_server_configs must be implemented in subclass") - @abstractmethod def offload(self): """Abstract method to offload the model and KVcache. diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index c713d30c3f..0babc22ecf 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -36,6 +36,7 @@ _snapshot_nested_objectrefs, ) from xtuner.v1.rl.rollout.controller import RolloutControllerProxy +from xtuner.v1.rl.rollout.rollout_generator import RolloutGenerateHandleConfig from xtuner.v1.rl.rollout.worker import RolloutConfig from xtuner.v1.rl.trainer.controller import TrainingController from xtuner.v1.rl.trainer.worker import WorkerConfig, WorkerLogItem @@ -48,7 +49,6 @@ set_cpu_resource_manager, sort_rollout_state_for_deterministic, ) -from xtuner.v1.rl.utils.misc import check_chat_completions, delete_from_routedapiproxy, register_to_routedapiproxy from xtuner.v1.train.trainer import LoadCheckpointConfig, XTunerMeta from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger, is_hf_model_path, set_deterministic, timer from xtuner.v1.utils.device import get_device, get_torch_device_module @@ -283,6 +283,7 @@ class BaseRLTrainerConfig(BaseModel): advantage_estimator_config: BaseAdvantageConfig = Field(default_factory=GRPOAdvantageConfig) sync_weights_interval: int = 1 gateway_config: GatewayConfig | None = None + rollout_generator_config: RolloutGenerateHandleConfig = Field(default_factory=RolloutGenerateHandleConfig) enable_evaluate: bool = True enable_initial_evaluate: bool = False @@ -623,10 +624,23 @@ def _maybe_start_gateway(self, cfg: BaseRLTrainerConfig) -> None: # gateway 依赖 rollout controller,因此在 rollout controller 构建完成后统一启动。 ray.get(self.rollout_controller.start_gateway.remote(cfg.gateway_config)) + def _build_rollout_generate_handle(self, cfg: BaseRLTrainerConfig): + internal_http_config = cfg.rollout_generator_config.build_internal_http_entry_config() + if internal_http_config is not None: + ray.get(self.rollout_controller.start_internal_http_entry.remote(internal_http_config)) + + external_http_entry_config = cfg.rollout_generator_config.build_external_http_entry_config() + if external_http_entry_config is not None: + ray.get(self.rollout_controller.start_external_http_entry.remote(external_http_entry_config)) + + return cfg.rollout_generator_config.build(self.rollout_controller) + def _build_agent_loop_components(self, cfg: BaseRLTrainerConfig, replay_buffer) -> None: self.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path, trust_remote_code=True) + rollout_generator = self._build_rollout_generate_handle(cfg) self.agent_loop_manager = cfg.agent_loop_manager_cfg.build( rollout_controller=self.rollout_controller, + rollout_generator=rollout_generator, tokenizer=self.tokenizer, replay_buffer=replay_buffer, logger=self.logger, @@ -637,6 +651,7 @@ def _build_agent_loop_components(self, cfg: BaseRLTrainerConfig, replay_buffer) assert cfg.eval_agent_loop_manager_cfg is not None self.eval_agent_loop_manager = cfg.eval_agent_loop_manager_cfg.build( rollout_controller=self.rollout_controller, + rollout_generator=rollout_generator, tokenizer=self.tokenizer, replay_buffer=replay_buffer, logger=self.logger, @@ -1320,37 +1335,6 @@ def _log_mini_batch_metrics(self, workers_log_item: List[WorkerLogItem]): ) self._global_train_step += len(workers_log_item[0]["train_metrics"]) - -def add_apiproxy(self): - info_dict = ray.get(self.rollout_controller.get_rollout_metadata.remote()) - model_name = info_dict["rollout_config"].model_name - - delete_from_routedapiproxy(model_name) - self.logger.info(f"deleted {model_name} from routedapiproxy") - self.logger.info("registering to routedapiproxy") - - worker_session_url_dict = info_dict["worker_session_url_dict"] - worker_session_urls_status = info_dict["worker_session_urls_status"] - for _, worker_session_url in sorted(worker_session_url_dict.items()): - if not worker_session_urls_status.get(worker_session_url, False): - continue - register_to_routedapiproxy(model_name, worker_session_url) - - # test server url - recheck_status_orig = check_chat_completions(worker_session_url, model_name) - if not recheck_status_orig: - raise ValueError(f"check chat completions failed for {worker_session_url}") - - # test routed url - routed_url = "http://s-20260104203038-22bhb.ailab-evalservice.pjh-service.org.cn/v1" - recheck_status_routed = check_chat_completions(routed_url, model_name) - if not recheck_status_routed: - raise ValueError(f"check chat completions failed for {routed_url}") - self.logger.info("registered to routedapiproxy") - # import time - # time.sleep(1000000) - - class RLColocateTrainer(BaseRLTrainer): _META_PATH = ".xtuner_rl_colocate_trainer" @@ -1373,7 +1357,6 @@ def __init__(self, cfg: RLColocateTrainerConfig): self._rollout_config.skip_load_weights = False self.rollout_controller = self._rollout_config.build(self._pg) # self._maybe_start_gateway(cfg) - add_apiproxy(self) replay_buffer = cfg.replay_buffer_config.build() self._build_agent_loop_components(cfg, replay_buffer) @@ -1408,8 +1391,6 @@ def __init__(self, cfg: RLColocateTrainerConfig): # self._maybe_start_gateway(cfg) bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller) - add_apiproxy(self) - replay_buffer = cfg.replay_buffer_config.build() self._build_agent_loop_components(cfg, replay_buffer) if checkpoint_path is not None: @@ -1540,7 +1521,6 @@ def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict) -> bool self._maybe_save_checkpoint(train_step) self._maybe_save_hf(train_step) - ray.get(self.rollout_controller.recover_failed_workers.remote()) timer_name = "sync_weight" if should_sync_weights else "switch_to_rollout" with timer(timer_name, step_timer_dict): if should_sync_weights: @@ -1714,7 +1694,6 @@ async def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict): self._maybe_save_checkpoint(train_step) self._maybe_save_hf(train_step) - ray.get(self.rollout_controller.recover_failed_workers.remote()) with timer("sync_weight", step_timer_dict): bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller) self.update_weights()