Skip to content
14 changes: 10 additions & 4 deletions tests/rl/test_rl_trainer_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,17 @@ def __init__(self):
self.update_weights_count = 0
self.rollout_info = None

def set_train_rollout_mode(self, mode: str):
self.train_rollout_mode = mode

def update_rollout_info(self, info):
def update_rollout_info(
self,
info,
train_rollout_mode,
weight_update_host,
weight_update_port
):
self.rollout_info = info
self.train_rollout_mode = train_rollout_mode
self.weight_update_host = weight_update_host
self.weight_update_port = weight_update_port

def onload(self, target="all"):
return f"onload:{target}"
Expand Down
384 changes: 109 additions & 275 deletions tests/rl/test_update_weight_disaggregated.py

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions xtuner/v1/rl/rollout/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ class RolloutConfig(BaseModel):
gpu_memory_utilization (float): GPU memory utilization ratio. Defaults to 0.85.
random_seed (int): Random seed for reproducible generation. Defaults to 1024.
rollout_cross_node_comm (bool): Enable cross-node communication. Defaults to False.
weight_update_host (Optional[str]): Host used by train rank 0 to initialize the external NCCL weight update
group. Defaults to None.
weight_update_port (Optional[int]): Port used by train rank 0 to initialize the external NCCL weight update
group. Defaults to 30000.
rollout_max_batch_size_per_instance (int): Maximum batch size for the rollout worker. If not set, it
will be determined automatically based on `context_length`. Defaults to 512.
allow_over_concurrency_ratio (float): Deprecated compatibility option. Rollout runtime concurrency is
Expand Down Expand Up @@ -223,6 +227,26 @@ class RolloutConfig(BaseModel):
help="Base port number for distributed communication among rollout workers.",
),
] = 25000
weight_update_host: Annotated[
Optional[str],
Parameter(
group=infer_group,
help=(
"Host used by train rank 0 to initialize the external NCCL weight update group. "
"Only used for NCCL weight update."
),
),
] = None
weight_update_port: Annotated[
Optional[int],
Parameter(
group=infer_group,
help=(
"Port used by train rank 0 to initialize the external NCCL weight update group. "
"Only used for NCCL weight update."
),
),
] = 30000
rollout_max_batch_size_per_instance: Annotated[
Optional[int],
Parameter(
Expand Down
17 changes: 12 additions & 5 deletions xtuner/v1/rl/trainer/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,18 @@ def onload(self, target: Literal["model", "optimizer", "all"] = "all"):
ray.get([worker.onload_optimizer.remote() for worker in self.workers], timeout=TRAIN_RAY_GET_TIMEOUT) # type: ignore
return

def update_rollout_info(self, info_dict):
ray.get([worker.update_rollout_info.remote(**info_dict) for worker in self.workers]) # type: ignore[attr-defined]

def set_train_rollout_mode(self, train_rollout_mode: str):
ray.get([worker.set_train_rollout_mode.remote(train_rollout_mode) for worker in self.workers])
def update_rollout_info(self, info_dict, train_rollout_mode, weight_update_host=None, weight_update_port=None):
ray.get(
[
worker.update_rollout_info.remote(
**info_dict,
train_rollout_mode=train_rollout_mode,
weight_update_host=weight_update_host,
weight_update_port=weight_update_port,
)
for worker in self.workers
]
)

def update_weights(self):
"""Update the weights of the training workers."""
Expand Down
Loading
Loading