Skip to content

refactor rollout weight update flow#1828

Open
PengchengShi00 wants to merge 11 commits into
InternLM:mainfrom
PengchengShi00:refactor-update-weight
Open

refactor rollout weight update flow#1828
PengchengShi00 wants to merge 11 commits into
InternLM:mainfrom
PengchengShi00:refactor-update-weight

Conversation

@PengchengShi00

Copy link
Copy Markdown
Contributor

重构了权重更新流程,重构新增了下面几个文件

  • data.py:定义共享数据结构、类型别名和` update batch。
  • client.py:封装 rollout engine 的 HTTP update 接口。
  • exporter.py:从训练模型中导出 HuggingFace 风格权重 batch。
  • transport.py:实现 IPC/NCCL 传输,以及不同 rollout backend 的适配器。
  • update_weighter.py:对外提供 update weight 编排逻辑,是上游主要调用入口。
  • __init__.py:导出公共接口。

后续新增RDMA的更新方式需要新增 RDMAWeightTransport,新增对LMdeploy支持需要增加LMdeployNCCLBackendAdapter

@jayhenry

jayhenry commented May 27, 2026

Copy link
Copy Markdown
Collaborator

Transport / Adapter 重构草图

当前问题

当前 PR 里 Adapter 的名字是对的,但 Interface 还偏浅:

  • IPCBackendAdapter 接收完整 IPCWeightTransport,能访问 IPC event、buffer cache、rollout info、client、process group。
  • SGLangNCCLBackendAdapter 依赖完整 NCCLWeightTransport,能访问 external group、executor、group name、engine urls、client。
  • RolloutWeightUpdateClient 混合了不同推理引擎接口:collective_rpcupdate_weightsupdate_weights_from_tensorinit_weights_update_groupupdate_weights_from_distributed

这些都让 Adapter 知道太多 Transport 信息。改法是:该上升到 Transport 的通用传输信息上升;该下沉到 Adapter 的推理引擎协议信息下沉。

Module 分工

UpdateWeighter

UpdateWeighter 只负责编排:

  1. 用现有 rollout_info 创建 WeightExporter
  2. 用同一个 rollout_info 创建 Transport
  3. 把 exporter 产出的 batch iterable 交给 transport

示例:

class UpdateWeighter:
    def update_weights(self) -> None:
        exporter = WeightExporter(
            config=self.config,
            engine=self.engine,
            rollout_info=self.rollout_info,
        )
        transport = self._get_transport()

        # 不要 list(exporter.iter_batches()),权重 batch 需要边导出边发送。
        transport.send_all(exporter.iter_batches())

UpdateWeighter 不知道:

  • LMDeploy IPC event 怎么复用
  • SGLang NCCL group 怎么初始化
  • HTTP endpoint 名字是什么
  • payload 怎么组织

WeightExporter

WeightExporter 仍然可以接收当前 rollout_info,减少重构面。

示例:

class WeightExporter:
    def __init__(self, *, config, engine, rollout_info):
        self.config = config
        self.engine = engine
        self.rollout_info = rollout_info

    def iter_batches(self) -> Iterable[WeightUpdateBatch]:
        ...

重点是保持 iter_batches() 流式产出。这样 transport 可以边拿 batch 边发送,避免权重 batch 提前全部驻留显存/内存。

Transport

Transport 持有通用传输生命周期,可以持有完整 rollout_info。但 Transport 调 Adapter 时,只传 Adapter 需要的字段,不把完整 rollout_info 传下去。

IPCWeightTransport 负责:

  • 调 adapter 的 update 生命周期 hook
  • 每个 batch 调 adapter 生成本地 payload
  • TP gather
  • head rank post
  • engine parallel barrier

示例:

class IPCWeightTransport:
    def send_all(self, batches: Iterable[WeightUpdateBatch]) -> None:
        self.adapter.before_update()
        try:
            for batch in batches:
                local_payload = self.adapter.build_local_payload(batch)
                tensors = self._gather_if_needed(local_payload.data)
                request = self.adapter.build_request(
                    batch,
                    tp=self.rollout_info.tp,
                    serialized_named_tensors=tensors,
                    load_format=local_payload.load_format,
                )
                if self._is_engine_parallel_head():
                    self._post_local_rollout(request)
                if request.needs_engine_parallel_barrier:
                    self._barrier_engine_parallel()
        finally:
            self.adapter.after_update()

NCCLWeightTransport 负责:

  • train head rank 判断
  • external NCCL group 生命周期
  • 调 adapter 完成 backend-specific broadcast/request
  • 向 rollout engine post
  • train update barrier

示例:

class NCCLWeightTransport:
    def send_all(self, batches: Iterable[WeightUpdateBatch]) -> None:
        for batch in batches:
            if not batch.state_dict:
                continue
            if not self._is_train_head_rank():
                self._barrier_train_update_group()
                continue

            self.nccl_group.ensure_started()
            request = self.adapter.build_request(batch, group=self.nccl_group)
            for endpoint in self.rollout_info.active_engine_endpoints:
                post_json(endpoint.url, request.endpoint, request.body, api_key=self.rollout_info.api_key)
            self._barrier_train_update_group()

Adapter

Adapter 只表达推理引擎协议差异。

LMDeployIPCAdapter

LMDeploy 的特殊点是 flattened bucket 和 IPC tensor cache。这些不应该放在通用 IPC transport 里,而应该下沉到 LMDeployIPCAdapter

示例:

class LMDeployIPCAdapter(IPCBackendAdapter):
    def __init__(self, bucket_size_bytes: int):
        self.flattened_bucket_cache = LMDeployFlattenedBucketCache(bucket_size_bytes)

    def before_update(self) -> None:
        self.flattened_bucket_cache.open()

    def after_update(self) -> None:
        self.flattened_bucket_cache.close()

    def build_local_payload(self, batch: WeightUpdateBatch) -> LocalPayload:
        if batch.state_dict and lmdeploy_supports_flattened_bucket():
            data = self.flattened_bucket_cache.flatten(batch.state_dict)
            return LocalPayload(
                data=lmdeploy_serialize_state_dict(data),
                load_format="flattened_bucket",
            )
        return LocalPayload(data=lmdeploy_serialize_state_dict(batch.state_dict))

    def build_request(self, batch, *, tp, serialized_named_tensors, load_format):
        body = {"serialized_named_tensors": serialized_named_tensors, "finished": batch.finished}
        if load_format is not None:
            body["load_format"] = load_format
        return IPCRequest(
            endpoint="update_weights",
            body=body,
            needs_engine_parallel_barrier=batch.finished or (batch.train_enable_ep and tp > 1),
        )

这里 LMDeployFlattenedBucketCache 承接旧实现里的 _update_params_ipc_event、per-dtype buffer cache、event wait/record、ipc handle 发送策略。通用 IPC transport 不知道这些细节。

SGLangIPCAdapter

SGLang colocate IPC 的特殊点是 torch reductions patch、SGLang serializer、update_weights_from_tensor payload。

示例:

class SGLangIPCAdapter(IPCBackendAdapter):
    def build_local_payload(self, batch: WeightUpdateBatch) -> LocalPayload:
        with patched_sglang_torch_reductions():
            if sglang_supports_flattened_bucket() and batch.state_dict:
                flattened = sglang_flattened_bucket(batch.state_dict.items())
                return LocalPayload(
                    data=sglang_serialize({
                        "flattened_tensor": flattened.tensor,
                        "metadata": flattened.metadata,
                    }),
                    load_format="flattened_bucket",
                )
            return LocalPayload(data=sglang_serialize(batch.state_dict.items()))

    def build_request(self, batch, *, tp, serialized_named_tensors, load_format):
        if tp == 1:
            serialized_named_tensors = [serialized_named_tensors]
        body = {"serialized_named_tensors": serialized_named_tensors, "flush_cache": False}
        if load_format is not None:
            body["load_format"] = load_format
        return IPCRequest(endpoint="update_weights_from_tensor", body=body)

SGLangNCCLAdapter

SGLang disaggregated NCCL 的特殊点有两个:

  • rollout side group init endpoint 是 init_weights_update_group
  • weight update endpoint 是 update_weights_from_distributed

这些 endpoint 不应该放在 client 或通用 group helper 里,而应该下沉到 SGLang adapter。

示例:

class SGLangNCCLAdapter(NCCLBackendAdapter):
    def build_group_init_requests(self, *, active_engine_endpoints, rendezvous):
        rank_offset = 1
        requests = []
        for endpoint in active_engine_endpoints:
            requests.append(NCCLGroupInitRequest(
                url=endpoint.url,
                endpoint="init_weights_update_group",
                body={
                    "master_address": rendezvous.master_address,
                    "master_port": rendezvous.master_port,
                    "rank_offset": rank_offset,
                    "world_size": rendezvous.world_size,
                    "group_name": rendezvous.group_name,
                    "backend": rendezvous.backend,
                },
            ))
            rank_offset += endpoint.engine_size
        return requests

    def build_request(self, batch, *, group) -> NCCLRequest:
        flattened = sglang_flattened_bucket(batch.state_dict.items())
        group.broadcast_tensor(flattened.tensor)
        return NCCLRequest(
            endpoint="update_weights_from_distributed",
            body={
                "names": flattened.names,
                "dtypes": flattened.dtypes,
                "shapes": flattened.shapes,
                "group_name": group.group_name,
                "load_format": "flattened_bucket",
            },
        )

HTTP helper 原则

不要让 client 提供这些方法:

client.collective_rpc(...)
client.update_weights(...)
client.update_weights_from_tensor(...)
client.init_weights_update_group(...)
client.update_weights_from_distributed(...)

这些方法名全是推理引擎协议,放在 client 会混合 vLLM、LMDeploy、SGLang 的 Interface。

更简洁的是低层 helper:

def post_json(url: str, endpoint: str, payload: dict, *, api_key=None) -> dict:
    headers = {"Content-Type": "application/json"}
    if api_key is not None:
        headers["Authorization"] = f"Bearer {api_key}"
    response = requests.post(f"{url}/{endpoint}", headers=headers, json=payload)
    response.raise_for_status()
    return response.json()

endpoint 和 body 由 Adapter 产出:

request = adapter.build_request(...)
post_json(url, request.endpoint, request.body, api_key=rollout_info.api_key)

这样 helper 只负责 HTTP,Adapter 负责推理引擎协议。

上升 / 下沉规则

  • 上升到 Transport:TP gather、head rank post、barrier、external NCCL group 生命周期、HTTP post 执行。
  • 留在 rollout_info:backend、transport type、tp/ep、api key、rollout URL、engine endpoints、device mesh。
  • 下沉到 Adapter:推理引擎 endpoint、payload schema、序列化格式、flattened bucket、SGLang group init 请求。
  • 下沉到 backend-specific helper:LMDeploy IPC event/cache、SGLang torch patch、SGLang FlattenedTensorBucket。

最终的 Interface 更小:

class IPCBackendAdapter:
    def before_update(self) -> None: ...
    def after_update(self) -> None: ...
    def build_local_payload(self, batch: WeightUpdateBatch) -> LocalPayload: ...
    def build_request(..., tp, serialized_named_tensors, load_format) -> IPCRequest: ...


class NCCLBackendAdapter:
    def build_group_init_requests(..., active_engine_endpoints, rendezvous) -> list[NCCLGroupInitRequest]: ...
    def check_group_init_result(self, result: dict) -> None: ...
    def build_request(self, batch, *, group) -> NCCLRequest: ...

这能让 Adapter 只关注不同推理引擎的差异点,同时让 Transport 保持通用传输语义。

Comment thread xtuner/v1/rl/weight_update/transport.py Outdated
Comment thread xtuner/v1/rl/weight_update/transport.py Outdated
Comment thread xtuner/v1/rl/weight_update/transport.py Outdated
Comment thread xtuner/v1/rl/weight_update/exporter.py Outdated
Comment thread xtuner/v1/rl/weight_update/transport.py Outdated
Comment thread xtuner/v1/rl/weight_update/update_weighter.py
Comment thread xtuner/v1/rl/weight_update/update_weighter.py
@PengchengShi00 PengchengShi00 force-pushed the refactor-update-weight branch 3 times, most recently from 408ac90 to 5fe2d53 Compare June 24, 2026 02:18
Comment thread xtuner/v1/train/rl_trainer.py Outdated
Comment thread xtuner/v1/rl/trainer/worker.py Outdated
Comment thread xtuner/v1/rl/weight_update/update_weighter.py Outdated
Comment thread xtuner/v1/rl/weight_update/data.py
Comment thread xtuner/v1/rl/weight_update/update_weighter.py Outdated
Comment thread xtuner/v1/rl/weight_update/update_weighter.py Outdated
) -> list[Any]:
raise NotImplementedError

def after_update_per_batch(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个既然作为一个 base 方法,输入参考有点太 hard code。我感觉应该和前面接口一样,传入 batch: WeightUpdateBatch,
cpu_group: dist.ProcessGroup 和 **kwargs 这样接口才对齐。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里并没有改,这个函数作用是处理不同backend更新完batch后的同步,主要是根据finishedtrain_enable_ep进行判断,与batch无关。

Comment thread xtuner/v1/rl/weight_update/update_weighter.py
Comment thread xtuner/v1/rl/weight_update/update_weighter.py Outdated
Comment thread xtuner/v1/rl/weight_update/transport.py Outdated

state_dict = batch.state_dict
# Pytorch backend will send empty state_dict when finished.
if state_dict is not None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WeightIterator 会用空 batch 表示完成标记 yield WeightUpdateBatch({}, finished=True)

但 NCCLWeightTransport.send() 直接跳过空 state_dict:

if not state_dict:
return

同时 LMDeployNCCLBackendAdapter 里又写了 finalize-only 分支,但判断是 state_dict is not None,而 WeightUpdateBatch.state_dict 类型永远是 dict,所以空 dict 也会走正常分支,else 基本不可达。这个对 pytorch disaggregated
路径尤其危险。这个似乎是一个 bug,但是影响有多大不清楚,需要测试

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改,这里通过state_dict来判断不太合理,已改为通过finished标志来进行判断

Comment thread xtuner/v1/rl/weight_update/data.py Outdated
@hhaAndroid

Copy link
Copy Markdown
Collaborator

@claude review

pass

def build_request(
self,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: build_weight_update_payload returns None implicitly (via pass). In NCCLWeightTransport.send, the result is unpacked as:

payload, flattened_tensor, weight_names = self._adapter.build_weight_update_payload(batch, self.group_name)

If a subclass forgets to override, or is called on the base, this crashes with TypeError: cannot unpack non-iterable NoneType. Should be raise NotImplementedError to match build_request below it.

Suggested change
self,
def build_weight_update_payload(self, batch: WeightUpdateBatch, group_name: str):
raise NotImplementedError

Comment on lines +487 to +490
monkey_unpatch_torch_reductions()


class NCCLBackendAdapter:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: monkey_unpatch_torch_reductions() is called unconditionally in the finally block for all IPC backends (vLLM, LMDeploy, SGLang), but monkey_patch_torch_reductions() is only called inside SGLangIPCBackendAdapter.serialize. For non-SGLang backends, the unpatch runs on every batch without a prior patch.

If monkey_unpatch_torch_reductions is not idempotent (e.g. it errors or corrupts state when no prior patch was applied), this is a latent bug. Consider either:

  • Moving the unpatch into the SGLang adapter (making it responsible for its own cleanup), or
  • Guarding with a flag that tracks whether the patch was applied.

Comment on lines +739 to +741
return

train_sync_group = self.get_train_update_sync_group()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: NCCLWeightTransport.send returns early when not state_dict, which means LMDeployNCCLBackendAdapter.build_weight_update_payload's finalize-only path (the if state_dict is not None: ... else: ... branch that returns (payload, None, None)) is unreachable.

Since WeightIterator.iter_hf_batches yields WeightUpdateBatch({}, finished=True) as the end marker for pytorch/vllm backends, state_dict is {} (empty dict, which is falsy). The early return here skips the finalize signal entirely.

Either:

  • This early return should check not state_dict and not batch.finished (allow finalize-only batches through), or
  • The LMDeployNCCLBackendAdapter finalize path is dead code that should be removed.

Comment on lines +154 to +155
raise ValueError(
f"Unsupported train_rollout_mode: {train_rollout_mode!r}. Expected 'colocate' or 'disaggregated'."

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Dead commented-out code. Should be removed.

Suggested change
raise ValueError(
f"Unsupported train_rollout_mode: {train_rollout_mode!r}. Expected 'colocate' or 'disaggregated'."


assert train_rollout_mode is not None, "update_rollout_info() must set train_rollout_mode."
tp = rollout_config.tensor_parallel_size
ep = rollout_config.expert_parallel_size

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Type mismatch: worker_server_urls_status is typed dict[str, bool], but the default value passed to .get() is the string "False" (truthy), not the boolean False. The is False check will always be False when the URL is absent from the dict, causing missing URLs to be silently treated as active.

Suggested change
ep = rollout_config.expert_parallel_size
if worker_server_urls_status.get(rollout_server_url, False) is False:

Comment on lines +657 to +668

# Deduplicate rollout engine URLs while keeping the first rank associated
# with each URL as the representative rank for that engine.
url_to_rank: dict[str, int] = {}
for rank, url in sorted(
self.rollout_info.rollout_server_url_dict.items(),
key=lambda item: int(item[0]),
):
if url:
url_to_rank.setdefault(url, int(rank))

# Collect the representative rank, URL, and engine size needed to create

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: ensure_nccl_weight_update_group builds url_to_rank by filtering only on if url (non-empty), without checking worker_server_urls_status. However, UpdateWeighter._build_transport_signature filters active URLs by both url and worker_server_urls_status.get(url, False).

This mismatch means: when a rollout worker goes unhealthy (status=False), the signature changes (transport is reset), but the new transport still connects to the unhealthy URL — likely hanging during NCCL group init.

Consider filtering by worker_server_urls_status here as well:

url_to_rank: dict[str, int] = {}
for rank, url in sorted(...):
    if url and self.rollout_info.worker_server_urls_status.get(url, False):
        url_to_rank.setdefault(url, int(rank))

self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip")

@staticmethod
def post_json(url: str, endpoint: str, payload: dict, *, api_key=None) -> dict:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Typo: "url in None" should be "url is None". Same typo appears in IPCWeightTransport.send as well.

Suggested change
def post_json(url: str, endpoint: str, payload: dict, *, api_key=None) -> dict:
self.logger.error(f"rank {self.rank} url is None, cannot update weights and skip")

Comment on lines +91 to +100
if should_gather_train_ep_shards and model_ep_size > 1:
assert model.ep_mesh is not None
ep_group = model.ep_mesh.get_group()
gathered_tensors = [torch.empty_like(fused_full_tensor) for _ in range(model_ep_size)]
dist.all_gather(
gathered_tensors,
fused_full_tensor.contiguous(),
group=ep_group,
)
fused_full_tensor = torch.cat(gathered_tensors, dim=dim)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: In the disaggregated + train EP > 1 path, this allocates model_ep_size full copies of the tensor for the gather, then concatenates them into a single tensor. Peak memory per rank = (2 × model_ep_size + 1) × local_shard_numel × dtype_bytes.

For large MoE models with ep_size=8, this could be very memory-intensive. Consider using dist.all_gather_into_tensor with a pre-allocated output buffer of the final size to avoid the intermediate list allocation:

output = torch.empty(
    *fused_full_tensor.shape[:dim],
    fused_full_tensor.shape[dim] * model_ep_size,
    *fused_full_tensor.shape[dim+1:],
    dtype=fused_full_tensor.dtype, device=fused_full_tensor.device,
)
dist.all_gather_into_tensor(output, fused_full_tensor.contiguous(), group=ep_group)
fused_full_tensor = output

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

很有帮助,原来的代码会占用更多内存,已按照意见修改

@claude

claude Bot commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Claude: ## Review Summary

This PR refactors the monolithic UpdateWeighter mixin into a well-structured package (xtuner/v1/rl/weight_update/) with clear separation between orchestration (UpdateWeighter), weight export (WeightIterator), transport (IPCWeightTransport, NCCLWeightTransport), and backend adapters. The layered design aligns well with the architecture sketch in the PR discussion and will make adding new backends (RDMA, LMDeploy NCCL) straightforward.

ProduceBatchResult impact: not affected
RoutedExperts impact: not affected

Issues

Critical

  • transport.py:498NCCLBackendAdapter.build_weight_update_payload is pass (returns None), should be raise NotImplementedError
  • transport.py:739-741NCCLWeightTransport.send skips when not state_dict, making the LMDeploy finalize-only path unreachable (empty dict {} is falsy)

Warning

  • transport.py:487-490monkey_unpatch_torch_reductions() called unconditionally in finally for all IPC backends, but only SGLang patches
  • transport.py:657-668ensure_nccl_weight_update_group doesn't filter by worker_server_urls_status, mismatching _build_transport_signature
  • update_weighter.py:75.get(url, "False") uses string default instead of bool False

Nit

  • transport.py:67 — Typo: "url in None" → "url is None"
  • update_weighter.py:154-155 — Dead commented-out code
  • weight_iterator.py:91-100all_gather + torch.cat peak memory concern for large MoE
  • transport.py:110-113 — Dead commented-out code in IPCBackendAdapter.__init__
  • data.py:31rollout_cfg_info: dict field appears unused in the refactored code

Verdict

REQUEST_CHANGES — the finalize-only path bug (critical #2) can silently break LMDeploy disaggregated weight updates. The other critical and warning items should also be addressed before merge.

@PengchengShi00 PengchengShi00 force-pushed the refactor-update-weight branch from 5fe2d53 to 09389d5 Compare June 24, 2026 08:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants