You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Compare outputs of the Llama model before and after applying deepspeed.initialize
import os
import torch
from accelerate import Accelerator, DeepSpeedPlugin
from transformers import LlamaForCausalLM
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "7998"
dummy = torch.randint(0, 128, (3,5)).to("cuda")
# Get output from Llama model (bfloat16)
model1 = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", torch_dtype=torch.bfloat16)
model1 = model1.to("cuda")
output1 = model1(dummy)
# Apply deepspeed to the model
deepspeed_plugin = DeepSpeedPlugin(zero_stage=0)
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = 1
accelerator = Accelerator(mixed_precision="bf16", deepspeed_plugin=deepspeed_plugin)
model2 = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", torch_dtype=torch.bfloat16)
model2 = accelerator.prepare(model2)
output2 = model2(dummy)
# These outputs should be the same, but actually not
# assert torch.all(output1.logits == output2.logits)
The cause of this difference is the precision of inv_freq buffer of Llama model
# The problem is that the original inv_freq is in fp32,
assert model1.model.rotary_emb.inv_freq.dtype == torch.float32
# while deepspeed.initialize called from accelerator.prepare cast inv_freq to bf16
assert model2.model.rotary_emb.inv_freq.dtype == torch.bfloat16
# Overwriting the inv_freq of model2 (bf16) with model1 (fp32) ...
model2.model.rotary_emb.inv_freq.data = model1.model.rotary_emb.inv_freq.data
output3 = model2(dummy)
assert torch.all(output1.logits == output3.logits) # resolves the problem
Expected behavior
There should be no difference between the output of the model before and after applying deepspeed.initialize
Cause of this behavior deepspeed.initialize eventually calls bfloat16() method on the model,
$ ds_report
[2025-01-07 22:05:15,391] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fp_quantizer ........... [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
[WARNING] gds requires the dev libaio .so object and headers but these were not found.
[WARNING] gds: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.5
[WARNING] using untested triton version (3.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/usr/local/lib/python3.10/dist-packages/torch']
torch version .................... 2.5.0a0+e000cf0ad9.nv24.10
deepspeed install path ........... ['/usr/local/lib/python3.10/dist-packages/deepspeed']
deepspeed info ................... 0.16.2, unknown, unknown
torch cuda version ............... 12.6
torch hip version ................ None
nvcc version ..................... 12.6
deepspeed wheel compiled w. ...... torch 2.5, cuda 12.6
shared memory (/dev/shm) size .... 1007.78 GB
System info (please complete the following information):
OS: Ubuntu 22.04
1 H100 GPU
Python version: 3.10.12
Docker context nvidia/pytorch
Additional context
This difference causes the training/serving skew, which deteriorates the performance of trained models. (e.g. train models with deepspeed in bf16, later serve the model in bf16)
Describe the bug
The output of Llama model changes before and after applying
deepspeed.initialize
.To Reproduce
Install dependent libraries
Compare outputs of the Llama model before and after applying
deepspeed.initialize
The cause of this difference is the precision of
inv_freq
buffer of Llama modelExpected behavior
There should be no difference between the output of the model before and after applying
deepspeed.initialize
Cause of this behavior
deepspeed.initialize
eventually callsbfloat16()
method on the model,DeepSpeed/deepspeed/runtime/engine.py
Line 1160 in f2cc809
which not only casts parameters of the model, but also casts buffers of the model including
inv_freq
https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/models/llama/modeling_llama.py#L125
where
inv_freq
is supposed to be in fp32 even if the parameters of the model is in bf16.ds_report output
ds_report
System info (please complete the following information):
Docker context
nvidia/pytorch
Additional context
The text was updated successfully, but these errors were encountered: