From 2818307e93c6e73837ca08c4f37864adb56651df Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 16 Jan 2025 17:04:54 +0000 Subject: [PATCH] [generate] can instantiate `GenerationConfig(cache_implementation="static")` (#35679) fix failing instantiation --- .../generation/configuration_utils.py | 18 ++++++++---------- tests/generation/test_configuration_utils.py | 6 ++++++ 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 18cf26b8b73415..3f142ce772984b 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -43,7 +43,7 @@ logger = logging.get_logger(__name__) METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version") -NEEDS_CACHE_CONFIG = {} +CACHE_CONFIG_MAPPING = {} NEED_SETUP_CACHE_CLASSES_MAPPING = {} QUANT_BACKEND_CLASSES_MAPPING = {} ALL_CACHE_IMPLEMENTATIONS = [] @@ -62,8 +62,8 @@ ) from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor - NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig - NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig + CACHE_CONFIG_MAPPING["quantized"] = QuantizedCacheConfig + CACHE_CONFIG_MAPPING["static"] = StaticCacheConfig NEED_SETUP_CACHE_CLASSES_MAPPING = { "static": StaticCache, "offloaded_static": OffloadedStaticCache, @@ -73,7 +73,7 @@ } QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache} ALL_CACHE_IMPLEMENTATIONS = ( - list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys()) + ["offloaded"] + list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(CACHE_CONFIG_MAPPING.keys()) + ["offloaded"] ) @@ -409,11 +409,9 @@ def __init__(self, **kwargs): self.use_cache = kwargs.pop("use_cache", True) self.cache_implementation = kwargs.pop("cache_implementation", None) self.cache_config = kwargs.pop("cache_config", None) - if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG: - cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation] - if self.cache_config is None: - self.cache_config = cache_config_class() - elif isinstance(self.cache_config, dict): + if self.cache_implementation is not None and self.cache_implementation in CACHE_CONFIG_MAPPING: + cache_config_class = CACHE_CONFIG_MAPPING[self.cache_implementation] + if isinstance(self.cache_config, dict): self.cache_config = cache_config_class.from_dict(self.cache_config) self.return_legacy_cache = kwargs.pop("return_legacy_cache", None) @@ -766,7 +764,7 @@ def validate(self, is_init=False): f"{ALL_CACHE_IMPLEMENTATIONS}" ) if self.cache_config is not None: - cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation) + cache_class = CACHE_CONFIG_MAPPING.get(self.cache_implementation) if cache_class is None: raise ValueError( "You provided a `cache_config` but the cache implementation you are using " diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index 24fea85a900d67..ef30599581ffae 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -259,6 +259,12 @@ def test_generation_mode(self): config = GenerationConfig() self.assertEqual(config.get_generation_mode(assistant_model="foo"), GenerationMode.ASSISTED_GENERATION) + def test_static_cache_without_cache_config(self): + """Regression test for #35026 -- static cache should work without a cache config.""" + config = GenerationConfig(cache_implementation="static") + self.assertEqual(config.cache_implementation, "static") + self.assertEqual(config.cache_config, None) + class GenerationConfigSerializationTest(unittest.TestCase): def test_serialize_generation_sequence_bias(self):