Skip to content

Commit

Permalink
[generate] can instantiate `GenerationConfig(cache_implementation="st…
Browse files Browse the repository at this point in the history
…atic")` (huggingface#35679)

fix failing instantiation
  • Loading branch information
gante authored Jan 16, 2025
1 parent aaa969e commit 2818307
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
18 changes: 8 additions & 10 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

logger = logging.get_logger(__name__)
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
NEEDS_CACHE_CONFIG = {}
CACHE_CONFIG_MAPPING = {}
NEED_SETUP_CACHE_CLASSES_MAPPING = {}
QUANT_BACKEND_CLASSES_MAPPING = {}
ALL_CACHE_IMPLEMENTATIONS = []
Expand All @@ -62,8 +62,8 @@
)
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor

NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig
NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig
CACHE_CONFIG_MAPPING["quantized"] = QuantizedCacheConfig
CACHE_CONFIG_MAPPING["static"] = StaticCacheConfig
NEED_SETUP_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
"offloaded_static": OffloadedStaticCache,
Expand All @@ -73,7 +73,7 @@
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
ALL_CACHE_IMPLEMENTATIONS = (
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys()) + ["offloaded"]
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(CACHE_CONFIG_MAPPING.keys()) + ["offloaded"]
)


Expand Down Expand Up @@ -409,11 +409,9 @@ def __init__(self, **kwargs):
self.use_cache = kwargs.pop("use_cache", True)
self.cache_implementation = kwargs.pop("cache_implementation", None)
self.cache_config = kwargs.pop("cache_config", None)
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
if self.cache_config is None:
self.cache_config = cache_config_class()
elif isinstance(self.cache_config, dict):
if self.cache_implementation is not None and self.cache_implementation in CACHE_CONFIG_MAPPING:
cache_config_class = CACHE_CONFIG_MAPPING[self.cache_implementation]
if isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config)
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)

Expand Down Expand Up @@ -766,7 +764,7 @@ def validate(self, is_init=False):
f"{ALL_CACHE_IMPLEMENTATIONS}"
)
if self.cache_config is not None:
cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation)
cache_class = CACHE_CONFIG_MAPPING.get(self.cache_implementation)
if cache_class is None:
raise ValueError(
"You provided a `cache_config` but the cache implementation you are using "
Expand Down
6 changes: 6 additions & 0 deletions tests/generation/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,12 @@ def test_generation_mode(self):
config = GenerationConfig()
self.assertEqual(config.get_generation_mode(assistant_model="foo"), GenerationMode.ASSISTED_GENERATION)

def test_static_cache_without_cache_config(self):
"""Regression test for #35026 -- static cache should work without a cache config."""
config = GenerationConfig(cache_implementation="static")
self.assertEqual(config.cache_implementation, "static")
self.assertEqual(config.cache_config, None)


class GenerationConfigSerializationTest(unittest.TestCase):
def test_serialize_generation_sequence_bias(self):
Expand Down

0 comments on commit 2818307

Please sign in to comment.