diff --git a/alibi_detect/cd/tensorflow/preprocess.py b/alibi_detect/cd/tensorflow/preprocess.py index dc2d71315..632fedd17 100644 --- a/alibi_detect/cd/tensorflow/preprocess.py +++ b/alibi_detect/cd/tensorflow/preprocess.py @@ -4,7 +4,7 @@ import tensorflow as tf from alibi_detect.utils.tensorflow.prediction import ( - predict_batch, predict_batch_transformer, get_named_arg + predict_batch, predict_batch_transformer, get_call_arg_mapping ) from tensorflow.keras.layers import Dense, Flatten, Input, Lambda from tensorflow.keras.models import Model @@ -37,7 +37,7 @@ def __init__( def call(self, x: Union[np.ndarray, tf.Tensor, Dict[str, tf.Tensor]]) -> tf.Tensor: if not isinstance(x, (np.ndarray, tf.Tensor)): - x = get_named_arg(self.input_layer, x) + x = get_call_arg_mapping(self.input_layer, x) x = self.input_layer(**x) else: x = self.input_layer(x) @@ -68,7 +68,7 @@ def __init__( def call(self, x: Union[np.ndarray, tf.Tensor, Dict[str, tf.Tensor]]) -> tf.Tensor: if not isinstance(x, (np.ndarray, tf.Tensor)): - x = get_named_arg(self.encoder, x) + x = get_call_arg_mapping(self.encoder, x) return self.encoder(**x) else: return self.encoder(x) diff --git a/alibi_detect/saving/_tensorflow/tests/test_saving_tf.py b/alibi_detect/saving/_tensorflow/tests/test_saving_tf.py index 42b690e76..21b9d2934 100644 --- a/alibi_detect/saving/_tensorflow/tests/test_saving_tf.py +++ b/alibi_detect/saving/_tensorflow/tests/test_saving_tf.py @@ -15,9 +15,6 @@ backend = param_fixture("backend", ['tensorflow']) -# Note: The full save/load functionality of optimizers (inc. validation) is tested in test_save_classifierdrift. -@pytest.mark.skipif(version.parse(tf.__version__) < version.parse('2.16.0'), - reason="Skipping since tensorflow < 2.16.0") def test_load_optimizer_object_tf2pt11(backend): """ Test the _load_optimizer_config with a tensorflow optimizer config. Only run if tensorflow>=2.16. diff --git a/alibi_detect/utils/tensorflow/prediction.py b/alibi_detect/utils/tensorflow/prediction.py index 539c2fba3..50c2b0267 100644 --- a/alibi_detect/utils/tensorflow/prediction.py +++ b/alibi_detect/utils/tensorflow/prediction.py @@ -7,10 +7,14 @@ from alibi_detect.utils.prediction import tokenize_transformer -def get_named_arg(model: tf.keras.Model, x: Any) -> Dict[str, Any]: - """ Extract argument names from the model call function - because keras3 does not accept other types of input - as a positional argument. +def get_call_arg_mapping(model: tf.keras.Model, x: Any) -> Dict[str, Any]: + """ Generates a dictionary mapping the first argument name of the + `call` method of a Keras model to the provided input value. + + This function is particularly useful when working with Keras 3, + which enforces stricter input handling and requires named arguments + for certain operations. It extracts the argument names from the + `call` method of the provided model and maps the first argument to `x`. Parameters ---------- @@ -67,7 +71,7 @@ def predict_batch( x_batch = preprocess_fn(x_batch) if not isinstance(x_batch, (np.ndarray, tf.Tensor)): - x_batch = get_named_arg(model, x_batch) + x_batch = get_call_arg_mapping(model, x_batch) preds_tmp = model(**x_batch) else: preds_tmp = model(x_batch) diff --git a/alibi_detect/utils/tests/test_saving_legacy.py b/alibi_detect/utils/tests/test_saving_legacy.py index db73e4bf2..6dddd75e5 100644 --- a/alibi_detect/utils/tests/test_saving_legacy.py +++ b/alibi_detect/utils/tests/test_saving_legacy.py @@ -262,8 +262,6 @@ def test_save_load(select_detector): # Note: The full save/load functionality of optimizers (inc. validation) is tested in test_save_classifierdrift. -@pytest.mark.skipif(version.parse(tf.__version__) < version.parse('2.16.0'), - reason="Skipping since tensorflow < 2.16.0") @parametrize('legacy', [True, False]) def test_load_optimizer_object_tf2pt11(legacy, backend): """