diff --git a/examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb b/examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb index 266f141c17..16806792f7 100644 --- a/examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb +++ b/examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb @@ -10,7 +10,7 @@ "\n", "**Author:** [fchollet](https://twitter.com/fchollet)
\n", "**Date created:** 2021/05/26
\n", - "**Last modified:** 2023/02/25
\n", + "**Last modified:** 2024/11/18
\n", "**Description:** Implementing a sequence-to-sequence Transformer and training it on a machine translation task." ] }, @@ -84,7 +84,7 @@ "import keras\n", "from keras import layers\n", "from keras import ops\n", - "from keras.layers import TextVectorization\n" + "from keras.layers import TextVectorization" ] }, { @@ -213,7 +213,7 @@ "The English layer will use the default string standardization (strip punctuation characters)\n", "and splitting scheme (split on whitespace), while\n", "the Spanish layer will use a custom standardization, where we add the character\n", - "`\"\u00bf\"` to the set of punctuation characters to be stripped.\n", + "`\"¿\"` to the set of punctuation characters to be stripped.\n", "\n", "Note: in a production-grade machine translation model, I would not recommend\n", "stripping the punctuation characters in either language. Instead, I would recommend turning\n", @@ -229,7 +229,7 @@ }, "outputs": [], "source": [ - "strip_chars = string.punctuation + \"\u00bf\"\n", + "strip_chars = string.punctuation + \"¿\"\n", "strip_chars = strip_chars.replace(\"[\", \"\")\n", "strip_chars = strip_chars.replace(\"]\", \"\")\n", "\n", @@ -441,10 +441,7 @@ " return embedded_tokens + embedded_positions\n", "\n", " def compute_mask(self, inputs, mask=None):\n", - " if mask is None:\n", - " return None\n", - " else:\n", - " return ops.not_equal(inputs, 0)\n", + " return ops.not_equal(inputs, 0)\n", "\n", " def get_config(self):\n", " config = super().get_config()\n", @@ -481,16 +478,21 @@ " self.layernorm_3 = layers.LayerNormalization()\n", " self.supports_masking = True\n", "\n", - " def call(self, inputs, encoder_outputs, mask=None):\n", + " def call(self, inputs, mask=None):\n", + " inputs, encoder_outputs = inputs\n", " causal_mask = self.get_causal_attention_mask(inputs)\n", - " if mask is not None:\n", - " padding_mask = ops.cast(mask[:, None, :], dtype=\"int32\")\n", - " padding_mask = ops.minimum(padding_mask, causal_mask)\n", + "\n", + " if mask is None:\n", + " inputs_padding_mask, encoder_outputs_padding_mask = None, None\n", " else:\n", - " padding_mask = None\n", + " inputs_padding_mask, encoder_outputs_padding_mask = mask\n", "\n", " attention_output_1 = self.attention_1(\n", - " query=inputs, value=inputs, key=inputs, attention_mask=causal_mask\n", + " query=inputs,\n", + " value=inputs,\n", + " key=inputs,\n", + " attention_mask=causal_mask,\n", + " query_mask=inputs_padding_mask,\n", " )\n", " out_1 = self.layernorm_1(inputs + attention_output_1)\n", "\n", @@ -498,7 +500,8 @@ " query=out_1,\n", " value=encoder_outputs,\n", " key=encoder_outputs,\n", - " attention_mask=padding_mask,\n", + " query_mask=inputs_padding_mask,\n", + " key_mask=encoder_outputs_padding_mask,\n", " )\n", " out_2 = self.layernorm_2(out_1 + attention_output_2)\n", "\n", @@ -527,8 +530,7 @@ " \"num_heads\": self.num_heads,\n", " }\n", " )\n", - " return config\n", - "" + " return config\n" ] }, { @@ -560,14 +562,15 @@ "decoder_inputs = keras.Input(shape=(None,), dtype=\"int64\", name=\"decoder_inputs\")\n", "encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name=\"decoder_state_inputs\")\n", "x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)\n", - "x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs)\n", + "x = TransformerDecoder(embed_dim, latent_dim, num_heads)([x, encoder_outputs])\n", "x = layers.Dropout(0.5)(x)\n", "decoder_outputs = layers.Dense(vocab_size, activation=\"softmax\")(x)\n", "decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)\n", "\n", - "decoder_outputs = decoder([decoder_inputs, encoder_outputs])\n", "transformer = keras.Model(\n", - " [encoder_inputs, decoder_inputs], decoder_outputs, name=\"transformer\"\n", + " {\"encoder_inputs\": encoder_inputs, \"decoder_inputs\": decoder_inputs},\n", + " decoder_outputs,\n", + " name=\"transformer\",\n", ")" ] }, @@ -598,7 +601,9 @@ "\n", "transformer.summary()\n", "transformer.compile(\n", - " \"rmsprop\", loss=\"sparse_categorical_crossentropy\", metrics=[\"accuracy\"]\n", + " \"rmsprop\",\n", + " loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=0),\n", + " metrics=[\"accuracy\"],\n", ")\n", "transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)" ] @@ -635,7 +640,12 @@ " decoded_sentence = \"[start]\"\n", " for i in range(max_decoded_sentence_length):\n", " tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]\n", - " predictions = transformer([tokenized_input_sentence, tokenized_target_sentence])\n", + " predictions = transformer(\n", + " {\n", + " \"encoder_inputs\": tokenized_input_sentence,\n", + " \"decoder_inputs\": tokenized_target_sentence,\n", + " }\n", + " )\n", "\n", " # ops.argmax(predictions[0, i, :]) is not a concrete value for jax here\n", " sampled_token_index = ops.convert_to_numpy(\n", @@ -664,19 +674,19 @@ "After 30 epochs, we get results such as:\n", "\n", "> She handed him the money.\n", - "> [start] ella le pas\u00f3 el dinero [end]\n", + "> [start] ella le pasó el dinero [end]\n", "\n", "> Tom has never heard Mary sing.\n", - "> [start] tom nunca ha o\u00eddo cantar a mary [end]\n", + "> [start] tom nunca ha oído cantar a mary [end]\n", "\n", "> Perhaps she will come tomorrow.\n", - "> [start] tal vez ella vendr\u00e1 ma\u00f1ana [end]\n", + "> [start] tal vez ella vendrá mañana [end]\n", "\n", "> I love to write.\n", "> [start] me encanta escribir [end]\n", "\n", "> His French is improving little by little.\n", - "> [start] su franc\u00e9s va a [UNK] s\u00f3lo un poco [end]\n", + "> [start] su francés va a [UNK] sólo un poco [end]\n", "\n", "> My hotel told me to call you.\n", "> [start] mi hotel me dijo que te [UNK] [end]" @@ -693,7 +703,7 @@ "toc_visible": true }, "kernelspec": { - "display_name": "Python 3", + "display_name": "venv", "language": "python", "name": "python3" }, @@ -707,9 +717,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.0" + "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/examples/nlp/md/neural_machine_translation_with_transformer.md b/examples/nlp/md/neural_machine_translation_with_transformer.md index 97b78eb110..760c48c4df 100644 --- a/examples/nlp/md/neural_machine_translation_with_transformer.md +++ b/examples/nlp/md/neural_machine_translation_with_transformer.md @@ -2,7 +2,7 @@ **Author:** [fchollet](https://twitter.com/fchollet)
**Date created:** 2021/05/26
-**Last modified:** 2023/02/25
+**Last modified:** 2024/11/18
**Description:** Implementing a sequence-to-sequence Transformer and training it on a machine translation task. @@ -109,11 +109,11 @@ for _ in range(5):
``` -("On Saturday nights, it's difficult to find parking around here.", '[start] Los sábados por la noche es difícil encontrar aparcamiento por aquí. [end]') -('I was the worst student in the class.', '[start] Fui el peor estudiante en la clase. [end]') -('There is nothing to do today.', '[start] No hay nada que hacer hoy. [end]') -('The twins do resemble each other.', '[start] Los gemelos se parecen mutuamente. [end]') -('They found Tom in the crowd.', '[start] Encontraron a Tom entre la multitud. [end]') +('The trouble is that we have nowhere to stay tonight.', '[start] El problema es que no tenemos donde quedarnos esta noche. [end]') +("I want to help you, but I can't.", '[start] Quiero ayudarte, pero no puedo. [end]') +('I can help.', '[start] Yo puedo ayudar. [end]') +('Tom fed his dog table scraps.', '[start] Tom alimentó a su perro con sobras de la mesa. [end]') +('Tom never eats junk food.', '[start] Tom nunca come comida chatarra. [end]') ```
@@ -346,10 +346,7 @@ class PositionalEmbedding(layers.Layer): return embedded_tokens + embedded_positions def compute_mask(self, inputs, mask=None): - if mask is None: - return None - else: - return ops.not_equal(inputs, 0) + return ops.not_equal(inputs, 0) def get_config(self): config = super().get_config() @@ -386,16 +383,21 @@ class TransformerDecoder(layers.Layer): self.layernorm_3 = layers.LayerNormalization() self.supports_masking = True - def call(self, inputs, encoder_outputs, mask=None): + def call(self, inputs, mask=None): + inputs, encoder_outputs = inputs causal_mask = self.get_causal_attention_mask(inputs) - if mask is not None: - padding_mask = ops.cast(mask[:, None, :], dtype="int32") - padding_mask = ops.minimum(padding_mask, causal_mask) + + if mask is None: + inputs_padding_mask, encoder_outputs_padding_mask = None, None else: - padding_mask = None + inputs_padding_mask, encoder_outputs_padding_mask = mask attention_output_1 = self.attention_1( - query=inputs, value=inputs, key=inputs, attention_mask=causal_mask + query=inputs, + value=inputs, + key=inputs, + attention_mask=causal_mask, + query_mask=inputs_padding_mask, ) out_1 = self.layernorm_1(inputs + attention_output_1) @@ -403,7 +405,8 @@ class TransformerDecoder(layers.Layer): query=out_1, value=encoder_outputs, key=encoder_outputs, - attention_mask=padding_mask, + query_mask=inputs_padding_mask, + key_mask=encoder_outputs_padding_mask, ) out_2 = self.layernorm_2(out_1 + attention_output_2) @@ -452,14 +455,15 @@ encoder = keras.Model(encoder_inputs, encoder_outputs) decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs") encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs") x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs) -x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs) +x = TransformerDecoder(embed_dim, latent_dim, num_heads)([x, encoder_outputs]) x = layers.Dropout(0.5)(x) decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x) decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs) -decoder_outputs = decoder([decoder_inputs, encoder_outputs]) transformer = keras.Model( - [encoder_inputs, decoder_inputs], decoder_outputs, name="transformer" + {"encoder_inputs": encoder_inputs, "decoder_inputs": decoder_inputs}, + decoder_outputs, + name="transformer", ) ``` @@ -478,7 +482,9 @@ epochs = 1 # This should be at least 30 for convergence transformer.summary() transformer.compile( - "rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"] + "rmsprop", + loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=0), + metrics=["accuracy"], ) transformer.fit(train_ds, epochs=epochs, validation_data=val_ds) ``` @@ -490,24 +496,40 @@ transformer.fit(train_ds, epochs=epochs, validation_data=val_ds) -
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
-┃ Layer (type)         Output Shape       Param #  Connected to         ┃
-┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
-│ encoder_inputs      │ (None, None)      │       0 │ -                    │
-│ (InputLayer)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ positional_embeddi… │ (None, None, 256) │ 3,845,… │ encoder_inputs[0][0] │
-│ (PositionalEmbeddi… │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ decoder_inputs      │ (None, None)      │       0 │ -                    │
-│ (InputLayer)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ transformer_encoder │ (None, None, 256) │ 3,155,… │ positional_embeddin… │
-│ (TransformerEncode… │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ functional_5        │ (None, None,      │ 12,959… │ decoder_inputs[0][0… │
-│ (Functional)        │ 15000)            │         │ transformer_encoder… │
-└─────────────────────┴───────────────────┴─────────┴──────────────────────┘
+
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
+┃ Layer (type)         Output Shape          Param #  Connected to      ┃
+┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
+│ encoder_inputs      │ (None, None)      │          0 │ -                 │
+│ (InputLayer)        │                   │            │                   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ decoder_inputs      │ (None, None)      │          0 │ -                 │
+│ (InputLayer)        │                   │            │                   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ positional_embeddi… │ (None, None, 256) │  3,845,120 │ encoder_inputs[0… │
+│ (PositionalEmbeddi… │                   │            │                   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ not_equal           │ (None, None)      │          0 │ encoder_inputs[0… │
+│ (NotEqual)          │                   │            │                   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ positional_embeddi… │ (None, None, 256) │  3,845,120 │ decoder_inputs[0… │
+│ (PositionalEmbeddi… │                   │            │                   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ transformer_encoder │ (None, None, 256) │  3,155,456 │ positional_embed… │
+│ (TransformerEncode… │                   │            │ not_equal[0][0]   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ not_equal_1         │ (None, None)      │          0 │ decoder_inputs[0… │
+│ (NotEqual)          │                   │            │                   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ transformer_decoder │ (None, None, 256) │  5,259,520 │ positional_embed… │
+│ (TransformerDecode… │                   │            │ transformer_enco… │
+│                     │                   │            │ not_equal_1[0][0… │
+│                     │                   │            │ not_equal[0][0]   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ dropout_3 (Dropout) │ (None, None, 256) │          0 │ transformer_deco… │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ dense_4 (Dense)     │ (None, None,      │  3,855,000 │ dropout_3[0][0]   │
+│                     │ 15000)            │            │                   │
+└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 
@@ -532,14 +554,22 @@ transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)
``` - 5/1302 ━━━━━━━━━━━━━━━━━━━━ 42s 33ms/step - accuracy: 0.3558 - loss: 8.3596 +/root/keras-io/venv/lib/python3.10/site-packages/keras/src/layers/layer.py:932: UserWarning: Layer 'query' (of type EinsumDense) was passed an input with a mask attached to it. However, this layer does not support masking and will therefore destroy the mask information. Downstream layers will not see the mask. + warnings.warn( +/root/keras-io/venv/lib/python3.10/site-packages/keras/src/layers/layer.py:932: UserWarning: Layer 'key' (of type EinsumDense) was passed an input with a mask attached to it. However, this layer does not support masking and will therefore destroy the mask information. Downstream layers will not see the mask. + warnings.warn( +/root/keras-io/venv/lib/python3.10/site-packages/keras/src/layers/layer.py:932: UserWarning: Layer 'value' (of type EinsumDense) was passed an input with a mask attached to it. However, this layer does not support masking and will therefore destroy the mask information. Downstream layers will not see the mask. + warnings.warn( WARNING: All log messages before absl::InitializeLog() is called are written to STDERR -I0000 00:00:1699484373.932513 76082 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. +``` +
- 1302/1302 ━━━━━━━━━━━━━━━━━━━━ 64s 39ms/step - accuracy: 0.7073 - loss: 2.2372 - val_accuracy: 0.7329 - val_loss: 1.6477 + 1302/1302 ━━━━━━━━━━━━━━━━━━━━ 57s 30ms/step - accuracy: 0.1042 - loss: 5.0703 - val_accuracy: 0.1926 - val_loss: 2.9115 - +
+``` + ```
@@ -563,7 +593,12 @@ def decode_sequence(input_sentence): decoded_sentence = "[start]" for i in range(max_decoded_sentence_length): tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1] - predictions = transformer([tokenized_input_sentence, tokenized_target_sentence]) + predictions = transformer( + { + "encoder_inputs": tokenized_input_sentence, + "decoder_inputs": tokenized_target_sentence, + } + ) # ops.argmax(predictions[0, i, :]) is not a concrete value for jax here sampled_token_index = ops.convert_to_numpy( diff --git a/examples/nlp/neural_machine_translation_with_transformer.py b/examples/nlp/neural_machine_translation_with_transformer.py index cef7869608..5fcfa45cb1 100644 --- a/examples/nlp/neural_machine_translation_with_transformer.py +++ b/examples/nlp/neural_machine_translation_with_transformer.py @@ -2,7 +2,7 @@ Title: English-to-Spanish translation with a sequence-to-sequence Transformer Author: [fchollet](https://twitter.com/fchollet) Date created: 2021/05/26 -Last modified: 2023/02/25 +Last modified: 2024/11/18 Description: Implementing a sequence-to-sequence Transformer and training it on a machine translation task. Accelerator: GPU """ @@ -302,10 +302,7 @@ def call(self, inputs): return embedded_tokens + embedded_positions def compute_mask(self, inputs, mask=None): - if mask is None: - return None - else: - return ops.not_equal(inputs, 0) + return ops.not_equal(inputs, 0) def get_config(self): config = super().get_config() @@ -342,16 +339,21 @@ def __init__(self, embed_dim, latent_dim, num_heads, **kwargs): self.layernorm_3 = layers.LayerNormalization() self.supports_masking = True - def call(self, inputs, encoder_outputs, mask=None): + def call(self, inputs, mask=None): + inputs, encoder_outputs = inputs causal_mask = self.get_causal_attention_mask(inputs) - if mask is not None: - padding_mask = ops.cast(mask[:, None, :], dtype="int32") - padding_mask = ops.minimum(padding_mask, causal_mask) + + if mask is None: + inputs_padding_mask, encoder_outputs_padding_mask = None, None else: - padding_mask = None + inputs_padding_mask, encoder_outputs_padding_mask = mask attention_output_1 = self.attention_1( - query=inputs, value=inputs, key=inputs, attention_mask=causal_mask + query=inputs, + value=inputs, + key=inputs, + attention_mask=causal_mask, + query_mask=inputs_padding_mask, ) out_1 = self.layernorm_1(inputs + attention_output_1) @@ -359,7 +361,8 @@ def call(self, inputs, encoder_outputs, mask=None): query=out_1, value=encoder_outputs, key=encoder_outputs, - attention_mask=padding_mask, + query_mask=inputs_padding_mask, + key_mask=encoder_outputs_padding_mask, ) out_2 = self.layernorm_2(out_1 + attention_output_2) @@ -407,14 +410,15 @@ def get_config(self): decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs") encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs") x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs) -x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs) +x = TransformerDecoder(embed_dim, latent_dim, num_heads)([x, encoder_outputs]) x = layers.Dropout(0.5)(x) decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x) decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs) -decoder_outputs = decoder([decoder_inputs, encoder_outputs]) transformer = keras.Model( - [encoder_inputs, decoder_inputs], decoder_outputs, name="transformer" + {"encoder_inputs": encoder_inputs, "decoder_inputs": decoder_inputs}, + decoder_outputs, + name="transformer", ) """ @@ -431,7 +435,9 @@ def get_config(self): transformer.summary() transformer.compile( - "rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"] + "rmsprop", + loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=0), + metrics=["accuracy"], ) transformer.fit(train_ds, epochs=epochs, validation_data=val_ds) @@ -454,7 +460,12 @@ def decode_sequence(input_sentence): decoded_sentence = "[start]" for i in range(max_decoded_sentence_length): tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1] - predictions = transformer([tokenized_input_sentence, tokenized_target_sentence]) + predictions = transformer( + { + "encoder_inputs": tokenized_input_sentence, + "decoder_inputs": tokenized_target_sentence, + } + ) # ops.argmax(predictions[0, i, :]) is not a concrete value for jax here sampled_token_index = ops.convert_to_numpy(