Skip to content

Commit

Permalink
Support loading models without ISQ using device map (#1045)
Browse files Browse the repository at this point in the history
* Device mapping support for normal cache (no paged attn)

* Attempt to load layers according to device map

* Don't recompile regexes

* Don't allocate all the memory
  • Loading branch information
EricLBuehler authored Jan 10, 2025
1 parent 5e1a615 commit dba4c76
Show file tree
Hide file tree
Showing 12 changed files with 223 additions and 34 deletions.
15 changes: 10 additions & 5 deletions mistralrs-core/src/device_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ pub trait DeviceMapper: Debug {
/// Set non mapped layer device. This is for ISQ + device mapping support
/// If ISQ layer, then do not change the device. *They will do it later in NormalModel::quantize*
fn set_nm_device<'a>(&self, varbuilder: VarBuilder<'a>, loading_isq: bool) -> VarBuilder<'a>;
fn num_device_mapping_layers(&self) -> usize;

// === IMMEDIATELY AFTER INIT ===
fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType>;
Expand Down Expand Up @@ -263,6 +264,9 @@ impl DeviceMapper for LayerDeviceMapper {
.try_into_dtype(&self.mappings.iter().collect::<Vec<_>>())
.map_err(candle_core::Error::msg)
}
fn num_device_mapping_layers(&self) -> usize {
self.mappings.len()
}
}

#[derive(Debug)]
Expand All @@ -286,11 +290,8 @@ impl DeviceMapper for DummyDeviceMapper {
varbuilder.set_device(self.nm_device.clone())
}
}
fn device_for(&self, _: usize, loading_isq: bool) -> Option<&Device> {
if loading_isq {
return Some(&self.nm_device);
}
None
fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
Some(&self.nm_device)
}
fn get_unique_devices(&self) -> Vec<Device> {
vec![self.nm_device.clone()]
Expand All @@ -314,6 +315,10 @@ impl DeviceMapper for DummyDeviceMapper {
.try_into_dtype(&[&self.nm_device])
.map_err(candle_core::Error::msg)
}
fn num_device_mapping_layers(&self) -> usize {
// Effectively one layer
1
}
}

/// Get all devices on the same device type but different ordinals
Expand Down
20 changes: 15 additions & 5 deletions mistralrs-core/src/diffusion_models/flux/stepper.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{cmp::Ordering, fs::File};
use std::{cmp::Ordering, fs::File, sync::Arc};

use candle_core::{DType, Device, Result, Tensor, D};
use candle_nn::{Module, VarBuilder};
Expand All @@ -14,7 +14,7 @@ use crate::{
DiffusionGenerationParams,
},
pipeline::DiffusionModel,
utils::varbuilder_utils::from_mmaped_safetensors,
utils::varbuilder_utils::{from_mmaped_safetensors, DeviceForLoadTensor},
};

use super::{autoencoder::AutoEncoder, model::Flux};
Expand Down Expand Up @@ -104,9 +104,11 @@ fn get_t5_model(
vec![],
Some(dtype),
device,
vec![None],
silent,
None,
|_| true,
Arc::new(|_| DeviceForLoadTensor::Base),
)?;
let config_filename = repo.get("config.json").map_err(candle_core::Error::msg)?;
let config = std::fs::read_to_string(config_filename)?;
Expand All @@ -125,9 +127,17 @@ fn get_clip_model_and_tokenizer(
));

let model_file = repo.get("model.safetensors")?;
let vb = from_mmaped_safetensors(vec![model_file], vec![], None, device, silent, None, |_| {
true
})?;
let vb = from_mmaped_safetensors(
vec![model_file],
vec![],
None,
device,
vec![None],
silent,
None,
|_| true,
Arc::new(|_| DeviceForLoadTensor::Base),
)?;
let config_file = repo.get("config.json")?;
let config: ClipConfig = serde_json::from_reader(File::open(config_file)?)?;
let config = config.text_config;
Expand Down
24 changes: 22 additions & 2 deletions mistralrs-core/src/pipeline/cache_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,18 @@ impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for NormalCa
let template_cache_dim = pipeline.cache().normal().0[0].k.dim;
let template_cache_msl = pipeline.cache().normal().0[0].k.max_seq_len;

for layer in pipeline.cache().normal().0.iter_mut() {
let layer_devices = if let Some(device_mapper) = pipeline.device_mapper() {
let mut layer_devices = Vec::new();
for layer in 0..device_mapper.num_device_mapping_layers() {
let device = device_mapper.device_for(layer, false).cloned();
layer_devices.push(device.expect("Internal bug, layer out of range!"));
}
Some(layer_devices)
} else {
None
};

for (layer_idx, layer) in pipeline.cache().normal().0.iter_mut().enumerate() {
if !load_preallocated_cache {
layer.reset();
continue;
Expand All @@ -398,8 +409,17 @@ impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for NormalCa
let mut k_caches = Vec::new();
let mut v_caches = Vec::new();
for seq in seqs.iter_mut() {
let (k_preallocated_cache, v_preallocated_cache) =
let (mut k_preallocated_cache, mut v_preallocated_cache) =
(*seq.preallocated_cache().as_ref().unwrap()).clone();
if let Some(layer_devices) = &layer_devices {
let layer_dev = &layer_devices[layer_idx];
k_preallocated_cache = k_preallocated_cache
.to_device(layer_dev)
.expect("Could not prepare cache");
v_preallocated_cache = v_preallocated_cache
.to_device(layer_dev)
.expect("Could not prepare cache");
}
k_caches.push(k_preallocated_cache);
v_caches.push(v_preallocated_cache);
}
Expand Down
6 changes: 5 additions & 1 deletion mistralrs-core/src/pipeline/diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::paged_attention::AttentionImplementation;
use crate::pipeline::ChatTemplate;
use crate::prefix_cacher_v2::PrefixCacheManagerV2;
use crate::sequence::Sequence;
use crate::utils::varbuilder_utils::DeviceForLoadTensor;
use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
use crate::{DeviceMapSetting, PagedAttentionConfig, Pipeline, TryIntoDType};
use anyhow::Result;
Expand Down Expand Up @@ -175,14 +176,17 @@ impl Loader for DiffusionLoader {
.iter()
.zip(self.inner.force_cpu_vb())
.map(|(path, force_cpu)| {
let dev = if force_cpu { &Device::Cpu } else { device };
from_mmaped_safetensors(
vec![path.clone()],
Vec::new(),
Some(dtype),
if force_cpu { &Device::Cpu } else { device },
dev,
vec![None],
silent,
None,
|_| true,
Arc::new(|_| DeviceForLoadTensor::Base),
)
})
.collect::<candle_core::Result<Vec<_>>>()?;
Expand Down
9 changes: 7 additions & 2 deletions mistralrs-core/src/pipeline/loaders/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@ pub trait DeviceMappedModelLoader {

let mut per_layer_avail = Vec::new();
for dev in devices.clone() {
let usage = MemoryUsage.get_memory_available(&dev)?;
per_layer_avail.push((usage, dev));
let avail = MemoryUsage.get_memory_available(&dev)?;
per_layer_avail.push((avail, dev));
}
// Reverse so we don't use the cpu first!
per_layer_avail.reverse();
Expand All @@ -403,8 +403,13 @@ pub trait DeviceMappedModelLoader {
let (device_capacity, device) = per_layer_avail
.pop()
.context("No more devices to map to. The model does not fit on this system.")?;
// All usage of 90% of the memory as a maximum.
#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
let device_capacity = (device_capacity as f64 * 0.90) as usize;
let layers_on_device = if device_capacity >= remaining_to_map {
num_layers - current_layer
} else if current_ordinal == 0 {
(device_capacity - non_mapped_size_in_bytes) / per_layer_size_in_bytes
} else {
device_capacity / per_layer_size_in_bytes
};
Expand Down
23 changes: 22 additions & 1 deletion mistralrs-core/src/pipeline/loaders/normal_loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{
collections::HashMap,
fmt::{Debug, Display},
str::FromStr,
sync::Arc,
};

use crate::{
Expand All @@ -16,7 +17,7 @@ use crate::{
EitherCache, IsqModel,
},
serde_default_fn,
utils::log::once_log_info,
utils::{log::once_log_info, varbuilder_utils::DeviceForLoadTensor},
xlora_models::NonGranularState,
DeviceMapMetadata,
};
Expand Down Expand Up @@ -115,6 +116,26 @@ pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoa
fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>>;
/// Get total num_hidden_layers for the layers which will be device mapped.
fn get_total_device_mapping_num_layers(&self, config: &str) -> Result<usize>;
fn get_device_for_tensor(
&self,
_config: &str,
_mapper: &dyn DeviceMapper,
) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
let closure = move |name: String| {
if let Some(captures) = re.captures(&name) {
captures
.get(1)
.and_then(|m| m.as_str().parse::<usize>().ok())
.map(DeviceForLoadTensor::Idx)
.unwrap_or(DeviceForLoadTensor::Base)
} else {
DeviceForLoadTensor::Base
}
};

Ok(Arc::new(closure))
}
}

#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
Expand Down
22 changes: 22 additions & 0 deletions mistralrs-core/src/pipeline/loaders/vision_loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ use serde::Deserialize;

use super::{DeviceMappedModelLoader, NormalLoadingMetadata};
use crate::amoe::AnyMoeBaseModelMixin;
use crate::device_map::DeviceMapper;
use crate::layers::Conv3dConfig;
use crate::paged_attention::{AttentionImplementation, ModelConfigMetadata};
use crate::pipeline::isq::IsqModelLoader;
use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
use crate::pipeline::{EitherCache, IsqModel, Processor, ProcessorCreator, VisionPromptPrefixer};
use crate::utils::varbuilder_utils::DeviceForLoadTensor;
use crate::vision_models::clip::ClipConfig;
use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
use crate::vision_models::idefics2_input_processor::Idefics2Processor;
Expand Down Expand Up @@ -82,6 +84,26 @@ pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoa
) -> Arc<dyn Processor + Send + Sync>;
fn supports_paged_attention(&self) -> bool;
fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer>;
fn get_device_for_tensor(
&self,
_config: &str,
_mapper: &dyn DeviceMapper,
) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
let closure = move |name: String| {
if let Some(captures) = re.captures(&name) {
captures
.get(1)
.and_then(|m| m.as_str().parse::<usize>().ok())
.map(DeviceForLoadTensor::Idx)
.unwrap_or(DeviceForLoadTensor::Base)
} else {
DeviceForLoadTensor::Base
}
};

Ok(Arc::new(closure))
}
}

#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
Expand Down
33 changes: 31 additions & 2 deletions mistralrs-core/src/pipeline/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ macro_rules! normal_model_loader {
$paths:expr,
$dtype:expr,
$device:expr,
$layer_devices:expr,
$config:expr,
$loader:expr,
$use_flash_attn:expr,
Expand All @@ -401,15 +402,18 @@ macro_rules! normal_model_loader {
} else {
None
};
let get_device_for_tensor = $loader.get_device_for_tensor(&$config, &*$mapper)?;

let vb = from_mmaped_safetensors(
$paths.get_weight_filenames().to_vec(),
Vec::new(),
$dtype,
$device,
$layer_devices,
$silent,
regexes,
|_| true, // Will be overwritten...
get_device_for_tensor,
)?;

$loader.load(
Expand All @@ -433,6 +437,7 @@ macro_rules! vision_normal_model_loader {
$paths:expr,
$dtype:expr,
$device:expr,
$layer_devices:expr,
$config:expr,
$loader:expr,
$use_flash_attn:expr,
Expand All @@ -449,15 +454,18 @@ macro_rules! vision_normal_model_loader {
} else {
None
};
let get_device_for_tensor = $loader.get_device_for_tensor(&$config, &*$mapper)?;

let vb = from_mmaped_safetensors(
$paths.get_weight_filenames().to_vec(),
Vec::new(),
$dtype,
$device,
$layer_devices,
$silent,
regexes,
|_| true,
|_| true, // Will be overwritten...
get_device_for_tensor,
)?;

$loader.load(
Expand All @@ -481,6 +489,7 @@ macro_rules! xlora_model_loader {
$paths:expr,
$dtype:expr,
$device:expr,
$layer_devices:expr,
$config:expr,
$loader:expr,
$use_flash_attn:expr,
Expand All @@ -491,6 +500,8 @@ macro_rules! xlora_model_loader {
) => {{
let mut safetensors_paths = $paths.get_weight_filenames().iter().collect::<Vec<_>>();
safetensors_paths.push($paths.get_classifier_path().as_ref().unwrap());
let get_device_for_tensor = $loader.get_device_for_tensor(&$config, &*$mapper)?;

let vb = from_mmaped_safetensors(
safetensors_paths
.iter()
Expand All @@ -505,9 +516,11 @@ macro_rules! xlora_model_loader {
.collect::<Vec<_>>(),
$dtype,
$device,
$layer_devices,
$silent,
None,
|_| true,
get_device_for_tensor,
)?;

$loader.load_xlora(
Expand All @@ -530,8 +543,22 @@ macro_rules! xlora_model_loader {
#[doc(hidden)]
#[macro_export]
macro_rules! lora_model_loader {
($paths:expr, $dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{
(
$paths:expr,
$dtype:expr,
$device:expr,
$layer_devices:expr,
$config:expr,
$loader:expr,
$use_flash_attn:expr,
$silent:expr,
$mapper:expr,
$loading_isq:expr,
$real_device:expr
) => {{
let safetensors_paths = $paths.get_weight_filenames().iter().collect::<Vec<_>>();
let get_device_for_tensor = $loader.get_device_for_tensor(&$config, &*$mapper)?;

let vb = from_mmaped_safetensors(
safetensors_paths
.iter()
Expand All @@ -546,9 +573,11 @@ macro_rules! lora_model_loader {
.collect::<Vec<_>>(),
Some($dtype),
$device,
$layer_devices,
$silent,
None,
|_| true,
get_device_for_tensor,
)?;

$loader.load_xlora(
Expand Down
Loading

0 comments on commit dba4c76

Please sign in to comment.