refactor rollout weight update flow#1828
Conversation
Transport / Adapter 重构草图当前问题当前 PR 里 Adapter 的名字是对的,但 Interface 还偏浅:
这些都让 Adapter 知道太多 Transport 信息。改法是:该上升到 Transport 的通用传输信息上升;该下沉到 Adapter 的推理引擎协议信息下沉。 Module 分工UpdateWeighter
示例: class UpdateWeighter:
def update_weights(self) -> None:
exporter = WeightExporter(
config=self.config,
engine=self.engine,
rollout_info=self.rollout_info,
)
transport = self._get_transport()
# 不要 list(exporter.iter_batches()),权重 batch 需要边导出边发送。
transport.send_all(exporter.iter_batches())
WeightExporter
示例: class WeightExporter:
def __init__(self, *, config, engine, rollout_info):
self.config = config
self.engine = engine
self.rollout_info = rollout_info
def iter_batches(self) -> Iterable[WeightUpdateBatch]:
...重点是保持 TransportTransport 持有通用传输生命周期,可以持有完整
示例: class IPCWeightTransport:
def send_all(self, batches: Iterable[WeightUpdateBatch]) -> None:
self.adapter.before_update()
try:
for batch in batches:
local_payload = self.adapter.build_local_payload(batch)
tensors = self._gather_if_needed(local_payload.data)
request = self.adapter.build_request(
batch,
tp=self.rollout_info.tp,
serialized_named_tensors=tensors,
load_format=local_payload.load_format,
)
if self._is_engine_parallel_head():
self._post_local_rollout(request)
if request.needs_engine_parallel_barrier:
self._barrier_engine_parallel()
finally:
self.adapter.after_update()
示例: class NCCLWeightTransport:
def send_all(self, batches: Iterable[WeightUpdateBatch]) -> None:
for batch in batches:
if not batch.state_dict:
continue
if not self._is_train_head_rank():
self._barrier_train_update_group()
continue
self.nccl_group.ensure_started()
request = self.adapter.build_request(batch, group=self.nccl_group)
for endpoint in self.rollout_info.active_engine_endpoints:
post_json(endpoint.url, request.endpoint, request.body, api_key=self.rollout_info.api_key)
self._barrier_train_update_group()AdapterAdapter 只表达推理引擎协议差异。 LMDeployIPCAdapterLMDeploy 的特殊点是 flattened bucket 和 IPC tensor cache。这些不应该放在通用 IPC transport 里,而应该下沉到 示例: class LMDeployIPCAdapter(IPCBackendAdapter):
def __init__(self, bucket_size_bytes: int):
self.flattened_bucket_cache = LMDeployFlattenedBucketCache(bucket_size_bytes)
def before_update(self) -> None:
self.flattened_bucket_cache.open()
def after_update(self) -> None:
self.flattened_bucket_cache.close()
def build_local_payload(self, batch: WeightUpdateBatch) -> LocalPayload:
if batch.state_dict and lmdeploy_supports_flattened_bucket():
data = self.flattened_bucket_cache.flatten(batch.state_dict)
return LocalPayload(
data=lmdeploy_serialize_state_dict(data),
load_format="flattened_bucket",
)
return LocalPayload(data=lmdeploy_serialize_state_dict(batch.state_dict))
def build_request(self, batch, *, tp, serialized_named_tensors, load_format):
body = {"serialized_named_tensors": serialized_named_tensors, "finished": batch.finished}
if load_format is not None:
body["load_format"] = load_format
return IPCRequest(
endpoint="update_weights",
body=body,
needs_engine_parallel_barrier=batch.finished or (batch.train_enable_ep and tp > 1),
)这里 SGLangIPCAdapterSGLang colocate IPC 的特殊点是 torch reductions patch、SGLang serializer、 示例: class SGLangIPCAdapter(IPCBackendAdapter):
def build_local_payload(self, batch: WeightUpdateBatch) -> LocalPayload:
with patched_sglang_torch_reductions():
if sglang_supports_flattened_bucket() and batch.state_dict:
flattened = sglang_flattened_bucket(batch.state_dict.items())
return LocalPayload(
data=sglang_serialize({
"flattened_tensor": flattened.tensor,
"metadata": flattened.metadata,
}),
load_format="flattened_bucket",
)
return LocalPayload(data=sglang_serialize(batch.state_dict.items()))
def build_request(self, batch, *, tp, serialized_named_tensors, load_format):
if tp == 1:
serialized_named_tensors = [serialized_named_tensors]
body = {"serialized_named_tensors": serialized_named_tensors, "flush_cache": False}
if load_format is not None:
body["load_format"] = load_format
return IPCRequest(endpoint="update_weights_from_tensor", body=body)SGLangNCCLAdapterSGLang disaggregated NCCL 的特殊点有两个:
这些 endpoint 不应该放在 client 或通用 group helper 里,而应该下沉到 SGLang adapter。 示例: class SGLangNCCLAdapter(NCCLBackendAdapter):
def build_group_init_requests(self, *, active_engine_endpoints, rendezvous):
rank_offset = 1
requests = []
for endpoint in active_engine_endpoints:
requests.append(NCCLGroupInitRequest(
url=endpoint.url,
endpoint="init_weights_update_group",
body={
"master_address": rendezvous.master_address,
"master_port": rendezvous.master_port,
"rank_offset": rank_offset,
"world_size": rendezvous.world_size,
"group_name": rendezvous.group_name,
"backend": rendezvous.backend,
},
))
rank_offset += endpoint.engine_size
return requests
def build_request(self, batch, *, group) -> NCCLRequest:
flattened = sglang_flattened_bucket(batch.state_dict.items())
group.broadcast_tensor(flattened.tensor)
return NCCLRequest(
endpoint="update_weights_from_distributed",
body={
"names": flattened.names,
"dtypes": flattened.dtypes,
"shapes": flattened.shapes,
"group_name": group.group_name,
"load_format": "flattened_bucket",
},
)HTTP helper 原则不要让 client 提供这些方法: client.collective_rpc(...)
client.update_weights(...)
client.update_weights_from_tensor(...)
client.init_weights_update_group(...)
client.update_weights_from_distributed(...)这些方法名全是推理引擎协议,放在 client 会混合 vLLM、LMDeploy、SGLang 的 Interface。 更简洁的是低层 helper: def post_json(url: str, endpoint: str, payload: dict, *, api_key=None) -> dict:
headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
response = requests.post(f"{url}/{endpoint}", headers=headers, json=payload)
response.raise_for_status()
return response.json()endpoint 和 body 由 Adapter 产出: request = adapter.build_request(...)
post_json(url, request.endpoint, request.body, api_key=rollout_info.api_key)这样 helper 只负责 HTTP,Adapter 负责推理引擎协议。 上升 / 下沉规则
最终的 Interface 更小: class IPCBackendAdapter:
def before_update(self) -> None: ...
def after_update(self) -> None: ...
def build_local_payload(self, batch: WeightUpdateBatch) -> LocalPayload: ...
def build_request(..., tp, serialized_named_tensors, load_format) -> IPCRequest: ...
class NCCLBackendAdapter:
def build_group_init_requests(..., active_engine_endpoints, rendezvous) -> list[NCCLGroupInitRequest]: ...
def check_group_init_result(self, result: dict) -> None: ...
def build_request(self, batch, *, group) -> NCCLRequest: ...这能让 Adapter 只关注不同推理引擎的差异点,同时让 Transport 保持通用传输语义。 |
408ac90 to
5fe2d53
Compare
| ) -> list[Any]: | ||
| raise NotImplementedError | ||
|
|
||
| def after_update_per_batch( |
There was a problem hiding this comment.
这个既然作为一个 base 方法,输入参考有点太 hard code。我感觉应该和前面接口一样,传入 batch: WeightUpdateBatch,
cpu_group: dist.ProcessGroup 和 **kwargs 这样接口才对齐。
There was a problem hiding this comment.
这里并没有改,这个函数作用是处理不同backend更新完batch后的同步,主要是根据finished和train_enable_ep进行判断,与batch无关。
|
|
||
| state_dict = batch.state_dict | ||
| # Pytorch backend will send empty state_dict when finished. | ||
| if state_dict is not None: |
There was a problem hiding this comment.
WeightIterator 会用空 batch 表示完成标记 yield WeightUpdateBatch({}, finished=True)
但 NCCLWeightTransport.send() 直接跳过空 state_dict:
if not state_dict:
return
同时 LMDeployNCCLBackendAdapter 里又写了 finalize-only 分支,但判断是 state_dict is not None,而 WeightUpdateBatch.state_dict 类型永远是 dict,所以空 dict 也会走正常分支,else 基本不可达。这个对 pytorch disaggregated
路径尤其危险。这个似乎是一个 bug,但是影响有多大不清楚,需要测试
There was a problem hiding this comment.
已修改,这里通过state_dict来判断不太合理,已改为通过finished标志来进行判断
|
@claude review |
| pass | ||
|
|
||
| def build_request( | ||
| self, |
There was a problem hiding this comment.
Claude: build_weight_update_payload returns None implicitly (via pass). In NCCLWeightTransport.send, the result is unpacked as:
payload, flattened_tensor, weight_names = self._adapter.build_weight_update_payload(batch, self.group_name)If a subclass forgets to override, or is called on the base, this crashes with TypeError: cannot unpack non-iterable NoneType. Should be raise NotImplementedError to match build_request below it.
| self, | |
| def build_weight_update_payload(self, batch: WeightUpdateBatch, group_name: str): | |
| raise NotImplementedError |
| monkey_unpatch_torch_reductions() | ||
|
|
||
|
|
||
| class NCCLBackendAdapter: |
There was a problem hiding this comment.
Claude: monkey_unpatch_torch_reductions() is called unconditionally in the finally block for all IPC backends (vLLM, LMDeploy, SGLang), but monkey_patch_torch_reductions() is only called inside SGLangIPCBackendAdapter.serialize. For non-SGLang backends, the unpatch runs on every batch without a prior patch.
If monkey_unpatch_torch_reductions is not idempotent (e.g. it errors or corrupts state when no prior patch was applied), this is a latent bug. Consider either:
- Moving the unpatch into the SGLang adapter (making it responsible for its own cleanup), or
- Guarding with a flag that tracks whether the patch was applied.
| return | ||
|
|
||
| train_sync_group = self.get_train_update_sync_group() |
There was a problem hiding this comment.
Claude: NCCLWeightTransport.send returns early when not state_dict, which means LMDeployNCCLBackendAdapter.build_weight_update_payload's finalize-only path (the if state_dict is not None: ... else: ... branch that returns (payload, None, None)) is unreachable.
Since WeightIterator.iter_hf_batches yields WeightUpdateBatch({}, finished=True) as the end marker for pytorch/vllm backends, state_dict is {} (empty dict, which is falsy). The early return here skips the finalize signal entirely.
Either:
- This early return should check
not state_dict and not batch.finished(allow finalize-only batches through), or - The
LMDeployNCCLBackendAdapterfinalize path is dead code that should be removed.
| raise ValueError( | ||
| f"Unsupported train_rollout_mode: {train_rollout_mode!r}. Expected 'colocate' or 'disaggregated'." |
There was a problem hiding this comment.
Claude: Dead commented-out code. Should be removed.
| raise ValueError( | |
| f"Unsupported train_rollout_mode: {train_rollout_mode!r}. Expected 'colocate' or 'disaggregated'." |
|
|
||
| assert train_rollout_mode is not None, "update_rollout_info() must set train_rollout_mode." | ||
| tp = rollout_config.tensor_parallel_size | ||
| ep = rollout_config.expert_parallel_size |
There was a problem hiding this comment.
Claude: Type mismatch: worker_server_urls_status is typed dict[str, bool], but the default value passed to .get() is the string "False" (truthy), not the boolean False. The is False check will always be False when the URL is absent from the dict, causing missing URLs to be silently treated as active.
| ep = rollout_config.expert_parallel_size | |
| if worker_server_urls_status.get(rollout_server_url, False) is False: |
|
|
||
| # Deduplicate rollout engine URLs while keeping the first rank associated | ||
| # with each URL as the representative rank for that engine. | ||
| url_to_rank: dict[str, int] = {} | ||
| for rank, url in sorted( | ||
| self.rollout_info.rollout_server_url_dict.items(), | ||
| key=lambda item: int(item[0]), | ||
| ): | ||
| if url: | ||
| url_to_rank.setdefault(url, int(rank)) | ||
|
|
||
| # Collect the representative rank, URL, and engine size needed to create |
There was a problem hiding this comment.
Claude: ensure_nccl_weight_update_group builds url_to_rank by filtering only on if url (non-empty), without checking worker_server_urls_status. However, UpdateWeighter._build_transport_signature filters active URLs by both url and worker_server_urls_status.get(url, False).
This mismatch means: when a rollout worker goes unhealthy (status=False), the signature changes (transport is reset), but the new transport still connects to the unhealthy URL — likely hanging during NCCL group init.
Consider filtering by worker_server_urls_status here as well:
url_to_rank: dict[str, int] = {}
for rank, url in sorted(...):
if url and self.rollout_info.worker_server_urls_status.get(url, False):
url_to_rank.setdefault(url, int(rank))| self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip") | ||
|
|
||
| @staticmethod | ||
| def post_json(url: str, endpoint: str, payload: dict, *, api_key=None) -> dict: |
There was a problem hiding this comment.
Claude: Typo: "url in None" should be "url is None". Same typo appears in IPCWeightTransport.send as well.
| def post_json(url: str, endpoint: str, payload: dict, *, api_key=None) -> dict: | |
| self.logger.error(f"rank {self.rank} url is None, cannot update weights and skip") |
| if should_gather_train_ep_shards and model_ep_size > 1: | ||
| assert model.ep_mesh is not None | ||
| ep_group = model.ep_mesh.get_group() | ||
| gathered_tensors = [torch.empty_like(fused_full_tensor) for _ in range(model_ep_size)] | ||
| dist.all_gather( | ||
| gathered_tensors, | ||
| fused_full_tensor.contiguous(), | ||
| group=ep_group, | ||
| ) | ||
| fused_full_tensor = torch.cat(gathered_tensors, dim=dim) |
There was a problem hiding this comment.
Claude: In the disaggregated + train EP > 1 path, this allocates model_ep_size full copies of the tensor for the gather, then concatenates them into a single tensor. Peak memory per rank = (2 × model_ep_size + 1) × local_shard_numel × dtype_bytes.
For large MoE models with ep_size=8, this could be very memory-intensive. Consider using dist.all_gather_into_tensor with a pre-allocated output buffer of the final size to avoid the intermediate list allocation:
output = torch.empty(
*fused_full_tensor.shape[:dim],
fused_full_tensor.shape[dim] * model_ep_size,
*fused_full_tensor.shape[dim+1:],
dtype=fused_full_tensor.dtype, device=fused_full_tensor.device,
)
dist.all_gather_into_tensor(output, fused_full_tensor.contiguous(), group=ep_group)
fused_full_tensor = outputThere was a problem hiding this comment.
很有帮助,原来的代码会占用更多内存,已按照意见修改
|
Claude: ## Review Summary This PR refactors the monolithic ProduceBatchResult impact: not affected IssuesCritical
Warning
Nit
VerdictREQUEST_CHANGES — the finalize-only path bug (critical #2) can silently break LMDeploy disaggregated weight updates. The other critical and warning items should also be addressed before merge. |
5fe2d53 to
09389d5
Compare
重构了权重更新流程,重构新增了下面几个文件
data.py:定义共享数据结构、类型别名和` update batch。client.py:封装 rollout engine 的 HTTP update 接口。exporter.py:从训练模型中导出 HuggingFace 风格权重 batch。transport.py:实现 IPC/NCCL 传输,以及不同 rollout backend 的适配器。update_weighter.py:对外提供 update weight 编排逻辑,是上游主要调用入口。__init__.py:导出公共接口。后续新增RDMA的更新方式需要新增 RDMAWeightTransport,新增对LMdeploy支持需要增加LMdeployNCCLBackendAdapter