diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 74a56d4643..bb0d3b5441 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -569,6 +569,9 @@ def _get_total_loss(self, model_outputs: ModelOutputs) -> torch.Tensor: loss = torch.tensor(0.0, device=DEVICE) for key in model_outputs.model_fields: value = getattr(model_outputs, key) - if "loss" in key and isinstance(value, torch.Tensor): - loss += value + if "loss" in key: + loss_values = list(value.values()) if isinstance(value, dict) else [value] + loss_values = [i for i in loss_values if isinstance(i, torch.Tensor)] + for value in loss_values: + loss += value return loss diff --git a/xtuner/v1/loss/__init__.py b/xtuner/v1/loss/__init__.py index d2f20b3a16..099d735640 100644 --- a/xtuner/v1/loss/__init__.py +++ b/xtuner/v1/loss/__init__.py @@ -10,7 +10,7 @@ ZLossContext, ZLossKwargs, ) -from .mtp_loss import MTPLossContext +from .mtp_loss import MTPLossContext, SciMTPLossContext, MTPLossConfig, SciMTPLossConfig from .rl_loss import LogProbConfig, LogProbContext @@ -31,6 +31,9 @@ "BaseLossKwargs", "LMHeadLossContext", "MTPLossContext", + "MTPLossConfig", + "SciMTPLossContext", + "SciMTPLossConfig", "LogProbConfig", "LogProbContext", ] diff --git a/xtuner/v1/loss/mtp_loss.py b/xtuner/v1/loss/mtp_loss.py index a5aebc3b19..6c6e2bf2d8 100644 --- a/xtuner/v1/loss/mtp_loss.py +++ b/xtuner/v1/loss/mtp_loss.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any +from typing import Any, Optional import torch import torch.nn.functional as F @@ -53,12 +53,11 @@ class MTPLossConfig(CELossConfig): Args: mtp_depth (int): 1-indexed MTP layer depth. The first MTP layer uses - ``mtp_depth=1`` (shift=-1 on top of the existing label shift). + ``mtp_depth=1`` (shift=-1 on top of the existing label shift). Default: 1. detach_mtp_lm_head_weight (bool): Whether to detach the LM head weight. This is used in RL training. Default is False. """ - mtp_depth: int detach_mtp_lm_head_weight: bool = False @property @@ -88,6 +87,7 @@ def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContex MTPLossContext | None: Built loss context, or ``None`` if ``shifted_labels`` is not present in ``data``. """ + # TODO: Should move the common utils function to public package to avoid from circular import. from xtuner.v1.module.mtp.utils import roll_packed_tensor @@ -96,6 +96,7 @@ def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContex shifted_labels = data["shifted_labels"] cu_seq_lens = data["seq_ctx"].cu_seq_lens_k + mtp_depth = data["mtp_depth"] # cu_seq_lens[-1] may be larger than shifted_labels.shape[-1] when seq_ctx # was split for sequence parallelism (padding is added to make the sequence @@ -112,7 +113,7 @@ def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContex ) shifted_labels = torch.cat([shifted_labels, pad], dim=-1) - rolled = roll_packed_tensor(shifted_labels, cu_seq_lens, shifts=-self.mtp_depth, dim=-1, fill_value=-100) + rolled = roll_packed_tensor(shifted_labels, cu_seq_lens, shifts=-mtp_depth, dim=-1, fill_value=-100) # Roll logprobs by the same amount as shifted_labels logprobs = data.get("logprobs", None) @@ -126,7 +127,7 @@ def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContex device=logprobs.device, ) logprobs = torch.cat([logprobs, rp_pad], dim=-1) - rolled_logprobs = roll_packed_tensor(logprobs, cu_seq_lens, shifts=-self.mtp_depth, dim=-1, fill_value=0) + rolled_logprobs = roll_packed_tensor(logprobs, cu_seq_lens, shifts=-mtp_depth, dim=-1, fill_value=0) loss_kwargs = MTPLossKwargs( shifted_labels=rolled, @@ -135,7 +136,28 @@ def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContex if sp_mesh is not None and sp_mesh.size() > 1: loss_kwargs = loss_kwargs.sp_split(sp_mesh) - return MTPLossContext(self, loss_kwargs) + loss_context = self.loss_ctx_cls(self, loss_kwargs) + loss_context.bind_mtp_depth(mtp_depth) + return loss_context + + +class SciMTPLossConfig(MTPLossConfig): + """Loss configuration for Multi-Token Prediction (MTP). + + Extends ``MTPLossConfig`` with a ``mask_type`` field that controls how to mask + ``loss_kwargs`` when calculating loss. + + Args: + detach_mtp_lm_head_weight (bool): Whether to detach the LM head weight. + This is used in RL training. Default is False. + mask_type (str | None): Mask method when calculating Science MTP. + """ + + mask_type: Optional[str] = None + + @property + def loss_ctx_cls(self) -> type["SciMTPLossContext"]: + return SciMTPLossContext class MTPLossContext(LMHeadLossContext): @@ -156,6 +178,10 @@ class MTPLossContext(LMHeadLossContext): loss_kwargs (MTPLossKwargs): Pre-rolled keyword arguments for loss computation. """ + def __init__(self, loss_cfg: MTPLossConfig, loss_kwargs: MTPLossKwargs): + super().__init__(loss_cfg, loss_kwargs) + + self.mtp_depth = None def forward( self, @@ -163,6 +189,7 @@ def forward( head_weight: torch.Tensor, head_bias: torch.Tensor | None = None, ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: + assert self.mtp_depth is not None, "Please bind mtp depth for MTPLossContext!" if self.loss_cfg.detach_mtp_lm_head_weight: head_weight = head_weight.detach() head_bias = head_bias.detach() if head_bias is not None else None @@ -214,3 +241,76 @@ def _kl_loss_fn( ) return kl_loss, (None, {}) + + def bind_mtp_depth(self, depth: int) -> None: + """Bind MTP depth to the given index. + + Args: + depth (int): 1-indexed MTP layer depth to bind. + """ + self.mtp_depth = depth + + +class SciMTPLossContext(MTPLossContext): + """Loss context for Science Multi-Token Prediction (MTP). + + Supports two modes: + - **CE mode** (default): Standard cross-entropy loss on rolled labels. + Used during SFT/pretraining. + - **KL mode**: When ``logprobs`` is available (RL training), + computes KL divergence between MTP's log-probabilities and the + rolled rollout log-probabilities. + + Both modes support chunk mode for memory-efficient computation via the + base class's ``forward() → eager_mode()/chunk_mode() → loss_fn()`` dispatch. + + Args: + loss_cfg (MTPLossConfig): The MTP loss configuration. + loss_kwargs (MTPLossKwargs): Pre-rolled keyword arguments for loss + computation. + """ + + def forward( + self, + hidden_states: torch.Tensor, + head_weight: torch.Tensor, + head_bias: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: + mask_type = self.loss_cfg.mask_type + if mask_type == "v1": + self.process_loss_weight_v1() + elif mask_type is not None: + raise NotImplementedError(f"Unknown MTP Loss Mask Type: {mask_type}") + + return super().forward(hidden_states, head_weight, head_bias) + + def process_loss_weight_v1(self): + layer_idx = self.mtp_depth - 1 + shifted_labels = self.loss_kwargs.shifted_labels + loss_weight = self.loss_kwargs.loss_weight + sum_loss_weight = loss_weight.sum() + + easy_to_use = torch.cat( + [ + shifted_labels, + torch.zeros((shifted_labels.size(0), 1), dtype=shifted_labels.dtype, device=shifted_labels.device), + ], + dim=-1, + ) + + # TODO: digit and dot token config + is_digit = torch.where(easy_to_use < 25, easy_to_use > 14, 0) + is_dot = torch.where(easy_to_use == 13, 1, 0) + is_digit_or_dot = is_digit | is_dot + + mask = is_digit_or_dot.clone() + for i in range(layer_idx + 1): + mask |= torch.roll(is_digit_or_dot, shifts=i + 1, dims=-1) + + mtp_mask = mask.bool()[:, :-1] + + loss_weight[mtp_mask == 0.0] = 0.0 + if loss_weight.sum().item() != 0: + loss_weight = loss_weight * sum_loss_weight / loss_weight.sum() + + self.loss_kwargs.loss_weight = loss_weight diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index c02c4dc9dd..32aef6f355 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -1313,7 +1313,18 @@ def post_micro_batch_forward(self, batch_outputs: Sequence[ModelOutputs]) -> Bat output_copy = output.model_copy() for name in output_copy.model_fields: obj = getattr(output_copy, name) - if "loss" in name and isinstance(obj, torch.Tensor): + if name == "mtp_loss" and isinstance(obj, dict): + for key, value in obj.items(): + loss_item = value.item() + local_total_loss += loss_item + reduced_name = f"{key}_reduced_mtp_loss" + + if reduced_name not in reduced_other_losses: + reduced_other_losses[reduced_name] = loss_item + else: + reduced_other_losses[reduced_name] += loss_item + + elif "loss" in name and isinstance(obj, torch.Tensor): loss_item = obj.item() local_total_loss += loss_item reduced_name = f"reduced_{name}" diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 061e2f6e29..50969f8d64 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from __future__ import annotations + +import copy import os import types from pathlib import Path @@ -36,7 +39,7 @@ ZLossConfig, ZLossContext, ) -from xtuner.v1.loss.mtp_loss import MTPLossConfig +from xtuner.v1.loss.mtp_loss import MTPLossConfig, SciMTPLossConfig from xtuner.v1.model.base import ( DEFAULT_FLOAT8_CFG, BaseModel, @@ -102,7 +105,7 @@ class MoEModelOutputs(ModelOutputs): balancing_loss: torch.Tensor | None = None z_loss: torch.Tensor | None = None tokens_per_expert_global: torch.Tensor - mtp_loss: torch.Tensor | None = None + mtp_loss: dict[str, torch.Tensor] | None = None def free_nongrad_feature(self): """Release large intermediate tensors not needed for backward or @@ -129,7 +132,7 @@ class MoELossContextDict(TypedDict): lm: BaseLossContext balancing: BalancingLossContext | None z_loss: ZLossContext | None - mtp: list[BaseLossContext] | None + mtp: dict[str, list[BaseLossContext]] | None class MoEConfig(TransformerConfig): @@ -151,7 +154,7 @@ class MoEConfig(TransformerConfig): router_compute_dtype: Literal["float32", "native"] = "float32" moe_bias: bool = False moe_act_fn_cfg: MoEActFnConfig = MoEActFnConfig() - mtp_config: MTPConfig | None = None + mtp_config: list[MTPConfig] | MTPConfig | None = None freeze_routers: bool = False router_async_offload: bool = False aux_loss_cfg: AuxLossConfig = AuxLossConfig() @@ -186,6 +189,11 @@ class MoE(BaseModel): def __init__(self, config: MoEConfig): super().__init__(config) + + # Normalize mtp_config to always be a list or None for consistent handling + if config.mtp_config is not None and not isinstance(config.mtp_config, list): + config.mtp_config = [config.mtp_config] + if config.ep_size is not None and config.ep_size > 1: world_size = dist.get_world_size() self.ep_mesh = init_device_mesh( @@ -202,7 +210,7 @@ def __init__(self, config: MoEConfig): self.layers = self.build_layers(config) self.rotary_emb = self.build_rotary_embedding(config) self.embed_tokens = self.build_embeddings(config) - self.mtp_block = self.build_mtp_block(config) if config.mtp_config is not None else None + self.mtp_block = self.build_mtp_block_list(config) if config.mtp_config is not None else None self.fp32_layers = [self.rotary_emb] @@ -338,23 +346,39 @@ def build_loss_ctx_batch( # type: ignore[override] # Add MTP loss contexts if MTP is enabled if self.config.mtp_config is not None: - for mtp_idx in range(self.config.mtp_config.num_layers): - mtp_loss_cfg = MTPLossConfig( - **self.config.lm_loss_cfg.model_dump(), - mtp_depth=mtp_idx + 1, - detach_mtp_lm_head_weight=self.config.mtp_config.detach_mtp_lm_head_weight, - ) - mtp_loss_ctx_list = self._build_loss_ctx(mtp_loss_cfg, _data_batch, sp_mesh) - if mtp_loss_ctx_list is not None: - mtp_loss_ctx_list = MTPLossContext.build_batches( # type: ignore[assignment] - cast(list[MTPLossContext], mtp_loss_ctx_list), # type: ignore[arg-type] - cu_seq_lens_list=cu_seq_lens_list, - sp_mesh=sp_mesh, - ) - for i, mtp_loss_ctx in enumerate(mtp_loss_ctx_list): - if "mtp" not in res[i]: - res[i]["mtp"] = [] - res[i]["mtp"].append(mtp_loss_ctx) # type: ignore[union-attr] + # Build MTP loss contexts using the same approach as LM loss + # Each MTP depth needs its own loss context + for mtp_config in self.config.mtp_config: + for mtp_idx in range(mtp_config.num_layers): + # Get loss_cfg from mtp_config, or create a default one if not provided + if mtp_config.loss_cfg is not None: + # Create a copy to avoid modifying the original config + mtp_loss_cfg = mtp_config.loss_cfg.model_copy() + else: + # Create default MTPLossConfig from model's lm_loss_cfg + mtp_loss_cfg = MTPLossConfig( + **self.config.lm_loss_cfg.model_dump(), + detach_mtp_lm_head_weight=mtp_config.detach_mtp_lm_head_weight, + ) + + # copy data_batch to insert mtp_depth + _new_data_batch = copy.copy(_data_batch) + for _data in _new_data_batch: + _data["mtp_depth"] = mtp_idx + 1 + # MTP needs to shift labels multiple times. Since rebuild the `shifted_labels` in data_batch + mtp_loss_ctx_list = self._build_loss_ctx(mtp_loss_cfg, _new_data_batch, sp_mesh) + if mtp_loss_ctx_list is not None: + mtp_loss_ctx_list = type(mtp_loss_ctx_list[0]).build_batches( # type: ignore[assignment] + mtp_loss_ctx_list, # type: ignore[arg-type] + cu_seq_lens_list=cu_seq_lens_list, + sp_mesh=sp_mesh, + ) + for i, mtp_loss_ctx in enumerate(mtp_loss_ctx_list): + if "mtp" not in res[i]: + res[i]["mtp"] = {} + if mtp_config.name not in res[i]["mtp"]: + res[i]["mtp"][mtp_config.name] = [] + res[i]["mtp"][mtp_config.name].append(mtp_loss_ctx) # type: ignore[union-attr] # Ensure all microbatches have mtp key for loss_ctx_dict in res: @@ -571,34 +595,53 @@ def _micro_batch_forward( ) ) - mtp_outputs_per_mb = self.mtp_block( - *hidden_states_list, - embed_tokens_fn=self.embed_tokens, - position_embeddings=position_embeddings_list, - seq_ctx=mtp_seq_ctx_list, - ) + # Initialize mtp_losses dict to store losses for each mtp_config + mtp_losses_dict: dict[str, torch.Tensor] = {} - mtp_losses = torch.tensor(0.0, device=DEVICE) - has_mtp_loss = False - for micro_batch_idx, (loss_ctx_dict, mtp_outputs) in enumerate(zip(loss_ctx_list, mtp_outputs_per_mb)): - mtp_loss_ctx_list = loss_ctx_dict.get("mtp") - if mtp_loss_ctx_list is None: - continue + # Loop through each mtp_config + for mtp_block in self.mtp_block: + mtp_config = mtp_block.mtp_config + name = mtp_config.name - micro_batch_mtp_losses = torch.tensor(0.0, device=DEVICE) - for mtp_idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)): - mtp_hidden_states, mtp_router_results, _ = mtp_hidden - mtp_loss, _ = self.lm_head(mtp_hidden_states, cast(MTPLossContext, mtp_ctx)) - micro_batch_mtp_losses += mtp_loss + # Get the MTP block for this config by name + mtp_outputs_per_mb = mtp_block( + *hidden_states_list, + embed_tokens_fn=self.embed_tokens, + position_embeddings=position_embeddings_list, + seq_ctx=mtp_seq_ctx_list, + ) - if keep_router: - router_logits_list[micro_batch_idx][f"mtp_layer{mtp_idx}"] = mtp_router_results + mtp_losses = torch.tensor(0.0, device=DEVICE) + has_mtp_loss = False + for micro_batch_idx, (loss_ctx_dict, mtp_outputs) in enumerate(zip(loss_ctx_list, mtp_outputs_per_mb)): + # Get the mtp loss context dict + mtp_loss_ctx_dict = loss_ctx_dict.get("mtp") + if mtp_loss_ctx_dict is None or name not in mtp_loss_ctx_dict: + continue + + # Get the loss context list for this mtp_config name + mtp_loss_ctx_list = mtp_loss_ctx_dict[name] + + micro_batch_mtp_losses = torch.tensor(0.0, device=DEVICE) + for mtp_idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)): + mtp_hidden_states, mtp_router_results, _ = mtp_hidden + mtp_loss, _ = self.lm_head(mtp_hidden_states, cast(MTPLossContext, mtp_ctx)) + micro_batch_mtp_losses += mtp_loss + + if keep_router: + # Add name prefix to router logits key + router_logits_list[micro_batch_idx][f"{name}_mtp_layer{mtp_idx}"] = mtp_router_results - mtp_losses += micro_batch_mtp_losses / len(mtp_loss_ctx_list) - has_mtp_loss = True + mtp_losses += micro_batch_mtp_losses / len(mtp_loss_ctx_list) + has_mtp_loss = True - if has_mtp_loss: - output["mtp_loss"] = mtp_losses * self.config.mtp_config.loss_scaling_factor + if has_mtp_loss: + # Use the loss_scaling_factor from current mtp_config + mtp_losses_dict[name] = mtp_losses * mtp_config.loss_scaling_factor + + # Store mtp losses as dict + if mtp_losses_dict: + output["mtp_loss"] = mtp_losses_dict # Apply final norm to all micro-batches cat_hidden_states = torch.cat(hidden_states_list, dim=1) @@ -751,59 +794,64 @@ def _forward( if ( self.mtp_block is not None and loss_ctx is not None - and (mtp_loss_ctx_list := loss_ctx.get("mtp")) is not None + and (mtp_loss_ctx_dict := loss_ctx.get("mtp")) is not None ): + output["mtp_loss"] = {} mtp_seq_ctx = seq_ctx.copy( input_ids=input_ids.clone() if input_ids is not None else None, position_ids=position_ids.clone(), inputs_embeds=seq_ctx.inputs_embeds.clone() if seq_ctx.inputs_embeds is not None else None, ) - # MTP uses its own mask; main mask's non-pad indices do not apply. - mtp_nonpad_indices = torch.nonzero(mtp_seq_ctx.mask, as_tuple=True)[1] - mtp_non_pad_token = mtp_nonpad_indices.numel() - mtp_num_tokens_global, mtp_z_world_size = self._z_loss_dist_token_count( - z_ctx, mtp_non_pad_token, mtp_seq_ctx.mask.device - ) - # Forward through MTP block - mtp_outputs = self.mtp_block( - layer_hidden_states, - embed_tokens_fn=self.embed_tokens, - position_embeddings=position_embeddings, - seq_ctx=mtp_seq_ctx, - ) - - # Compute MTP losses for each depth - mtp_losses = torch.tensor(0.0, device=DEVICE) - for idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)): - mtp_hidden_states, mtp_router_results, mtp_router_weights = mtp_hidden + for mtp_block in self.mtp_block: + mtp_config = mtp_block.mtp_config + name = mtp_config.name + mtp_nonpad_indices = torch.nonzero(mtp_seq_ctx.mask, as_tuple=True)[1] + mtp_non_pad_token = mtp_nonpad_indices.numel() + mtp_num_tokens_global, mtp_z_world_size = self._z_loss_dist_token_count( + z_ctx, mtp_non_pad_token, mtp_seq_ctx.mask.device + ) - if keep_router: - output["router_logits"][f"mtp_layer{idx}"] = mtp_router_results - output["router_weights"][f"mtp_layer{idx}"] = mtp_router_weights - # Inject this MTP layer's z-loss before lm_head so backward through mtp_loss - # traverses the AuxLossScaler node and releases this layer's logsumexp activations. - mtp_hidden_states = self.aux_loss.accumulate( - selected_router_weights=mtp_router_weights.index_select(0, mtp_nonpad_indices) - .contiguous() - .float(), - selected_router_logits=mtp_router_results.index_select(0, mtp_nonpad_indices).contiguous().float(), - hidden_states=mtp_hidden_states, - balancing_ctx=balancing_ctx, - z_ctx=z_ctx, - num_tokens_local=mtp_non_pad_token, - num_tokens_global=mtp_num_tokens_global, - world_size=mtp_z_world_size, + # Forward through MTP block + mtp_outputs = mtp_block( + layer_hidden_states, + embed_tokens_fn=self.embed_tokens, + position_embeddings=position_embeddings, + seq_ctx=mtp_seq_ctx, ) - mtp_loss, _ = self.lm_head(mtp_hidden_states, cast(MTPLossContext, mtp_ctx)) - mtp_losses += mtp_loss - # Average MTP losses across depths and scale - mtp_losses = mtp_losses / len(mtp_loss_ctx_list) - scaled_mtp_loss = mtp_losses * self.config.mtp_config.loss_scaling_factor # type: ignore + # Compute MTP losses for each depth + mtp_losses = torch.tensor(0.0, device=DEVICE) + mtp_loss_ctx_list = mtp_loss_ctx_dict[name] + for idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)): + mtp_hidden_states, mtp_router_results, mtp_router_weights = mtp_hidden + + if keep_router: + output["router_logits"][f"{name}_mtp_layer{idx}"] = mtp_router_results + output["router_weights"][f"{name}_mtp_layer{idx}"] = mtp_router_weights + # Inject this MTP layer's z-loss before lm_head so backward through mtp_loss + # traverses the AuxLossScaler node and releases this layer's logsumexp activations. + mtp_hidden_states = self.aux_loss.accumulate( + selected_router_weights=mtp_router_weights.index_select(0, mtp_nonpad_indices) + .contiguous() + .float(), + selected_router_logits=mtp_router_results.index_select(0, mtp_nonpad_indices).contiguous().float(), + hidden_states=mtp_hidden_states, + balancing_ctx=balancing_ctx, + z_ctx=z_ctx, + num_tokens_local=mtp_non_pad_token, + num_tokens_global=mtp_num_tokens_global, + world_size=mtp_z_world_size, + ) + mtp_loss, _ = self.lm_head(mtp_hidden_states, cast(MTPLossContext, mtp_ctx)) + mtp_losses += mtp_loss - # Add to total loss - output["mtp_loss"] = scaled_mtp_loss + # Average MTP losses across depths and scale + mtp_losses = mtp_losses / len(mtp_loss_ctx_list) + scaled_mtp_loss = mtp_losses * mtp_config.loss_scaling_factor # type: ignore + + # Add to total loss + output["mtp_loss"][name] = scaled_mtp_loss split_aux_output = self.aux_loss.finalize( balancing_ctx=balancing_ctx, @@ -896,16 +944,38 @@ def build_layers(self, config: MoEConfig) -> nn.ModuleDict: layers.__class__.__repr__ = module_dict_repr # type: ignore[method-assign] return layers - def build_mtp_block(self, config: MoEConfig) -> MTPBlock: + def build_mtp_block_list(self, config): + mtp_block_list = [] + layer_idx_offset = 0 # Cumulative offset for layer indices across all mtp_configs + mtp_name_list = [] + + for mtp_config in config.mtp_config: + if mtp_config.name not in ("normal", "sci"): + raise ValueError(f"Expected mtp keys to be either `normal` or `sci`, but got `{mtp_config.name}`") + if mtp_config.name in mtp_name_list: + raise ValueError(f"Duplicate mtp name: `{mtp_config.name}`") + + mtp_name_list.append(mtp_config.name) + # Build the MTP block with the current offset + mtp_block_list.append(self.build_mtp_block(config, mtp_config, layer_idx_offset)) + + # Update offset: number of physical layers for this mtp_config + num_physical_layer = 1 if mtp_config.share_weights else mtp_config.num_layers + layer_idx_offset += num_physical_layer + + return nn.ModuleList(mtp_block_list) + + def build_mtp_block(self, config: MoEConfig, mtp_config: MTPConfig, layer_idx_offset: int) -> MTPBlock: """Build MTP block with MoE decoder layers. Args: config (MoEConfig): Model configuration. + mtp_config (MTPConfig): MTP configuration for this specific block. + layer_idx_offset (int): Offset for layer indices to ensure uniqueness across multiple mtp_configs. Returns: MTPBlock: Constructed MTP block. """ - mtp_config = config.mtp_config assert mtp_config is not None, "mtp_config must be provided" mtp_layers = [] @@ -949,7 +1019,7 @@ def build_mtp_block(self, config: MoEConfig) -> MTPBlock: router_compute_dtype=config.router_compute_dtype, moe_act_fn_cfg=config.moe_act_fn_cfg, float8_cfg=config.float8_cfg, - layer_idx=config.num_hidden_layers + i, + layer_idx=config.num_hidden_layers + layer_idx_offset + i, dispatcher=config.dispatcher, ep_mesh=self.ep_mesh, ) @@ -1084,30 +1154,38 @@ def fully_shard( # Shard MTP block if it exists if self.mtp_block is not None: - for mtp_idx, mtp_layer in enumerate(self.mtp_block.layers): - if self._should_recompute(None, mtp_idx=mtp_idx) or ( - self.config.mtp_config is not None and self.config.mtp_config.share_weights - ): # share mtp head must recompute - mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT) - self.mtp_block.layers[mtp_idx] = mtp_layer - - reshard_after_forward = mtp_idx != len(self.mtp_block.layers) - 1 - self._fully_shard( - mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, - mp_policy=mp_policy, - reshard_after_forward=reshard_after_forward, - offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, - module=mtp_layer, - ) - if mtp_idx == 0: - layer_next.set_modules_to_forward_prefetch([mtp_layer]) # type: ignore + total_mtp_layers = sum([len(mtp_block.layers) for mtp_block in self.mtp_block]) + global_mtp_idx = 0 # Track global MTP layer index across all mtp_configs + mtp_block_layers = [] + for mtp_block in self.mtp_block: + mtp_config = mtp_block.mtp_config + for local_mtp_idx, mtp_layer in enumerate(mtp_block.layers): + if self._should_recompute(None, mtp_idx=global_mtp_idx) or ( + mtp_config is not None and mtp_config.share_weights + ): # share mtp head must recompute + mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT) + mtp_block.layers[local_mtp_idx] = mtp_layer + + reshard_after_forward = global_mtp_idx != total_mtp_layers - 1 + self._fully_shard( + mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, + module=mtp_layer, + ) + # Only set prefetch for the first MTP layer across all mtp_configs + if global_mtp_idx == 0: + layer_next.set_modules_to_forward_prefetch([mtp_layer]) # type: ignore + global_mtp_idx += 1 + + mtp_block_layers.extend(list(mtp_block.layers)) - if self.config.mtp_config is not None and self.config.mtp_config.num_layers > 0: - for prev_mtp_layer, next_mtp_layer in zip( - list(self.mtp_block.layers)[:-1], - list(self.mtp_block.layers)[1:], - ): - prev_mtp_layer.set_modules_to_forward_prefetch([next_mtp_layer]) # type: ignore + for prev_mtp_layer, next_mtp_layer in zip( + mtp_block_layers[:-1], + mtp_block_layers[1:], + ): + prev_mtp_layer.set_modules_to_forward_prefetch([next_mtp_layer]) # type: ignore self._fully_shard( mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, @@ -1329,7 +1407,9 @@ def _should_recompute( """ num_layers = self.config.num_hidden_layers if self.config.mtp_config is not None: - mtp_layers = 1 if self.config.mtp_config.share_weights else self.config.mtp_config.num_layers + mtp_layers = sum( + [1 if mtp_config.share_weights else mtp_config.num_layers for mtp_config in self.config.mtp_config] + ) else: mtp_layers = 0 recompute_ratio = self.fsdp_config.recompute_ratio if self.fsdp_config is not None else 0.0 diff --git a/xtuner/v1/model/moe/qwen3_5_text.py b/xtuner/v1/model/moe/qwen3_5_text.py index e4cf2e2fc3..94fb488051 100644 --- a/xtuner/v1/model/moe/qwen3_5_text.py +++ b/xtuner/v1/model/moe/qwen3_5_text.py @@ -45,19 +45,29 @@ class Qwen3_5_VLTextMoE(Qwen3VLTextMoE): def to_hf_key_list(self, key: str) -> list[str]: # Handle MTP parameters if key.startswith("mtp_block."): - # Remove "mtp_block." prefix - key = key.replace("mtp_block.", "", 1) + + # Extract MTP name from mtp_block.{mtp_name}.{rest} + match = re.match(r"mtp_block\.(\d+)\.(.*)", key) + if not match: + raise ValueError( + f"Invalid mtp_block key format: {key}. " + f"Expected 'mtp_block.{{idx}}.*" + ) + + mtp_idx = int(match.group(1)) + mtp_name = self.config.mtp_config[mtp_idx].name + key = match.group(2) # Handle MTP layer-specific parameters - # xtuner: mtp_block.layers.{idx}.decoder_layer.{param} - # HF: mtp.layers.{idx}.{param} + # xtuner: mtp_block.{mtp_name}.layers.{idx}.decoder_layer.{param} + # HF normal: mtp.layers.{idx}.{param} + # HF sci: mtp.sci.layers.{idx}.{param} key = re.sub(r"layers\.(\d+)\.decoder_layer\.", r"layers.\1.", key) # Handle MTP normalization layers - # xtuner: mtp_block.layers.{idx}.enorm -> HF: mtp.pre_fc_norm_embedding - # xtuner: mtp_block.layers.{idx}.hnorm -> HF: mtp.pre_fc_norm_hidden - # xtuner: mtp_block.layers.{idx}.final_layernorm -> HF: mtp.norm - # Note: Currently assuming single MTP layer (idx=0), may need adjustment for multiple layers + # xtuner: mtp_block.{mtp_name}.layers.{idx}.enorm -> HF: mtp[.sci].pre_fc_norm_embedding + # xtuner: mtp_block.{mtp_name}.layers.{idx}.hnorm -> HF: mtp[.sci].pre_fc_norm_hidden + # xtuner: mtp_block.{mtp_name}.layers.{idx}.final_layernorm -> HF: mtp[.sci].norm if ".enorm." in key: key = re.sub(r"layers\.\d+\.enorm\.", "pre_fc_norm_embedding.", key) elif ".hnorm." in key: @@ -66,7 +76,7 @@ def to_hf_key_list(self, key: str) -> list[str]: key = re.sub(r"layers\.\d+\.final_layernorm\.", "norm.", key) # Handle MTP projection layer - # xtuner: mtp_block.layers.{idx}.eh_proj -> HF: mtp.fc + # xtuner: mtp_block.{mtp_name}.layers.{idx}.eh_proj -> HF: mtp.{mtp_name}.fc if ".eh_proj." in key: key = re.sub(r"layers\.\d+\.eh_proj\.", "fc.", key) @@ -74,6 +84,12 @@ def to_hf_key_list(self, key: str) -> list[str]: key = re.sub(r"layers\.(\d+)\.(experts|gate|shared_experts|shared_expert_gate)", r"layers.\1.mlp.\2", key) key = key.replace("shared_experts", "shared_expert") + # Determine HF prefix based on mtp_name + # Normal MTP (mtp_block.normal.*): mtp.{key} + # Science MTP (mtp_block.sci.*): mtp.sci.{key} + # TODO: normal mtp prefix + hf_prefix = "mtp." if mtp_name == "normal" else f"mtp.{mtp_name}." + # Handle fused weights n_routed_experts = self.config.n_routed_experts if "fused_w1w3.weight" in key: @@ -83,15 +99,15 @@ def to_hf_key_list(self, key: str) -> list[str]: w1w3_keys.append(key.replace("fused_w1w3.weight", f"{i}.gate_proj.weight")) w1w3_keys.append(key.replace("fused_w1w3.weight", f"{i}.up_proj.weight")) - return [f"mtp.{key}" for key in w1w3_keys] + return [f"{hf_prefix}{key}" for key in w1w3_keys] elif "fused_w2.weight" in key: w2_keys: list[str] = [] for i in range(n_routed_experts): w2_keys.append(key.replace("fused_w2.weight", f"{i}.down_proj.weight")) - return [f"mtp.{key}" for key in w2_keys] + return [f"{hf_prefix}{key}" for key in w2_keys] else: - return ["mtp." + key] + return [hf_prefix + key] # Handle main model parameters if "layers" in key or "embed_tokens" in key: diff --git a/xtuner/v1/module/mtp/config.py b/xtuner/v1/module/mtp/config.py index d6bcf7a9b5..6926358125 100644 --- a/xtuner/v1/module/mtp/config.py +++ b/xtuner/v1/module/mtp/config.py @@ -1,10 +1,14 @@ """Configuration for Multi-Token Prediction (MTP).""" -from typing import Annotated +from __future__ import annotations + +from typing import TYPE_CHECKING, Annotated from cyclopts import Parameter from pydantic import BaseModel, ConfigDict +from xtuner.v1.loss.mtp_loss import MTPLossConfig + class MTPConfig(BaseModel): """Configuration for Multi-Token Prediction (MTP). @@ -18,6 +22,7 @@ class MTPConfig(BaseModel): decoder layers. Args: + name (str): Name of mtp module. num_layers (int): Number of MTP layers (prediction depths). Each layer predicts tokens at increasing future positions (i+1, i+2, ..., i+D). share_weights (bool): Whether to share the weights of the MTP layers. @@ -30,6 +35,8 @@ class MTPConfig(BaseModel): loss_scaling_factor (float): Scaling factor for MTP loss. The total MTP loss is computed as the average of losses across all depths, multiplied by this factor. Default: 0.1. + loss_cfg (MTPLossConfig | None): Loss configuration for MTP. + If None, loss config will be constructed from MTPLossConfig(). Default: None. Example: >>> # In model config @@ -39,14 +46,17 @@ class MTPConfig(BaseModel): ... num_layers=2, ... share_weights=True, ... loss_scaling_factor=0.1, + ... loss_cfg=MTPLossConfig() ... ), ... ) """ model_config = ConfigDict(extra="forbid") + name: Annotated[str, Parameter(group="model")] num_layers: Annotated[int, Parameter(group="model")] share_weights: Annotated[bool, Parameter(group="model")] = False detach_mtp_lm_head_weight: Annotated[bool, Parameter(group="model")] = False detach_mtp_inputs: Annotated[bool, Parameter(group="model")] = False loss_scaling_factor: Annotated[float, Parameter(group="model")] = 0.1 + loss_cfg: MTPLossConfig | None = None