diff --git a/rollout_refactor_requirements.md b/rollout_refactor_requirements.md
new file mode 100644
index 0000000000..eb86bb8b35
--- /dev/null
+++ b/rollout_refactor_requirements.md
@@ -0,0 +1,275 @@
+# Rollout Generate Refactor Requirements
+
+## 背景
+
+当前 rollout controller 和 rollout worker 同时承担控制面、生成数据面、健康检测与失败恢复等多类职责,导致调用路径复杂、序列化开销高,并且普通 agent 与 agentic/sandbox 场景存在不一致的生成入口。
+
+本轮重构先聚焦职责拆分和生成路径解耦,不在第一版引入健康检测和失败恢复。
+
+## 核心需求
+
+### 1. 删除健康检测和失败恢复
+
+第一版使用纯净模型,假设 rollout backend、server、worker 不会失败。
+
+需要删除或避免引入以下逻辑:
+
+- worker/server 健康检测、探活、ready/liveness 轮询
+- worker 失败标记、降级、屏蔽
+- 自动恢复、重启 failed worker
+- 健康检测状态在 controller、worker、router 之间同步
+
+这类能力后续如果重新引入,需要作为独立设计处理,不应混入本轮 generate 解耦。
+
+### 2. generate 从 controller 和 worker 中彻底移出
+
+`RolloutController.generate` 和 `RolloutWorker.generate` 都不应该保留。
+
+目标职责边界:
+
+- `RolloutController`:只负责中心化控制面,例如 runtime 初始化、worker metadata、权重同步、offload/onload、pause/continue、endpoint/router 启动。
+- `RolloutWorker`:只负责 backend server 生命周期和 backend 控制,例如启动服务、权重更新、KV cache 控制、pause/continue。
+- 独立生成模块:负责真正的数据生成调用、请求/响应转换、parser、partial rollout 等生成相关逻辑。
+
+### 3. worker 侧生成类需要评估普通类与 Ray actor 形态
+
+controller 侧拆出的生成入口相对轻量,可以优先按普通 async 类实现。
+
+worker 侧生成模块可能包含 tokenizer、partial rollout handler、parser、状态转换等重操作,因此需要评估两种实现方式:
+
+#### 方案 A:生成类本身是 Ray actor
+
+每个 active worker 绑定一个 generation actor,数量与 active worker 一致,并尽量调度到对应 worker 所在节点或资源附近。
+
+优点:
+
+- 重逻辑天然隔离,不阻塞调用方进程。
+- partial rollout handler、tokenizer 等状态可以常驻 actor 内。
+- 和 worker 一一绑定,路由关系清晰。
+
+缺点:
+
+- 引入 Ray actor/grpc 序列化开销。
+- 需要管理 actor 生命周期和 placement。
+- 轻量单轮推理场景可能不划算。
+
+#### 方案 B:生成类保持普通类,重操作外包
+
+主 generate 类保持普通 async 类;当启用 tokenizer、partial rollout handler 等重操作时,将这些重操作单独外包给 Ray actor 或其他执行器。
+
+优点:
+
+- 无重操作时路径最短,避免额外 Ray 序列化。
+- 可按需引入远程执行能力。
+- 更适合简单单轮推理场景。
+
+缺点:
+
+- 接口拆分更复杂,需要明确哪些步骤属于主流程,哪些步骤可外包。
+- partial/tokenize actor 的生命周期和路由关系仍需设计。
+- 如果重操作很多,主流程可能变成多个远程调用,反而增加复杂度。
+
+当前第一版选择方案 A。普通类/重操作外包方案先不进入实现,后续只有在明确需要进一步降低 Ray actor 调用开销时再单独评估。
+
+当前实现原则:
+
+- 第一版统一使用 `RolloutWorkerGenerator`,每个 active rollout worker 绑定一个生成 actor。
+- agentloop 运行时只负责按 session 选择 worker,然后直接调用对应 `RolloutWorkerGenerator`,不经过 controller。
+- partial rollout 不再通过 controller 全局开关传播,而是作为单次 generate 调用参数进入生成模块,避免跨请求共享状态。
+- 对外入口集中在 `xtuner/v1/rl/rollout/rollout_generator.py`,内部实现集中在 `xtuner/v1/rl/rollout/_generation/`,避免 controller/worker 文件继续承载生成细节。
+
+### 4. controller 拆分出的生成入口只是可选路径
+
+generate 解耦后,生成路径需要支持至少三种运行模式:
+
+#### 模式 1:调用内部 generate 接口
+
+agentloop 直接调用 xtuner 内部的独立 generate 类或接口。
+
+适用场景:
+
+- 普通内联 agentloop
+- 希望减少 controller 中转
+- 希望减少不必要序列化
+
+#### 模式 2:调用 xtuner 内部 routed url
+
+xtuner 内部启动 routed HTTP endpoint,对外暴露一个 routed url。agentloop 通过该 url 生成,请求再路由到对应 worker 的生成 URL。
+
+适用场景:
+
+- 需要 HTTP 接口兼容
+- agentloop 不适合直接调用 Ray/Python 接口
+- 希望由 xtuner 内部管理 router
+
+#### 模式 3:调用外部 routed 注册服务
+
+xtuner 将每个 worker 的生成 URL 注册到外部 router,由外部服务提供 routed url。agentloop/sandbox 使用外部 routed url 生成。
+
+适用场景:
+
+- agentic/sandbox 场景
+- 需要第三方 router、隔离部署或独立扩容
+- 不希望 xtuner 内部 router 成为中心化瓶颈
+
+实现上,模式 2 和模式 3 对 agentloop 都是 HTTP 生成路径,因此统一为 `kind="http"`。
+二者区别只体现在 HTTP 入口来源:`http_entry="internal"` 表示 XTuner 启动内部 router,
+`http_entry="external"` 表示 XTuner 将 worker 生成 URL 注册到外部 router。
+
+代码上这两类入口对称命名为 `InternalRolloutHttpEntry` 和 `ExternalRolloutHttpEntry`。
+`InternalRolloutHttpEntry` 是 XTuner 内部 FastAPI 转发入口;`ExternalRolloutHttpEntry`
+不转发请求,只复用 routedapiproxy 注册逻辑,将 worker 生成 URL 注册到外部 router。
+
+worker 生成 URL 由 `http_worker_url_source` 控制:
+
+- `backend`:使用 worker backend url,适合评测或不需要 SessionServer 增强逻辑的简单场景。
+- `session`:使用每个 worker 外层的 SessionServer url,适合需要 session id、token cache、trace/replay 等增强逻辑的 agentic 场景。内部 router 会自动把选出的 `session_id` 写入请求体;外部 router 场景要求调用方或外部 router 保留/注入 `session_id`。
+
+## 运行流程图
+
+### 公共启动流程
+
+所有模式共享同一套 worker 构建流程。`RolloutWorkerBuilder` 启动 active worker、收集 backend URL 和 SessionServer URL,并为每个 active worker 创建一个 `RolloutWorkerGenerator` actor。
+
+```mermaid
+flowchart LR
+ Cfg[RolloutConfig.build] --> Builder[RolloutWorkerBuilder]
+ Builder --> Worker[RolloutWorker]
+ Worker --> Backend[Backend server
lmdeploy/vLLM/SGLang]
+ Worker --> SessionServer[SessionServerActor
optional proxy URL]
+ Builder --> Generator[RolloutWorkerGenerator
one actor per active worker]
+ Builder --> Handle[RolloutWorkerHandle
rank, worker_actor, backend_url,
generator_actor, session_server_url]
+ Handle --> Controller[RolloutController metadata]
+```
+
+### 1. 本地生成:`kind="local"`
+
+`AgentLoop` 持有 `RolloutGenerateHandle`,其中包含 `LocalRolloutGenerator`。运行时按 session 选择 worker,然后直接调用该 worker 绑定的 `RolloutWorkerGenerator` actor。controller 不在数据生成路径上。
+
+```mermaid
+flowchart LR
+ AgentLoop[AgentLoop] --> Handle[RolloutGenerateHandle
kind=local]
+ Handle --> LocalGen[LocalRolloutGenerator]
+ LocalGen --> Selector[SessionWorkerSelector]
+ Selector --> WorkerHandle[RolloutWorkerHandle]
+ WorkerHandle --> GenActor[RolloutWorkerGenerator actor]
+ GenActor --> Backend[worker backend_url]
+```
+
+参与类:
+
+- `RolloutGenerateHandle`
+- `LocalRolloutGenerator`
+- `SessionWorkerSelector`
+- `RolloutWorkerHandle`
+- `RolloutWorkerGenerator`
+
+### 2. 内部 router,直接走 backend:`kind="http", http_entry="internal", http_worker_url_source="backend"`
+
+XTuner 启动 `InternalRolloutHttpEntry`。agentloop 只看到一个 internal router base URL。router 按 session 选择 worker,并把请求转发到 worker 的 `backend_url`。这条路径不经过 `SessionServerActor`。
+
+```mermaid
+flowchart LR
+ AgentLoop[AgentLoop] --> Handle[RolloutGenerateHandle
kind=http]
+ Handle --> RouterURL[internal_http_entry_url]
+ RouterURL --> InternalRouter[InternalRolloutHttpEntry]
+ InternalRouter --> Selector[SessionWorkerSelector]
+ Selector --> WorkerHandle[RolloutWorkerHandle]
+ WorkerHandle --> Backend[worker backend_url]
+```
+
+参与类:
+
+- `RolloutGenerateHandle`
+- `InternalRolloutHttpEntry`
+- `SessionWorkerSelector`
+- `RolloutWorkerHandle`
+
+### 3. 内部 router,走 SessionServer:`kind="http", http_entry="internal", http_worker_url_source="session"`
+
+XTuner 仍然启动 `InternalRolloutHttpEntry`,但 router 选择 worker 后转发到 `session_server_url`。router 会把选出的 `session_id` 写入请求体,满足 `SessionServer` 的 session/cache/trace 逻辑。
+
+```mermaid
+flowchart LR
+ AgentLoop[AgentLoop] --> Handle[RolloutGenerateHandle
kind=http]
+ Handle --> RouterURL[internal_http_entry_url]
+ RouterURL --> InternalRouter[InternalRolloutHttpEntry]
+ InternalRouter --> Selector[SessionWorkerSelector]
+ Selector --> WorkerHandle[RolloutWorkerHandle]
+ WorkerHandle --> SessionServer[SessionServerActor
session_server_url]
+ SessionServer --> Backend[worker backend_url]
+```
+
+参与类:
+
+- `RolloutGenerateHandle`
+- `InternalRolloutHttpEntry`
+- `SessionWorkerSelector`
+- `RolloutWorkerHandle`
+- `SessionServerActor`
+
+### 4. 外部 router,注册 backend:`kind="http", http_entry="external", http_worker_url_source="backend"`
+
+XTuner 不转发请求,只通过 `ExternalRolloutHttpEntry` 复用 routedapiproxy 注册逻辑,将每个 worker 的 `backend_url` 注册到外部 router。agentloop/sandbox 访问外部 router,由外部 router 转发到 worker backend。
+
+```mermaid
+flowchart LR
+ Entry[ExternalRolloutHttpEntry] --> WorkerHandle[RolloutWorkerHandle]
+ WorkerHandle --> Register[register backend_url
to external router]
+
+ AgentLoop[AgentLoop or Sandbox] --> Handle[RolloutGenerateHandle
kind=http]
+ Handle --> ExternalURL[external router base_url]
+ ExternalURL --> ExternalRouter[External router]
+ ExternalRouter --> Backend[worker backend_url]
+```
+
+参与类:
+
+- `RolloutGenerateHandle`
+- `ExternalRolloutHttpEntry`
+- `RolloutWorkerHandle`
+- 外部 router 服务
+
+### 5. 外部 router,注册 SessionServer:`kind="http", http_entry="external", http_worker_url_source="session"`
+
+XTuner 通过 `ExternalRolloutHttpEntry` 复用 routedapiproxy 注册逻辑,将每个 worker 的 `session_server_url` 注册到外部 router。agentloop/sandbox 访问外部 router,由外部 router 转发到 SessionServer,再到 worker backend。该模式要求调用方或外部 router 保留/注入 `session_id`。
+
+```mermaid
+flowchart LR
+ Entry[ExternalRolloutHttpEntry] --> WorkerHandle[RolloutWorkerHandle]
+ WorkerHandle --> Register[register session_server_url
to external router]
+
+ AgentLoop[AgentLoop or Sandbox] --> Handle[RolloutGenerateHandle
kind=http]
+ Handle --> ExternalURL[external router base_url]
+ ExternalURL --> ExternalRouter[External router]
+ ExternalRouter --> SessionServer[SessionServerActor
session_server_url]
+ SessionServer --> Backend[worker backend_url]
+```
+
+参与类:
+
+- `RolloutGenerateHandle`
+- `ExternalRolloutHttpEntry`
+- `RolloutWorkerHandle`
+- `SessionServerActor`
+- 外部 router 服务
+
+## 非目标
+
+本轮暂不处理:
+
+- 健康检测、失败恢复、自动重启
+- worker/router 的复杂容错状态同步
+- CI 和旧单测兼容修复
+- SGLang 和 vLLM 分支的逻辑正确性
+- 外部 router 的完整协议细节
+- gateway 内部的所有东西都可以忽略
+
+## 当前判断
+
+本轮重构应先完成纯净职责拆分:
+
+1. 删除健康检测和失败恢复。
+2. 将 controller/worker 的 generate 完全外移。
+3. 明确独立生成模块的接口和运行模式。
+4. 再评估 worker 侧生成模块采用普通类、Ray actor,还是普通类加重操作外包的混合方案。
diff --git a/tests/rl/test_rollout_utils.py b/tests/rl/test_rollout_utils.py
index 7a4995acc0..4812f1ecb7 100644
--- a/tests/rl/test_rollout_utils.py
+++ b/tests/rl/test_rollout_utils.py
@@ -1,26 +1,26 @@
import ray
import torch
import threading
-import time
import unittest
import os
import tempfile
+from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch
-from xtuner.v1.data_proto.rl_data import Status, RolloutState, SampleParams
+from xtuner.v1.data_proto.rl_data import Status, RolloutState
from xtuner.v1.rl.rollout.worker import RolloutConfig
-from xtuner.v1.rl.rollout.controller import RolloutController, WorkerInfo
+from xtuner.v1.rl.rollout.controller import RolloutController
+from xtuner.v1.rl.rollout.health_manager import RolloutHealthManager, RolloutWorkerRouteInfo
+from xtuner.v1.rl.rollout.runtime import build_rollout_runtime
from xtuner.v1.rl.rollout.utils import (
PartialRolloutHandler,
- RolloutHealthChecker,
SessionRouter,
)
-from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers, asyncio_run
+from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers
MODEL_PATH = os.environ.get("ROLLOUT_MODEL_PATH", "")
RESOURCE_MAP = {"npu": "NPU", "cuda": "GPU"}
-TEST_TEXT_MESSAGES=[{"role": "user", "content": "Hello!"}]
class _FakeRemoteMethod:
@@ -40,55 +40,53 @@ def __init__(self):
self.shutdown = _FakeRemoteMethod("shutdown", self.call_log)
-class TestRolloutHealthChecker(unittest.TestCase):
- def _build_checker(self, workers_info):
- config = SimpleNamespace(health_check_interval_seconds=10, health_check_failure_threshold=1)
- return RolloutHealthChecker(config, workers_info)
-
- def test_shutdown_runs_when_offload_fails(self):
- worker = _FakeWorker()
- workers_info = {0: SimpleNamespace(actor=worker, url="http://worker-0", is_active=True)}
- checker = self._build_checker(workers_info)
-
- async def unhealthy_worker(*args, **kwargs):
- return False
-
- def ray_get(ref, timeout=None):
- worker.call_log.append((ref, "get"))
- if ref == "offload":
- raise RuntimeError("offload failed")
- return None
+class _FakeGenerator:
+ def __init__(self):
+ self.call_log = []
+ self.ping = _FakeRemoteMethod("ping", self.call_log)
- with (
- patch("xtuner.v1.rl.rollout.utils.check_worker_health", side_effect=unhealthy_worker),
- patch("xtuner.v1.rl.rollout.utils.ray.get", side_effect=ray_get),
- ):
- checker.run_once()
- self.assertFalse(workers_info[0].is_active)
- self.assertEqual(
- worker.call_log,
+class TestRolloutHealthManager(unittest.TestCase):
+ def _build_manager(self, worker):
+ config = SimpleNamespace(
+ health_check_interval_seconds=10,
+ health_check_failure_threshold=1,
+ worker_log_dir=Path("."),
+ )
+ manager = RolloutHealthManager(
+ config,
[
- ("offload", "remote"),
- ("offload", "get"),
- ("shutdown", "remote"),
- ("shutdown", "get"),
+ RolloutWorkerRouteInfo(
+ rank=0,
+ actor=worker,
+ url="http://worker-0",
+ generator=_FakeGenerator(),
+ is_active=True,
+ )
],
)
+ manager.stop()
+ return manager
- def test_inactive_worker_is_not_cleaned_up_again(self):
+ def test_run_once_does_not_probe_or_deactivate_workers(self):
worker = _FakeWorker()
- workers_info = {0: SimpleNamespace(actor=worker, url="http://worker-0", is_active=False)}
- checker = self._build_checker(workers_info)
+ manager = self._build_manager(worker)
- with (
- patch("xtuner.v1.rl.rollout.utils.check_worker_health") as check_worker_health_mock,
- patch("xtuner.v1.rl.rollout.utils.ray.get") as ray_get_mock,
- ):
- checker.run_once()
+ with patch("xtuner.v1.rl.rollout.health_manager.ray.get", side_effect=ray_get):
+ manager.run_once()
+
+ self.assertTrue(manager.rank2info[0].is_active)
+ self.assertEqual(worker.call_log, [])
+
+ def test_report_worker_failure_is_registry_only_noop(self):
+ worker = _FakeWorker()
+ manager = self._build_manager(worker)
+
+ with patch("xtuner.v1.rl.rollout.health_manager.ray.get") as ray_get_mock:
+ manager.report_worker_failure(0, "request failed")
- check_worker_health_mock.assert_not_called()
ray_get_mock.assert_not_called()
+ self.assertTrue(manager.rank2info[0].is_active)
self.assertEqual(worker.call_log, [])
@@ -179,39 +177,36 @@ def init_rollout_controller(self):
health_check_interval_seconds=10,
health_check_failure_threshold=1,
)
- controller = RolloutController(rollout_cfg, pg)
+ runtime = build_rollout_runtime(rollout_cfg, pg)
+ controller = RolloutController(rollout_cfg, runtime)
return controller
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
- def test_healthcheck_deactivate_and_recover(self):
+ def test_registry_does_not_deactivate_or_recover_workers(self):
controller = self.init_rollout_controller()
ranks = list(controller.rank2info.keys())
rank0 = ranks[0]
actor0 = controller.rank2info[rank0].actor
ray.get(actor0.shutdown.remote())
- time.sleep(3) # wait for the actor to be fully killed
health_before_recover = ray.get(actor0.check_health.remote())
url = controller.rank2info[rank0].url
self.assertFalse(health_before_recover)
- controller.health_checker.run_once()
+ ray.get(controller.health_manager.run_once.remote())
- self.assertFalse(controller.rank2info[rank0].is_active)
- rollout_state = RolloutState(
- message=TEST_TEXT_MESSAGES,
- sample_params=SampleParams(return_token_ids=True),
- )
- out = asyncio_run(controller.generate(rollout_state))
- self.assertEqual(out.status, Status.FAILED)
+ ready, details = ray.get(controller.health_manager.get_ready_status.remote())
+ self.assertTrue(details["worker_routes"][rank0]["is_active"])
+ self.assertTrue(ready)
controller.recover_failed_workers()
- self.assertTrue(controller.rank2info[rank0].is_active)
- self.assertEqual(url, controller.rank2info[rank0].url)
+ ready, details = ray.get(controller.health_manager.get_ready_status.remote())
+ self.assertTrue(details["worker_routes"][rank0]["is_active"])
+ self.assertTrue(ready)
+ route_infos = ray.get(controller.health_manager.get_worker_route_infos.remote())
+ self.assertEqual(url, route_infos[0].url)
health_after_recover = ray.get(actor0.check_health.remote())
- self.assertTrue(health_after_recover)
- out = asyncio_run(controller.generate(rollout_state))
- self.assertNotEqual(out.status, Status.FAILED)
+ self.assertFalse(health_after_recover)
if __name__ == "__main__":
diff --git a/tests/rl/test_rollout_worker.py b/tests/rl/test_rollout_worker.py
index 5057e09715..ab07822d4b 100644
--- a/tests/rl/test_rollout_worker.py
+++ b/tests/rl/test_rollout_worker.py
@@ -4,6 +4,7 @@
from unittest.mock import AsyncMock, MagicMock, patch
from xtuner.v1.data_proto.rl_data import Status
+from xtuner.v1.rl.rollout.generation import RolloutGenerationService
from xtuner.v1.rl.rollout.lmdeploy import LMDeployWorker
from xtuner.v1.rl.rollout.sglang import SGLangWorker
from xtuner.v1.rl.rollout.worker import RolloutWorker
@@ -37,19 +38,22 @@ def test_continue_generation_clears_abort_flag(self):
worker._make_request.assert_called_once_with("continue_generation")
-class TestRolloutWorker(unittest.IsolatedAsyncioTestCase):
+class TestRolloutGenerationService(unittest.IsolatedAsyncioTestCase):
async def test_generate_returns_aborted_when_abort_flag_is_set(self):
- worker = RolloutWorker.__new__(RolloutWorker)
- worker.receive_abort_request = threading.Event()
- worker.receive_abort_request.set()
+ service = RolloutGenerationService.__new__(RolloutGenerationService)
+ service.receive_abort_request = threading.Event()
+ service.receive_abort_request.set()
rollout_state = MagicMock()
- result = await worker.generate(rollout_state)
+ result = await service.generate(rollout_state)
self.assertIs(result, rollout_state)
self.assertEqual(rollout_state.finish_reason, "abort")
self.assertEqual(rollout_state.status, Status.ABORTED)
+
+class TestRolloutWorker(unittest.IsolatedAsyncioTestCase):
+
async def test_pause_generation_sets_abort_flag(self):
worker = RolloutWorker.__new__(RolloutWorker)
worker.receive_abort_request = threading.Event()
@@ -115,9 +119,9 @@ async def test_lmdeploy_cleanup_after_pause_skips_without_routed_experts(self):
get_actor.assert_not_called()
async def test_safe_post_request_returns_aborted_on_cancellation(self):
- worker = RolloutWorker.__new__(RolloutWorker)
- worker.receive_abort_request = threading.Event()
- worker.logger = MagicMock()
+ service = RolloutGenerationService.__new__(RolloutGenerationService)
+ service.receive_abort_request = threading.Event()
+ service.logger = MagicMock()
send_started = asyncio.Event()
send_cancelled = asyncio.Event()
@@ -133,23 +137,23 @@ async def send(self, req):
send_cancelled.set()
raise
- worker.client = _Client()
+ service.client = _Client()
- task = asyncio.create_task(worker._safe_post_request("http://test", headers={}, payload={"input_ids": [1]}))
+ task = asyncio.create_task(service._safe_post_request("http://test", headers={}, payload={"input_ids": [1]}))
await send_started.wait()
task.cancel()
result = await task
self.assertEqual(result.error_type, HttpRequestErrorType.REQUEST_ABORTED)
- self.assertTrue(worker.receive_abort_request.is_set())
+ self.assertTrue(service.receive_abort_request.is_set())
self.assertTrue(send_cancelled.is_set())
async def test_safe_post_request_cancels_inflight_request_after_abort_timeout(self):
- worker = RolloutWorker.__new__(RolloutWorker)
- worker.receive_abort_request = threading.Event()
- worker.abort_timeout = 0.01
- worker.logger = MagicMock()
+ service = RolloutGenerationService.__new__(RolloutGenerationService)
+ service.receive_abort_request = threading.Event()
+ service.abort_timeout = 0.01
+ service.logger = MagicMock()
send_started = asyncio.Event()
send_cancelled = asyncio.Event()
@@ -165,11 +169,11 @@ async def send(self, req):
send_cancelled.set()
raise
- worker.client = _Client()
+ service.client = _Client()
- task = asyncio.create_task(worker._safe_post_request("http://test", headers={}, payload={"input_ids": [1]}))
+ task = asyncio.create_task(service._safe_post_request("http://test", headers={}, payload={"input_ids": [1]}))
await send_started.wait()
- worker.receive_abort_request.set()
+ service.receive_abort_request.set()
result = await task
@@ -177,10 +181,10 @@ async def send(self, req):
self.assertTrue(send_cancelled.is_set())
async def test_safe_post_request_keeps_abort_response_within_timeout(self):
- worker = RolloutWorker.__new__(RolloutWorker)
- worker.receive_abort_request = threading.Event()
- worker.abort_timeout = 1.0
- worker.logger = MagicMock()
+ service = RolloutGenerationService.__new__(RolloutGenerationService)
+ service.receive_abort_request = threading.Event()
+ service.abort_timeout = 1.0
+ service.logger = MagicMock()
send_started = asyncio.Event()
finish_send = asyncio.Event()
@@ -199,11 +203,11 @@ async def send(self, req):
await finish_send.wait()
return response
- worker.client = _Client()
+ service.client = _Client()
- task = asyncio.create_task(worker._safe_post_request("http://test", headers={}, payload={"input_ids": [1]}))
+ task = asyncio.create_task(service._safe_post_request("http://test", headers={}, payload={"input_ids": [1]}))
await send_started.wait()
- worker.receive_abort_request.set()
+ service.receive_abort_request.set()
finish_send.set()
result = await task
diff --git a/xtuner/v1/rl/agent_loop/agent_loop.py b/xtuner/v1/rl/agent_loop/agent_loop.py
index 4ae6000e82..06ec1f8b8a 100644
--- a/xtuner/v1/rl/agent_loop/agent_loop.py
+++ b/xtuner/v1/rl/agent_loop/agent_loop.py
@@ -10,7 +10,7 @@
from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams
from xtuner.v1.rl.judger import Judger
-from xtuner.v1.rl.rollout import RolloutController
+from xtuner.v1.rl.rollout import RolloutController, RolloutGenerateHandle, RolloutGenerateHandleConfig
from xtuner.v1.rl.utils import (
CPUActorLauncher,
CPUResourcesConfig,
@@ -27,10 +27,21 @@ class AgentLoopConfig(ABC, BaseModel):
sample_params: SampleParams | None = None
cpu_resources: CPUResourcesConfig | None = None
- def build(self, rollout_controller, judger: Judger | None = None, logger=None) -> AgentLoopSpec:
+ def build(
+ self,
+ rollout_controller=None,
+ rollout_generator: RolloutGenerateHandle | None = None,
+ judger: Judger | None = None,
+ logger=None,
+ ) -> AgentLoopSpec:
+ if rollout_generator is None:
+ if rollout_controller is None:
+ raise ValueError("Either rollout_controller or rollout_generator must be provided.")
+ rollout_generator = RolloutGenerateHandleConfig().build(rollout_controller)
if self.cpu_resources is None:
return self.build_local(
rollout_controller=rollout_controller,
+ rollout_generator=rollout_generator,
judger=judger,
logger=logger,
)
@@ -43,12 +54,14 @@ def build(self, rollout_controller, judger: Judger | None = None, logger=None) -
if self.cpu_resources.num_workers > 1:
return self._build_router(
rollout_controller=rollout_controller,
+ rollout_generator=rollout_generator,
cpu_resources=self.cpu_resources,
judger=judger,
logger=logger,
)
return self._build_ray_actor(
rollout_controller=rollout_controller,
+ rollout_generator=rollout_generator,
cpu_resources=self.cpu_resources,
judger=judger,
logger=logger,
@@ -58,6 +71,7 @@ def build(self, rollout_controller, judger: Judger | None = None, logger=None) -
def build_local(
self,
rollout_controller,
+ rollout_generator: RolloutGenerateHandle | None = None,
judger: Judger | None = None,
logger=None,
) -> AgentLoop: ...
@@ -65,6 +79,7 @@ def build_local(
def _build_ray_actor(
self,
rollout_controller: RolloutController,
+ rollout_generator: RolloutGenerateHandle,
cpu_resources: CPUResourcesConfig,
pg: PlacementGroup | None = None,
judger: Judger | None = None,
@@ -76,6 +91,7 @@ def _build_ray_actor(
AgentLoopActor,
self,
rollout_controller,
+ rollout_generator,
judger,
pg=pg,
bundle_idx=0,
@@ -88,6 +104,7 @@ def _build_ray_actor(
def _build_ray_actors(
self,
rollout_controller: RolloutController,
+ rollout_generator: RolloutGenerateHandle,
cpu_resources: CPUResourcesConfig,
pg: PlacementGroup | None = None,
judger: Judger | None = None,
@@ -100,6 +117,7 @@ def _build_ray_actors(
AgentLoopActor,
self,
rollout_controller,
+ rollout_generator,
judger,
pg=pg,
start_bundle_idx=start_bundle_idx,
@@ -113,6 +131,7 @@ def _build_ray_actors(
def _build_router(
self,
rollout_controller: RolloutController,
+ rollout_generator: RolloutGenerateHandle,
cpu_resources: CPUResourcesConfig,
pg: PlacementGroup | None = None,
judger: Judger | None = None,
@@ -122,6 +141,7 @@ def _build_router(
return RouterAgentLoop(
workers=self._build_ray_actors(
rollout_controller=rollout_controller,
+ rollout_generator=rollout_generator,
cpu_resources=cpu_resources,
pg=pg,
judger=judger,
@@ -136,12 +156,14 @@ class AgentLoop(ABC):
def __init__(
self,
rollout_ctl: RolloutController | None,
+ rollout_generator: RolloutGenerateHandle | None,
sample_params: SampleParams | None,
hf_checkpoint: str,
judger: Judger | None = None,
logger=None,
- ) -> None:
+ ) -> None:
self.rollout_ctl = rollout_ctl
+ self.rollout_generator = rollout_generator
self.hf_checkpoint = hf_checkpoint
self.tokenizer = load_tokenizer(hf_checkpoint, trust_remote_code=True)
self.processor = load_processor(hf_checkpoint, trust_remote_code=True)
@@ -155,6 +177,24 @@ def __init__(
@abstractmethod
async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: ...
+ async def rollout_generate(self, rollout_state: RolloutState, *, enable_partial_rollout: bool = False) -> RolloutState:
+ if self.rollout_generator is None:
+ raise RuntimeError("AgentLoop requires rollout_generator; RolloutController.generate has been removed.")
+ if self.rollout_generator.kind == "local":
+ return await self.rollout_generator.require_local_generator().generate(
+ rollout_state,
+ enable_partial_rollout=enable_partial_rollout,
+ )
+ return await self.rollout_generate_from_url(
+ rollout_state=rollout_state,
+ base_url=self.rollout_generator.require_base_url(),
+ )
+
+ async def rollout_generate_from_url(self, rollout_state: RolloutState, base_url: str) -> RolloutState:
+ raise NotImplementedError(
+ f"{type(self).__name__} does not implement URL rollout generation for endpoint at {base_url!r}."
+ )
+
async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]:
pending_tasks = []
for state in rollout_state:
@@ -221,11 +261,13 @@ def __init__(
self,
agent_loop_config: AgentLoopConfig,
rollout_controller: RolloutController,
+ rollout_generator: RolloutGenerateHandle,
judger: Judger | None = None,
logger=None,
):
self.agent_loop = agent_loop_config.build_local(
rollout_controller=rollout_controller,
+ rollout_generator=rollout_generator,
judger=judger,
logger=logger,
)
diff --git a/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py b/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py
index feb1a7c9ce..38c868f6df 100644
--- a/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py
+++ b/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py
@@ -8,7 +8,7 @@
from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams
from xtuner.v1.rl.agent_loop import AgentLoop, AgentLoopConfig
from xtuner.v1.rl.judger import Judger
-from xtuner.v1.rl.rollout import RolloutController
+from xtuner.v1.rl.rollout import RolloutController, RolloutGenerateHandle
from xtuner.v1.utils import get_logger
@@ -18,10 +18,17 @@
class GSM8KToolAgentLoopConfig(AgentLoopConfig):
max_turns: int
- def build_local(self, rollout_controller, judger: Judger | None = None, logger=None) -> "GSM8KToolAgentLoop":
+ def build_local(
+ self,
+ rollout_controller,
+ rollout_generator: RolloutGenerateHandle | None = None,
+ judger: Judger | None = None,
+ logger=None,
+ ) -> "GSM8KToolAgentLoop":
return GSM8KToolAgentLoop(
max_turns=self.max_turns,
rollout_ctl=rollout_controller,
+ rollout_generator=rollout_generator,
hf_checkpoint=self.hf_checkpoint,
sample_params=self.sample_params,
judger=judger,
@@ -40,12 +47,17 @@ def __init__(
self,
max_turns: int,
rollout_ctl: RolloutController,
+ rollout_generator: RolloutGenerateHandle | None,
hf_checkpoint: str,
sample_params: SampleParams,
judger: Judger | None = None,
):
super().__init__(
- rollout_ctl=rollout_ctl, hf_checkpoint=hf_checkpoint, sample_params=sample_params, judger=judger
+ rollout_ctl=rollout_ctl,
+ rollout_generator=rollout_generator,
+ hf_checkpoint=hf_checkpoint,
+ sample_params=sample_params,
+ judger=judger,
)
self.max_turns = max_turns
self.tool_call_pattern = re.compile(r"\n*(.*?)", re.DOTALL)
@@ -99,7 +111,10 @@ async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> Rollou
rollout_state.sample_params = copy.deepcopy(base_sample_params)
rollout_state.sample_params.max_tokens = remaining_max_tokens
- rollout_state = await self.rollout_ctl.generate.remote(rollout_state) # type: ignore[attr-defined]
+ rollout_state = await self.rollout_generate(
+ rollout_state,
+ enable_partial_rollout=kwargs.get("enable_partial_rollout", False),
+ )
cur_turn += 1
response_ids = cast(list[int], rollout_state.response_ids)
cur_turn_tokens.extend(response_ids)
diff --git a/xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py b/xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py
index a4db6e80c2..d503217717 100644
--- a/xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py
+++ b/xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py
@@ -11,7 +11,7 @@
from xtuner.v1.data_proto.rl_data import RolloutState, Status
from xtuner.v1.rl.judger import Judger
-from xtuner.v1.rl.rollout import RolloutController
+from xtuner.v1.rl.rollout import RolloutController, RolloutGenerateHandle
from xtuner.v1.rl.utils import create_task
from ..agent_loop import AgentLoop, AgentLoopConfig
@@ -54,9 +54,16 @@ class AgentInSandboxLoopConfig(AgentLoopConfig):
"""
max_concurrent_samples: int | None = None
- def build_local(self, rollout_controller: RolloutController | None = None, judger: Judger | None = None, logger=None) -> "AgentInSandboxLoop":
+ def build_local(
+ self,
+ rollout_controller: RolloutController | None = None,
+ rollout_generator: RolloutGenerateHandle | None = None,
+ judger: Judger | None = None,
+ logger=None,
+ ) -> "AgentInSandboxLoop":
return AgentInSandboxLoop(
rollout_ctl=rollout_controller,
+ rollout_generator=rollout_generator,
hf_checkpoint=self.hf_checkpoint,
judger=judger,
logger=logger,
@@ -68,12 +75,13 @@ class AgentInSandboxLoop(AgentLoop):
def __init__(
self,
rollout_ctl: RolloutController | None = None,
+ rollout_generator: RolloutGenerateHandle | None = None,
hf_checkpoint: str = None,
judger: Judger | None = None,
logger=None,
max_concurrent_samples: int | None = None,
):
- super().__init__(rollout_ctl, None, hf_checkpoint, judger, logger)
+ super().__init__(rollout_ctl, rollout_generator, None, hf_checkpoint, judger, logger)
self.max_concurrent_samples = max_concurrent_samples
self._sample_semaphore = asyncio.Semaphore(max_concurrent_samples) if max_concurrent_samples else None
diff --git a/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py
index 40edc2dca7..9e550607a8 100644
--- a/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py
+++ b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py
@@ -2,7 +2,7 @@
from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status
from xtuner.v1.rl.judger import Judger
-from xtuner.v1.rl.rollout import RolloutController
+from xtuner.v1.rl.rollout import RolloutController, RolloutGenerateHandle
from xtuner.v1.rl.utils import create_task
from .agent_loop import AgentLoop, AgentLoopConfig
@@ -39,9 +39,16 @@ class SingleTurnAgentLoopConfig(AgentLoopConfig):
enable_batch_judge: bool = False
- def build_local(self, rollout_controller, judger: Judger | None = None, logger=None) -> "SingleTurnAgentLoop":
+ def build_local(
+ self,
+ rollout_controller,
+ rollout_generator: RolloutGenerateHandle | None = None,
+ judger: Judger | None = None,
+ logger=None,
+ ) -> "SingleTurnAgentLoop":
return SingleTurnAgentLoop(
rollout_ctl=rollout_controller,
+ rollout_generator=rollout_generator,
sample_params=self.sample_params,
hf_checkpoint=self.hf_checkpoint,
judger=judger,
@@ -54,13 +61,14 @@ class SingleTurnAgentLoop(AgentLoop):
def __init__(
self,
rollout_ctl: RolloutController,
+ rollout_generator: RolloutGenerateHandle | None,
sample_params: SampleParams,
hf_checkpoint: str,
judger: Judger | None = None,
logger=None,
enable_batch_judge: bool = False,
):
- super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger)
+ super().__init__(rollout_ctl, rollout_generator, sample_params, hf_checkpoint, judger, logger)
self.enable_batch_judge = enable_batch_judge
async def generate_sample(
@@ -72,7 +80,10 @@ async def generate_sample(
rollout_state.tokens = rollout_state.prompt_ids
# 推理引擎generate, 生成的结果会覆盖到 rollout_state.response_ids 上
- rollout_state = await self.rollout_ctl.generate.remote(rollout_state) # type: ignore[attr-defined]
+ rollout_state = await self.rollout_generate(
+ rollout_state,
+ enable_partial_rollout=kwargs.get("enable_partial_rollout", False),
+ )
# 非 COMPLETED 状态(如被截断、放弃等)直接早退,不触发打分
if rollout_state.status != Status.COMPLETED:
return rollout_state
diff --git a/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py b/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py
index 7663ecc85e..54b7caeeab 100644
--- a/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py
+++ b/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py
@@ -13,7 +13,7 @@
from xtuner.v1.rl.agent_loop import AgentLoopConfig, AgentLoopSpec, get_agent_loop_rollout_ctl
from xtuner.v1.rl.judger import ComposedJudgerConfig, JudgerConfig, build_judger
from xtuner.v1.rl.replay_buffer import ReplayBuffer
-from xtuner.v1.rl.rollout import RolloutController
+from xtuner.v1.rl.rollout import RolloutController, RolloutGenerateHandle
from xtuner.v1.rl.utils import asyncio_run
from xtuner.v1.utils import get_logger
@@ -300,6 +300,7 @@ def build(
replay_buffer: ReplayBuffer,
logger=None,
sync_weights_interval: int = 1,
+ rollout_generator: RolloutGenerateHandle | None = None,
) -> "AgentLoopManager":
tasks = self.tasks if isinstance(self.tasks, list) else [self.tasks]
if not tasks:
@@ -314,6 +315,7 @@ def build(
agent_loop = task_cfg.agent_loop_config.build(
rollout_controller=rollout_controller,
+ rollout_generator=rollout_generator,
judger=build_judger(task_cfg.judger_config) if task_cfg.judger_config is not None else None,
logger=logger,
)
diff --git a/xtuner/v1/rl/agent_loop_manager/producer.py b/xtuner/v1/rl/agent_loop_manager/producer.py
index 35db08f27d..7dc34d5aa7 100644
--- a/xtuner/v1/rl/agent_loop_manager/producer.py
+++ b/xtuner/v1/rl/agent_loop_manager/producer.py
@@ -502,10 +502,6 @@ def build(
sync_weights_interval: int = 1,
rollout_controller: "Optional[RolloutControllerProxy]" = None,
) -> "AsyncProduceStrategy":
- if rollout_controller is not None:
- import ray
-
- ray.get(rollout_controller.set_enable_partial_rollout.remote(self.enable_partial_rollout))
return AsyncProduceStrategy(
over_sample_threshold=self.over_sample_threshold,
enable_partial_rollout=self.enable_partial_rollout,
diff --git a/xtuner/v1/rl/gateway/backend/local_backend.py b/xtuner/v1/rl/gateway/backend/local_backend.py
index ae8e773cb9..af22badef3 100644
--- a/xtuner/v1/rl/gateway/backend/local_backend.py
+++ b/xtuner/v1/rl/gateway/backend/local_backend.py
@@ -11,6 +11,7 @@
from xtuner.v1.data_proto.rl_data import RolloutState, RolloutToolCall, SampleParams, Status
from xtuner.v1.rl.rollout.parser.factory import build_tool_call_parser
from xtuner.v1.rl.rollout.worker import RolloutConfig
+from xtuner.v1.rl.rollout.rollout_generator import LocalRolloutGenerator, LocalRolloutGeneratorConfig
from ..adapters.base import coerce_content_to_text
from ..adapters.trace import normalize_trace_payload
@@ -44,6 +45,7 @@ def __init__(
):
self._controller = controller
self._config = rollout_config or self._resolve_rollout_config(controller)
+ self._generator: LocalRolloutGenerator = LocalRolloutGeneratorConfig().build(controller)
if isinstance(tokenizer, str):
tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
resolved_tokenizer = tokenizer
@@ -57,12 +59,12 @@ def __init__(
async def generate(self, request: CanonicalGenerateRequest) -> CanonicalGenerateResponse:
rollout_state = self._canonical_request_to_rollout_state(request)
- rollout_state = await self._controller.generate.remote(rollout_state)
+ rollout_state = await self._generator.generate(rollout_state)
self._raise_for_failed_rollout(rollout_state, request_id=str(rollout_state.uid))
return self._rollout_state_to_canonical_response(rollout_state, request)
async def health(self) -> BackendHealth:
- ready, details = await self._controller.get_ready_status.remote()
+ ready, details = await self._controller.get_runtime_status.remote()
return BackendHealth(
ready=ready,
status="ready" if ready else "unavailable",
diff --git a/xtuner/v1/rl/rollout/__init__.py b/xtuner/v1/rl/rollout/__init__.py
index dd73324c55..b669bd7afe 100644
--- a/xtuner/v1/rl/rollout/__init__.py
+++ b/xtuner/v1/rl/rollout/__init__.py
@@ -1,6 +1,21 @@
import os
+from ._generation.external_http_entry import ExternalRolloutHttpEntry, ExternalRolloutHttpEntryConfig
+from ._generation.internal_http_entry import (
+ InternalRolloutHttpEntry,
+ InternalRolloutHttpEntryConfig,
+ build_internal_rollout_http_entry_app,
+ serve_internal_rollout_http_entry_in_thread,
+)
+from ._generation.session_worker_selector import RolloutWorkerHandle, SessionWorkerSelector
from .controller import RolloutController
+from .rollout_generator import (
+ LocalRolloutGenerator,
+ LocalRolloutGeneratorConfig,
+ RolloutGenerateHandle,
+ RolloutGenerateHandleConfig,
+)
+from .rollout_worker_build import RolloutRuntime, RolloutWorkerBuilder, RolloutWorkerRuntime, build_rollout_runtime
from .worker import RolloutWorker
diff --git a/xtuner/v1/rl/rollout/_generation/__init__.py b/xtuner/v1/rl/rollout/_generation/__init__.py
new file mode 100644
index 0000000000..0e345c2f2d
--- /dev/null
+++ b/xtuner/v1/rl/rollout/_generation/__init__.py
@@ -0,0 +1 @@
+"""Internal rollout generation helpers."""
diff --git a/xtuner/v1/rl/rollout/_generation/external_http_entry.py b/xtuner/v1/rl/rollout/_generation/external_http_entry.py
new file mode 100644
index 0000000000..2d5170c5cd
--- /dev/null
+++ b/xtuner/v1/rl/rollout/_generation/external_http_entry.py
@@ -0,0 +1,74 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+from xtuner.v1.rl.utils.misc import check_chat_completions, delete_from_routedapiproxy, register_to_routedapiproxy
+from xtuner.v1.utils import get_logger
+
+from .session_worker_selector import RolloutWorkerHandle, RolloutWorkerUrlSource
+from ..worker import RolloutConfig
+
+
+@dataclass
+class ExternalRolloutHttpEntryConfig:
+ """Control-plane config for an external routedapiproxy HTTP entry.
+
+ This class does not proxy user requests. It registers rollout worker
+ generation URLs into an externally managed router, then AgentLoop uses
+ ``base_url`` as the HTTP generation entry.
+ """
+
+ # The external router URL used by AgentLoop for generation.
+ base_url: str
+
+ # Which per-worker URL should be registered into the external router.
+ worker_url_source: RolloutWorkerUrlSource = "backend"
+
+ # Existing routedapiproxy semantics: delete old model registration before
+ # registering the current worker URLs.
+ delete_existing: bool = True
+
+ # Optional smoke checks. They are useful for debug rollout, but can be
+ # disabled when registration happens before the external router is ready.
+ check_worker_urls: bool = True
+ check_base_url: bool = True
+
+
+class ExternalRolloutHttpEntry:
+ def __init__(
+ self,
+ worker_handles: list[RolloutWorkerHandle],
+ rollout_config: RolloutConfig,
+ config: ExternalRolloutHttpEntryConfig,
+ *,
+ log_dir: str | None = None,
+ ) -> None:
+ self.worker_handles = worker_handles
+ self.rollout_config = rollout_config
+ self.config = config
+ self.logger = get_logger(log_dir=log_dir, tag="ExternalRolloutHttpEntry")
+ self._registered_urls: list[str] = []
+
+ def start(self) -> None:
+ model_name = self.rollout_config.model_name
+ if self.config.delete_existing:
+ delete_from_routedapiproxy(model_name)
+ self.logger.info(f"Deleted existing routedapiproxy registrations for model {model_name}.")
+
+ self.logger.info("Registering rollout worker URLs to routedapiproxy.")
+ for worker in sorted(self.worker_handles, key=lambda item: item.rank):
+ worker_url = worker.get_generate_url(self.config.worker_url_source)
+ register_to_routedapiproxy(model_name, worker_url)
+ self._registered_urls.append(worker_url)
+ self.logger.info(f"Registered rollout worker {worker.rank} to routedapiproxy: {worker_url}")
+
+ if self.config.check_worker_urls and not check_chat_completions(worker_url, model_name):
+ raise RuntimeError(f"check chat completions failed for rollout worker URL {worker_url}")
+
+ if self.config.check_base_url and not check_chat_completions(self.config.base_url, model_name):
+ raise RuntimeError(f"check chat completions failed for external router URL {self.config.base_url}")
+
+ self.logger.info("Registered rollout worker URLs to routedapiproxy.")
+
+ def stop(self) -> None:
+ self._registered_urls.clear()
diff --git a/xtuner/v1/rl/rollout/_generation/internal_http_entry.py b/xtuner/v1/rl/rollout/_generation/internal_http_entry.py
new file mode 100644
index 0000000000..73b02d2d38
--- /dev/null
+++ b/xtuner/v1/rl/rollout/_generation/internal_http_entry.py
@@ -0,0 +1,215 @@
+from __future__ import annotations
+
+import json
+import hashlib
+import socket
+import threading
+from dataclasses import dataclass
+from typing import Any
+from uuid import uuid4
+
+import httpx
+import uvicorn
+from fastapi import FastAPI, HTTPException, Request
+from fastapi.responses import Response, StreamingResponse
+
+from xtuner.v1.utils import get_logger
+
+from .session_worker_selector import RolloutWorkerHandle, RolloutWorkerUrlSource, SessionWorkerSelector
+from ..worker import RolloutConfig
+
+
+@dataclass
+class InternalRolloutHttpEntryConfig:
+ port: int
+ host: str = "0.0.0.0"
+ title: str = "XTuner Internal Rollout Router"
+ version: str = "0.1.0"
+ log_level: str = "warning"
+ request_timeout: float | None = None
+ stream_timeout: float | None = None
+ worker_url_source: RolloutWorkerUrlSource = "backend"
+
+
+class InternalRolloutHttpEntry:
+ def __init__(
+ self,
+ worker_handles: list[RolloutWorkerHandle],
+ rollout_config: RolloutConfig,
+ config: InternalRolloutHttpEntryConfig,
+ ) -> None:
+ self.worker_selector = SessionWorkerSelector(worker_handles)
+ self.rollout_config = rollout_config
+ self.config = config
+ timeout = config.request_timeout or rollout_config.rollout_timeout
+ self.client = httpx.AsyncClient(timeout=timeout)
+ self.stream_timeout = config.stream_timeout or rollout_config.rollout_timeout
+ self.logger = get_logger(log_dir=rollout_config.worker_log_dir, tag="InternalRolloutHttpEntry")
+
+ async def models(self) -> dict[str, Any]:
+ model_id = self.rollout_config.model_name or "xtuner-rollout"
+ return {
+ "object": "list",
+ "data": [
+ {
+ "id": model_id,
+ "object": "model",
+ "created": 0,
+ "owned_by": "xtuner",
+ }
+ ],
+ }
+
+ async def chat_completions(self, request: Request) -> Response:
+ payload = await request.json()
+ session_id = self._extract_session_id(payload, request)
+ worker = await self.worker_selector.select(session_id)
+ if worker is None:
+ raise HTTPException(status_code=503, detail={"error": "No active rollout worker available."})
+
+ if self.config.worker_url_source == "session":
+ payload.setdefault("session_id", session_id)
+ try:
+ worker_base_url = worker.get_generate_url(self.config.worker_url_source)
+ except (RuntimeError, ValueError) as exc:
+ raise HTTPException(status_code=503, detail={"error": str(exc)}) from exc
+ url = f"{worker_base_url.rstrip('/')}/v1/chat/completions"
+ headers = self._forward_headers(request)
+ if payload.get("stream") is True:
+ return await self._stream_chat_completions(url, payload, headers, worker)
+ return await self._post_chat_completions(url, payload, headers, worker)
+
+ def _extract_session_id(self, payload: dict[str, Any], request: Request) -> int:
+ for header_name in ("x-session-uid", "x-session-id", "x-request-id"):
+ header_value = request.headers.get(header_name)
+ if header_value:
+ return self._stable_int(header_value)
+
+ metadata = payload.get("metadata")
+ if isinstance(metadata, dict):
+ for key in ("session_uid", "session_id", "conversation_id", "thread_id"):
+ if key in metadata and metadata[key] is not None:
+ return self._stable_int(metadata[key])
+
+ for key in ("session_uid", "session_id"):
+ if key in payload and payload[key] is not None:
+ return self._stable_int(payload[key])
+
+ return uuid4().int
+
+ def _stable_int(self, value: Any) -> int:
+ if isinstance(value, int):
+ return value
+ if isinstance(value, str):
+ try:
+ return int(value)
+ except ValueError:
+ return uuid4().int if not value else self._hash_to_int(value)
+ return self._hash_to_int(json.dumps(value, sort_keys=True, default=str))
+
+ def _hash_to_int(self, value: str) -> int:
+ return int.from_bytes(hashlib.sha256(value.encode("utf-8")).digest()[:16], byteorder="big")
+
+ def _forward_headers(self, request: Request) -> dict[str, str]:
+ ignored = {
+ "host",
+ "content-length",
+ "connection",
+ "keep-alive",
+ "proxy-authenticate",
+ "proxy-authorization",
+ "te",
+ "trailers",
+ "transfer-encoding",
+ "upgrade",
+ }
+ headers = {key: value for key, value in request.headers.items() if key.lower() not in ignored}
+ if "content-type" not in {key.lower() for key in headers}:
+ headers["content-type"] = "application/json"
+ return headers
+
+ async def _post_chat_completions(
+ self,
+ url: str,
+ payload: dict[str, Any],
+ headers: dict[str, str],
+ worker: RolloutWorkerHandle,
+ ) -> Response:
+ try:
+ response = await self.client.post(url, json=payload, headers=headers)
+ except Exception as exc:
+ raise HTTPException(status_code=502, detail={"error": str(exc)}) from exc
+
+ return Response(
+ content=response.content,
+ status_code=response.status_code,
+ media_type=response.headers.get("content-type", "application/json"),
+ )
+
+ async def _stream_chat_completions(
+ self,
+ url: str,
+ payload: dict[str, Any],
+ headers: dict[str, str],
+ worker: RolloutWorkerHandle,
+ ) -> StreamingResponse:
+ async def stream_response():
+ try:
+ async with self.client.stream("POST", url, json=payload, headers=headers, timeout=self.stream_timeout) as response:
+ if response.status_code >= 500:
+ body = await response.aread()
+ yield body
+ return
+ async for chunk in response.aiter_bytes():
+ yield chunk
+ except Exception as exc:
+ self.logger.error(f"Streaming chat completion failed for worker {worker.rank}: {exc}")
+ yield f'data: {{"error": {json.dumps(str(exc))}}}\n\n'.encode()
+
+ return StreamingResponse(stream_response(), media_type="text/event-stream")
+
+
+def build_internal_rollout_http_entry_app(
+ worker_handles: list[RolloutWorkerHandle],
+ rollout_config: RolloutConfig,
+ config: InternalRolloutHttpEntryConfig,
+) -> FastAPI:
+ entry = InternalRolloutHttpEntry(worker_handles=worker_handles, rollout_config=rollout_config, config=config)
+ app = FastAPI(title=config.title, version=config.version)
+ app.state.internal_rollout_http_entry = entry
+
+ @app.get("/v1/models")
+ async def models():
+ return await entry.models()
+
+ @app.post("/v1/chat/completions")
+ async def chat_completions(request: Request):
+ return await entry.chat_completions(request)
+
+ return app
+
+
+def serve_internal_rollout_http_entry(app: FastAPI, config: InternalRolloutHttpEntryConfig) -> None:
+ _ensure_port_available(config)
+ uvicorn.run(app, host=config.host, port=config.port, log_level=config.log_level)
+
+
+def serve_internal_rollout_http_entry_in_thread(
+ app: FastAPI, config: InternalRolloutHttpEntryConfig
+) -> threading.Thread:
+ thread = threading.Thread(
+ target=serve_internal_rollout_http_entry,
+ args=(app, config),
+ daemon=True,
+ name="internal-rollout-http-entry",
+ )
+ thread.start()
+ return thread
+
+
+def _ensure_port_available(config: InternalRolloutHttpEntryConfig) -> None:
+ host = "127.0.0.1" if config.host in ("", "0.0.0.0") else config.host
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ sock.settimeout(1.0)
+ if sock.connect_ex((host, config.port)) == 0:
+ raise OSError(f"Internal rollout HTTP entry port already in use: {config.host}:{config.port}")
diff --git a/xtuner/v1/rl/rollout/_generation/rollout_worker_generator.py b/xtuner/v1/rl/rollout/_generation/rollout_worker_generator.py
new file mode 100644
index 0000000000..5588f95eb0
--- /dev/null
+++ b/xtuner/v1/rl/rollout/_generation/rollout_worker_generator.py
@@ -0,0 +1,658 @@
+from __future__ import annotations
+
+import asyncio
+import base64
+import copy
+import json
+import threading
+import traceback
+from typing import Any, cast
+
+import httpx
+import numpy as np
+import ray
+from transformers import AutoConfig, AutoTokenizer
+
+from xtuner.v1.data_proto.rl_data import (
+ RolloutState,
+ SampleParams,
+ Status,
+ reset_rollout_response,
+ update_status_from_finish_reason,
+)
+from xtuner.v1.rl.utils import cancel_and_drain, get_eos_token
+from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger
+from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult
+
+from ..utils import PartialRolloutHandler
+from ..worker import RolloutConfig
+
+
+LMDEPLOY_SHARED_STORE = "shared_store"
+LMDEPLOY_SHARED_STORE_NAMESPACE = "lmdeploy"
+
+
+class RolloutWorkerGenerator:
+ """Generator bound to one rollout worker backend URL."""
+
+ def __init__(self, config: RolloutConfig, rank: int, server_url: str) -> None:
+ self.config = config
+ self.rank = rank
+ self.server_url = server_url
+ self.backend = config.rollout_backend
+ self.logger = get_logger(log_dir=config.worker_log_dir, tag=f"RolloutWorkerGenerator-{rank}")
+ self.endpoints = self._build_endpoints()
+ tokenizer_path = config.tokenizer_path or config.model_path
+ self.tokenizer_path = tokenizer_path
+ self._tokenizer = None
+ self.model_name = config.model_name
+ self.enable_return_routed_experts = config.enable_return_routed_experts
+ self._partial_rollout_handler: PartialRolloutHandler | None = None
+ self.receive_abort_request = threading.Event()
+ self.abort_timeout = 10.0
+ http_concurrency = config.rollout_max_batch_size_per_instance * config.allow_over_concurrency_ratio
+ limits = httpx.Limits(max_connections=http_concurrency, max_keepalive_connections=100)
+ self.client = httpx.AsyncClient(limits=limits, timeout=config.rollout_timeout)
+ eos_token = get_eos_token(config.model_path)
+ self.eos_token: list[int] = [eos_token] if isinstance(eos_token, int) else eos_token
+ self.lmdeploy_actor = None
+ self.routed_experts_num_hidden_layers = None
+ self.routed_experts_num_experts_per_tok = None
+ if self.backend == "sglang":
+ model_config = AutoConfig.from_pretrained(config.model_path, trust_remote_code=True)
+ text_config = getattr(model_config, "text_config", model_config)
+ self.routed_experts_num_hidden_layers = getattr(text_config, "num_hidden_layers", None)
+ self.routed_experts_num_experts_per_tok = getattr(text_config, "num_experts_per_tok", None)
+
+ @property
+ def tokenizer(self):
+ if self._tokenizer is None:
+ self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, trust_remote_code=True)
+ return self._tokenizer
+
+ @property
+ def partial_rollout_handler(self) -> PartialRolloutHandler:
+ if self._partial_rollout_handler is None:
+ self._partial_rollout_handler = PartialRolloutHandler()
+ return self._partial_rollout_handler
+
+ def _build_endpoints(self) -> dict[str, str]:
+ if self.backend == "lmdeploy":
+ return {
+ "generate": "generate",
+ "v1/chat/completions": "v1/chat/completions",
+ "abort_request": "abort_request",
+ }
+ if self.backend == "sglang":
+ return {
+ "generate": "generate",
+ "v1/chat/completions": "v1/chat/completions",
+ "abort_request": "abort_request",
+ }
+ if self.backend == "vllm":
+ return {
+ "generate": "v1/chat/completions",
+ "v1/chat/completions": "v1/chat/completions",
+ "abort_request": "abort_request",
+ }
+ raise ValueError(f"Unsupported rollout backend: {self.backend}")
+
+ def update_server_url(self, server_url: str) -> None:
+ self.server_url = server_url
+ self.receive_abort_request.clear()
+
+ async def pause_generation(self) -> bool:
+ self.receive_abort_request.set()
+ return await self._send_abort_request()
+
+ def continue_generation(self) -> None:
+ self.receive_abort_request.clear()
+
+ async def _send_abort_request(self) -> bool:
+ endpoint = self.endpoints.get("abort_request", "abort_request")
+ url = f"{self.server_url}/{endpoint}"
+ try:
+ response = await self.client.post(url)
+ response.raise_for_status()
+ return True
+ except Exception as exc:
+ self.logger.warning(f"Failed to send abort request to {url}: {exc}")
+ return False
+
+ async def _wait_abort_request(self) -> None:
+ while not self.receive_abort_request.is_set():
+ await asyncio.sleep(1)
+
+ async def generate(self, rollout_state: RolloutState, *, enable_partial_rollout: bool = False) -> RolloutState:
+ if self.receive_abort_request.is_set():
+ rollout_state.finish_reason = "abort"
+ rollout_state.status = Status.ABORTED
+ return rollout_state
+
+ uid = rollout_state.uid
+ sample_params: SampleParams = rollout_state.sample_params
+ max_tokens = sample_params.max_tokens
+ if sample_params.return_token_ids:
+ endpoint_url = f"{self.server_url}/{self.endpoints['generate']}"
+ else:
+ endpoint_url = f"{self.server_url}/{self.endpoints['v1/chat/completions']}"
+
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {self.config.api_key}",
+ }
+
+ if enable_partial_rollout:
+ rollout_state = self.partial_rollout_handler.preprocess(rollout_state, max_tokens)
+ elif rollout_state.status == Status.ABORTED:
+ rollout_state = reset_rollout_response(rollout_state)
+ rollout_state.sample_params = rollout_state.sample_params.model_copy(update={"max_tokens": max_tokens})
+ rollout_state.status = Status.INIT
+
+ payload = self._get_request_payload(rollout_state)
+ max_retries = self.config.max_retry_per_sample
+
+ if rollout_state.status == Status.COMPLETED:
+ self.logger.debug(f"Request {uid} is already marked as COMPLETED, skipping generation.")
+ return rollout_state
+
+ input_ids = payload.get("input_ids", [])
+ max_tokens = payload.get("max_tokens") or payload.get("max_new_tokens")
+ sampling_params = payload.get("sampling_params")
+ if max_tokens is None and isinstance(sampling_params, dict):
+ max_tokens = sampling_params.get("max_tokens") or sampling_params.get("max_new_tokens")
+ max_tokens = cast(int | None, max_tokens)
+ last_id = input_ids[-1] if len(input_ids) > 0 else "None"
+ is_max_tokens_zero = max_tokens is not None and max_tokens <= 0
+ is_eos_reached = len(input_ids) > 0 and input_ids[-1] in self.eos_token
+ if is_max_tokens_zero or is_eos_reached:
+ self.logger.debug(
+ f"No generation needed for request {uid}: max_tokens={max_tokens} or last input_id={last_id} is in eos_token."
+ )
+ rollout_state.finish_reason = "stop" if is_eos_reached else "length"
+ rollout_state.status = Status.COMPLETED
+ return rollout_state
+
+ for attempt in range(max_retries + 1):
+ is_last_attempt = attempt == max_retries
+ http_result = await self._safe_post_request(endpoint_url, headers=headers, payload=payload)
+ if http_result.response:
+ rollout_state = await self._safe_handle_response(rollout_state, http_result.response)
+ if rollout_state.status in [Status.COMPLETED, Status.ABORTED]:
+ return rollout_state
+ if is_last_attempt:
+ self.logger.warning(
+ f"Invalid rollout response for request {uid} after {max_retries} attempts, marking as FAILED."
+ )
+ rollout_state.status = Status.FAILED
+ rollout_state.error_msg = f"Invalid rollout response after {max_retries} attempts."
+ return rollout_state
+ self.logger.warning(
+ f"Invalid rollout response for request {uid}, retrying {attempt + 1}/{max_retries}."
+ )
+ await asyncio.sleep(0.1)
+ continue
+
+ if http_result.error_type == HttpRequestErrorType.REQUEST_ABORTED:
+ rollout_state.finish_reason = "abort"
+ rollout_state.status = update_status_from_finish_reason("abort")
+ return rollout_state
+
+ if http_result.is_client_error:
+ self.logger.warning(
+ f"rollout request {uid} to {http_result.url} was skipped due to client error {http_result.error_type} with {http_result.error_msg}"
+ )
+ rollout_state.error_msg = (
+ f"Client error {http_result.error_type} with message: {http_result.error_msg}"
+ )
+ rollout_state.status = Status.FAILED
+ return rollout_state
+
+ if http_result.is_server_error:
+ self.logger.warning(
+ f"rollout request {uid} to {http_result.url} failed due to server error {http_result.error_type} with {http_result.error_msg}"
+ )
+ rollout_state.error_msg = (
+ f"Server error {http_result.error_type} with message: {http_result.error_msg}"
+ )
+ rollout_state.status = Status.FAILED
+ return rollout_state
+
+ if http_result.is_retryable:
+ if is_last_attempt:
+ self.logger.warning(
+ f"rollout request {uid} to {http_result.url} failed after {max_retries} attempts due to retryable error {http_result.error_type} with {http_result.error_msg}"
+ )
+ rollout_state.error_msg = f"Request failed after {max_retries} attempts due to retryable error {http_result.error_type} with message: {http_result.error_msg}"
+ rollout_state.status = Status.FAILED
+ return rollout_state
+ self.logger.warning(
+ f"rollout request {uid} to {http_result.url} failed due to retryable error {http_result.error_type} with {http_result.error_msg}, retrying {attempt + 1}/{max_retries}."
+ )
+ await asyncio.sleep(0.1)
+ continue
+
+ if http_result.is_unknown_error:
+ raise RuntimeError(
+ f"Unexpected error during rollout request {uid} to {http_result.url}: {http_result.exception}"
+ )
+ return rollout_state
+
+ async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult:
+ send_task = None
+ abort_task = None
+ try:
+ if self.receive_abort_request.is_set():
+ self.logger.debug(f"Request to {url} was cancelled before sending due to an abort signal.")
+ return HttpRequestResult(error_type=HttpRequestErrorType.REQUEST_ABORTED, url=url, payload=payload)
+ req = self.client.build_request("POST", url, headers=headers, json=payload)
+ send_task = asyncio.create_task(self.client.send(req))
+ abort_task = asyncio.create_task(self._wait_abort_request())
+ done, _ = await asyncio.wait({send_task, abort_task}, return_when=asyncio.FIRST_COMPLETED)
+ if send_task in done:
+ response = await send_task
+ else:
+ try:
+ response = await asyncio.wait_for(asyncio.shield(send_task), timeout=self.abort_timeout)
+ except asyncio.TimeoutError:
+ self.logger.debug(
+ f"Request to {url} did not return within {self.abort_timeout:.2f}s after abort signal."
+ )
+ await cancel_and_drain([send_task])
+ return HttpRequestResult(error_type=HttpRequestErrorType.REQUEST_ABORTED, url=url, payload=payload)
+ response.raise_for_status()
+ return HttpRequestResult(response=response)
+ except asyncio.CancelledError:
+ self.logger.debug(f"Request to {url} was cancelled while waiting for the response.")
+ await cancel_and_drain([send_task, abort_task])
+ self.receive_abort_request.set()
+ return HttpRequestResult(error_type=HttpRequestErrorType.REQUEST_ABORTED, url=url, payload=payload)
+ except Exception as exc:
+ error_type = HttpRequestErrorType.from_exception(exc)
+ return HttpRequestResult(error_type=error_type, exception=exc, url=url, payload=payload)
+ finally:
+ await cancel_and_drain([abort_task])
+
+ def _get_request_payload(self, rollout_state: RolloutState) -> dict[str, Any]:
+ if self.backend == "lmdeploy":
+ return self._get_lmdeploy_request_payload(rollout_state)
+ if self.backend == "sglang":
+ return self._get_sglang_request_payload(rollout_state)
+ if self.backend == "vllm":
+ return self._get_vllm_request_payload(rollout_state)
+ raise ValueError(f"Unsupported rollout backend: {self.backend}")
+
+ def _get_lmdeploy_request_payload(self, rollout_state: RolloutState) -> dict[str, Any]:
+ sample_params = rollout_state.sample_params
+ optional_fields: dict[str, object] = {}
+ if rollout_state.tools is not None:
+ optional_fields["tools"] = rollout_state.tools
+ if rollout_state.tool_choice is not None:
+ optional_fields["tool_choice"] = rollout_state.tool_choice
+
+ if sample_params.return_token_ids:
+ payload: dict[str, Any] = {"model": self.model_name, **optional_fields}
+ if "image_data" in rollout_state.extra_fields:
+ assert rollout_state.tokens is not None, "input_tokens is required when image_data is provided."
+ payload["image_data"] = rollout_state.extra_fields["image_data"]
+ if rollout_state.tokens is not None:
+ payload["input_ids"] = rollout_state.tokens
+ else:
+ text_prompt = self.tokenizer.apply_chat_template(
+ rollout_state.message, tokenize=False, add_generation_prompt=True
+ )
+ payload["input_ids"] = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"]
+ sample_params.return_routed_experts = True if self.enable_return_routed_experts else False
+ payload.update(sample_params.model_dump(exclude_none=True))
+ return payload
+
+ payload = {"model": self.model_name, "messages": rollout_state.message, **optional_fields}
+ lmdeploy_sample_params: dict[str, Any] = {
+ "temperature": sample_params.temperature,
+ "top_p": sample_params.top_p,
+ "n": sample_params.n,
+ "stream": sample_params.stream,
+ "max_tokens": sample_params.max_tokens,
+ "repetition_penalty": sample_params.repetition_penalty,
+ "top_k": sample_params.top_k,
+ "skip_special_tokens": sample_params.skip_special_tokens,
+ }
+ if sample_params.stops:
+ lmdeploy_sample_params["stop"] = sample_params.stops
+ if sample_params.min_tokens > 0:
+ lmdeploy_sample_params["min_new_tokens"] = sample_params.min_tokens
+ payload.update(lmdeploy_sample_params)
+ return payload
+
+ def _get_sglang_request_payload(self, rollout_state: RolloutState) -> dict[str, Any]:
+ sample_params = rollout_state.sample_params
+ payload: dict[str, Any] = {"model": self.model_name}
+ if rollout_state.tools is not None:
+ payload["tools"] = rollout_state.tools
+ if rollout_state.tool_choice is not None:
+ payload["tool_choice"] = rollout_state.tool_choice
+
+ sample_params_dict = sample_params.model_dump()
+ sglang_sample_params = self._transform_sglang_sample_params(sample_params_dict)
+ sglang_extra_params = self._transform_sglang_extra_params(sample_params_dict)
+ payload.update(sglang_extra_params)
+ if self.enable_return_routed_experts and not rollout_state.extra_fields.get("disable_routed_experts", False):
+ payload["return_routed_experts"] = True
+
+ if sample_params.return_token_ids:
+ if "image_data" in rollout_state.extra_fields:
+ assert rollout_state.tokens is not None, "input_ids is required when image_data is provided."
+ payload["image_data"] = rollout_state.extra_fields["image_data"]
+ if rollout_state.tokens is not None:
+ payload["input_ids"] = rollout_state.tokens
+ else:
+ text_prompt = self.tokenizer.apply_chat_template(
+ rollout_state.message, tokenize=False, add_generation_prompt=True
+ )
+ payload["input_ids"] = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"]
+ payload["sampling_params"] = sglang_sample_params
+ return payload
+
+ payload["messages"] = rollout_state.message
+ payload.update(sglang_sample_params)
+ payload["max_tokens"] = sglang_sample_params["max_new_tokens"]
+ payload["min_tokens"] = sglang_sample_params["min_new_tokens"]
+ payload.pop("max_new_tokens", None)
+ payload.pop("min_new_tokens", None)
+ return payload
+
+ def _get_vllm_request_payload(self, rollout_state: RolloutState) -> dict[str, Any]:
+ sample_params = rollout_state.sample_params
+ prompt = copy.deepcopy(rollout_state.message)
+ extra_fields = rollout_state.extra_fields
+ if "image_data" in extra_fields:
+ image_index = 0
+ for message in prompt:
+ if not isinstance(message, dict) or message.get("role") != "user":
+ continue
+ new_content = []
+ for content_part in message.get("content", []):
+ if not isinstance(content_part, dict):
+ new_content.append(content_part)
+ continue
+ if content_part.get("type") == "image_url":
+ content_part["image_url"]["url"] = f"file://{extra_fields['image_data'][image_index]}"
+ content_part["image_url"].pop("image_wh", None)
+ image_index += 1
+ new_content.append(content_part)
+ message["content"] = new_content
+ assert image_index == len(extra_fields["image_data"]), (
+ f"Expected {len(extra_fields['image_data'])} images, but processed {image_index}."
+ )
+
+ payload: dict[str, Any] = {
+ "model": self.config.model_path,
+ "messages": prompt,
+ "stream": sample_params.stream,
+ }
+ if rollout_state.tokens is not None:
+ payload["input_ids"] = rollout_state.tokens
+ elif "train_prompt_ids" in extra_fields:
+ payload["input_ids"] = extra_fields["train_prompt_ids"]
+ payload.update(self._transform_vllm_sample_params(sample_params.model_dump(), sample_params.model_dump()))
+ return payload
+
+ def _transform_sglang_sample_params(self, sample_params: dict[str, Any]) -> dict[str, Any]:
+ if sample_params["top_p"] > 0:
+ sample_params["top_k"] = -1
+ sglang_sample_params = {
+ "n": sample_params["n"],
+ "top_k": sample_params["top_k"],
+ "top_p": sample_params["top_p"],
+ "temperature": sample_params["temperature"],
+ "repetition_penalty": sample_params["repetition_penalty"],
+ "presence_penalty": sample_params["presence_penalty"],
+ "frequency_penalty": sample_params["frequency_penalty"],
+ "max_new_tokens": sample_params["max_tokens"],
+ "min_new_tokens": sample_params["min_tokens"],
+ "stop": sample_params["stops"],
+ "stop_token_ids": sample_params["stop_token_ids"],
+ "skip_special_tokens": sample_params["skip_special_tokens"],
+ }
+ sampling_seed = sample_params.get("sampling_seed")
+ if sampling_seed is None and XTUNER_DETERMINISTIC:
+ sampling_seed = self.config.random_seed
+ if sampling_seed is not None:
+ sglang_sample_params["sampling_seed"] = sampling_seed
+ return sglang_sample_params
+
+ def _transform_sglang_extra_params(self, extra_params: dict[str, Any]) -> dict[str, Any]:
+ return {
+ "stream": extra_params["stream"],
+ "return_logprob": extra_params["return_logprob"],
+ "include_stop_str_in_output": extra_params["include_stop_str_in_output"],
+ "no_stop_trim": extra_params.get("no_stop_trim", False),
+ "spaces_between_special_tokens": extra_params.get("spaces_between_special_tokens", False),
+ }
+
+ def _transform_vllm_sample_params(
+ self, sample_params: dict[str, Any], extra_params: dict[str, Any] | None = None
+ ) -> dict[str, Any]:
+ vllm_sample_params = copy.deepcopy(sample_params)
+ if extra_params:
+ vllm_sample_params.update(extra_params)
+ if "stops" in vllm_sample_params:
+ vllm_sample_params["stop"] = vllm_sample_params.pop("stops")
+ if "no_stop_trim" in vllm_sample_params:
+ vllm_sample_params["include_stop_str_in_output"] = vllm_sample_params.pop("no_stop_trim")
+ if "top_logprobs" in vllm_sample_params and "return_logprob" in vllm_sample_params:
+ vllm_sample_params["logprobs"] = vllm_sample_params.pop("return_logprob")
+ return vllm_sample_params
+
+ async def _decode_routed_experts(self, routed_experts: Any) -> Any:
+ if self.backend == "lmdeploy":
+ if isinstance(routed_experts, str):
+ if self.lmdeploy_actor is None:
+ self.lmdeploy_actor = ray.get_actor(LMDEPLOY_SHARED_STORE, namespace=LMDEPLOY_SHARED_STORE_NAMESPACE)
+ routed_experts_data = await self.lmdeploy_actor.get.remote(routed_experts)
+ return ray.put(np.asarray(routed_experts_data))
+ return np.asarray(routed_experts)
+ if self.backend == "sglang":
+ if isinstance(routed_experts, str):
+ routed_experts_flat = np.frombuffer(base64.b64decode(routed_experts), dtype=np.int32)
+ routed_experts_array = routed_experts_flat.reshape(
+ -1,
+ self.routed_experts_num_hidden_layers,
+ self.routed_experts_num_experts_per_tok,
+ )
+ return routed_experts_array.copy()
+ return np.asarray(routed_experts)
+ if self.backend == "vllm":
+ if isinstance(routed_experts, str):
+ routed_experts = ray.cloudpickle.loads(base64.b64decode(routed_experts))
+ return np.asarray(routed_experts)
+ return routed_experts
+
+ async def _safe_handle_response(self, rollout_state: RolloutState, http_response: httpx.Response) -> RolloutState:
+ if self.backend == "vllm":
+ return await self._safe_handle_vllm_response(rollout_state, http_response)
+ return await self._safe_handle_openai_or_token_response(rollout_state, http_response)
+
+ async def _safe_handle_openai_or_token_response(
+ self, rollout_state: RolloutState, http_response: httpx.Response
+ ) -> RolloutState:
+ uid = rollout_state.message_uid
+ sample_params = rollout_state.sample_params
+ response = http_response.json()
+
+ if sample_params.return_token_ids:
+ response_ids: list[int] = []
+ logprobs: list[float] = []
+ routed_experts = None
+ returned_response = ""
+ try:
+ meta_info = response.get("meta_info") or {}
+ finish_reason_info = meta_info.get("finish_reason") or {}
+ finish_reason = finish_reason_info.get("type")
+ if finish_reason is None:
+ if self.receive_abort_request.is_set():
+ rollout_state.finish_reason = "abort"
+ rollout_state.status = Status.ABORTED
+ else:
+ rollout_state.finish_reason = "error"
+ rollout_state.status = Status.FAILED
+ rollout_state.error_msg = "Missing finish_reason in response meta_info"
+ return rollout_state
+ returned_response = response.get("text", "")
+ if meta_info.get("output_token_logprobs") is not None:
+ response_ids = [item[1] for item in meta_info["output_token_logprobs"]]
+ logprobs = [item[0] for item in meta_info["output_token_logprobs"]]
+ else:
+ num_return_tokens = meta_info.get("completion_tokens", 0)
+ response_ids = response["output_ids"][-num_return_tokens:] if num_return_tokens > 0 else []
+
+ if self.enable_return_routed_experts:
+ assert "routed_experts" in meta_info, (
+ "enable_return_routed_experts is True, but routed_experts is not in meta_info"
+ )
+ routed_experts = meta_info["routed_experts"]
+ if routed_experts is not None:
+ routed_experts = await self._decode_routed_experts(routed_experts)
+ if not isinstance(routed_experts, ray.ObjectRef):
+ routed_experts = ray.put(routed_experts)
+
+ rollout_status = update_status_from_finish_reason(finish_reason)
+ if rollout_status == Status.COMPLETED:
+ validation_errors = []
+ if not response_ids:
+ validation_errors.append("empty response_ids")
+ if not returned_response:
+ validation_errors.append("empty response text")
+ if sample_params.return_logprob and not logprobs:
+ validation_errors.append("missing logprobs")
+ if self.enable_return_routed_experts and routed_experts is None:
+ validation_errors.append("missing routed_experts")
+ if validation_errors:
+ error_msg = f"Incomplete rollout data for msg {uid}: {', '.join(validation_errors)}"
+ self.logger.error(error_msg)
+ rollout_state.status = Status.FAILED
+ rollout_state.error_msg = error_msg
+ return rollout_state
+ elif rollout_status == Status.FAILED:
+ error_msg = f"Rollout failed for msg {uid} with finish_reason {finish_reason}"
+ self.logger.error(error_msg)
+ rollout_state.status = Status.FAILED
+ rollout_state.error_msg = error_msg
+ return rollout_state
+
+ if enable_partial_rollout:
+ expect_len = meta_info.get("prompt_tokens", 0) + meta_info.get("completion_tokens", 0) - 1
+ rollout_state = await self.partial_rollout_handler.postprocess(
+ rollout_state,
+ response=returned_response,
+ response_ids=response_ids,
+ logprobs=logprobs,
+ routed_experts=routed_experts,
+ finish_reason=finish_reason,
+ status=rollout_status,
+ routed_experts_expect_len=expect_len,
+ )
+ else:
+ rollout_state.response = returned_response
+ rollout_state.response_ids = response_ids
+ rollout_state.logprobs = logprobs
+ rollout_state.routed_experts = routed_experts
+ rollout_state.finish_reason = finish_reason
+ rollout_state.status = rollout_status
+ return rollout_state
+ except Exception as exc:
+ raise self._response_error(exc, response, uid)
+
+ try:
+ returned_response = response["choices"][0]["message"]["content"]
+ finish_reason = response["choices"][0]["finish_reason"]
+ rollout_status = update_status_from_finish_reason(finish_reason)
+ if rollout_status == Status.COMPLETED and not returned_response:
+ rollout_state.status = Status.FAILED
+ rollout_state.error_msg = "Empty response text"
+ return rollout_state
+ rollout_state.response = returned_response
+ rollout_state.finish_reason = finish_reason
+ rollout_state.status = rollout_status
+ return rollout_state
+ except Exception as exc:
+ raise self._response_error(exc, response, uid)
+
+ async def _safe_handle_vllm_response(self, rollout_state: RolloutState, http_response) -> RolloutState:
+ uid = rollout_state.uid or rollout_state.message_uid
+ sample_params = rollout_state.sample_params
+ last_token_ids: list[int] = []
+ last_logprobs: list[float] = []
+ routed_experts = None
+
+ response_json = http_response.json()
+ try:
+ response_choice = response_json["choices"][0]
+ if response_choice.get("logprobs") is not None:
+ last_token_ids = response_choice.get("token_ids", response_json.get("token_ids", []))
+ last_logprobs = [
+ item["logprob"] for item in response_choice["logprobs"].get("content", []) if "logprob" in item
+ ]
+ assert len(last_token_ids) == len(last_logprobs)
+ assert len(last_token_ids) <= sample_params.max_tokens, (
+ f"Generation length exceeds limit: generated {len(last_token_ids)}, limit {sample_params.max_tokens}"
+ )
+
+ last_trajectory = response_choice["message"].get("content") or ""
+ finish_reason = response_choice.get("finish_reason")
+ if finish_reason == "abort" and not self.receive_abort_request.is_set():
+ self.receive_abort_request.set()
+ self.logger.info(f"Setting receive_abort_request to True for rank {self.rank}")
+
+ if self.enable_return_routed_experts:
+ routed_experts = response_choice.get("routed_experts", response_json.get("routed_experts"))
+ if routed_experts is not None:
+ routed_experts = await self._decode_routed_experts(routed_experts)
+ if not isinstance(routed_experts, ray.ObjectRef):
+ routed_experts = ray.put(routed_experts)
+
+ rollout_status = update_status_from_finish_reason(finish_reason)
+ if rollout_status == Status.COMPLETED:
+ validation_errors = []
+ if sample_params.return_token_ids and len(last_token_ids) == 0:
+ validation_errors.append("empty response_ids")
+ if sample_params.return_logprob and len(last_logprobs) == 0:
+ validation_errors.append("missing logprobs")
+ if not last_trajectory:
+ validation_errors.append("empty response text")
+ if self.enable_return_routed_experts and routed_experts is None:
+ validation_errors.append("missing routed_experts")
+ if validation_errors:
+ error_msg = f"Incomplete rollout data for request {uid}: {', '.join(validation_errors)}"
+ self.logger.error(f"{error_msg}. Raw response: {response_json}")
+ rollout_state.status = Status.FAILED
+ rollout_state.error_msg = error_msg
+ return rollout_state
+
+ rollout_state.response = last_trajectory
+ rollout_state.response_ids = last_token_ids if len(last_token_ids) > 0 else None
+ rollout_state.logprobs = last_logprobs if len(last_logprobs) > 0 else None
+ rollout_state.routed_experts = routed_experts
+ rollout_state.finish_reason = finish_reason
+ rollout_state.status = rollout_status
+ return rollout_state
+ except Exception as exc:
+ raise self._response_error(exc, response_json, uid)
+
+ def _response_error(self, exc: Exception, response: dict[str, Any], uid: Any) -> RuntimeError:
+ response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")}
+ if isinstance(exc, KeyError):
+ return RuntimeError(f"Missing expected key {exc} in response {response_for_log} for {uid}")
+ if isinstance(exc, IndexError):
+ return RuntimeError(f"Index error {exc} while processing response {response_for_log} for {uid}")
+ if isinstance(exc, AssertionError):
+ return RuntimeError(f"AssertionError: {exc} when processing response {response_for_log} for {uid}")
+ if isinstance(exc, json.JSONDecodeError):
+ return RuntimeError(f"JSONDecodeError: {exc} when processing response {response} for {uid}")
+ if isinstance(exc, TypeError):
+ return RuntimeError(f"TypeError: {exc} when processing response {response_for_log} for {uid}")
+ return RuntimeError(
+ f"Unexpected error: {exc} when processing response {response_for_log} for {uid}\nTraceback: {traceback.format_exc()}"
+ )
+
+RayRolloutWorkerGenerator = ray.remote(RolloutWorkerGenerator)
diff --git a/xtuner/v1/rl/rollout/_generation/session_worker_selector.py b/xtuner/v1/rl/rollout/_generation/session_worker_selector.py
new file mode 100644
index 0000000000..e563637305
--- /dev/null
+++ b/xtuner/v1/rl/rollout/_generation/session_worker_selector.py
@@ -0,0 +1,112 @@
+from __future__ import annotations
+
+import asyncio
+import time
+from collections import OrderedDict
+from dataclasses import dataclass
+from itertools import cycle
+from typing import TYPE_CHECKING, Any, Literal
+
+
+if TYPE_CHECKING:
+ from .rollout_worker_generator import RolloutWorkerGenerator
+
+
+RolloutWorkerUrlSource = Literal["backend", "session"]
+
+
+@dataclass
+class RolloutWorkerHandle:
+ """Runtime handles for one active rollout worker.
+
+ This object intentionally groups the control-plane worker actor with the
+ data-plane generation entries selected from that worker.
+ """
+
+ # Stable active-worker rank. Used for session affinity, logging, and
+ # registering worker entries into HTTP routers.
+ rank: int
+
+ # Original RolloutWorker Ray actor. Controller control-plane operations
+ # such as offload/onload/pause/shutdown are sent to this actor.
+ worker_actor: Any
+
+ # Raw backend server URL exposed by lmdeploy/vLLM/SGLang.
+ backend_url: str
+
+ # Per-active-worker generation actor used by local Python/Ray generation.
+ generator_actor: RolloutWorkerGenerator
+
+ # Optional SessionServer proxy URL wrapping backend_url. It is only used
+ # when HTTP generation is configured to require session/cache/trace logic.
+ session_server_url: str | None = None
+
+ def require_session_server_url(self) -> str:
+ if self.session_server_url is None:
+ raise RuntimeError(f"Rollout worker {self.rank} does not have a SessionServer URL.")
+ return self.session_server_url
+
+ def get_generate_url(self, source: RolloutWorkerUrlSource) -> str:
+ if source == "backend":
+ return self.backend_url
+ if source == "session":
+ return self.require_session_server_url()
+ raise ValueError(f"Unsupported rollout worker URL source: {source!r}")
+
+
+class SessionWorkerSelector:
+ def __init__(
+ self,
+ workers: list[RolloutWorkerHandle],
+ *,
+ max_sessions: int = 10000,
+ max_idle_seconds: float | None = 3600.0,
+ ) -> None:
+ self._workers = {worker.rank: worker for worker in workers}
+ self._rank_cycle = cycle(self._workers)
+ self._max_sessions = max_sessions
+ self._max_idle_seconds = max_idle_seconds
+ self._sessions: OrderedDict[int, tuple[int, float]] = OrderedDict()
+ self._lock = asyncio.Lock()
+
+ async def select(self, session_id: int) -> RolloutWorkerHandle | None:
+ async with self._lock:
+ self._evict_expired()
+ if session_id in self._sessions:
+ rank, _ = self._sessions.pop(session_id)
+ worker = self._workers.get(rank)
+ if worker is not None:
+ self._sessions[session_id] = (rank, self._now())
+ return worker
+
+ worker = self._next_worker()
+ if worker is None:
+ return None
+ self._sessions[session_id] = (worker.rank, self._now())
+ self._evict_to_capacity()
+ return worker
+
+ def _next_worker(self) -> RolloutWorkerHandle | None:
+ if not self._workers:
+ return None
+ return self._workers[next(self._rank_cycle)]
+
+ def _evict_expired(self) -> None:
+ if self._max_idle_seconds is None:
+ return
+ now = self._now()
+ expired = []
+ for session_id, (_, last_used_at) in self._sessions.items():
+ if now - last_used_at > self._max_idle_seconds:
+ expired.append(session_id)
+ else:
+ break
+ for session_id in expired:
+ self._sessions.pop(session_id, None)
+
+ def _evict_to_capacity(self) -> None:
+ while len(self._sessions) > self._max_sessions:
+ self._sessions.popitem(last=False)
+
+ def _now(self) -> float:
+ return time.time()
diff --git a/xtuner/v1/rl/rollout/controller.py b/xtuner/v1/rl/rollout/controller.py
index 86137e9167..0ce82d7b7d 100644
--- a/xtuner/v1/rl/rollout/controller.py
+++ b/xtuner/v1/rl/rollout/controller.py
@@ -1,40 +1,21 @@
-import asyncio
-import math
-import os
-import threading
-from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeAlias, TypedDict
-from uuid import uuid4
import ray
from ray.actor import ActorProxy
-from ray.util.placement_group import PlacementGroup
-from transformers import AutoTokenizer
-from xtuner.v1.data_proto.rl_data import RolloutState, Status
-from xtuner.v1.rl.utils import AutoAcceleratorWorkers
-from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger
+from xtuner.v1.utils import get_logger
-from .parser.factory import build_reasoning_parser, build_tool_call_parser
-from .parser.reasoning_parser import ReasoningParser
-from .parser.tool_parser import ToolCallParser
-from .utils import ROLLOUT_RAY_GET_TIMEOUT, RolloutHealthChecker, SessionRouter
-from .worker import ROLLOUT_CONCURRENCY_GROUP_GENERATE, RolloutConfig, RolloutWorker
+from ._generation.session_worker_selector import RolloutWorkerHandle
+from .rollout_worker_build import RolloutRuntime
+from .utils import ROLLOUT_RAY_GET_TIMEOUT
+from .worker import RolloutConfig
if TYPE_CHECKING:
from xtuner.v1.rl.gateway.config import GatewayConfig
-
-
-@dataclass
-class WorkerInfo:
- """A data class to hold all state information for a single worker."""
-
- actor: RolloutWorker
- url: str
- session_url: str | None = None
- is_active: bool = True
+ from xtuner.v1.rl.rollout._generation.external_http_entry import ExternalRolloutHttpEntryConfig
+ from xtuner.v1.rl.rollout._generation.internal_http_entry import InternalRolloutHttpEntryConfig
class RolloutWorkerMetadata(TypedDict):
@@ -59,24 +40,28 @@ class RolloutWorkerMetadata(TypedDict):
# 包括:并行策略(TP/EP)、超时设置、后端类型(LMDeploy/vLLM/SGLang)等
rollout_config: RolloutConfig
- # 每个 worker 服务器 URL 的当前活跃状态
- # 键:服务器 URL 字符串
- # 值:布尔值,True 表示该 worker 处于活跃状态,False 表示已失效或停用
+ # Worker server URL map used by trainer-side control paths. First-version
+ # rollout refactor assumes all workers remain available.
worker_server_urls_status: Dict[str, bool]
# Gateway HTTP server URL (e.g. "http://1.2.3.4:8080").
# Set after start_gateway() is called; None if the gateway has not been started.
api_server_url: Optional[str]
- # worker rank -> SessionServer proxy URL. These are the externally
- # registered URLs for routedapiproxy; server_url_dict keeps the original
- # worker URLs for trainer-side weight update / backend control paths.
+ # Internal rollout HTTP entry URL. Set after start_internal_http_entry() is called.
+ internal_http_entry_url: Optional[str]
+
+ # worker rank -> SessionServer proxy URL. server_url_dict keeps the
+ # original worker URLs for trainer-side weight update / backend control paths.
worker_session_url_dict: Dict[int, str]
- # SessionServer URL -> active status. This mirrors worker_server_urls_status
- # but is keyed by the proxy URL that external traffic uses.
+ # SessionServer URL -> availability status. First-version rollout refactor
+ # does not deactivate workers, so these values are always True.
worker_session_urls_status: Dict[str, bool]
+ # Runtime worker handles consumed by local/http/external generation paths.
+ worker_handles: List[RolloutWorkerHandle]
+
# Keep this as a Ray actor because Ray AgentLoop actors need a shared, cross-process handle to the same controller
# state; passing a normal Python object would serialize a separate copy into each actor.
@@ -87,37 +72,30 @@ class RolloutController:
def __init__(
self,
infer_config: RolloutConfig,
- placement_group: PlacementGroup,
+ runtime: RolloutRuntime,
):
"""Initialize the RolloutController.
Args:
infer_config (RolloutConfig): The configuration for the rollout.
- placement_group (PlacementGroup): The placement group for the
- RolloutWorker actors.
+ runtime: Pre-built rollout runtime, including worker handles and
+ backend server URLs.
"""
self.config = infer_config
self.num_gpus_per_engine = self.config.num_gpus_per_engine
self.logger = get_logger(log_dir=infer_config.worker_log_dir, tag="RolloutController")
- self.engine_rank_mesh_array: List[List[int]] = []
- self.worker_server_urls_map: dict[int, str] = {}
- self.rank2info: dict[int, WorkerInfo] = {}
- self.engine_rank_mesh_array, self.worker_server_urls_map, self.rank2info = self._init_workers(placement_group)
- self.num_active_workers = len(self.rank2info)
- self.worker_info_lock = threading.RLock()
+ self.engine_rank_mesh_array = runtime.engine_rank_mesh_array
+ self.worker_server_urls_map = runtime.worker_server_urls_map
+ self.rank2worker = runtime.rank2worker
+ self.worker_handles = runtime.worker_handles
+ self.num_rollout_workers = len(self.rank2worker)
# The timeout for the environment to wait for the rollout controller's response.
# This should be longer than the controller's internal timeout (`rollout_timeout`)
# to account for potential queuing delays and other overheads.
self.timeout_multiplier = 2.0
- self.router = SessionRouter(self.rank2info, worker_infos_lock=self.worker_info_lock)
- self.health_checker = RolloutHealthChecker(
- config=self.config,
- workers_info=self.rank2info,
- worker_infos_lock=self.worker_info_lock,
- )
- self.health_checker.start()
- self._tool_call_parser, self._reasoning_parser = self._build_output_parsers()
self._gateway_url: str | None = None
+ self._internal_http_entry_url: str | None = None
+ self._external_http_entries = []
def start_gateway(self, config: "GatewayConfig") -> str | None:
"""Start the gateway HTTP server in a daemon thread and return its URL.
@@ -152,6 +130,36 @@ def start_gateway(self, config: "GatewayConfig") -> str | None:
self.logger.info(f"Gateway server started at {url}, capture_folder: {config.capture_folder}")
return url
+ def start_internal_http_entry(self, config: "InternalRolloutHttpEntryConfig") -> str:
+ from xtuner.v1.rl.rollout._generation.internal_http_entry import (
+ build_internal_rollout_http_entry_app,
+ serve_internal_rollout_http_entry_in_thread,
+ )
+
+ app = build_internal_rollout_http_entry_app(
+ worker_handles=self.worker_handles,
+ rollout_config=self.config,
+ config=config,
+ )
+ serve_internal_rollout_http_entry_in_thread(app, config)
+ host = ray.util.get_node_ip_address() if config.host in ("", "0.0.0.0") else config.host
+ url = f"http://{host}:{config.port}"
+ self._internal_http_entry_url = url
+ self.logger.info(f"Internal rollout HTTP entry started at {url}")
+ return url
+
+ def start_external_http_entry(self, config: "ExternalRolloutHttpEntryConfig") -> None:
+ from xtuner.v1.rl.rollout._generation.external_http_entry import ExternalRolloutHttpEntry
+
+ entry = ExternalRolloutHttpEntry(
+ worker_handles=self.worker_handles,
+ rollout_config=self.config,
+ config=config,
+ log_dir=str(self.config.worker_log_dir),
+ )
+ entry.start()
+ self._external_http_entries.append(entry)
+
def get_rollout_metadata(self) -> RolloutWorkerMetadata:
"""Get information about the current rollout setup.
@@ -159,124 +167,67 @@ def get_rollout_metadata(self) -> RolloutWorkerMetadata:
dict: A dictionary containing the engine mesh list, server URL
dictionary, and the rollout configuration.
"""
- with self.worker_info_lock:
- worker_server_urls_status = {info.url: info.is_active for info in self.rank2info.values()}
- worker_session_url_dict = {
- rank: info.session_url for rank, info in self.rank2info.items() if info.session_url is not None
- }
- worker_session_urls_status = {
- info.session_url: info.is_active for info in self.rank2info.values() if info.session_url is not None
- }
+ worker_server_urls_map = {worker.rank: worker.backend_url for worker in self.worker_handles}
+ worker_server_urls_status = {worker.backend_url: True for worker in self.worker_handles}
+ worker_session_url_dict = {
+ worker.rank: worker.session_server_url for worker in self.worker_handles if worker.session_server_url is not None
+ }
+ worker_session_urls_status = {
+ worker.session_server_url: True for worker in self.worker_handles if worker.session_server_url is not None
+ }
rollout_metadata: RolloutWorkerMetadata = {
"engine_rank_mesh_array": self.engine_rank_mesh_array,
- "server_url_dict": self.worker_server_urls_map,
+ "server_url_dict": worker_server_urls_map,
"rollout_config": self.config,
"worker_server_urls_status": worker_server_urls_status,
- "api_server_url": self._gateway_url,
+ "api_server_url": self._internal_http_entry_url or self._gateway_url,
"worker_session_url_dict": worker_session_url_dict,
"worker_session_urls_status": worker_session_urls_status,
+ "internal_http_entry_url": self._internal_http_entry_url,
+ "worker_handles": list(self.worker_handles),
}
return rollout_metadata
- def _build_output_parsers(self) -> tuple[ToolCallParser | None, ReasoningParser | None]:
- tool_call_parser = None
- reasoning_parser = None
-
- if self.config.tool_call_parser != "none":
- tool_call_parser = build_tool_call_parser(self.config.tool_call_parser)
-
- if self.config.reasoning_parser != "none":
- tokenizer_path = self.config.tokenizer_path or self.config.model_path
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
- reasoning_parser = build_reasoning_parser(self.config.reasoning_parser, tokenizer)
-
- return tool_call_parser, reasoning_parser
-
- def get_ready_status(self) -> tuple[bool, dict[str, Any]]:
- with self.worker_info_lock:
- active_workers = sum(1 for info in self.rank2info.values() if info.is_active)
- total_workers = len(self.rank2info)
- return active_workers > 0, {
- "active_workers": active_workers,
- "total_workers": total_workers,
+ def get_worker_handles(self) -> list[RolloutWorkerHandle]:
+ return list(self.worker_handles)
+
+ def get_runtime_status(self) -> tuple[bool, dict[str, Any]]:
+ return bool(self.worker_handles), {
+ "rollout_workers": len(self.worker_handles),
+ "total_workers": len(self.worker_handles),
+ "workers": {
+ worker.rank: {
+ "url": worker.backend_url,
+ "session_url": worker.session_server_url,
+ "has_rollout_worker_generator": True,
+ }
+ for worker in self.worker_handles
+ },
}
- @ray.method(concurrency_group=ROLLOUT_CONCURRENCY_GROUP_GENERATE)
- async def generate(self, rollout_state: RolloutState) -> RolloutState:
- if XTUNER_DETERMINISTIC:
- sample_params = rollout_state.sample_params.model_copy(deep=True)
- sample_params.sampling_seed = self.config.random_seed + (
- (rollout_state.uid or 0) - (rollout_state.message_uid or 0)
- )
- rollout_state.sample_params = sample_params
-
- session_id = rollout_state.session_uid if rollout_state.session_uid is not None else uuid4().int
- worker = await self.router.get_worker(session_id)
- if worker is None:
- rollout_state.status = Status.FAILED
- rollout_state.error_msg = "No active rollout worker available."
- return rollout_state
-
- response_ref = worker.generate.remote(rollout_state=rollout_state) # type: ignore[attr-defined]
- try:
- response_rollout_state = await asyncio.wait_for(
- response_ref,
- timeout=self.config.rollout_timeout * self.timeout_multiplier,
- )
- self._apply_output_parsers(response_rollout_state)
- return response_rollout_state
- except asyncio.TimeoutError:
- self.logger.error(f"Rollout timeout for worker {worker}. Skipping sample.")
- rollout_state.status = Status.FAILED
- rollout_state.error_msg = (
- f"Rollout request timed out after {self.config.rollout_timeout * self.timeout_multiplier} seconds."
- )
- return rollout_state
-
- def _apply_output_parsers(self, rollout_state: RolloutState) -> None:
- """Apply tool-call and reasoning parsers to the rollout state in-
- place."""
- if self._tool_call_parser is not None:
- parsed = self._tool_call_parser.parse(rollout_state)
- rollout_state.tool_calls = parsed.tool_calls
- rollout_state.response = parsed.remaining_text or None
- if self._reasoning_parser is not None:
- parsed_reasoning = self._reasoning_parser.parse(rollout_state)
- rollout_state.response = parsed_reasoning.remaining_text
- if parsed_reasoning.reasoning_text:
- rollout_state.extra_fields["reasoning_text"] = parsed_reasoning.reasoning_text
- else:
- rollout_state.extra_fields.pop("reasoning_text", None)
-
- def set_enable_partial_rollout(self, enable: bool) -> None:
- """Propagate enable_partial_rollout flag to all active workers."""
- with self.worker_info_lock:
- active_actors = [info.actor for info in self.rank2info.values() if info.is_active]
- ray.get([actor.set_enable_partial_rollout.remote(enable) for actor in active_actors]) # type: ignore[attr-defined]
-
def pause_generation(self):
- self.health_checker.pause()
- self._broadcast_to_active_workers("pause_generation")
+ self._broadcast_to_rollout_worker_generators("pause_generation")
+ self._broadcast_to_workers("pause_generation")
def cleanup_after_pause(self):
- self._broadcast_to_active_workers("cleanup_after_pause")
+ self._broadcast_to_workers("cleanup_after_pause")
def continue_generation(self):
- self.health_checker.resume()
- self._broadcast_to_active_workers("continue_generation")
+ self._broadcast_to_rollout_worker_generators("continue_generation")
+ self._broadcast_to_workers("continue_generation")
def offload(self):
- self._broadcast_to_active_workers("offload")
+ self._broadcast_to_workers("offload")
def onload(self):
- self._broadcast_to_active_workers("onload_weights")
- self._broadcast_to_active_workers("onload_kvcache")
+ self._broadcast_to_workers("onload_weights")
+ self._broadcast_to_workers("onload_kvcache")
def onload_weights(self):
- self._broadcast_to_active_workers("onload_weights")
+ self._broadcast_to_workers("onload_weights")
def onload_kvcache(self):
- self._broadcast_to_active_workers("onload_kvcache")
+ self._broadcast_to_workers("onload_kvcache")
def shutdown(self):
"""Shuts down all active rollout workers.
@@ -284,80 +235,13 @@ def shutdown(self):
Args:
block (bool): Whether to block until the operation completes.
"""
- self.health_checker.stop()
- self._broadcast_to_active_workers("shutdown")
-
- def recover_failed_workers(self):
- """Recovers from worker failures by restarting failed workers and
- reinitializing the rollout setup."""
- self.health_checker.pause()
- with self.worker_info_lock:
- failed_workers = [info for info in self.rank2info.values() if not info.is_active]
- if not failed_workers:
- self.logger.info("No failed workers detected during recovery.")
- return
-
- self.logger.warning(f"Detected {len(failed_workers)} failed workers. Initiating recovery process.")
- for worker in failed_workers:
- if self._restart_failed_workers(worker.actor):
- with self.worker_info_lock:
- rank = self._get_rank_by_actor(worker.actor)
- if rank is not None:
- self.rank2info[rank].is_active = True
- self.health_checker.resume()
-
- def _restart_failed_workers(self, worker: RolloutWorker) -> bool:
- try:
- dist_init_addr = ray.get(worker.init_dist_port.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined]
- _, url = ray.get(worker.init.remote(dist_init_addr), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined]
- _, session_url = ray.get(worker.get_session_server_info.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined]
- is_healthy = ray.get(worker.check_health.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined]
-
- if is_healthy:
- self.logger.info(f"Successfully restarted worker {worker} with URL {url}.")
- with self.worker_info_lock:
- rank = self._get_rank_by_actor(worker)
- if rank is not None:
- self.rank2info[rank].url = url
- self.rank2info[rank].session_url = session_url
- self.worker_server_urls_map[rank] = url
- return True
- else:
- self.logger.error(f"Worker {worker} is still unhealthy after restart.")
- return False
- except Exception as e:
- self.logger.error(f"Failed to restart worker: {e}")
- return False
-
- def _update_dist_init_addr(self, nodes_per_engine, server_urls_per_engine, dist_init_addrs, tp_size):
- """Update the distributed initialization addresses for workers.
-
- This is used to group workers that belong to the same inference engine.
+ for entry in self._external_http_entries:
+ entry.stop()
+ self._external_http_entries.clear()
+ self._broadcast_to_workers("shutdown")
- Args:
- nodes_per_engine (int): The number of nodes per inference engine.
- server_urls_per_engine (int): The number of server urls per inference engine.
- dist_init_addrs (list): The list of initial addresses.
- tp_size (int): The tensor parallel size.
-
- Returns:
- list: The updated list of distributed initialization addresses.
- """
- # lmdeploy pytorch ep: server_urls_per_engine > 1
- # sglang cross node engine: nodes_per_engine > 1
- assert server_urls_per_engine == 1 or nodes_per_engine == 1
- if nodes_per_engine > 1:
- index = list(range(0, self.num_active_workers + 1, tp_size)) + [self.num_active_workers]
- for i in range(1, len(index)):
- dist_init_addrs[index[i - 1] : index[i]] = [dist_init_addrs[index[i - 1]]] * (index[i] - index[i - 1])
- if server_urls_per_engine > 1:
- activate_servers = len(dist_init_addrs)
- for i in range(0, activate_servers, server_urls_per_engine):
- dist_init_addrs[i : i + server_urls_per_engine] = [dist_init_addrs[i]] * server_urls_per_engine
- return dist_init_addrs
-
- def _broadcast_to_active_workers(self, method_name: str):
- """Helper function to call a method on all active workers.
+ def _broadcast_to_workers(self, method_name: str):
+ """Helper function to call a method on all rollout workers.
Args:
method_name (str): The name of the method to call.
@@ -366,147 +250,17 @@ def _broadcast_to_active_workers(self, method_name: str):
Returns:
A list of futures if `block` is False, otherwise a list of results.
"""
- futures = []
- with self.worker_info_lock:
- active_actors = [info.actor for info in self.rank2info.values() if info.is_active]
- futures = [getattr(actor, method_name).remote() for actor in active_actors]
+ worker_actors = [worker.worker_actor for worker in self.worker_handles]
+ futures = [getattr(actor, method_name).remote() for actor in worker_actors]
results = ray.get(futures, timeout=ROLLOUT_RAY_GET_TIMEOUT)
return results
- def _get_worker_cls(self):
- if os.environ.get("XTUNER_USE_LMDEPLOY") == "1":
- from .lmdeploy import LMDeployWorker
-
- worker_cls = LMDeployWorker
- elif os.environ.get("XTUNER_USE_VLLM") == "1":
- from .vllm import vLLMWorker
-
- worker_cls = vLLMWorker
- elif os.environ.get("XTUNER_USE_SGLANG") == "1":
- from .sglang import SGLangWorker
-
- worker_cls = SGLangWorker
- else:
- raise NotImplementedError(
- "Rollout backend is not supported."
- "Please set XTUNER_USE_LMDEPLOY or XTUNER_USE_VLLM"
- " or XTUNER_USE_SGLANG environment variable."
- )
- assert self.config.rollout_max_batch_size_per_instance is not None, (
- "rollout_max_batch_size_per_instance must be set before building RolloutWorker."
- )
- worker_generate_max_concurrency = max(
- 1000, # Ray async actor default max_concurrency.
- math.ceil(self.config.rollout_max_batch_size_per_instance * self.config.allow_over_concurrency_ratio),
- )
- return ray.remote(
- concurrency_groups={
- ROLLOUT_CONCURRENCY_GROUP_GENERATE: worker_generate_max_concurrency,
- },
- )(worker_cls)
-
- def _get_rank_by_actor(self, actor: RolloutWorker) -> Optional[int]:
- """Get rank by actor object.
-
- Args:
- actor: The RolloutWorker actor object.
-
- Returns:
- The rank of the worker, or None if not found.
- """
- for rank, info in self.rank2info.items():
- if info.actor == actor:
- return rank
- return None
-
- def _update_active_workers_and_urls_map(self, active_rollout_workers, worker_server_urls_map):
- """Update the list of active rollout workers and their server URLs.
-
- When the inference engine is launched across nodes (rollout_cross_node_comm=True), only the worker with
- tp_rank=0 in each engine is responsible for receiving input data. Other tp_ranks do not accept input.
- Therefore, this function updates active_rollout_workers and worker_server_urls_map to keep only the tp_rank=0
- workers and their corresponding URLs.
- """
- if self.config.rollout_cross_node_comm or self.num_gpus_per_engine < self.config.gpus_per_node:
- return active_rollout_workers, worker_server_urls_map
- else:
- active_worker_interval = self.num_gpus_per_engine // self.config.gpus_per_node
- active_rank = list(worker_server_urls_map.keys())[::active_worker_interval]
- active_worker_server_urls = list(worker_server_urls_map.values())[::active_worker_interval]
- return active_rollout_workers[::active_worker_interval], dict(zip(active_rank, active_worker_server_urls))
-
- def _init_workers(self, placement_group: PlacementGroup):
- """Initializes and configures the pool of RolloutWorker actors.
-
- This method creates workers from the placement group, configures distributed
- inference engines by grouping workers, where each group forms a tensor-parallel
- inference engine. It determines the `active_workers` to act as the head of each
- engine, constructs the `engine_rank_mesh_array` to define engine topology,
- acquires necessary distributed communication ports, and finally launches servers
- on the `active_workers` to get their addresses.
-
- Returns:
- Tuple[List, Dict]: A tuple where the first element is
- `engine_rank_mesh_array`, a list of lists containing the ranks of workers
- in each engine, and the second element is `worker_server_urls_map`,
- a dictionary mapping the rank of each active worker to its
- corresponding server URL.
- """
- # Create workers from placement group
- workers, rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group(
- self._get_worker_cls(), self.config, placement_group
- )
- active_servers_count, nodes_per_engine = self.config.get_active_servers_count(len(workers))
- interval = len(workers) // active_servers_count
- active_rollout_workers = workers[::interval]
- server_urls_per_engine = self.config.server_urls_per_engine
-
- set_bundle_idxs_objectref = []
- engine_rank_mesh_array = []
- activate_worker_idx = 0
- for active_worker in active_rollout_workers:
- head_rank, _ = rank_bundle_idx_list[activate_worker_idx]
- engine_workers_meta = rank_bundle_idx_list[head_rank : head_rank + interval]
- engine_bundle_idxs = [meta[1] for meta in engine_workers_meta] # meta: (rank, bundle_idx)
- set_bundle_idxs_objectref.append(active_worker._set_engine_bundle_idxs.remote(engine_bundle_idxs)) # type: ignore[attr-defined]
- engine_rank_mesh_array.append([meta[0] for meta in engine_workers_meta])
- activate_worker_idx += interval
- ray.get(set_bundle_idxs_objectref)
- # set engine mesh list for each worker
- ray.get(
- [worker._set_engine_rank_mesh_array.remote(engine_rank_mesh_array) for worker in active_rollout_workers]
- ) # type: ignore[attr-defined]
- # init dist_init_addr for each worker according to parallel settings
- init_dist_init_addrs = ray.get([worker.init_dist_port.remote() for worker in active_rollout_workers]) # type: ignore[attr-defined]
- dist_init_addrs = self._update_dist_init_addr(
- nodes_per_engine, server_urls_per_engine, init_dist_init_addrs, self.num_gpus_per_engine
- )
- # launch rollout servers
- init_results = ray.get(
- [worker.init.remote(dist_init_addrs[i]) for i, worker in enumerate(active_rollout_workers)]
- )
- worker_server_urls_map = dict(init_results) # rank -> url
- worker_session_url_dict = dict(
- ray.get([worker.get_session_server_info.remote() for worker in active_rollout_workers])
- )
- active_rollout_workers, worker_server_urls_map = self._update_active_workers_and_urls_map(
- active_rollout_workers, worker_server_urls_map
- )
- active_ranks = list(worker_server_urls_map.keys())
- worker_session_url_dict = {rank: worker_session_url_dict[rank] for rank in active_ranks}
- workers_info = {}
- for i in range(len(active_rollout_workers)):
- rank = list(worker_server_urls_map.keys())[i]
- url = worker_server_urls_map[rank]
- workers_info[rank] = WorkerInfo(
- actor=active_rollout_workers[i],
- url=url,
- session_url=worker_session_url_dict[rank],
- )
- self.logger.info(f"Rollout worker server URLs: {[info.url for info in workers_info.values()]}")
- self.logger.info(f"Rollout worker session server URLs: {[info.session_url for info in workers_info.values()]}")
- return engine_rank_mesh_array, worker_server_urls_map, workers_info
-
+ def _broadcast_to_rollout_worker_generators(self, method_name: str):
+ generators = [worker.generator_actor for worker in self.worker_handles]
+ futures = [getattr(actor, method_name).remote() for actor in generators]
+ if not futures:
+ return []
+ return ray.get(futures, timeout=ROLLOUT_RAY_GET_TIMEOUT)
RayRolloutController = ray.remote(RolloutController)
RolloutControllerProxy: TypeAlias = ActorProxy[RayRolloutController]
diff --git a/xtuner/v1/rl/rollout/lmdeploy.py b/xtuner/v1/rl/rollout/lmdeploy.py
index e8eb241038..d3e2eb7037 100644
--- a/xtuner/v1/rl/rollout/lmdeploy.py
+++ b/xtuner/v1/rl/rollout/lmdeploy.py
@@ -3,14 +3,10 @@
from itertools import chain
from typing import Any, Dict, List
-import numpy as np
import ray
import requests
from ray.util.placement_group import placement_group_table
-from transformers import AutoTokenizer
-from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams
-
from .worker import RolloutConfig, RolloutWorker
@@ -67,14 +63,12 @@ def __init__(
super().__init__(config, rank, master_addr, master_port, world_size, accelerator)
self.server_func = run_lmdeploy_server_wrapper
self.router_func_str = "lmdeploy.serve.proxy.proxy.proxy"
- self.endpoints["health_generate"] = "health"
self.endpoints["generate"] = "generate"
self.endpoints["v1/chat/completions"] = "v1/chat/completions"
self.endpoints["output_ids"] = "output_ids"
self.endpoints["response"] = "text"
self.endpoints["sleep"] = "sleep"
self.endpoints["wake_up"] = "wakeup"
- self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_path, trust_remote_code=True)
self.api_keys = self.config.api_key
self.model_name = self.config.model_name
self.enable_return_routed_experts = self.config.enable_return_routed_experts
@@ -92,58 +86,6 @@ def onload_kvcache(self):
"""Onloads the KV cache by waking up the model."""
return self._wake_up(tags=["kv_cache"])
- def _get_request_payload(self, rollout_state: RolloutState) -> dict:
- tools = rollout_state.tools
- tool_choice = rollout_state.tool_choice
- sample_params = rollout_state.sample_params
- message = rollout_state.message
- input_tokens = rollout_state.tokens
-
- optional_fields: dict[str, object] = {}
- if tools is not None:
- optional_fields["tools"] = tools
- if tool_choice is not None:
- optional_fields["tool_choice"] = tool_choice
-
- if sample_params.return_token_ids:
- payload = {"model": self.model_name, **optional_fields}
-
- if "image_data" in rollout_state.extra_fields:
- assert input_tokens is not None, "input_tokens is required when image_data is provided."
- payload["image_data"] = rollout_state.extra_fields["image_data"]
-
- if input_tokens is not None:
- payload["input_ids"] = input_tokens
- else:
- text_prompt = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
- prompt_token_ids = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"]
- payload["input_ids"] = prompt_token_ids
- sample_params.return_routed_experts = True if self.enable_return_routed_experts else False
- lmdeploy_sample_params = self._transform_sample_params(sample_params)
- payload.update(lmdeploy_sample_params)
- else:
- payload = {
- "model": self.model_name,
- "messages": rollout_state.message,
- **optional_fields,
- }
- lmdeploy_sample_params = {
- "temperature": sample_params.temperature,
- "top_p": sample_params.top_p,
- "n": sample_params.n,
- "stream": sample_params.stream,
- "max_tokens": sample_params.max_tokens,
- "repetition_penalty": sample_params.repetition_penalty,
- "top_k": sample_params.top_k,
- "skip_special_tokens": sample_params.skip_special_tokens,
- }
- if sample_params.stops:
- lmdeploy_sample_params["stop"] = sample_params.stops
- if sample_params.min_tokens > 0:
- lmdeploy_sample_params["min_new_tokens"] = sample_params.min_tokens
- payload.update(lmdeploy_sample_params)
- return payload
-
def _sleep(self, level: int = 1):
"""Put the model into a sleep state to save resources.
@@ -177,15 +119,6 @@ def _wake_up(self, tags: List[str] | None = None):
assert response.status_code == 200, response.status_code
return response.text
- async def _decode_routed_experts(self, routed_experts: Any) -> Any:
- if isinstance(routed_experts, str):
- if self.lmdeploy_actor is None:
- self.lmdeploy_actor = ray.get_actor(SHARED_STORE, namespace=SHARED_STORE_NAMESPACE)
- assert self.lmdeploy_actor is not None, "LMDeploy actor should be available in the shared store."
- routed_experts_data = await self.lmdeploy_actor.get.remote(routed_experts)
- return ray.put(np.asarray(routed_experts_data))
- return np.asarray(routed_experts)
-
async def cleanup_after_pause(self) -> None:
if not self.enable_return_routed_experts:
return
@@ -385,6 +318,3 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace:
speculative_config=speculative_config,
**lmdeploy_config_kwargs,
)
-
- def _transform_sample_params(self, sample_params: SampleParams) -> dict:
- return sample_params.model_dump(exclude_none=True)
diff --git a/xtuner/v1/rl/rollout/rollout_generator.py b/xtuner/v1/rl/rollout/rollout_generator.py
new file mode 100644
index 0000000000..39bd18243b
--- /dev/null
+++ b/xtuner/v1/rl/rollout/rollout_generator.py
@@ -0,0 +1,228 @@
+from __future__ import annotations
+
+import asyncio
+from dataclasses import dataclass
+from typing import Any, Literal
+from uuid import uuid4
+
+import ray
+from pydantic import BaseModel, ConfigDict, Field
+from ray.exceptions import RayActorError
+from transformers import AutoTokenizer
+
+from xtuner.v1.data_proto.rl_data import RolloutState, Status
+from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger
+
+from ._generation.external_http_entry import ExternalRolloutHttpEntryConfig
+from ._generation.internal_http_entry import InternalRolloutHttpEntryConfig
+from ._generation.session_worker_selector import RolloutWorkerHandle, RolloutWorkerUrlSource, SessionWorkerSelector
+from .parser.factory import build_reasoning_parser, build_tool_call_parser
+from .parser.reasoning_parser import ReasoningParser
+from .parser.tool_parser import ToolCallParser
+from .worker import RolloutConfig
+
+
+class LocalRolloutGeneratorConfig(BaseModel):
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
+
+ timeout_multiplier: float = Field(default=2.0, gt=0)
+
+ def build(self, rollout_controller) -> "LocalRolloutGenerator":
+ rollout_metadata = ray.get(rollout_controller.get_rollout_metadata.remote())
+ return LocalRolloutGenerator(
+ worker_handles=rollout_metadata["worker_handles"],
+ rollout_config=rollout_metadata["rollout_config"],
+ timeout_multiplier=self.timeout_multiplier,
+ )
+
+
+class LocalRolloutGenerator:
+ """Local AgentLoop generation path.
+
+ It chooses one active rollout worker by session id, then calls the worker's
+ bound RolloutWorkerGenerator actor directly. The RolloutController is not
+ on the runtime generation path.
+ """
+
+ def __init__(
+ self,
+ worker_handles: list[RolloutWorkerHandle],
+ rollout_config: RolloutConfig,
+ timeout_multiplier: float = 2.0,
+ ) -> None:
+ self.worker_handles = worker_handles
+ self.worker_selector = SessionWorkerSelector(worker_handles)
+ self.config = rollout_config
+ self.timeout_multiplier = timeout_multiplier
+ self.logger = get_logger(log_dir=rollout_config.worker_log_dir, tag="LocalRolloutGenerator")
+ self._tool_call_parser, self._reasoning_parser = self._build_output_parsers()
+
+ def _build_output_parsers(self) -> tuple[ToolCallParser | None, ReasoningParser | None]:
+ tool_call_parser = None
+ reasoning_parser = None
+
+ if self.config.tool_call_parser != "none":
+ tool_call_parser = build_tool_call_parser(self.config.tool_call_parser)
+
+ if self.config.reasoning_parser != "none":
+ tokenizer_path = self.config.tokenizer_path or self.config.model_path
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
+ reasoning_parser = build_reasoning_parser(self.config.reasoning_parser, tokenizer)
+
+ return tool_call_parser, reasoning_parser
+
+ def _apply_output_parsers(self, rollout_state: RolloutState) -> None:
+ if self._tool_call_parser is not None:
+ parsed = self._tool_call_parser.parse(rollout_state)
+ rollout_state.tool_calls = parsed.tool_calls
+ rollout_state.response = parsed.remaining_text or None
+ if self._reasoning_parser is not None:
+ parsed_reasoning = self._reasoning_parser.parse(rollout_state)
+ rollout_state.response = parsed_reasoning.remaining_text
+ if parsed_reasoning.reasoning_text:
+ rollout_state.extra_fields["reasoning_text"] = parsed_reasoning.reasoning_text
+ else:
+ rollout_state.extra_fields.pop("reasoning_text", None)
+
+ async def generate(self, rollout_state: RolloutState, *, enable_partial_rollout: bool = False) -> RolloutState:
+ if XTUNER_DETERMINISTIC:
+ sample_params = rollout_state.sample_params.model_copy(deep=True)
+ sample_params.sampling_seed = self.config.random_seed + (
+ (rollout_state.uid or 0) - (rollout_state.message_uid or 0)
+ )
+ rollout_state.sample_params = sample_params
+
+ session_id = rollout_state.session_uid if rollout_state.session_uid is not None else uuid4().int
+ worker = await self.worker_selector.select(session_id)
+ if worker is None:
+ rollout_state.status = Status.FAILED
+ rollout_state.error_msg = "No rollout worker available."
+ return rollout_state
+
+ try:
+ response_rollout_state = await asyncio.wait_for(
+ worker.generator_actor.generate.remote(
+ rollout_state=rollout_state,
+ enable_partial_rollout=enable_partial_rollout,
+ ),
+ timeout=self.config.rollout_timeout * self.timeout_multiplier,
+ )
+ self._apply_output_parsers(response_rollout_state)
+ return response_rollout_state
+ except (asyncio.TimeoutError, RayActorError) as exc:
+ self.logger.error(f"Rollout failed for worker {worker.rank}. Skipping sample. Error: {exc}")
+ rollout_state.status = Status.FAILED
+ if isinstance(exc, asyncio.TimeoutError):
+ rollout_state.error_msg = (
+ f"Rollout request timed out after {self.config.rollout_timeout * self.timeout_multiplier} seconds."
+ )
+ else:
+ rollout_state.error_msg = f"Rollout worker generator actor failed with error: {exc}"
+ return rollout_state
+
+
+RolloutGenerateKind = Literal["local", "http"]
+RolloutHttpEntryKind = Literal["internal", "external"]
+
+
+@dataclass
+class RolloutGenerateHandle:
+ kind: RolloutGenerateKind
+ local_generator: LocalRolloutGenerator | None = None
+ base_url: str | None = None
+ rollout_controller: Any | None = None
+
+ def require_local_generator(self) -> LocalRolloutGenerator:
+ if self.local_generator is None:
+ raise RuntimeError(f"Rollout generate handle {self.kind!r} does not provide a local generator.")
+ return self.local_generator
+
+ def require_base_url(self) -> str:
+ if self.base_url is None:
+ raise RuntimeError(f"Rollout generate handle {self.kind!r} does not provide a base URL.")
+ return self.base_url
+
+
+class RolloutGenerateHandleConfig(BaseModel):
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
+
+ kind: RolloutGenerateKind = "local"
+ http_entry: RolloutHttpEntryKind = "internal"
+ local_generator_config: LocalRolloutGeneratorConfig = Field(default_factory=LocalRolloutGeneratorConfig)
+
+ base_url: str | None = None
+
+ internal_http_entry_host: str = "0.0.0.0"
+ internal_http_entry_port: int = 8081
+ internal_http_entry_title: str = "XTuner Internal Rollout Router"
+ internal_http_entry_version: str = "0.1.0"
+ internal_http_entry_log_level: str = "warning"
+ internal_http_entry_request_timeout: float | None = None
+ internal_http_entry_stream_timeout: float | None = None
+ http_worker_url_source: RolloutWorkerUrlSource = "backend"
+
+ external_http_entry_delete_existing: bool = True
+ external_http_entry_check_worker_urls: bool = True
+ external_http_entry_check_base_url: bool = True
+
+ def build(self, rollout_controller) -> RolloutGenerateHandle:
+ if self.kind == "local":
+ return RolloutGenerateHandle(
+ kind=self.kind,
+ local_generator=self.local_generator_config.build(rollout_controller),
+ rollout_controller=rollout_controller,
+ )
+
+ if self.kind != "http":
+ raise ValueError(f"Unsupported rollout generate kind: {self.kind!r}")
+
+ base_url = self._resolve_http_base_url(rollout_controller)
+ if base_url is None:
+ raise ValueError(f"Rollout generate handle {self.kind!r} requires base_url.")
+ return RolloutGenerateHandle(kind=self.kind, base_url=base_url, rollout_controller=rollout_controller)
+
+ def _resolve_http_base_url(self, rollout_controller) -> str | None:
+ if self.base_url is not None:
+ return self.base_url
+
+ if self.http_entry == "internal":
+ rollout_metadata = ray.get(rollout_controller.get_rollout_metadata.remote())
+ base_url = rollout_metadata.get("internal_http_entry_url")
+ if base_url is None:
+ raise ValueError(
+ "Rollout generate handle kind='http' and http_entry='internal' requires the internal HTTP "
+ "entry to be started before building AgentLoop, or base_url to be provided as a runtime override."
+ )
+ return base_url
+
+ if self.http_entry == "external":
+ raise ValueError("Rollout generate handle kind='http' and http_entry='external' requires base_url.")
+
+ raise ValueError(f"Unsupported rollout HTTP entry: {self.http_entry!r}")
+
+ def build_internal_http_entry_config(self) -> InternalRolloutHttpEntryConfig | None:
+ if self.kind != "http" or self.http_entry != "internal":
+ return None
+ return InternalRolloutHttpEntryConfig(
+ host=self.internal_http_entry_host,
+ port=self.internal_http_entry_port,
+ title=self.internal_http_entry_title,
+ version=self.internal_http_entry_version,
+ log_level=self.internal_http_entry_log_level,
+ request_timeout=self.internal_http_entry_request_timeout,
+ stream_timeout=self.internal_http_entry_stream_timeout,
+ worker_url_source=self.http_worker_url_source,
+ )
+
+ def build_external_http_entry_config(self) -> ExternalRolloutHttpEntryConfig | None:
+ if self.kind != "http" or self.http_entry != "external":
+ return None
+ if self.base_url is None:
+ raise ValueError("Rollout generate handle kind='http' and http_entry='external' requires base_url.")
+ return ExternalRolloutHttpEntryConfig(
+ base_url=self.base_url,
+ worker_url_source=self.http_worker_url_source,
+ delete_existing=self.external_http_entry_delete_existing,
+ check_worker_urls=self.external_http_entry_check_worker_urls,
+ check_base_url=self.external_http_entry_check_base_url,
+ )
diff --git a/xtuner/v1/rl/rollout/rollout_worker_build.py b/xtuner/v1/rl/rollout/rollout_worker_build.py
new file mode 100644
index 0000000000..f7d34bea7e
--- /dev/null
+++ b/xtuner/v1/rl/rollout/rollout_worker_build.py
@@ -0,0 +1,190 @@
+from __future__ import annotations
+
+import os
+from dataclasses import dataclass
+from typing import Any
+
+import ray
+from ray.util.placement_group import PlacementGroup
+from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
+
+from xtuner.v1.rl.utils import AutoAcceleratorWorkers
+from xtuner.v1.utils import get_logger
+
+from ._generation.rollout_worker_generator import RayRolloutWorkerGenerator
+from ._generation.session_worker_selector import RolloutWorkerHandle
+from .worker import RolloutConfig, RolloutWorker
+
+
+@dataclass
+class RolloutWorkerRuntime:
+ """Runtime handles for one rollout worker."""
+
+ worker_actor: RolloutWorker
+ backend_url: str
+ session_server_url: str | None = None
+ generator_actor: Any | None = None
+ bundle_idx: int | None = None
+
+
+@dataclass
+class RolloutRuntime:
+ engine_rank_mesh_array: list[list[int]]
+ worker_server_urls_map: dict[int, str]
+ rank2worker: dict[int, RolloutWorkerRuntime]
+ worker_handles: list[RolloutWorkerHandle]
+
+
+class RolloutWorkerBuilder:
+ """Bootstrap rollout workers and worker generation actors."""
+
+ def __init__(self, config: RolloutConfig, placement_group: PlacementGroup) -> None:
+ self.config = config
+ self.placement_group = placement_group
+ self.num_gpus_per_engine = config.num_gpus_per_engine
+ self.logger = get_logger(log_dir=config.worker_log_dir, tag="RolloutWorkerBuilder")
+ self.num_active_workers = 0
+
+ def build(self) -> RolloutRuntime:
+ engine_rank_mesh_array, worker_server_urls_map, rank2worker = self._init_workers()
+ self._init_rollout_worker_generators(rank2worker)
+ worker_handles = self._build_worker_handles(rank2worker)
+ return RolloutRuntime(
+ engine_rank_mesh_array=engine_rank_mesh_array,
+ worker_server_urls_map=worker_server_urls_map,
+ rank2worker=rank2worker,
+ worker_handles=worker_handles,
+ )
+
+ def _build_worker_handles(self, rank2worker: dict[int, RolloutWorkerRuntime]) -> list[RolloutWorkerHandle]:
+ worker_handles = []
+ for rank, worker in rank2worker.items():
+ if worker.generator_actor is None:
+ raise RuntimeError(f"Missing RolloutWorkerGenerator for rollout worker rank {rank}.")
+ worker_handles.append(
+ RolloutWorkerHandle(
+ rank=rank,
+ worker_actor=worker.worker_actor,
+ backend_url=worker.backend_url,
+ generator_actor=worker.generator_actor,
+ session_server_url=worker.session_server_url,
+ )
+ )
+ return worker_handles
+
+ def _init_rollout_worker_generators(self, rank2worker: dict[int, RolloutWorkerRuntime]) -> None:
+ for rank, worker in rank2worker.items():
+ if worker.bundle_idx is None:
+ raise RuntimeError(f"Missing placement bundle index for active rollout worker rank {rank}.")
+ scheduling_strategy = PlacementGroupSchedulingStrategy(
+ placement_group=self.placement_group,
+ placement_group_capture_child_tasks=False,
+ placement_group_bundle_index=worker.bundle_idx,
+ )
+ worker.generator_actor = RayRolloutWorkerGenerator.options(
+ scheduling_strategy=scheduling_strategy,
+ num_cpus=0,
+ ).remote(self.config, rank, worker.backend_url)
+
+ def _get_worker_cls(self):
+ if os.environ.get("XTUNER_USE_LMDEPLOY") == "1":
+ from .lmdeploy import LMDeployWorker
+
+ worker_cls = LMDeployWorker
+ elif os.environ.get("XTUNER_USE_VLLM") == "1":
+ from .vllm import vLLMWorker
+
+ worker_cls = vLLMWorker
+ elif os.environ.get("XTUNER_USE_SGLANG") == "1":
+ from .sglang import SGLangWorker
+
+ worker_cls = SGLangWorker
+ else:
+ raise NotImplementedError(
+ "Rollout backend is not supported."
+ "Please set XTUNER_USE_LMDEPLOY or XTUNER_USE_VLLM"
+ " or XTUNER_USE_SGLANG environment variable."
+ )
+ return ray.remote(worker_cls)
+
+ def _update_dist_init_addr(self, nodes_per_engine, server_urls_per_engine, dist_init_addrs, tp_size):
+ assert server_urls_per_engine == 1 or nodes_per_engine == 1
+ if nodes_per_engine > 1:
+ index = list(range(0, self.num_active_workers + 1, tp_size)) + [self.num_active_workers]
+ for i in range(1, len(index)):
+ dist_init_addrs[index[i - 1] : index[i]] = [dist_init_addrs[index[i - 1]]] * (index[i] - index[i - 1])
+ if server_urls_per_engine > 1:
+ activate_servers = len(dist_init_addrs)
+ for i in range(0, activate_servers, server_urls_per_engine):
+ dist_init_addrs[i : i + server_urls_per_engine] = [dist_init_addrs[i]] * server_urls_per_engine
+ return dist_init_addrs
+
+ def _update_active_workers_and_urls_map(self, active_rollout_workers, worker_server_urls_map):
+ if self.config.rollout_cross_node_comm or self.num_gpus_per_engine < self.config.gpus_per_node:
+ return active_rollout_workers, worker_server_urls_map
+
+ active_worker_interval = self.num_gpus_per_engine // self.config.gpus_per_node
+ active_rank = list(worker_server_urls_map.keys())[::active_worker_interval]
+ active_worker_server_urls = list(worker_server_urls_map.values())[::active_worker_interval]
+ return active_rollout_workers[::active_worker_interval], dict(zip(active_rank, active_worker_server_urls))
+
+ def _init_workers(self):
+ workers, rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group(
+ self._get_worker_cls(), self.config, self.placement_group
+ )
+ active_servers_count, nodes_per_engine = self.config.get_active_servers_count(len(workers))
+ interval = len(workers) // active_servers_count
+ active_rollout_workers = workers[::interval]
+ self.num_active_workers = len(active_rollout_workers)
+ server_urls_per_engine = self.config.server_urls_per_engine
+
+ set_bundle_idxs_objectref = []
+ engine_rank_mesh_array = []
+ activate_worker_idx = 0
+ for active_worker in active_rollout_workers:
+ head_rank, _ = rank_bundle_idx_list[activate_worker_idx]
+ engine_workers_meta = rank_bundle_idx_list[head_rank : head_rank + interval]
+ engine_bundle_idxs = [meta[1] for meta in engine_workers_meta]
+ set_bundle_idxs_objectref.append(active_worker._set_engine_bundle_idxs.remote(engine_bundle_idxs)) # type: ignore[attr-defined]
+ engine_rank_mesh_array.append([meta[0] for meta in engine_workers_meta])
+ activate_worker_idx += interval
+ ray.get(set_bundle_idxs_objectref)
+ ray.get(
+ [worker._set_engine_rank_mesh_array.remote(engine_rank_mesh_array) for worker in active_rollout_workers]
+ ) # type: ignore[attr-defined]
+
+ init_dist_init_addrs = ray.get([worker.init_dist_port.remote() for worker in active_rollout_workers]) # type: ignore[attr-defined]
+ dist_init_addrs = self._update_dist_init_addr(
+ nodes_per_engine, server_urls_per_engine, init_dist_init_addrs, self.num_gpus_per_engine
+ )
+ init_results = ray.get(
+ [worker.init.remote(dist_init_addrs[i]) for i, worker in enumerate(active_rollout_workers)]
+ )
+ worker_server_urls_map = dict(init_results)
+ worker_session_url_dict = dict(
+ ray.get([worker.get_session_server_info.remote() for worker in active_rollout_workers])
+ )
+ active_rollout_workers, worker_server_urls_map = self._update_active_workers_and_urls_map(
+ active_rollout_workers, worker_server_urls_map
+ )
+ active_ranks = list(worker_server_urls_map.keys())
+ worker_session_url_dict = {rank: worker_session_url_dict[rank] for rank in active_ranks}
+
+ worker_runtimes = {}
+ rank_to_bundle_idx = {rank: bundle_idx for rank, bundle_idx in rank_bundle_idx_list}
+ for i, rank in enumerate(worker_server_urls_map.keys()):
+ worker_runtimes[rank] = RolloutWorkerRuntime(
+ worker_actor=active_rollout_workers[i],
+ backend_url=worker_server_urls_map[rank],
+ session_server_url=worker_session_url_dict[rank],
+ bundle_idx=rank_to_bundle_idx[rank],
+ )
+ self.logger.info(f"Rollout worker server URLs: {[worker.backend_url for worker in worker_runtimes.values()]}")
+ self.logger.info(
+ f"Rollout worker session server URLs: {[worker.session_server_url for worker in worker_runtimes.values()]}"
+ )
+ return engine_rank_mesh_array, worker_server_urls_map, worker_runtimes
+
+
+def build_rollout_runtime(config: RolloutConfig, placement_group: PlacementGroup) -> RolloutRuntime:
+ return RolloutWorkerBuilder(config, placement_group).build()
diff --git a/xtuner/v1/rl/rollout/sglang.py b/xtuner/v1/rl/rollout/sglang.py
index aedd490e23..61763b2a82 100644
--- a/xtuner/v1/rl/rollout/sglang.py
+++ b/xtuner/v1/rl/rollout/sglang.py
@@ -28,7 +28,6 @@ def __init__(
from sglang.srt.entrypoints.http_server import launch_server
self.server_func = launch_server
- self.endpoints["health_generate"] = "health"
self.endpoints["generate"] = "generate"
self.endpoints["v1/chat/completions"] = "v1/chat/completions"
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path, trust_remote_code=True)
@@ -136,14 +135,6 @@ def _make_request(self, endpoint: str, payload=None):
response.raise_for_status()
return response.json()
- def check_health(self) -> bool:
- try:
- response = requests.get(f"{self.server_url}/{self.endpoints['health_generate']}", timeout=5.0)
- return response.status_code == 200
- except requests.RequestException as e:
- self.logger.error(f"Health check failed for server {self.server_url}: {e}")
- return False
-
def flush_cache(self):
"""Flush the cache of the server."""
# TODO: 支持 tp
diff --git a/xtuner/v1/rl/rollout/utils.py b/xtuner/v1/rl/rollout/utils.py
index 69fdf12f52..cf7039822c 100644
--- a/xtuner/v1/rl/rollout/utils.py
+++ b/xtuner/v1/rl/rollout/utils.py
@@ -1,254 +1,20 @@
-import asyncio
import os
-import threading
import time
-from collections import OrderedDict
-from itertools import cycle
-from typing import TYPE_CHECKING, Any, Optional, cast
+from typing import cast
import numpy as np
import ray
from ray import ObjectRef as RayObjectRef
from xtuner.v1.data_proto.rl_data import RolloutState, Status
-from xtuner.v1.rl.utils import asyncio_run
+from xtuner.v1.rl.utils import free_object_refs
from xtuner.v1.utils import get_logger
-if TYPE_CHECKING:
- from .controller import WorkerInfo
- from .worker import RolloutConfig, RolloutWorker
-
ROLLOUT_RAY_GET_TIMEOUT = int(os.getenv("XTUNER_ROLLOUT_RAY_GET_TIMEOUT", str(5 * 3600))) # default 5 hours
logger = get_logger()
-class SessionRouter:
- def __init__(
- self,
- worker_infos: dict[int, "WorkerInfo"], # worker: worker_status
- worker_infos_lock: Optional[threading.RLock] = None,
- max_sessions: int = 10000,
- max_idle_seconds: Optional[float] = 3600.0,
- ):
- self._worker_infos = worker_infos
- self._worker_infos_lock = worker_infos_lock
- self._max_sessions = max_sessions
- self._max_idle = max_idle_seconds
-
- # OrderedDict: key=session_id -> value=(worker_rank, last_used_ts)
- self._map: OrderedDict[int, tuple[int, float]] = OrderedDict()
-
- self._worker_cycler = cycle(worker_infos.keys())
- self._lock = asyncio.Lock()
- self.logger = get_logger()
-
- def _now(self) -> float:
- return time.time()
-
- def _evict_expired(self):
- if self._max_idle is None:
- return
- now = self._now()
-
- to_delete = []
- for sid, (_, last_used) in self._map.items():
- if now - last_used > self._max_idle:
- to_delete.append(sid)
- else:
- break
- for sid in to_delete:
- self._map.pop(sid, None)
-
- def _evict_lru_to_capacity(self):
- while len(self._map) > self._max_sessions:
- self._map.popitem(last=False)
-
- def _choose_next_active_worker(self) -> tuple[int, Any]:
- n = len(self._worker_infos)
- for _ in range(n):
- rank = next(self._worker_cycler)
- if self._worker_infos_lock is None:
- info = self._worker_infos[rank]
- if info and info.is_active:
- return rank, info.actor
- else:
- with self._worker_infos_lock:
- info = self._worker_infos[rank]
- if info and info.is_active:
- return rank, info.actor
- return -1, None
-
- async def get_worker(self, session_id: int) -> Optional[Any]:
- async with self._lock:
- self._evict_expired()
-
- if session_id in self._map:
- worker_rank, _ = self._map.pop(session_id)
- if self._worker_infos_lock is None:
- info = self._worker_infos.get(worker_rank)
- else:
- with self._worker_infos_lock:
- info = self._worker_infos.get(worker_rank)
- if info and info.is_active:
- self._map[session_id] = (worker_rank, self._now())
- return info.actor
-
- rank, worker = self._choose_next_active_worker()
- if rank == -1:
- return None
- self._map[session_id] = (rank, self._now())
- self._evict_lru_to_capacity()
- return worker
-
-
-class RolloutHealthChecker:
- def __init__(
- self,
- config: "RolloutConfig",
- workers_info: dict[int, "WorkerInfo"],
- worker_infos_lock: Optional[threading.RLock] = None,
- ):
- self._workers_info = workers_info
- self._worker_infos_lock = worker_infos_lock
- self._check_interval = config.health_check_interval_seconds
- self._check_failure_threshold = config.health_check_failure_threshold
- self._stop_event: Optional[threading.Event] = None
- self._pause_event: Optional[threading.Event] = None
- self._thread: Optional[threading.Thread] = None
-
- def start(self) -> None:
- if self._thread and self._thread.is_alive():
- return
-
- self._stop_event = threading.Event()
- self._pause_event = threading.Event()
- self._pause_event.set() # 启动时设置为暂停状态,开始generation后再调用restart方法恢复
- self._thread = threading.Thread(target=self._run_loop, daemon=True)
- self._thread.start()
- logger.info("RolloutHealthChecker started.")
-
- def stop(self) -> None:
- if not self._thread:
- return
-
- assert self._stop_event is not None
- self._stop_event.set()
- if self._pause_event:
- self._pause_event.clear()
- self._thread.join(timeout=5)
- self._thread = None
- self._stop_event = None
- logger.info("RolloutHealthChecker stopped.")
-
- def pause(self) -> None:
- if self._pause_event is None:
- return
- self._pause_event.set()
- logger.info("RolloutHealthChecker paused.")
-
- def resume(self) -> None:
- if self._pause_event is None:
- return
- self._pause_event.clear()
- logger.info("RolloutHealthChecker restarted.")
-
- def run_once(self) -> None:
- logger.debug("RolloutHealthChecker running health checks for all workers.")
- if self._worker_infos_lock is None:
- workers_snapshot = {
- rank: (info.actor, info.url, info.is_active) for rank, info in self._workers_info.items()
- }
- else:
- with self._worker_infos_lock:
- workers_snapshot = {
- rank: (info.actor, info.url, info.is_active) for rank, info in self._workers_info.items()
- }
-
- workers_to_check = [
- (rank, actor, url, is_active) for rank, (actor, url, is_active) in workers_snapshot.items() if is_active
- ]
- if not workers_to_check:
- return
-
- tasks = [
- check_worker_health(actor, rank, url, is_active, self._check_failure_threshold)
- for rank, actor, url, is_active in workers_to_check
- ]
-
- async def _run_checks() -> list[bool]:
- return await asyncio.gather(*tasks)
-
- check_results = asyncio_run(_run_checks())
- inactive_workers = []
- for (rank, _, _, _), is_healthy in zip(workers_to_check, check_results):
- if not is_healthy:
- logger.warning(f"Worker {rank} failed health check. Marking as inactive.")
- if self._worker_infos_lock is None:
- self._workers_info[rank].is_active = False
- inactive_worker = self._workers_info[rank].actor
- else:
- with self._worker_infos_lock:
- self._workers_info[rank].is_active = False
- inactive_worker = self._workers_info[rank].actor
- if inactive_worker is None:
- logger.error(f"[RolloutHealthChecker] Worker {rank} has no actor reference. Skipping shutdown.")
- continue
- inactive_workers.append((rank, inactive_worker))
- else:
- logger.debug(f"[RolloutHealthChecker] Worker {rank} passed health check.")
-
- for rank, inactive_worker in inactive_workers:
- try:
- ray.get(inactive_worker.offload.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined]
- except Exception as e:
- logger.error(f"Exception while offloading worker {rank}: {e}")
-
- try:
- ray.get(inactive_worker.shutdown.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined]
- except Exception as e:
- logger.error(f"Exception while shutting down worker {rank}: {e}")
-
- def _run_loop(self) -> None:
- assert self._stop_event is not None and self._pause_event is not None
- logger.info("RolloutHealthChecker loop started.")
-
- while not self._stop_event.is_set():
- while self._pause_event.is_set() and not self._stop_event.is_set():
- self._stop_event.wait(timeout=0.5)
-
- if self._stop_event.is_set():
- break
-
- if not self._pause_event.is_set() and not self._stop_event.is_set():
- self.run_once()
-
- if self._stop_event.wait(self._check_interval):
- break
-
-
-async def check_worker_health(
- worker: "RolloutWorker", rank: int, url: str, is_active: bool, failure_threshold: int = 3
-) -> bool:
- if worker is None or not is_active:
- logger.warning("Worker has no actor reference or is marked inactive.")
- return False
- failing_count = 0
- while failing_count < failure_threshold:
- try:
- health_status = await worker.check_health.remote() # type: ignore[attr-defined]
- if health_status:
- return True
- failing_count += 1
- logger.warning(f"Health check failed for worker {rank} at {url}. Failure count: {failing_count}")
- except Exception as e:
- failing_count += 1
- logger.error(
- f"Exception during health check for worker {rank} at {url}: {e}. Failure count: {failing_count}"
- )
- return False
-
-
async def _resolve_routed_experts(routed_experts: np.ndarray | RayObjectRef) -> np.ndarray:
if isinstance(routed_experts, RayObjectRef):
routed_experts = await routed_experts
@@ -294,7 +60,8 @@ async def postprocess(
routed_experts: np.ndarray | RayObjectRef | None,
finish_reason: str,
status: Status,
- routed_experts_expect_len: int,
+ routed_experts_expect_len: int | None = None,
+ enable_partial_rollout: bool | None = None,
) -> RolloutState:
rollout_state.finish_reason = finish_reason
rollout_state.status = status
@@ -312,7 +79,9 @@ async def postprocess(
if history_routed_experts is not None and routed_experts is not None:
# case 1: 上一次 rolloutstate 有 response, 本次推理也有 response,需要对 routed experts 进行拼接
start_time = time.perf_counter()
- history_routed_experts = await _resolve_routed_experts(history_routed_experts) # type: ignore[assignment]
+ history_routed_experts_ref = history_routed_experts
+ cur_routed_experts_ref = routed_experts
+ history_routed_experts = await _resolve_routed_experts(history_routed_experts_ref) # type: ignore[assignment]
cur_routed_experts = await _resolve_routed_experts(routed_experts) # type: ignore[assignment]
history_routed_experts_len = len(history_routed_experts)
cur_routed_experts_len = len(cur_routed_experts)
@@ -320,18 +89,26 @@ async def postprocess(
f"Existing routed_experts len: {history_routed_experts_len}, current routed_experts len: {cur_routed_experts_len}, history_response_ids len: {len(history_response_ids)}, current response_ids len: {len(response_ids)}"
)
cur_routed_experts = cur_routed_experts[history_routed_experts_len:]
- concat_routed_experts = np.concatenate([history_routed_experts, cur_routed_experts], axis=0)
+ if isinstance(history_routed_experts, list) and isinstance(cur_routed_experts, list):
+ concat_routed_experts = history_routed_experts + cur_routed_experts
+ else:
+ concat_routed_experts = np.concatenate([history_routed_experts, cur_routed_experts], axis=0)
rollout_state.routed_experts = ray.put(concat_routed_experts)
- expected_len = len(cast(list[int], rollout_state.prompt_ids)) + len(rollout_state.response_ids) - 1
- assert expected_len == routed_experts_expect_len, (
- f"Expected routed_experts len: {expected_len}, routed_experts_expect_len: {routed_experts_expect_len}, prompt_ids_len: {len(cast(list[int], rollout_state.prompt_ids))}, response_ids_len: {len(rollout_state.response_ids)}"
- )
- assert len(concat_routed_experts) == routed_experts_expect_len, (
- f"After concatenation, routed_experts len: {len(concat_routed_experts)}, expected len: {expected_len}, history_routed_experts_len: {history_routed_experts_len}, current_routed_experts_len: {len(cur_routed_experts)}, prompt_ids_len: {len(cast(list[int], rollout_state.prompt_ids))}, response_ids_len: {len(rollout_state.response_ids)}"
+ if routed_experts_expect_len is not None:
+ expected_len = len(cast(list[int], rollout_state.prompt_ids)) + len(rollout_state.response_ids) - 1
+ assert expected_len == routed_experts_expect_len, (
+ f"Expected routed_experts len: {expected_len}, routed_experts_expect_len: {routed_experts_expect_len}, prompt_ids_len: {len(cast(list[int], rollout_state.prompt_ids))}, response_ids_len: {len(rollout_state.response_ids)}"
+ )
+ assert len(concat_routed_experts) == routed_experts_expect_len, (
+ f"After concatenation, routed_experts len: {len(concat_routed_experts)}, expected len: {expected_len}, history_routed_experts_len: {history_routed_experts_len}, current_routed_experts_len: {len(cur_routed_experts)}, prompt_ids_len: {len(cast(list[int], rollout_state.prompt_ids))}, response_ids_len: {len(rollout_state.response_ids)}"
+ )
+ free_object_refs(
+ [
+ ref
+ for ref in (history_routed_experts_ref, cur_routed_experts_ref)
+ if isinstance(ref, RayObjectRef)
+ ]
)
- # free_object_refs(
- # [ref for ref in (history_routed_experts_ref, cur_routed_experts_ref) if isinstance(ref, ray.ObjectRef)]
- # )
end_time = time.perf_counter()
self.logger.debug(
f"[PartialRolloutHandler] Postprocess routed_experts concatenation time: {end_time - start_time:.4f} seconds"
diff --git a/xtuner/v1/rl/rollout/vllm.py b/xtuner/v1/rl/rollout/vllm.py
index b1ff1794bb..6b5c211d63 100644
--- a/xtuner/v1/rl/rollout/vllm.py
+++ b/xtuner/v1/rl/rollout/vllm.py
@@ -121,13 +121,13 @@ def __init__(self, server_namespace: Namespace):
print(stack_trace)
raise # Re-raise the exception to prevent silent failure
- def actor_health(self):
- return "healthy"
+ def started(self):
+ return "started"
# Add a dummy task.
def run_lmdeploy_server_wrapper(server_namespace: Namespace):
- return ray.get(VllmServerWrapper.remote(server_namespace).actor_health.remote()) # type: ignore
+ return ray.get(VllmServerWrapper.remote(server_namespace).started.remote()) # type: ignore
class vLLMWorker(RolloutWorker):
@@ -143,7 +143,6 @@ def __init__(
super().__init__(config, rank, master_addr, master_port, world_size, accelerator)
self.router_func = ""
self.server_func = run_lmdeploy_server_wrapper
- self.endpoints["health_generate"] = "health"
self.endpoints["v1/chat/completions"] = "v1/chat/completions"
self.endpoints["generate"] = "v1/chat/completions"
self.endpoints["sleep"] = "sleep"
diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py
index 4dabb19f7c..7655a83022 100644
--- a/xtuner/v1/rl/rollout/worker.py
+++ b/xtuner/v1/rl/rollout/worker.py
@@ -1,48 +1,34 @@
import asyncio
-import copy
import json
-import math
import multiprocessing
import os
import socket
import threading
import time
-import traceback
from abc import abstractmethod
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Union, cast
+from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Union
import httpx
import ray
-import requests # type: ignore[import-untyped]
from cyclopts import Group, Parameter
-from packaging.version import Version
from pydantic import BaseModel, ConfigDict, PrivateAttr
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from typing_extensions import Annotated
from transformers import AutoTokenizer
-from xtuner.v1.data_proto.rl_data import (
- RolloutState,
- SampleParams,
- Status,
- reset_rollout_response,
- update_status_from_finish_reason,
-)
from xtuner.v1.rl.utils import (
AutoAcceleratorWorkers,
CPUResourcesConfig,
SingleAcceleratorWorker,
- cancel_and_drain,
find_master_addr_and_port,
get_eos_token,
register_cpu_resources,
)
from xtuner.v1.utils import get_logger
-from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult
from .session_server import SessionServerActor
-from .utils import ROLLOUT_RAY_GET_TIMEOUT, PartialRolloutHandler
+from .utils import ROLLOUT_RAY_GET_TIMEOUT
if TYPE_CHECKING:
@@ -50,7 +36,6 @@
infer_group = Group("inference", help="Inference worker configuration.")
-ROLLOUT_CONCURRENCY_GROUP_GENERATE = "generate"
class RolloutConfig(BaseModel):
@@ -272,13 +257,6 @@ class RolloutConfig(BaseModel):
help='Extra configuration for different rollout worker. vllm parameters will start with prefix "vllm", etc.',
),
] = {}
- max_retry_per_worker: Annotated[
- Optional[int],
- Parameter(
- group=infer_group,
- help="Maximum number of retries per rollout worker before deactivation.",
- ),
- ] = None
max_retry_per_sample: Annotated[
int,
Parameter(
@@ -308,20 +286,6 @@ class RolloutConfig(BaseModel):
),
] = False
worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir"
- health_check_interval_seconds: Annotated[
- float,
- Parameter(
- group=infer_group,
- help="Interval in seconds between rollout worker health checks.",
- ),
- ] = 30.0
- health_check_failure_threshold: Annotated[
- int,
- Parameter(
- group=infer_group,
- help="Number of consecutive health check failures required before marking a worker inactive.",
- ),
- ] = 3
_logged_server_urls_per_engine: bool = PrivateAttr(default=False)
@property
@@ -375,17 +339,6 @@ def get_active_servers_count(self, num_rollout_workers: int) -> tuple[int, int]:
)
return active_servers_count, nodes_per_engine
- def get_controller_generate_concurrency(self, placement_group: "PlacementGroup") -> int:
- active_worker_count, _ = self.get_active_servers_count(len(placement_group.bundle_specs))
- assert self.rollout_max_batch_size_per_instance is not None, (
- "rollout_max_batch_size_per_instance must be set before building RolloutController."
- )
- concurrency_per_worker = math.ceil(
- self.rollout_max_batch_size_per_instance * self.allow_over_concurrency_ratio
- )
- generate_max_concurrency = active_worker_count * concurrency_per_worker
- return generate_max_concurrency
-
def model_post_init(self, __context: Any) -> None:
if self.model_name is None:
model_name_from_config = None
@@ -433,9 +386,6 @@ def model_post_init(self, __context: Any) -> None:
else:
self.rollout_max_batch_size_per_instance = 128
- if self.max_retry_per_worker is None:
- self.max_retry_per_worker = self.rollout_max_batch_size_per_instance
-
self.worker_log_dir.mkdir(parents=True, exist_ok=True)
def build(self, placement_group: "PlacementGroup"):
@@ -450,31 +400,22 @@ def build(self, placement_group: "PlacementGroup"):
import ray
from xtuner.v1.rl.rollout.controller import RolloutController
+ from xtuner.v1.rl.rollout.rollout_worker_build import build_rollout_runtime
num_workers = 1
register_cpu_resources(
name="rollout_controller",
cpu_resources=CPUResourcesConfig(num_workers=num_workers),
)
- generate_max_concurrency = self.get_controller_generate_concurrency(placement_group)
- get_logger().info(f"Calculated RolloutController generate concurrency: {generate_max_concurrency}")
- return (
- ray.remote(
- concurrency_groups={
- ROLLOUT_CONCURRENCY_GROUP_GENERATE: generate_max_concurrency,
- },
- )(RolloutController)
- .options(num_cpus=num_workers)
- .remote(self, placement_group)
- )
+ runtime = build_rollout_runtime(self, placement_group)
+ return ray.remote(RolloutController).options(num_cpus=num_workers).remote(self, runtime=runtime)
class RolloutWorker(SingleAcceleratorWorker):
"""Base class for a rollout worker that runs an inference server.
- This class manages the lifecycle of a distributed inference server, including initialization, launching, and
- handling generation requests. It is designed to be subclassed for specific inference backends like LMDeploy, vLLM
- or SGLang.
+ This class manages the lifecycle of a distributed inference server, including initialization, launching, weight
+ updates, and backend control. Runtime generation is handled by RolloutWorkerGenerator.
"""
def __init__(
@@ -506,13 +447,9 @@ def __init__(
self.server_func: Callable
self.endpoints: dict[str, str] = dict()
self.engine_rank_mesh_array: list[list[int]]
- # http_concurrency is calculated based on the max batch size per engine and the total number of engines
assert config.rollout_max_batch_size_per_instance, (
"rollout_max_batch_size_per_instance must be set in RolloutConfig"
)
- http_concurrency = config.rollout_max_batch_size_per_instance * config.allow_over_concurrency_ratio
- limits = httpx.Limits(max_connections=http_concurrency, max_keepalive_connections=100)
- self.client = httpx.AsyncClient(limits=limits, timeout=self.config.rollout_timeout)
self.server_task = None
self.engine_bundle_idxs: list[int] = []
self.server_process: Optional[multiprocessing.Process] = None
@@ -533,11 +470,6 @@ def __init__(
self.abort_timeout = 10.0
self.dist_init_addr: str = ""
self.serverl_url: str = ""
- self.partial_rollout_handler = PartialRolloutHandler()
- self.enable_partial_rollout: bool = False
-
- def set_enable_partial_rollout(self, enable: bool) -> None:
- self.enable_partial_rollout = enable
def init(self, dist_init_addr: str) -> tuple[int, str]:
"""Initialize the worker and launch the server.
@@ -686,185 +618,9 @@ def continue_generation(self):
"""Resume the worker's generation process."""
self.receive_abort_request.clear()
- def check_health(self) -> bool:
- """Check the health of the worker's server.
-
- Returns:
- bool: True if the server is healthy, False otherwise.
- """
- try:
- headers = {
- "Content-Type": "application/json; charset=utf-8",
- "Authorization": f"Bearer {self.config.api_key}",
- }
- response = requests.get(
- f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers, timeout=5.0
- )
- return response.status_code == 200
- except requests.RequestException as e:
- self.logger.error(f"Health check failed for server {self.server_url}: {e}")
- return False
-
- async def _decode_routed_experts(self, routed_experts: Any) -> Any:
- return routed_experts
-
- @ray.method(concurrency_group=ROLLOUT_CONCURRENCY_GROUP_GENERATE)
- async def generate(self, rollout_state: RolloutState) -> RolloutState:
- # TODO(@duanyanhui):
- # 1. support claude format input
- # 2. 需要看下新的输入输出(RolloutState)怎么适配PartialRollout的逻辑,先跑起来
- # 3. 对于流式返回的response先删掉,目前还用不上,等需要的时候再加上
-
- if self.receive_abort_request.is_set():
- rollout_state.finish_reason = "abort"
- rollout_state.status = Status.ABORTED
- return rollout_state
-
- uid = rollout_state.uid
- sample_params: SampleParams = rollout_state.sample_params
- max_tokens = sample_params.max_tokens
- enable_partial_rollout = self.enable_partial_rollout
- if sample_params.return_token_ids:
- endpoint_url = f"{self.server_url}/{self.endpoints['generate']}"
- else:
- endpoint_url = f"{self.server_url}/{self.endpoints['v1/chat/completions']}"
-
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {self.config.api_key}",
- }
-
- if enable_partial_rollout:
- rollout_state = self.partial_rollout_handler.preprocess(rollout_state, max_tokens)
- elif rollout_state.status == Status.ABORTED:
- # ABORTED samples can be replayed; without partial rollout, rerun from the original prompt.
- rollout_state = reset_rollout_response(rollout_state)
- rollout_state.sample_params = rollout_state.sample_params.model_copy(update={"max_tokens": max_tokens})
- rollout_state.status = Status.INIT
- payload = self._get_request_payload(rollout_state)
- max_retries = self.config.max_retry_per_sample
-
- # 早退逻辑 1:检查是否已被标记为完成
- if rollout_state.status == Status.COMPLETED:
- self.logger.debug(f"Request {uid} is already marked as COMPLETED, skipping generation.")
- return rollout_state
-
- # 早退逻辑 2:检测输入是否还需要 generation (安全获取变量)
- input_ids = payload.get("input_ids", [])
- max_tokens = cast(int, payload.get("max_tokens"))
-
- last_id = input_ids[-1] if len(input_ids) > 0 else "None"
- is_max_tokens_zero = max_tokens is not None and max_tokens <= 0
- is_eos_reached = len(input_ids) > 0 and input_ids[-1] in self.eos_token
-
- if is_max_tokens_zero or is_eos_reached:
- self.logger.debug(
- f"No generation needed for request {uid}: max_tokens={max_tokens} or last input_id={last_id} is in eos_token."
- )
- finish_reason = "stop" if is_eos_reached else "length"
- # 对于是否开 partial rollout 的情况都直接标记为完成并返回,因为本轮 rollout 未开始,也不需要拼接
- rollout_state.finish_reason = finish_reason
- rollout_state.status = Status.COMPLETED
- return rollout_state
-
- for attempt in range(max_retries + 1):
- is_last_attempt = attempt == max_retries
- http_result = await self._safe_post_request(endpoint_url, headers=headers, payload=payload)
-
- # Case 1: HTTP Request is Successful
- if http_result.response:
- # Case 1.1: Valid rollout response
- rollout_state = await self._safe_handle_response(rollout_state, http_result.response)
- if rollout_state.status in [Status.COMPLETED, Status.ABORTED]:
- return rollout_state
-
- if is_last_attempt:
- # Case 1.2: Invalid rollout response and no retries left, so we return FAILED
- self.logger.warning(
- f"Invalid rollout response for request {uid} after {max_retries} attempts, marking as FAILED."
- )
- rollout_state.status = Status.FAILED
- rollout_state.error_msg = f"Invalid rollout response after {max_retries} attempts."
- return rollout_state
-
- # Case 1.3: Invalid rollout response but we have retries left
- self.logger.warning(
- f"Invalid rollout response for request {uid}, retrying {attempt + 1}/{max_retries}."
- )
- await asyncio.sleep(0.1)
- continue
-
- # Case 2: Error occurred during HTTP Request
- if http_result.error_type == HttpRequestErrorType.REQUEST_ABORTED:
- # Case 2.1: The request was aborted due to an signal set by `receive_abort_request`
- rollout_state.finish_reason = "abort"
- rollout_state.status = update_status_from_finish_reason("abort")
- return rollout_state
-
- if http_result.is_client_error:
- # Case 2.2: A non-retryable client error occurred (such as 4xx HTTP status)
- self.logger.warning(
- f"rollout request {uid} to {http_result.url} was skipped due to client error {http_result.error_type} with {http_result.error_msg}"
- )
- rollout_state.error_msg = (
- f"Client error {http_result.error_type} with message: {http_result.error_msg}"
- )
- rollout_state.status = Status.FAILED
- return rollout_state
-
- if http_result.is_server_error:
- # Case 2.3: A non-retryable server error occurred (such as 5xx HTTP status)
- self.logger.warning(
- f"rollout request {uid} to {http_result.url} failed due to server error {http_result.error_type} with {http_result.error_msg}"
- )
- rollout_state.error_msg = (
- f"Server error {http_result.error_type} with message: {http_result.error_msg}"
- )
- rollout_state.status = Status.FAILED
- return rollout_state
-
- # Case 3: Retryable error occurred during HTTP Request
- if http_result.is_retryable:
- if is_last_attempt:
- self.logger.warning(
- f"rollout request {uid} to {http_result.url} failed after {max_retries} attempts due to retryable error {http_result.error_type} with {http_result.error_msg}"
- )
- rollout_state.error_msg = f"Request failed after {max_retries} attempts due to retryable error {http_result.error_type} with message: {http_result.error_msg}"
- rollout_state.status = Status.FAILED
- return rollout_state
-
- self.logger.warning(
- f"rollout request {uid} to {http_result.url} failed due to retryable error {http_result.error_type} with {http_result.error_msg}, retrying {attempt + 1}/{max_retries}."
- )
- await asyncio.sleep(0.1)
- continue
-
- # Case 4: Unknown error occurred during HTTP Request and stop the rollout
- if http_result.is_unknown_error:
- raise RuntimeError(
- f"Unexpected error during rollout request {uid} to {http_result.url}: {http_result.exception}"
- )
- return rollout_state
-
def _launch_server(self):
- """Launch the inference server as a separate process or Ray task.
-
- It waits for the server to become healthy before returning.
-
- Raises:
- TimeoutError: If the server fails to start within the specified
- timeout.
- Exception: If the server task terminates unexpectedly.
- """
+ """Launch the inference server as a separate process or Ray task."""
server_configs = self._transform_rollout_config_to_server_configs()
- timeout = 3600.0 # Increased timeout to 5 minutes for downloading large models
- start_time = time.perf_counter()
- last_log_time = start_time
- headers = {
- "Content-Type": "application/json; charset=utf-8",
- "Authorization": f"Bearer {server_configs.api_key}",
- }
-
self.logger.info(f"Launch server task on server_url: {self.server_url}")
# note(@duanyanhui): launch server as multiprocessing for sglang temporarily
@@ -873,30 +629,8 @@ def _launch_server(self):
process = mp_ctx.Process(target=self.server_func, args=(server_configs,))
process.start()
self.server_process = process
- time.sleep(60) # Wait for the server to start
- with requests.Session() as session:
- while time.perf_counter() - start_time < timeout:
- try:
- response = session.get(
- f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers
- )
- if response.status_code == 200:
- return
- except requests.RequestException as e:
- self.logger.error(
- f"can't connect to server url {self.server_url}/{self.endpoints['health_generate']} because {e}"
- )
-
- current_time = time.perf_counter()
- if current_time - last_log_time >= 15:
- self.logger.info(
- f"Waiting for server to start, Elapsed time: {current_time - start_time:.2f}s"
- )
- last_log_time = current_time
-
- time.sleep(5)
- process.terminate()
- raise TimeoutError("Server failed to start within the timeout period.")
+ time.sleep(60)
+ return
else:
# launch the server as ray task
# so that the lmdeploy backend could get externl pg
@@ -919,312 +653,7 @@ def _launch_server(self):
)
.remote(server_configs)
)
-
- with requests.Session() as session:
- while time.perf_counter() - start_time < timeout:
- try:
- response = session.get(
- f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers
- )
- if response.status_code == 200:
- return
- except requests.RequestException:
- pass
-
- try:
- ray.get(self.server_task, timeout=0.1)
- raise Exception("Server task terminated unexpectedly.")
- except ray.exceptions.GetTimeoutError:
- pass
- except Exception as e:
- raise e
-
- current_time = time.perf_counter()
- if current_time - last_log_time >= 15:
- self.logger.info(
- f"Waiting for server to start... Elapsed time: {current_time - start_time:.2f}s"
- )
- last_log_time = current_time
-
- ray.cancel(self.server_task)
- raise TimeoutError("Server failed to start within the timeout period.")
-
- async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult:
- send_task = None
- abort_task = None
-
- try:
- if self.receive_abort_request.is_set():
- self.logger.debug(f"Request to {url} was cancelled before sending due to an abort signal.")
- return HttpRequestResult(error_type=HttpRequestErrorType.REQUEST_ABORTED, url=url, payload=payload)
- req = self.client.build_request(
- "POST",
- url,
- headers=headers,
- json=payload,
- )
- send_task = asyncio.create_task(self.client.send(req))
- abort_task = asyncio.create_task(self._wait_abort_request())
- done, _ = await asyncio.wait(
- {send_task, abort_task},
- return_when=asyncio.FIRST_COMPLETED,
- )
- if send_task in done:
- r = await send_task
- else:
- try:
- r = await asyncio.wait_for(asyncio.shield(send_task), timeout=self.abort_timeout)
- except asyncio.TimeoutError:
- self.logger.debug(
- f"Request to {url} did not return within {self.abort_timeout:.2f}s after abort signal."
- )
- await cancel_and_drain([send_task])
- return HttpRequestResult(
- error_type=HttpRequestErrorType.REQUEST_ABORTED,
- url=url,
- payload=payload,
- )
- r.raise_for_status()
- return HttpRequestResult(response=r)
-
- except asyncio.CancelledError:
- self.logger.debug(f"Request to {url} was cancelled while waiting for the response.")
- await cancel_and_drain([send_task, abort_task])
- self.receive_abort_request.set()
- return HttpRequestResult(error_type=HttpRequestErrorType.REQUEST_ABORTED, url=url, payload=payload)
- except Exception as e:
- error_type = HttpRequestErrorType.from_exception(e)
- result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload)
- return result
- finally:
- await cancel_and_drain([abort_task])
-
- async def _safe_handle_response(self, rollout_state: RolloutState, http_response: httpx.Response) -> RolloutState:
- uid = rollout_state.message_uid
- sample_params = rollout_state.sample_params
- is_token_out = sample_params.return_token_ids
- response = http_response.json()
-
- if is_token_out:
- response_ids: list[int] = []
- logprobs: list[float] = []
- routed_experts = None
- returned_response = ""
- try:
- meta_info = response.get("meta_info") or {}
- finish_reason_info = meta_info.get("finish_reason") or {}
- finish_reason = finish_reason_info.get("type")
- if finish_reason is None:
- if self.receive_abort_request.is_set():
- rollout_state.finish_reason = "abort"
- rollout_state.status = Status.ABORTED
- self.logger.warning(
- f"finish_reason is missing in response meta_info when waiting for aborted message {uid}, defaulting to 'abort'. Response: {response}"
- )
- else:
- rollout_state.finish_reason = "error"
- rollout_state.status = Status.FAILED
- self.logger.warning(
- f"finish_reason is missing in response meta_info for message {uid}, defaulting to 'error'. Response: {response}"
- )
- rollout_state.error_msg = "Missing finish_reason in response meta_info"
- return rollout_state
- returned_response = response.get("text", "")
- # 获取response_ids && respoonse_ids
- if (
- "output_token_logprobs" in response["meta_info"]
- and response["meta_info"]["output_token_logprobs"] is not None
- ):
- response_ids = [item[1] for item in response["meta_info"]["output_token_logprobs"]]
- logprobs = [item[0] for item in response["meta_info"]["output_token_logprobs"]]
- else:
- num_return_tokens = response["meta_info"].get("completion_tokens", 0)
- response_ids = response["output_ids"][-num_return_tokens:] if num_return_tokens > 0 else []
-
- # 获取 routed_experts
- if self.enable_return_routed_experts:
- assert "routed_experts" in response["meta_info"], (
- "enable_return_routed_experts is True, but routed_experts is not in meta_info"
- )
- routed_experts = response["meta_info"]["routed_experts"] # token[layer[expert]]
- if routed_experts is not None:
- routed_experts = await self._decode_routed_experts(routed_experts)
- if not isinstance(routed_experts, ray.ObjectRef):
- routed_experts = ray.put(routed_experts)
-
- # 获取 status
- rollout_status = update_status_from_finish_reason(finish_reason)
-
- # 检查输出结果
- if rollout_status == Status.COMPLETED:
- validation_errors = []
-
- if not response_ids:
- validation_errors.append("empty response_ids")
-
- if not response:
- validation_errors.append("empty response text")
-
- if sample_params.return_logprob and not logprobs:
- validation_errors.append("missing logprobs")
-
- if self.enable_return_routed_experts and routed_experts is None:
- validation_errors.append("missing routed_experts")
-
- if validation_errors:
- error_msg = f"Incomplete rollout data for msg {uid}: {', '.join(validation_errors)}"
- self.logger.error(error_msg)
- rollout_state.status = Status.FAILED
- rollout_state.error_msg = error_msg
- return rollout_state
- elif rollout_status == Status.FAILED:
- error_msg = f"Rollout failed for msg {uid} with finish_reason {finish_reason}"
- self.logger.error(error_msg)
- rollout_state.status = Status.FAILED
- rollout_state.error_msg = error_msg
- return rollout_state
-
- if self.enable_partial_rollout:
- expect_len = (
- response["meta_info"]["prompt_tokens"] + response["meta_info"]["completion_tokens"] - 1
- )
- rollout_state = await self.partial_rollout_handler.postprocess(
- rollout_state,
- response=returned_response,
- response_ids=response_ids,
- logprobs=logprobs,
- routed_experts=routed_experts,
- finish_reason=finish_reason,
- status=rollout_status,
- routed_experts_expect_len=expect_len,
- )
- else:
- rollout_state.response = returned_response
- rollout_state.response_ids = response_ids
- rollout_state.logprobs = logprobs
- rollout_state.routed_experts = routed_experts
- rollout_state.finish_reason = finish_reason
- rollout_state.status = rollout_status
- return rollout_state
- except KeyError as e:
- response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")}
- error_msg = f"Missing expected key {e} in response {response_for_log} for {uid}"
- raise RuntimeError(error_msg)
- except IndexError as e:
- response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")}
- error_msg = f"Index error {e} while processing response {response_for_log} for {uid}"
- raise RuntimeError(error_msg)
- except AssertionError as e:
- response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")}
- error_msg = f"AssertionError: {e} when processing response {response_for_log} for {uid}"
- raise RuntimeError(error_msg)
- except json.JSONDecodeError as e:
- error_msg = f"JSONDecodeError: {e} when processing response {response} for {uid}"
- raise RuntimeError(error_msg)
- except TypeError as e:
- response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")}
- error_msg = f"TypeError: {e} when processing response {response_for_log} for {uid}"
- raise RuntimeError(error_msg)
- except Exception as e:
- response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")}
- error_msg = f"Unexpected error: {e} when processing response {response_for_log} for {uid}\nTraceback: {traceback.format_exc()}"
- raise RuntimeError(error_msg)
- else:
- # v1/chat/completions API response
- try:
- returned_response = response["choices"][0]["message"]["content"]
- finish_reason = response["choices"][0]["finish_reason"]
- rollout_status = update_status_from_finish_reason(finish_reason)
- if rollout_status == Status.COMPLETED and not returned_response:
- self.logger.error(f"Empty response text for msg {uid} with finish_reason {finish_reason}")
- rollout_state.status = Status.FAILED
- rollout_state.error_msg = "Empty response text"
- return rollout_state
-
- rollout_state.response = returned_response
- rollout_state.finish_reason = finish_reason
- rollout_state.status = rollout_status
- return rollout_state
- except KeyError as e:
- response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")}
- error_msg = f"Missing expected key {e} in response {response_for_log} for {uid}"
- raise RuntimeError(error_msg)
- except IndexError as e:
- response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")}
- error_msg = f"Index error {e} while processing response {response_for_log} for {uid}"
- raise RuntimeError(error_msg)
- except AssertionError as e:
- response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")}
- error_msg = f"AssertionError: {e} when processing response {response_for_log} for {uid}"
- raise RuntimeError(error_msg)
- except json.JSONDecodeError as e:
- error_msg = f"JSONDecodeError: {e} when processing response {response} for {uid}"
- raise RuntimeError(error_msg)
- except TypeError as e:
- response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")}
- error_msg = f"TypeError: {e} when processing response {response_for_log} for {uid}"
- raise RuntimeError(error_msg)
- except Exception as e:
- response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")}
- error_msg = f"Unexpected error: {e} when processing response {response_for_log} for {uid}\nTraceback: {traceback.format_exc()}"
- raise RuntimeError(error_msg)
-
- def _adapt_input_to_openai_spec(self, prompts, tools, tool_choice):
- openai_prompts = []
- openai_tools = []
- # transform claude spec to openai spec
- # 1. transform system prompt: concat provided system_prompt to input prompt
- system_prompt = self.config.system_prompt
- if system_prompt:
- system_prompt_json = {"role": "system", "content": f"{system_prompt}"}
- prompts.insert(0, system_prompt_json)
- # 2. transform multi-modal usage
- for prompt in prompts:
- content = prompt["content"]
- openai_content = []
- for item in content:
- if item["type"] == "image":
- if item["source"]["type"] == "base64":
- openai_url = f"data:{item['source']['media_type']};base64,{item['source']['data']}"
- if item["source"]["type"] == "url":
- openai_url = item["source"]["url"]
- new_prompt = {"type": "image_url", "image_url": {"url": openai_url}}
- openai_content.append(new_prompt)
- elif item["type"] == "text":
- openai_content.append(item)
- new_prompt = copy.deepcopy(prompt)
- new_prompt["content"] = openai_content
- openai_prompts.append(new_prompt)
- # 3. transform tool use
- for tool in tools:
- openai_tool = {
- "type": "function",
- "function": {
- "name": tool["name"],
- "description": tool["description"],
- "parameters": tool["input_schema"],
- },
- }
- openai_tools.append(openai_tool)
- return openai_prompts, openai_tools
-
- def _check_infer_engine_version(self, return_token_ids: bool):
- # TODO(@duanyanhui): remove this check when all backends support return_token_ids
- if self.check_flag:
- if os.environ.get("XTUNER_USE_VLLM", "0") == "1":
- if return_token_ids:
- self.logger.error(
- "VLLM backend does not support return_token_ids or generate with input_ids as input in Xtuner now"
- )
- elif os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "1":
- import lmdeploy
-
- lmdeploy_version = lmdeploy.__version__
- if return_token_ids and Version(lmdeploy_version) < Version("0.10.2"):
- self.logger.error(
- f"You should use lmdeploy >= v0.10.2 to support return_token_ids, but current version is {lmdeploy_version}"
- )
- self.check_flag = False
+ return
def _set_engine_rank_mesh_array(self, engine_rank_mesh_array: list[list[int]]):
self.engine_rank_mesh_array = engine_rank_mesh_array
@@ -1240,14 +669,6 @@ def _set_engine_bundle_idxs(self, engine_bundle_idxs: list[int]):
"""
self.engine_bundle_idxs = engine_bundle_idxs
- @abstractmethod
- def _get_request_payload(self, rollout_state: RolloutState) -> dict:
- """Abstract method to create a generation request.
-
- Must be implemented by subclasses.
- """
- raise NotImplementedError("_create_request must be implemented in subclass")
-
@abstractmethod
def _transform_rollout_config_to_server_configs(self):
"""Abstract method to transform rollout config to server configs.
@@ -1256,14 +677,6 @@ def _transform_rollout_config_to_server_configs(self):
"""
raise NotImplementedError("_transform_rollout_config_to_server_configs must be implemented in subclass")
- @abstractmethod
- def _transform_sample_params(self, sample_params: SampleParams) -> dict:
- """Abstract method to transform rollout config to server configs.
-
- Must be implemented by subclasses.
- """
- raise NotImplementedError("_transform_rollout_config_to_server_configs must be implemented in subclass")
-
@abstractmethod
def offload(self):
"""Abstract method to offload the model and KVcache.
diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py
index c713d30c3f..0babc22ecf 100644
--- a/xtuner/v1/train/rl_trainer.py
+++ b/xtuner/v1/train/rl_trainer.py
@@ -36,6 +36,7 @@
_snapshot_nested_objectrefs,
)
from xtuner.v1.rl.rollout.controller import RolloutControllerProxy
+from xtuner.v1.rl.rollout.rollout_generator import RolloutGenerateHandleConfig
from xtuner.v1.rl.rollout.worker import RolloutConfig
from xtuner.v1.rl.trainer.controller import TrainingController
from xtuner.v1.rl.trainer.worker import WorkerConfig, WorkerLogItem
@@ -48,7 +49,6 @@
set_cpu_resource_manager,
sort_rollout_state_for_deterministic,
)
-from xtuner.v1.rl.utils.misc import check_chat_completions, delete_from_routedapiproxy, register_to_routedapiproxy
from xtuner.v1.train.trainer import LoadCheckpointConfig, XTunerMeta
from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger, is_hf_model_path, set_deterministic, timer
from xtuner.v1.utils.device import get_device, get_torch_device_module
@@ -283,6 +283,7 @@ class BaseRLTrainerConfig(BaseModel):
advantage_estimator_config: BaseAdvantageConfig = Field(default_factory=GRPOAdvantageConfig)
sync_weights_interval: int = 1
gateway_config: GatewayConfig | None = None
+ rollout_generator_config: RolloutGenerateHandleConfig = Field(default_factory=RolloutGenerateHandleConfig)
enable_evaluate: bool = True
enable_initial_evaluate: bool = False
@@ -623,10 +624,23 @@ def _maybe_start_gateway(self, cfg: BaseRLTrainerConfig) -> None:
# gateway 依赖 rollout controller,因此在 rollout controller 构建完成后统一启动。
ray.get(self.rollout_controller.start_gateway.remote(cfg.gateway_config))
+ def _build_rollout_generate_handle(self, cfg: BaseRLTrainerConfig):
+ internal_http_config = cfg.rollout_generator_config.build_internal_http_entry_config()
+ if internal_http_config is not None:
+ ray.get(self.rollout_controller.start_internal_http_entry.remote(internal_http_config))
+
+ external_http_entry_config = cfg.rollout_generator_config.build_external_http_entry_config()
+ if external_http_entry_config is not None:
+ ray.get(self.rollout_controller.start_external_http_entry.remote(external_http_entry_config))
+
+ return cfg.rollout_generator_config.build(self.rollout_controller)
+
def _build_agent_loop_components(self, cfg: BaseRLTrainerConfig, replay_buffer) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path, trust_remote_code=True)
+ rollout_generator = self._build_rollout_generate_handle(cfg)
self.agent_loop_manager = cfg.agent_loop_manager_cfg.build(
rollout_controller=self.rollout_controller,
+ rollout_generator=rollout_generator,
tokenizer=self.tokenizer,
replay_buffer=replay_buffer,
logger=self.logger,
@@ -637,6 +651,7 @@ def _build_agent_loop_components(self, cfg: BaseRLTrainerConfig, replay_buffer)
assert cfg.eval_agent_loop_manager_cfg is not None
self.eval_agent_loop_manager = cfg.eval_agent_loop_manager_cfg.build(
rollout_controller=self.rollout_controller,
+ rollout_generator=rollout_generator,
tokenizer=self.tokenizer,
replay_buffer=replay_buffer,
logger=self.logger,
@@ -1320,37 +1335,6 @@ def _log_mini_batch_metrics(self, workers_log_item: List[WorkerLogItem]):
)
self._global_train_step += len(workers_log_item[0]["train_metrics"])
-
-def add_apiproxy(self):
- info_dict = ray.get(self.rollout_controller.get_rollout_metadata.remote())
- model_name = info_dict["rollout_config"].model_name
-
- delete_from_routedapiproxy(model_name)
- self.logger.info(f"deleted {model_name} from routedapiproxy")
- self.logger.info("registering to routedapiproxy")
-
- worker_session_url_dict = info_dict["worker_session_url_dict"]
- worker_session_urls_status = info_dict["worker_session_urls_status"]
- for _, worker_session_url in sorted(worker_session_url_dict.items()):
- if not worker_session_urls_status.get(worker_session_url, False):
- continue
- register_to_routedapiproxy(model_name, worker_session_url)
-
- # test server url
- recheck_status_orig = check_chat_completions(worker_session_url, model_name)
- if not recheck_status_orig:
- raise ValueError(f"check chat completions failed for {worker_session_url}")
-
- # test routed url
- routed_url = "http://s-20260104203038-22bhb.ailab-evalservice.pjh-service.org.cn/v1"
- recheck_status_routed = check_chat_completions(routed_url, model_name)
- if not recheck_status_routed:
- raise ValueError(f"check chat completions failed for {routed_url}")
- self.logger.info("registered to routedapiproxy")
- # import time
- # time.sleep(1000000)
-
-
class RLColocateTrainer(BaseRLTrainer):
_META_PATH = ".xtuner_rl_colocate_trainer"
@@ -1373,7 +1357,6 @@ def __init__(self, cfg: RLColocateTrainerConfig):
self._rollout_config.skip_load_weights = False
self.rollout_controller = self._rollout_config.build(self._pg)
# self._maybe_start_gateway(cfg)
- add_apiproxy(self)
replay_buffer = cfg.replay_buffer_config.build()
self._build_agent_loop_components(cfg, replay_buffer)
@@ -1408,8 +1391,6 @@ def __init__(self, cfg: RLColocateTrainerConfig):
# self._maybe_start_gateway(cfg)
bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller)
- add_apiproxy(self)
-
replay_buffer = cfg.replay_buffer_config.build()
self._build_agent_loop_components(cfg, replay_buffer)
if checkpoint_path is not None:
@@ -1540,7 +1521,6 @@ def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict) -> bool
self._maybe_save_checkpoint(train_step)
self._maybe_save_hf(train_step)
- ray.get(self.rollout_controller.recover_failed_workers.remote())
timer_name = "sync_weight" if should_sync_weights else "switch_to_rollout"
with timer(timer_name, step_timer_dict):
if should_sync_weights:
@@ -1714,7 +1694,6 @@ async def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict):
self._maybe_save_checkpoint(train_step)
self._maybe_save_hf(train_step)
- ray.get(self.rollout_controller.recover_failed_workers.remote())
with timer("sync_weight", step_timer_dict):
bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller)
self.update_weights()