Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inf] Add config var to enable keeping module on host #6846

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
values for :any:`DeepSpeedMoEConfig`.
"""

keep_module_on_host: bool = False
"""
When loading checkpoints to model parameters, they are moved to the device. In very large models
this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on
host and not move them directly to the device (giving an option to quantize checkpoint data before
moving it to the device for example).
Set only for models with injection policies and auto TP.
"""

quant: QuantizationConfig = {}
"""
NOTE: only works for int8 dtype.
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __init__(self, model, config):
is_meta_device = hasattr(self.module, "device") and self.module.device.type == 'meta'
if is_meta_device:
self.module.to_empty(device=device)
else:
elif not config.keep_module_on_host:
self.module.to(device)

if config.tensor_parallel.tp_size > 1:
Expand Down
33 changes: 22 additions & 11 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


def move(tensor, device):
def move(tensor, device, copy=True):
if tensor.is_meta:
return torch.empty_like(tensor, device=device)
else:
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
# Using copy=True instead of clone() will help in case of cpu --> cpu.
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
return tensor.to(device, copy=True)
return tensor.to(device, copy=copy)


class ReplaceWithTensorSlicing:
Expand Down Expand Up @@ -188,7 +188,14 @@ def load(module, state_dict, prefix, mp_group=None):

class AutoTP():

def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl):
def __init__(self,
module,
all_reduce_linears,
prefix,
state_dict,
linear_layer_setting,
orig_layer_impl,
keep_module_on_host=False):
self.module = module
self.all_reduce_linears = all_reduce_linears
self.prefix = prefix
Expand All @@ -200,6 +207,7 @@ def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_
self.orig_layer_impl = orig_layer_impl
self.linear_policies = None
self.conv_linear_layer = False
self.keep_module_on_host = keep_module_on_host

def in_module_list(module, module_list):
for item in module_list:
Expand Down Expand Up @@ -330,6 +338,10 @@ def set_tensor_parallel_config(self, mp_size, mp_group):
def _replace(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return
device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
return_new_copy = not self.keep_module_on_host
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
# For mixtral-7x8b, need to skip MoE gate linear replace.
Expand Down Expand Up @@ -363,18 +375,17 @@ def _replace(self, child, name, conv_linear_layer):
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
dim=1)
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data

setattr(child, "replaced", True)
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
child.bias if child.bias is None else torch.nn.parameter.Parameter(
move(child.bias,
get_accelerator().current_device_name())), self.mp_group)
move(child.bias, device_name, return_new_copy)), self.mp_group)
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name())), self.mp_group)
torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group)
else:

# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
Expand All @@ -387,22 +398,22 @@ def _replace(self, child, name, conv_linear_layer):
#The copy is a regular copy, The shape of dst and src is the same
data_dc = move(
prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index),
get_accelerator().current_device_name())
device_name, return_new_copy)

bias_data_dc = None if child.bias is None else move(
prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index),
get_accelerator().current_device_name())
device_name, return_new_copy)
else:
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
dim=1 if self.conv_linear_layer else 0)
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data

if child.bias is not None:
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
dim=0)
bias_data = move(bias_data[mp_replace.gpu_index], get_accelerator().current_device_name())
bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy)
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
del bias_data
else:
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
#mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group)

# 1. Create AutoTP object
_autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl)
_autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl,
config.keep_module_on_host)

# 2. Set the tensor parallelism config
_autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
Expand Down
21 changes: 19 additions & 2 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ def test(self, model_w_task, injection_policy, query, inf_kwargs, assert_fn, dty


@pytest.mark.seq_inference
@pytest.mark.parametrize('keep_module_on_host', [True, False])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it useful to validate that tensors are on cpu when keep_module_on_host=True?

Copy link
Contributor Author

@oelayan7 oelayan7 Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjruwase
You're right.
A check was added.

@pytest.mark.parametrize(
"model_w_task",
[("Helsinki-NLP/opus-mt-en-de", "translation"), ("Salesforce/codegen-350M-mono", "text-generation")],
Expand All @@ -570,6 +571,7 @@ def test(
inf_kwargs,
assert_fn,
dtype,
keep_module_on_host,
):
invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
if invalid_test_msg:
Expand All @@ -592,13 +594,20 @@ def test(
framework="pt")
bs_output = pipe(query, **inf_kwargs)

pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
pipe.model = deepspeed.init_inference(pipe.model,
mp_size=world_size,
dtype=dtype,
keep_module_on_host=keep_module_on_host)
ds_output = pipe(query, **inf_kwargs)

print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)

if keep_module_on_host:
for name, param in model.named_parameters():
assert param.device == torch.device('cpu'), f"keep_module_on_host is on but param {name} is not on cpu"

@pytest.mark.world_size(3)
def test_odd_world_size(
self,
Expand All @@ -607,6 +616,7 @@ def test_odd_world_size(
inf_kwargs,
assert_fn,
dtype,
keep_module_on_host,
):
invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
if invalid_test_msg:
Expand All @@ -624,13 +634,20 @@ def test_odd_world_size(
framework="pt")
bs_output = pipe(query, **inf_kwargs)

pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
pipe.model = deepspeed.init_inference(pipe.model,
mp_size=world_size,
dtype=dtype,
keep_module_on_host=keep_module_on_host)
ds_output = pipe(query, **inf_kwargs)

print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)

if keep_module_on_host:
for name, param in model.named_parameters():
assert param.device == torch.device('cpu'), f"keep_module_on_host is on but param {name} is not on cpu"


@pytest.mark.nightly
@pytest.mark.parametrize(
Expand Down