Skip to content

Commit

Permalink
Deploying to gh-pages from @ dba4c76 🚀
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jan 10, 2025
1 parent 5be6ffe commit db15d8a
Show file tree
Hide file tree
Showing 64 changed files with 720 additions and 342 deletions.
10 changes: 5 additions & 5 deletions mistralrs/enum.NormalLoaderType.html

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions mistralrs/enum.VisionLoaderType.html

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions mistralrs/struct.DiffusionLoader.html

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions mistralrs/struct.DiffusionLoaderBuilder.html

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions mistralrs/struct.DiffusionSpecificConfig.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs/struct.GemmaLoader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs/struct.Idefics2Loader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs/struct.LLaVALoader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs/struct.LLaVANextLoader.html

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions mistralrs/struct.LayerDeviceMapper.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs/struct.LlamaLoader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs/struct.MistralLoader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs/struct.MixtralLoader.html

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions mistralrs/struct.NormalLoader.html

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions mistralrs/struct.NormalLoaderBuilder.html

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions mistralrs/struct.NormalSpecificConfig.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs/struct.Phi2Loader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs/struct.Phi3Loader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs/struct.Phi3VLoader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs/struct.Qwen2Loader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs/struct.Starcoder2Loader.html

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions mistralrs/struct.VisionLoader.html

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions mistralrs/struct.VisionLoaderBuilder.html

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions mistralrs/struct.VisionSpecificConfig.html

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions mistralrs/trait.Loader.html

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions mistralrs_core/enum.NormalLoaderType.html

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions mistralrs_core/enum.VisionLoaderType.html

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions mistralrs_core/struct.DiffusionLoader.html

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions mistralrs_core/struct.DiffusionLoaderBuilder.html

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions mistralrs_core/struct.DiffusionSpecificConfig.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs_core/struct.GemmaLoader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs_core/struct.Idefics2Loader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs_core/struct.LLaVALoader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs_core/struct.LLaVANextLoader.html

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions mistralrs_core/struct.LayerDeviceMapper.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs_core/struct.LlamaLoader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs_core/struct.MistralLoader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs_core/struct.MixtralLoader.html

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions mistralrs_core/struct.NormalLoader.html

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions mistralrs_core/struct.NormalLoaderBuilder.html

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions mistralrs_core/struct.NormalSpecificConfig.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs_core/struct.Phi2Loader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs_core/struct.Phi3Loader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs_core/struct.Phi3VLoader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs_core/struct.Qwen2Loader.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mistralrs_core/struct.Starcoder2Loader.html

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions mistralrs_core/struct.VisionLoader.html

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions mistralrs_core/struct.VisionLoaderBuilder.html

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions mistralrs_core/struct.VisionSpecificConfig.html

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions mistralrs_core/trait.Loader.html

Large diffs are not rendered by default.

352 changes: 176 additions & 176 deletions pyo3/mistralrs.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyo3/search.js

Large diffs are not rendered by default.

24 changes: 17 additions & 7 deletions src/mistralrs_core/device_map.rs.html
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,12 @@
<a href="#370" id="370">370</a>
<a href="#371" id="371">371</a>
<a href="#372" id="372">372</a>
<a href="#373" id="373">373</a></pre></div><pre class="rust"><code><span class="kw">use </span>std::fmt::Debug;
<a href="#373" id="373">373</a>
<a href="#374" id="374">374</a>
<a href="#375" id="375">375</a>
<a href="#376" id="376">376</a>
<a href="#377" id="377">377</a>
<a href="#378" id="378">378</a></pre></div><pre class="rust"><code><span class="kw">use </span>std::fmt::Debug;

<span class="kw">use crate</span>::{
utils::{debug::DeviceRepr, log::once_log_info},
Expand Down Expand Up @@ -576,6 +581,7 @@
<span class="doccomment">/// Set non mapped layer device. This is for ISQ + device mapping support
/// If ISQ layer, then do not change the device. *They will do it later in NormalModel::quantize*
</span><span class="kw">fn </span>set_nm_device&lt;<span class="lifetime">'a</span>&gt;(<span class="kw-2">&amp;</span><span class="self">self</span>, varbuilder: VarBuilder&lt;<span class="lifetime">'a</span>&gt;, loading_isq: bool) -&gt; VarBuilder&lt;<span class="lifetime">'a</span>&gt;;
<span class="kw">fn </span>num_device_mapping_layers(<span class="kw-2">&amp;</span><span class="self">self</span>) -&gt; usize;

<span class="comment">// === IMMEDIATELY AFTER INIT ===
</span><span class="kw">fn </span>get_min_dtype(<span class="kw-2">&amp;</span><span class="self">self</span>, dtype: <span class="kw-2">&amp;</span><span class="kw">dyn </span>TryIntoDType) -&gt; <span class="prelude-ty">Result</span>&lt;DType&gt;;
Expand Down Expand Up @@ -636,6 +642,9 @@
.try_into_dtype(<span class="kw-2">&amp;</span><span class="self">self</span>.mappings.iter().collect::&lt;Vec&lt;<span class="kw">_</span>&gt;&gt;())
.map_err(candle_core::Error::msg)
}
<span class="kw">fn </span>num_device_mapping_layers(<span class="kw-2">&amp;</span><span class="self">self</span>) -&gt; usize {
<span class="self">self</span>.mappings.len()
}
}

<span class="attr">#[derive(Debug)]
Expand All @@ -659,12 +668,9 @@
varbuilder.set_device(<span class="self">self</span>.nm_device.clone())
}
}
<span class="kw">fn </span>device_for(<span class="kw-2">&amp;</span><span class="self">self</span>, <span class="kw">_</span>: usize, loading_isq: bool) -&gt; <span class="prelude-ty">Option</span>&lt;<span class="kw-2">&amp;</span>Device&gt; {
<span class="kw">if </span>loading_isq {
<span class="kw">return </span><span class="prelude-val">Some</span>(<span class="kw-2">&amp;</span><span class="self">self</span>.nm_device);
}
<span class="prelude-val">None
</span>}
<span class="kw">fn </span>device_for(<span class="kw-2">&amp;</span><span class="self">self</span>, <span class="kw">_</span>: usize, _loading_isq: bool) -&gt; <span class="prelude-ty">Option</span>&lt;<span class="kw-2">&amp;</span>Device&gt; {
<span class="prelude-val">Some</span>(<span class="kw-2">&amp;</span><span class="self">self</span>.nm_device)
}
<span class="kw">fn </span>get_unique_devices(<span class="kw-2">&amp;</span><span class="self">self</span>) -&gt; Vec&lt;Device&gt; {
<span class="macro">vec!</span>[<span class="self">self</span>.nm_device.clone()]
}
Expand All @@ -687,6 +693,10 @@
.try_into_dtype(<span class="kw-2">&amp;</span>[<span class="kw-2">&amp;</span><span class="self">self</span>.nm_device])
.map_err(candle_core::Error::msg)
}
<span class="kw">fn </span>num_device_mapping_layers(<span class="kw-2">&amp;</span><span class="self">self</span>) -&gt; usize {
<span class="comment">// Effectively one layer
</span><span class="number">1
</span>}
}

<span class="doccomment">/// Get all devices on the same device type but different ordinals
Expand Down
30 changes: 25 additions & 5 deletions src/mistralrs_core/diffusion_models/flux/stepper.rs.html
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,17 @@
<a href="#280" id="280">280</a>
<a href="#281" id="281">281</a>
<a href="#282" id="282">282</a>
<a href="#283" id="283">283</a></pre></div><pre class="rust"><code><span class="kw">use </span>std::{cmp::Ordering, fs::File};
<a href="#283" id="283">283</a>
<a href="#284" id="284">284</a>
<a href="#285" id="285">285</a>
<a href="#286" id="286">286</a>
<a href="#287" id="287">287</a>
<a href="#288" id="288">288</a>
<a href="#289" id="289">289</a>
<a href="#290" id="290">290</a>
<a href="#291" id="291">291</a>
<a href="#292" id="292">292</a>
<a href="#293" id="293">293</a></pre></div><pre class="rust"><code><span class="kw">use </span>std::{cmp::Ordering, fs::File, sync::Arc};

<span class="kw">use </span>candle_core::{DType, Device, <span class="prelude-ty">Result</span>, Tensor, D};
<span class="kw">use </span>candle_nn::{Module, VarBuilder};
Expand All @@ -297,7 +307,7 @@
DiffusionGenerationParams,
},
pipeline::DiffusionModel,
utils::varbuilder_utils::from_mmaped_safetensors,
utils::varbuilder_utils::{from_mmaped_safetensors, DeviceForLoadTensor},
};

<span class="kw">use super</span>::{autoencoder::AutoEncoder, model::Flux};
Expand Down Expand Up @@ -387,9 +397,11 @@
<span class="macro">vec!</span>[],
<span class="prelude-val">Some</span>(dtype),
device,
<span class="macro">vec!</span>[<span class="prelude-val">None</span>],
silent,
<span class="prelude-val">None</span>,
|<span class="kw">_</span>| <span class="bool-val">true</span>,
Arc::new(|<span class="kw">_</span>| DeviceForLoadTensor::Base),
)<span class="question-mark">?</span>;
<span class="kw">let </span>config_filename = repo.get(<span class="string">"config.json"</span>).map_err(candle_core::Error::msg)<span class="question-mark">?</span>;
<span class="kw">let </span>config = std::fs::read_to_string(config_filename)<span class="question-mark">?</span>;
Expand All @@ -408,9 +420,17 @@
));

<span class="kw">let </span>model_file = repo.get(<span class="string">"model.safetensors"</span>)<span class="question-mark">?</span>;
<span class="kw">let </span>vb = from_mmaped_safetensors(<span class="macro">vec!</span>[model_file], <span class="macro">vec!</span>[], <span class="prelude-val">None</span>, device, silent, <span class="prelude-val">None</span>, |<span class="kw">_</span>| {
<span class="bool-val">true
</span>})<span class="question-mark">?</span>;
<span class="kw">let </span>vb = from_mmaped_safetensors(
<span class="macro">vec!</span>[model_file],
<span class="macro">vec!</span>[],
<span class="prelude-val">None</span>,
device,
<span class="macro">vec!</span>[<span class="prelude-val">None</span>],
silent,
<span class="prelude-val">None</span>,
|<span class="kw">_</span>| <span class="bool-val">true</span>,
Arc::new(|<span class="kw">_</span>| DeviceForLoadTensor::Base),
)<span class="question-mark">?</span>;
<span class="kw">let </span>config_file = repo.get(<span class="string">"config.json"</span>)<span class="question-mark">?</span>;
<span class="kw">let </span>config: ClipConfig = serde_json::from_reader(File::open(config_file)<span class="question-mark">?</span>)<span class="question-mark">?</span>;
<span class="kw">let </span>config = config.text_config;
Expand Down
46 changes: 43 additions & 3 deletions src/mistralrs_core/pipeline/cache_manager.rs.html
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,27 @@
<a href="#750" id="750">750</a>
<a href="#751" id="751">751</a>
<a href="#752" id="752">752</a>
<a href="#753" id="753">753</a></pre></div><pre class="rust"><code><span class="kw">use </span>std::sync::{Arc, Mutex, MutexGuard};
<a href="#753" id="753">753</a>
<a href="#754" id="754">754</a>
<a href="#755" id="755">755</a>
<a href="#756" id="756">756</a>
<a href="#757" id="757">757</a>
<a href="#758" id="758">758</a>
<a href="#759" id="759">759</a>
<a href="#760" id="760">760</a>
<a href="#761" id="761">761</a>
<a href="#762" id="762">762</a>
<a href="#763" id="763">763</a>
<a href="#764" id="764">764</a>
<a href="#765" id="765">765</a>
<a href="#766" id="766">766</a>
<a href="#767" id="767">767</a>
<a href="#768" id="768">768</a>
<a href="#769" id="769">769</a>
<a href="#770" id="770">770</a>
<a href="#771" id="771">771</a>
<a href="#772" id="772">772</a>
<a href="#773" id="773">773</a></pre></div><pre class="rust"><code><span class="kw">use </span>std::sync::{Arc, Mutex, MutexGuard};

<span class="kw">use </span>candle_core::{<span class="prelude-ty">Result</span>, Tensor, D};

Expand Down Expand Up @@ -1142,7 +1162,18 @@
</span><span class="kw">let </span>template_cache_dim = pipeline.cache().normal().<span class="number">0</span>[<span class="number">0</span>].k.dim;
<span class="kw">let </span>template_cache_msl = pipeline.cache().normal().<span class="number">0</span>[<span class="number">0</span>].k.max_seq_len;

<span class="kw">for </span>layer <span class="kw">in </span>pipeline.cache().normal().<span class="number">0</span>.iter_mut() {
<span class="kw">let </span>layer_devices = <span class="kw">if let </span><span class="prelude-val">Some</span>(device_mapper) = pipeline.device_mapper() {
<span class="kw">let </span><span class="kw-2">mut </span>layer_devices = Vec::new();
<span class="kw">for </span>layer <span class="kw">in </span><span class="number">0</span>..device_mapper.num_device_mapping_layers() {
<span class="kw">let </span>device = device_mapper.device_for(layer, <span class="bool-val">false</span>).cloned();
layer_devices.push(device.expect(<span class="string">"Internal bug, layer out of range!"</span>));
}
<span class="prelude-val">Some</span>(layer_devices)
} <span class="kw">else </span>{
<span class="prelude-val">None
</span>};

<span class="kw">for </span>(layer_idx, layer) <span class="kw">in </span>pipeline.cache().normal().<span class="number">0</span>.iter_mut().enumerate() {
<span class="kw">if </span>!load_preallocated_cache {
layer.reset();
<span class="kw">continue</span>;
Expand All @@ -1151,8 +1182,17 @@
<span class="kw">let </span><span class="kw-2">mut </span>k_caches = Vec::new();
<span class="kw">let </span><span class="kw-2">mut </span>v_caches = Vec::new();
<span class="kw">for </span>seq <span class="kw">in </span>seqs.iter_mut() {
<span class="kw">let </span>(k_preallocated_cache, v_preallocated_cache) =
<span class="kw">let </span>(<span class="kw-2">mut </span>k_preallocated_cache, <span class="kw-2">mut </span>v_preallocated_cache) =
(<span class="kw-2">*</span>seq.preallocated_cache().as_ref().unwrap()).clone();
<span class="kw">if let </span><span class="prelude-val">Some</span>(layer_devices) = <span class="kw-2">&amp;</span>layer_devices {
<span class="kw">let </span>layer_dev = <span class="kw-2">&amp;</span>layer_devices[layer_idx];
k_preallocated_cache = k_preallocated_cache
.to_device(layer_dev)
.expect(<span class="string">"Could not prepare cache"</span>);
v_preallocated_cache = v_preallocated_cache
.to_device(layer_dev)
.expect(<span class="string">"Could not prepare cache"</span>);
}
k_caches.push(k_preallocated_cache);
v_caches.push(v_preallocated_cache);
}
Expand Down
12 changes: 10 additions & 2 deletions src/mistralrs_core/pipeline/diffusion.rs.html
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,11 @@
<a href="#337" id="337">337</a>
<a href="#338" id="338">338</a>
<a href="#339" id="339">339</a>
<a href="#340" id="340">340</a></pre></div><pre class="rust"><code><span class="kw">use </span><span class="kw">super</span>::loaders::{DiffusionModelPaths, DiffusionModelPathsInner};
<a href="#340" id="340">340</a>
<a href="#341" id="341">341</a>
<a href="#342" id="342">342</a>
<a href="#343" id="343">343</a>
<a href="#344" id="344">344</a></pre></div><pre class="rust"><code><span class="kw">use </span><span class="kw">super</span>::loaders::{DiffusionModelPaths, DiffusionModelPathsInner};
<span class="kw">use super</span>::{
AdapterActivationMixin, AnyMoePipelineMixin, Cache, CacheManagerMixin, DiffusionLoaderType,
DiffusionModel, DiffusionModelLoader, EitherCache, FluxLoader, ForwardInputsResult,
Expand All @@ -351,6 +355,7 @@
<span class="kw">use </span><span class="kw">crate</span>::pipeline::ChatTemplate;
<span class="kw">use </span><span class="kw">crate</span>::prefix_cacher_v2::PrefixCacheManagerV2;
<span class="kw">use </span><span class="kw">crate</span>::sequence::Sequence;
<span class="kw">use </span><span class="kw">crate</span>::utils::varbuilder_utils::DeviceForLoadTensor;
<span class="kw">use </span><span class="kw">crate</span>::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
<span class="kw">use crate</span>::{DeviceMapSetting, PagedAttentionConfig, Pipeline, TryIntoDType};
<span class="kw">use </span>anyhow::Result;
Expand Down Expand Up @@ -515,14 +520,17 @@
.iter()
.zip(<span class="self">self</span>.inner.force_cpu_vb())
.map(|(path, force_cpu)| {
<span class="kw">let </span>dev = <span class="kw">if </span>force_cpu { <span class="kw-2">&amp;</span>Device::Cpu } <span class="kw">else </span>{ device };
from_mmaped_safetensors(
<span class="macro">vec!</span>[path.clone()],
Vec::new(),
<span class="prelude-val">Some</span>(dtype),
<span class="kw">if </span>force_cpu { <span class="kw-2">&amp;</span>Device::Cpu } <span class="kw">else </span>{ device },
dev,
<span class="macro">vec!</span>[<span class="prelude-val">None</span>],
silent,
<span class="prelude-val">None</span>,
|<span class="kw">_</span>| <span class="bool-val">true</span>,
Arc::new(|<span class="kw">_</span>| DeviceForLoadTensor::Base),
)
})
.collect::&lt;candle_core::Result&lt;Vec&lt;<span class="kw">_</span>&gt;&gt;&gt;()<span class="question-mark">?</span>;
Expand Down
16 changes: 13 additions & 3 deletions src/mistralrs_core/pipeline/loaders/mod.rs.html
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,12 @@
<a href="#497" id="497">497</a>
<a href="#498" id="498">498</a>
<a href="#499" id="499">499</a>
<a href="#500" id="500">500</a></pre></div><pre class="rust"><code><span class="kw">mod </span>diffusion_loaders;
<a href="#500" id="500">500</a>
<a href="#501" id="501">501</a>
<a href="#502" id="502">502</a>
<a href="#503" id="503">503</a>
<a href="#504" id="504">504</a>
<a href="#505" id="505">505</a></pre></div><pre class="rust"><code><span class="kw">mod </span>diffusion_loaders;
<span class="kw">mod </span>normal_loaders;
<span class="kw">mod </span>vision_loaders;

Expand Down Expand Up @@ -889,8 +894,8 @@

<span class="kw">let </span><span class="kw-2">mut </span>per_layer_avail = Vec::new();
<span class="kw">for </span>dev <span class="kw">in </span>devices.clone() {
<span class="kw">let </span>usage = MemoryUsage.get_memory_available(<span class="kw-2">&amp;</span>dev)<span class="question-mark">?</span>;
per_layer_avail.push((usage, dev));
<span class="kw">let </span>avail = MemoryUsage.get_memory_available(<span class="kw-2">&amp;</span>dev)<span class="question-mark">?</span>;
per_layer_avail.push((avail, dev));
}
<span class="comment">// Reverse so we don't use the cpu first!
</span>per_layer_avail.reverse();
Expand All @@ -903,8 +908,13 @@
<span class="kw">let </span>(device_capacity, device) = per_layer_avail
.pop()
.context(<span class="string">"No more devices to map to. The model does not fit on this system."</span>)<span class="question-mark">?</span>;
<span class="comment">// All usage of 90% of the memory as a maximum.
</span><span class="attr">#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
</span><span class="kw">let </span>device_capacity = (device_capacity <span class="kw">as </span>f64 * <span class="number">0.90</span>) <span class="kw">as </span>usize;
<span class="kw">let </span>layers_on_device = <span class="kw">if </span>device_capacity &gt;= remaining_to_map {
num_layers - current_layer
} <span class="kw">else if </span>current_ordinal == <span class="number">0 </span>{
(device_capacity - non_mapped_size_in_bytes) / per_layer_size_in_bytes
} <span class="kw">else </span>{
device_capacity / per_layer_size_in_bytes
};
Expand Down
46 changes: 44 additions & 2 deletions src/mistralrs_core/pipeline/loaders/normal_loaders.rs.html
Original file line number Diff line number Diff line change
Expand Up @@ -2507,10 +2507,32 @@
<a href="#2506" id="2506">2506</a>
<a href="#2507" id="2507">2507</a>
<a href="#2508" id="2508">2508</a>
<a href="#2509" id="2509">2509</a></pre></div><pre class="rust"><code><span class="kw">use </span>std::{
<a href="#2509" id="2509">2509</a>
<a href="#2510" id="2510">2510</a>
<a href="#2511" id="2511">2511</a>
<a href="#2512" id="2512">2512</a>
<a href="#2513" id="2513">2513</a>
<a href="#2514" id="2514">2514</a>
<a href="#2515" id="2515">2515</a>
<a href="#2516" id="2516">2516</a>
<a href="#2517" id="2517">2517</a>
<a href="#2518" id="2518">2518</a>
<a href="#2519" id="2519">2519</a>
<a href="#2520" id="2520">2520</a>
<a href="#2521" id="2521">2521</a>
<a href="#2522" id="2522">2522</a>
<a href="#2523" id="2523">2523</a>
<a href="#2524" id="2524">2524</a>
<a href="#2525" id="2525">2525</a>
<a href="#2526" id="2526">2526</a>
<a href="#2527" id="2527">2527</a>
<a href="#2528" id="2528">2528</a>
<a href="#2529" id="2529">2529</a>
<a href="#2530" id="2530">2530</a></pre></div><pre class="rust"><code><span class="kw">use </span>std::{
collections::HashMap,
fmt::{Debug, Display},
str::FromStr,
sync::Arc,
};

<span class="kw">use crate</span>::{
Expand All @@ -2525,7 +2547,7 @@
EitherCache, IsqModel,
},
serde_default_fn,
utils::log::once_log_info,
utils::{log::once_log_info, varbuilder_utils::DeviceForLoadTensor},
xlora_models::NonGranularState,
DeviceMapMetadata,
};
Expand Down Expand Up @@ -2624,6 +2646,26 @@
<span class="kw">fn </span>get_config_repr(<span class="kw-2">&amp;</span><span class="self">self</span>, config: <span class="kw-2">&amp;</span>str, use_flash_attn: bool) -&gt; <span class="prelude-ty">Result</span>&lt;Box&lt;<span class="kw">dyn </span>Debug&gt;&gt;;
<span class="doccomment">/// Get total num_hidden_layers for the layers which will be device mapped.
</span><span class="kw">fn </span>get_total_device_mapping_num_layers(<span class="kw-2">&amp;</span><span class="self">self</span>, config: <span class="kw-2">&amp;</span>str) -&gt; <span class="prelude-ty">Result</span>&lt;usize&gt;;
<span class="kw">fn </span>get_device_for_tensor(
<span class="kw-2">&amp;</span><span class="self">self</span>,
_config: <span class="kw-2">&amp;</span>str,
_mapper: <span class="kw-2">&amp;</span><span class="kw">dyn </span>DeviceMapper,
) -&gt; <span class="prelude-ty">Result</span>&lt;Arc&lt;<span class="kw">dyn </span>Fn(String) -&gt; DeviceForLoadTensor + Send + Sync + <span class="lifetime">'static</span>&gt;&gt; {
<span class="kw">let </span>re = Regex::new(<span class="string">r"\.layers\.(\d+)\."</span>).unwrap();
<span class="kw">let </span>closure = <span class="kw">move </span>|name: String| {
<span class="kw">if let </span><span class="prelude-val">Some</span>(captures) = re.captures(<span class="kw-2">&amp;</span>name) {
captures
.get(<span class="number">1</span>)
.and_then(|m| m.as_str().parse::&lt;usize&gt;().ok())
.map(DeviceForLoadTensor::Idx)
.unwrap_or(DeviceForLoadTensor::Base)
} <span class="kw">else </span>{
DeviceForLoadTensor::Base
}
};

<span class="prelude-val">Ok</span>(Arc::new(closure))
}
}

<span class="attr">#[cfg_attr(feature = <span class="string">"pyo3_macros"</span>, pyclass(eq, eq_int))]
Expand Down
Loading

0 comments on commit db15d8a

Please sign in to comment.