From be5dfae75c737e54af31f07b117bdba4cf5da7e9 Mon Sep 17 00:00:00 2001 From: lvapeab Date: Wed, 27 Feb 2019 16:38:09 +0100 Subject: [PATCH] Add losses --- keras/losses.py | 39 +++++++++++++++++++++++++++++++++++++++ tests/test_api.py | 2 +- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/keras/losses.py b/keras/losses.py index 4a295bd314a..b08d6854c42 100644 --- a/keras/losses.py +++ b/keras/losses.py @@ -118,6 +118,44 @@ def log_diff(args): cost_difference = K.switch(cost_difference > -1., cost_difference, 0.) return cost_difference + +def minmax_categorical_crossentropy(args): + """Minmax_categorical_crossentropy loss. + + # Arguments + args: y_true, y_pred, h_true, h_pred, mask_y, mask_h, weight_y, weight_h + + # Returns + weight_y * xent(mask_y * y_true, y_pred) + + weight_h * xent(mask_h * h_true, h_pred) + + """ + y_true, y_pred, h_true, h_pred, mask_y, mask_h, weight_y, weight_h = args + ce_y = categorical_crossentropy(mask_y[:, :, None] * y_true, y_pred) + ce_h = categorical_crossentropy(mask_h[:, :, None] * h_true, h_pred) + ce_y = weight_y[:, :, None] * ce_y + ce_h = weight_h[:, :, None] * ce_h + return ce_y + ce_h + + +def weighted_log_diff(args): + """Cross-entropy difference between a GT and a hypothesis. + + # Arguments + args: y_pred, y_true, h_pred, h_true. + + # Returns + cost_difference(categorical_crossentropy(y_true, y_pred) - + weight * categorical_crossentropy(h_true, h_pred)). + """ + y_true, y_pred, h_true, h_pred, mask_y, mask_h, weight = args + ce_y = categorical_crossentropy(mask_y[:, :, None] * y_true, + mask_y[:, :, None] * y_pred) + ce_h = categorical_crossentropy(mask_h[:, :, None] * h_true, + mask_h[:, :, None] * h_pred) + + return ce_y - weight[:, :, None] * ce_h + # Aliases. mse = MSE = mean_squared_error @@ -126,6 +164,7 @@ def log_diff(args): msle = MSLE = mean_squared_logarithmic_error kld = KLD = kullback_leibler_divergence cosine = cosine_proximity +pas_weighted_log_diff = weighted_log_diff def serialize(loss): diff --git a/tests/test_api.py b/tests/test_api.py index 67ddbc29252..da1b0af2ee2 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -6,7 +6,7 @@ import keras.backend.tensorflow_backend import keras.backend.theano_backend -import keras.backend.cntk_backend +# import keras.backend.cntk_backend import keras.backend.numpy_backend import keras.utils.test_utils