Skip to content

Commit

Permalink
Optimize ForCausalLMLoss by removing unnecessary contiguous() call to…
Browse files Browse the repository at this point in the history
… reduce memory overhead (huggingface#35646)

Optimize ForCausalLMLoss by removing unnecessary contiguous() calls to reduce memory overhead
  • Loading branch information
efsotr authored Jan 16, 2025
1 parent 1302c32 commit 8ebe9d7
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/transformers/loss/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def ForCausalLMLoss(
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()

# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
logits = logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
shift_labels = shift_labels.to(logits.device)
loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
return loss


Expand Down

0 comments on commit 8ebe9d7

Please sign in to comment.