Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][LoRA]LoRA support added for MolmoForCausalLM #11439

Merged
merged 12 commits into from
Dec 31, 2024
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ See [this page](#generative-models) for more information on how to use generativ
- Molmo
- T + I
- :code:`allenai/Molmo-7B-D-0924`, :code:`allenai/Molmo-72B-0924`, etc.
-
- ✅︎
- ✅︎
- ✅︎
* - :code:`NVLM_D_Model`
Expand Down
29 changes: 24 additions & 5 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.layer import MultiHeadAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
Expand Down Expand Up @@ -43,7 +43,7 @@
SequenceData)
from vllm.transformers_utils.processor import get_processor

from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
Expand Down Expand Up @@ -1121,9 +1121,26 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo)
@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
packed_modules_mapping = {
"att_proj": ["att_proj"],
"attn_out": ["attn_out"],
"ff_proj": ["ff_proj"],
"ff_out": ["ff_out"],
}
supported_lora_modules = [
"att_proj",
"ff_proj",
]
embedding_modules = {}
embedding_padding_modules = {}

def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
lora_config: Optional[LoRAConfig] = None):
ayylemao marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
Expand Down Expand Up @@ -1153,6 +1170,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

self.lora_config = lora_config

def _parse_and_validate_image_input(
self,
**kwargs: object,
Expand Down
Loading