Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,9 @@ def _get_total_loss(self, model_outputs: ModelOutputs) -> torch.Tensor:
loss = torch.tensor(0.0, device=DEVICE)
for key in model_outputs.model_fields:
value = getattr(model_outputs, key)
if "loss" in key and isinstance(value, torch.Tensor):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loss part should be autonomous to the model, rather than being hardcoded here.

if key == "mtp_loss" and isinstance(value, dict):
for mtp_loss_name, mtp_loss in value.items():
loss += mtp_loss
elif "loss" in key and isinstance(value, torch.Tensor):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if key == "mtp_loss" and isinstance(value, dict):
for mtp_loss_name, mtp_loss in value.items():
loss += mtp_loss
elif "loss" in key and isinstance(value, torch.Tensor):
elif "loss" in key:
loss_values = list(value.values()) if isinstance(value, dict) else [value]
loss_values = [i for i in loss_values if isinstance(i, torch.Tensor)]
for value in loss_values:
loss += value

loss += value
return loss
40 changes: 39 additions & 1 deletion xtuner/v1/loss/mtp_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any
from typing import Any, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -60,6 +60,7 @@ class MTPLossConfig(CELossConfig):

mtp_depth: int
detach_mtp_lm_head_weight: bool = False
mask_type: Optional[str] = None

@property
def loss_ctx_cls(self) -> type["MTPLossContext"]:
Expand Down Expand Up @@ -167,6 +168,12 @@ def forward(
head_weight = head_weight.detach()
head_bias = head_bias.detach() if head_bias is not None else None
# Dispatch to eager_mode/chunk_mode via base class, which calls loss_fn per chunk

mask_type = self.loss_cfg.mask_type
if mask_type == "v1":
self.process_loss_weight_v1()
elif mask_type is not None:
raise NotImplementedError(f"Unknown MTP Loss Mask Type: {mask_type}")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The calculation logic of loss should not be hard-coded here; please implement a new loss_context.

return super().forward(hidden_states, head_weight, head_bias)

def loss_fn(
Expand Down Expand Up @@ -214,3 +221,34 @@ def _kl_loss_fn(
)

return kl_loss, (None, {})

def process_loss_weight_v1(self):
Comment thread
HAOCHENYE marked this conversation as resolved.
layer_idx = self.loss_cfg.mtp_depth - 1
shifted_labels = self.loss_kwargs.shifted_labels
loss_weight = self.loss_kwargs.loss_weight
sum_loss_weight = loss_weight.sum()

easy_to_use = torch.cat(
[
shifted_labels,
torch.zeros((shifted_labels.size(0), 1), dtype=shifted_labels.dtype, device=shifted_labels.device),
],
dim=-1,
)

# TODO: digit and dot token config
is_digit = torch.where(easy_to_use < 25, easy_to_use > 14, 0)
is_dot = torch.where(easy_to_use == 13, 1, 0)
is_digit_or_dot = is_digit | is_dot

mask = is_digit_or_dot.clone()
for i in range(layer_idx + 1):
mask |= torch.roll(is_digit_or_dot, shifts=i + 1, dims=-1)

mtp_mask = mask.bool()[:, :-1]

loss_weight[mtp_mask == 0.0] = 0.0
if loss_weight.sum().item() != 0:
loss_weight = loss_weight * sum_loss_weight / loss_weight.sum()

self.loss_kwargs.loss_weight = loss_weight
13 changes: 12 additions & 1 deletion xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,7 +1313,18 @@ def post_micro_batch_forward(self, batch_outputs: Sequence[ModelOutputs]) -> Bat
output_copy = output.model_copy()
for name in output_copy.model_fields:
obj = getattr(output_copy, name)
if "loss" in name and isinstance(obj, torch.Tensor):
if name == "mtp_loss" and isinstance(obj, dict):
for key, value in obj.items():
loss_item = value.item()
local_total_loss += loss_item
reduced_name = f"{key}_reduced_mtp_loss"

if reduced_name not in reduced_other_losses:
reduced_other_losses[reduced_name] = loss_item
else:
reduced_other_losses[reduced_name] += loss_item

elif "loss" in name and isinstance(obj, torch.Tensor):
loss_item = obj.item()
local_total_loss += loss_item
reduced_name = f"reduced_{name}"
Expand Down
Loading
Loading