Skip to content

Commit

Permalink
Add losses
Browse files Browse the repository at this point in the history
  • Loading branch information
lvapeab committed Feb 27, 2019
1 parent a271eac commit be5dfae
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
39 changes: 39 additions & 0 deletions keras/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,44 @@ def log_diff(args):
cost_difference = K.switch(cost_difference > -1., cost_difference, 0.)
return cost_difference


def minmax_categorical_crossentropy(args):
"""Minmax_categorical_crossentropy loss.
# Arguments
args: y_true, y_pred, h_true, h_pred, mask_y, mask_h, weight_y, weight_h
# Returns
weight_y * xent(mask_y * y_true, y_pred) +
weight_h * xent(mask_h * h_true, h_pred)
"""
y_true, y_pred, h_true, h_pred, mask_y, mask_h, weight_y, weight_h = args
ce_y = categorical_crossentropy(mask_y[:, :, None] * y_true, y_pred)
ce_h = categorical_crossentropy(mask_h[:, :, None] * h_true, h_pred)
ce_y = weight_y[:, :, None] * ce_y
ce_h = weight_h[:, :, None] * ce_h
return ce_y + ce_h


def weighted_log_diff(args):
"""Cross-entropy difference between a GT and a hypothesis.
# Arguments
args: y_pred, y_true, h_pred, h_true.
# Returns
cost_difference(categorical_crossentropy(y_true, y_pred) -
weight * categorical_crossentropy(h_true, h_pred)).
"""
y_true, y_pred, h_true, h_pred, mask_y, mask_h, weight = args
ce_y = categorical_crossentropy(mask_y[:, :, None] * y_true,
mask_y[:, :, None] * y_pred)
ce_h = categorical_crossentropy(mask_h[:, :, None] * h_true,
mask_h[:, :, None] * h_pred)

return ce_y - weight[:, :, None] * ce_h

# Aliases.

mse = MSE = mean_squared_error
Expand All @@ -126,6 +164,7 @@ def log_diff(args):
msle = MSLE = mean_squared_logarithmic_error
kld = KLD = kullback_leibler_divergence
cosine = cosine_proximity
pas_weighted_log_diff = weighted_log_diff


def serialize(loss):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import keras.backend.tensorflow_backend
import keras.backend.theano_backend
import keras.backend.cntk_backend
# import keras.backend.cntk_backend
import keras.backend.numpy_backend
import keras.utils.test_utils

Expand Down

0 comments on commit be5dfae

Please sign in to comment.