diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 2e8e2bb5f149d7..0f39fde40a7c49 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -36,15 +36,15 @@ def ForCausalLMLoss( logits = logits.float() labels = labels.to(logits.device) # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() + labels = nn.functional.pad(labels, (0, 1), value=ignore_index) shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - shift_logits = shift_logits.view(-1, vocab_size) + logits = logits.view(-1, vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs) + shift_labels = shift_labels.to(logits.device) + loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs) return loss