Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] deepspeed.initialize changes the output of Llama model #6929

Open
Ktakuya332C opened this issue Jan 7, 2025 · 0 comments
Open

[BUG] deepspeed.initialize changes the output of Llama model #6929

Ktakuya332C opened this issue Jan 7, 2025 · 0 comments
Assignees
Labels
bug Something isn't working training

Comments

@Ktakuya332C
Copy link

Ktakuya332C commented Jan 7, 2025

Describe the bug
The output of Llama model changes before and after applying deepspeed.initialize.

To Reproduce
Install dependent libraries

pip install mpi4py==4.0.1 deepspeed==0.16.2 transformers==4.47.1 accelerate==1.2.1

Compare outputs of the Llama model before and after applying deepspeed.initialize

import os
import torch
from accelerate import Accelerator, DeepSpeedPlugin
from transformers import LlamaForCausalLM

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "7998"

dummy = torch.randint(0, 128, (3,5)).to("cuda")

# Get output from Llama model (bfloat16)
model1 = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", torch_dtype=torch.bfloat16)
model1 = model1.to("cuda")
output1 = model1(dummy)

# Apply deepspeed to the model
deepspeed_plugin = DeepSpeedPlugin(zero_stage=0)
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = 1
accelerator = Accelerator(mixed_precision="bf16", deepspeed_plugin=deepspeed_plugin)

model2 = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", torch_dtype=torch.bfloat16)
model2 = accelerator.prepare(model2)
output2 = model2(dummy)

# These outputs should be the same, but actually not
# assert torch.all(output1.logits == output2.logits)

The cause of this difference is the precision of inv_freq buffer of Llama model

# The problem is that the original inv_freq is in fp32,
assert model1.model.rotary_emb.inv_freq.dtype == torch.float32
# while deepspeed.initialize called from accelerator.prepare cast inv_freq to bf16
assert model2.model.rotary_emb.inv_freq.dtype == torch.bfloat16

# Overwriting the inv_freq of model2 (bf16) with model1 (fp32) ...
model2.model.rotary_emb.inv_freq.data = model1.model.rotary_emb.inv_freq.data
output3 = model2(dummy)
assert torch.all(output1.logits == output3.logits) # resolves the problem

Expected behavior
There should be no difference between the output of the model before and after applying deepspeed.initialize

Cause of this behavior
deepspeed.initialize eventually calls bfloat16() method on the model,

self.module.bfloat16()

which not only casts parameters of the model, but also casts buffers of the model including inv_freq
https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/models/llama/modeling_llama.py#L125
where inv_freq is supposed to be in fp32 even if the parameters of the model is in bf16.

ds_report output

ds_report
$ ds_report
[2025-01-07 22:05:15,391] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fp_quantizer ........... [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
 [WARNING]  gds requires the dev libaio .so object and headers but these were not found.
 [WARNING]  gds: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.5
 [WARNING]  using untested triton version (3.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/usr/local/lib/python3.10/dist-packages/torch']
torch version .................... 2.5.0a0+e000cf0ad9.nv24.10
deepspeed install path ........... ['/usr/local/lib/python3.10/dist-packages/deepspeed']
deepspeed info ................... 0.16.2, unknown, unknown
torch cuda version ............... 12.6
torch hip version ................ None
nvcc version ..................... 12.6
deepspeed wheel compiled w. ...... torch 2.5, cuda 12.6
shared memory (/dev/shm) size .... 1007.78 GB

System info (please complete the following information):

  • OS: Ubuntu 22.04
  • 1 H100 GPU
  • Python version: 3.10.12

Docker context
nvidia/pytorch

Additional context

  • This difference causes the training/serving skew, which deteriorates the performance of trained models. (e.g. train models with deepspeed in bf16, later serve the model in bf16)
  • I think this issue Enable deepspeed.zero.Init causes very strange spikes in PPO policy_loss #4932 is deeply related to this one, and seems to be resolved, but I think issue still persists. I doubt some changes on the transformer side makes this problem reappear, but not sure.
@Ktakuya332C Ktakuya332C added bug Something isn't working training labels Jan 7, 2025
@hwchen2017 hwchen2017 self-assigned this Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

2 participants