diff --git a/examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb b/examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb
index 266f141c17..16806792f7 100644
--- a/examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb
+++ b/examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb
@@ -10,7 +10,7 @@
"\n",
"**Author:** [fchollet](https://twitter.com/fchollet)
\n",
"**Date created:** 2021/05/26
\n",
- "**Last modified:** 2023/02/25
\n",
+ "**Last modified:** 2024/11/18
\n",
"**Description:** Implementing a sequence-to-sequence Transformer and training it on a machine translation task."
]
},
@@ -84,7 +84,7 @@
"import keras\n",
"from keras import layers\n",
"from keras import ops\n",
- "from keras.layers import TextVectorization\n"
+ "from keras.layers import TextVectorization"
]
},
{
@@ -213,7 +213,7 @@
"The English layer will use the default string standardization (strip punctuation characters)\n",
"and splitting scheme (split on whitespace), while\n",
"the Spanish layer will use a custom standardization, where we add the character\n",
- "`\"\u00bf\"` to the set of punctuation characters to be stripped.\n",
+ "`\"¿\"` to the set of punctuation characters to be stripped.\n",
"\n",
"Note: in a production-grade machine translation model, I would not recommend\n",
"stripping the punctuation characters in either language. Instead, I would recommend turning\n",
@@ -229,7 +229,7 @@
},
"outputs": [],
"source": [
- "strip_chars = string.punctuation + \"\u00bf\"\n",
+ "strip_chars = string.punctuation + \"¿\"\n",
"strip_chars = strip_chars.replace(\"[\", \"\")\n",
"strip_chars = strip_chars.replace(\"]\", \"\")\n",
"\n",
@@ -441,10 +441,7 @@
" return embedded_tokens + embedded_positions\n",
"\n",
" def compute_mask(self, inputs, mask=None):\n",
- " if mask is None:\n",
- " return None\n",
- " else:\n",
- " return ops.not_equal(inputs, 0)\n",
+ " return ops.not_equal(inputs, 0)\n",
"\n",
" def get_config(self):\n",
" config = super().get_config()\n",
@@ -481,16 +478,21 @@
" self.layernorm_3 = layers.LayerNormalization()\n",
" self.supports_masking = True\n",
"\n",
- " def call(self, inputs, encoder_outputs, mask=None):\n",
+ " def call(self, inputs, mask=None):\n",
+ " inputs, encoder_outputs = inputs\n",
" causal_mask = self.get_causal_attention_mask(inputs)\n",
- " if mask is not None:\n",
- " padding_mask = ops.cast(mask[:, None, :], dtype=\"int32\")\n",
- " padding_mask = ops.minimum(padding_mask, causal_mask)\n",
+ "\n",
+ " if mask is None:\n",
+ " inputs_padding_mask, encoder_outputs_padding_mask = None, None\n",
" else:\n",
- " padding_mask = None\n",
+ " inputs_padding_mask, encoder_outputs_padding_mask = mask\n",
"\n",
" attention_output_1 = self.attention_1(\n",
- " query=inputs, value=inputs, key=inputs, attention_mask=causal_mask\n",
+ " query=inputs,\n",
+ " value=inputs,\n",
+ " key=inputs,\n",
+ " attention_mask=causal_mask,\n",
+ " query_mask=inputs_padding_mask,\n",
" )\n",
" out_1 = self.layernorm_1(inputs + attention_output_1)\n",
"\n",
@@ -498,7 +500,8 @@
" query=out_1,\n",
" value=encoder_outputs,\n",
" key=encoder_outputs,\n",
- " attention_mask=padding_mask,\n",
+ " query_mask=inputs_padding_mask,\n",
+ " key_mask=encoder_outputs_padding_mask,\n",
" )\n",
" out_2 = self.layernorm_2(out_1 + attention_output_2)\n",
"\n",
@@ -527,8 +530,7 @@
" \"num_heads\": self.num_heads,\n",
" }\n",
" )\n",
- " return config\n",
- ""
+ " return config\n"
]
},
{
@@ -560,14 +562,15 @@
"decoder_inputs = keras.Input(shape=(None,), dtype=\"int64\", name=\"decoder_inputs\")\n",
"encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name=\"decoder_state_inputs\")\n",
"x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)\n",
- "x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs)\n",
+ "x = TransformerDecoder(embed_dim, latent_dim, num_heads)([x, encoder_outputs])\n",
"x = layers.Dropout(0.5)(x)\n",
"decoder_outputs = layers.Dense(vocab_size, activation=\"softmax\")(x)\n",
"decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)\n",
"\n",
- "decoder_outputs = decoder([decoder_inputs, encoder_outputs])\n",
"transformer = keras.Model(\n",
- " [encoder_inputs, decoder_inputs], decoder_outputs, name=\"transformer\"\n",
+ " {\"encoder_inputs\": encoder_inputs, \"decoder_inputs\": decoder_inputs},\n",
+ " decoder_outputs,\n",
+ " name=\"transformer\",\n",
")"
]
},
@@ -598,7 +601,9 @@
"\n",
"transformer.summary()\n",
"transformer.compile(\n",
- " \"rmsprop\", loss=\"sparse_categorical_crossentropy\", metrics=[\"accuracy\"]\n",
+ " \"rmsprop\",\n",
+ " loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=0),\n",
+ " metrics=[\"accuracy\"],\n",
")\n",
"transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)"
]
@@ -635,7 +640,12 @@
" decoded_sentence = \"[start]\"\n",
" for i in range(max_decoded_sentence_length):\n",
" tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]\n",
- " predictions = transformer([tokenized_input_sentence, tokenized_target_sentence])\n",
+ " predictions = transformer(\n",
+ " {\n",
+ " \"encoder_inputs\": tokenized_input_sentence,\n",
+ " \"decoder_inputs\": tokenized_target_sentence,\n",
+ " }\n",
+ " )\n",
"\n",
" # ops.argmax(predictions[0, i, :]) is not a concrete value for jax here\n",
" sampled_token_index = ops.convert_to_numpy(\n",
@@ -664,19 +674,19 @@
"After 30 epochs, we get results such as:\n",
"\n",
"> She handed him the money.\n",
- "> [start] ella le pas\u00f3 el dinero [end]\n",
+ "> [start] ella le pasó el dinero [end]\n",
"\n",
"> Tom has never heard Mary sing.\n",
- "> [start] tom nunca ha o\u00eddo cantar a mary [end]\n",
+ "> [start] tom nunca ha oído cantar a mary [end]\n",
"\n",
"> Perhaps she will come tomorrow.\n",
- "> [start] tal vez ella vendr\u00e1 ma\u00f1ana [end]\n",
+ "> [start] tal vez ella vendrá mañana [end]\n",
"\n",
"> I love to write.\n",
"> [start] me encanta escribir [end]\n",
"\n",
"> His French is improving little by little.\n",
- "> [start] su franc\u00e9s va a [UNK] s\u00f3lo un poco [end]\n",
+ "> [start] su francés va a [UNK] sólo un poco [end]\n",
"\n",
"> My hotel told me to call you.\n",
"> [start] mi hotel me dijo que te [UNK] [end]"
@@ -693,7 +703,7 @@
"toc_visible": true
},
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "venv",
"language": "python",
"name": "python3"
},
@@ -707,9 +717,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.7.0"
+ "version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 0
-}
\ No newline at end of file
+}
diff --git a/examples/nlp/md/neural_machine_translation_with_transformer.md b/examples/nlp/md/neural_machine_translation_with_transformer.md
index 97b78eb110..760c48c4df 100644
--- a/examples/nlp/md/neural_machine_translation_with_transformer.md
+++ b/examples/nlp/md/neural_machine_translation_with_transformer.md
@@ -2,7 +2,7 @@
**Author:** [fchollet](https://twitter.com/fchollet)
**Date created:** 2021/05/26
-**Last modified:** 2023/02/25
+**Last modified:** 2024/11/18
**Description:** Implementing a sequence-to-sequence Transformer and training it on a machine translation task.
@@ -109,11 +109,11 @@ for _ in range(5):
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ -┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩ -│ encoder_inputs │ (None, None) │ 0 │ - │ -│ (InputLayer) │ │ │ │ -├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ -│ positional_embeddi… │ (None, None, 256) │ 3,845,… │ encoder_inputs[0][0] │ -│ (PositionalEmbeddi… │ │ │ │ -├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ -│ decoder_inputs │ (None, None) │ 0 │ - │ -│ (InputLayer) │ │ │ │ -├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ -│ transformer_encoder │ (None, None, 256) │ 3,155,… │ positional_embeddin… │ -│ (TransformerEncode… │ │ │ │ -├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ -│ functional_5 │ (None, None, │ 12,959… │ decoder_inputs[0][0… │ -│ (Functional) │ 15000) │ │ transformer_encoder… │ -└─────────────────────┴───────────────────┴─────────┴──────────────────────┘ +┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ +┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ +│ encoder_inputs │ (None, None) │ 0 │ - │ +│ (InputLayer) │ │ │ │ +├─────────────────────┼───────────────────┼────────────┼───────────────────┤ +│ decoder_inputs │ (None, None) │ 0 │ - │ +│ (InputLayer) │ │ │ │ +├─────────────────────┼───────────────────┼────────────┼───────────────────┤ +│ positional_embeddi… │ (None, None, 256) │ 3,845,120 │ encoder_inputs[0… │ +│ (PositionalEmbeddi… │ │ │ │ +├─────────────────────┼───────────────────┼────────────┼───────────────────┤ +│ not_equal │ (None, None) │ 0 │ encoder_inputs[0… │ +│ (NotEqual) │ │ │ │ +├─────────────────────┼───────────────────┼────────────┼───────────────────┤ +│ positional_embeddi… │ (None, None, 256) │ 3,845,120 │ decoder_inputs[0… │ +│ (PositionalEmbeddi… │ │ │ │ +├─────────────────────┼───────────────────┼────────────┼───────────────────┤ +│ transformer_encoder │ (None, None, 256) │ 3,155,456 │ positional_embed… │ +│ (TransformerEncode… │ │ │ not_equal[0][0] │ +├─────────────────────┼───────────────────┼────────────┼───────────────────┤ +│ not_equal_1 │ (None, None) │ 0 │ decoder_inputs[0… │ +│ (NotEqual) │ │ │ │ +├─────────────────────┼───────────────────┼────────────┼───────────────────┤ +│ transformer_decoder │ (None, None, 256) │ 5,259,520 │ positional_embed… │ +│ (TransformerDecode… │ │ │ transformer_enco… │ +│ │ │ │ not_equal_1[0][0… │ +│ │ │ │ not_equal[0][0] │ +├─────────────────────┼───────────────────┼────────────┼───────────────────┤ +│ dropout_3 (Dropout) │ (None, None, 256) │ 0 │ transformer_deco… │ +├─────────────────────┼───────────────────┼────────────┼───────────────────┤ +│ dense_4 (Dense) │ (None, None, │ 3,855,000 │ dropout_3[0][0] │ +│ │ 15000) │ │ │ +└─────────────────────┴───────────────────┴────────────┴───────────────────┘@@ -532,14 +554,22 @@ transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)``` - 5/1302 [37m━━━━━━━━━━━━━━━━━━━━ 42s 33ms/step - accuracy: 0.3558 - loss: 8.3596 +/root/keras-io/venv/lib/python3.10/site-packages/keras/src/layers/layer.py:932: UserWarning: Layer 'query' (of type EinsumDense) was passed an input with a mask attached to it. However, this layer does not support masking and will therefore destroy the mask information. Downstream layers will not see the mask. + warnings.warn( +/root/keras-io/venv/lib/python3.10/site-packages/keras/src/layers/layer.py:932: UserWarning: Layer 'key' (of type EinsumDense) was passed an input with a mask attached to it. However, this layer does not support masking and will therefore destroy the mask information. Downstream layers will not see the mask. + warnings.warn( +/root/keras-io/venv/lib/python3.10/site-packages/keras/src/layers/layer.py:932: UserWarning: Layer 'value' (of type EinsumDense) was passed an input with a mask attached to it. However, this layer does not support masking and will therefore destroy the mask information. Downstream layers will not see the mask. + warnings.warn( WARNING: All log messages before absl::InitializeLog() is called are written to STDERR -I0000 00:00:1699484373.932513 76082 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. +``` +- 1302/1302 ━━━━━━━━━━━━━━━━━━━━ 64s 39ms/step - accuracy: 0.7073 - loss: 2.2372 - val_accuracy: 0.7329 - val_loss: 1.6477 + 1302/1302 ━━━━━━━━━━━━━━━━━━━━ 57s 30ms/step - accuracy: 0.1042 - loss: 5.0703 - val_accuracy: 0.1926 - val_loss: 2.9115 -+ +``` +@@ -563,7 +593,12 @@ def decode_sequence(input_sentence): decoded_sentence = "[start]" for i in range(max_decoded_sentence_length): tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1] - predictions = transformer([tokenized_input_sentence, tokenized_target_sentence]) + predictions = transformer( + { + "encoder_inputs": tokenized_input_sentence, + "decoder_inputs": tokenized_target_sentence, + } + ) # ops.argmax(predictions[0, i, :]) is not a concrete value for jax here sampled_token_index = ops.convert_to_numpy( diff --git a/examples/nlp/neural_machine_translation_with_transformer.py b/examples/nlp/neural_machine_translation_with_transformer.py index cef7869608..5fcfa45cb1 100644 --- a/examples/nlp/neural_machine_translation_with_transformer.py +++ b/examples/nlp/neural_machine_translation_with_transformer.py @@ -2,7 +2,7 @@ Title: English-to-Spanish translation with a sequence-to-sequence Transformer Author: [fchollet](https://twitter.com/fchollet) Date created: 2021/05/26 -Last modified: 2023/02/25 +Last modified: 2024/11/18 Description: Implementing a sequence-to-sequence Transformer and training it on a machine translation task. Accelerator: GPU """ @@ -302,10 +302,7 @@ def call(self, inputs): return embedded_tokens + embedded_positions def compute_mask(self, inputs, mask=None): - if mask is None: - return None - else: - return ops.not_equal(inputs, 0) + return ops.not_equal(inputs, 0) def get_config(self): config = super().get_config() @@ -342,16 +339,21 @@ def __init__(self, embed_dim, latent_dim, num_heads, **kwargs): self.layernorm_3 = layers.LayerNormalization() self.supports_masking = True - def call(self, inputs, encoder_outputs, mask=None): + def call(self, inputs, mask=None): + inputs, encoder_outputs = inputs causal_mask = self.get_causal_attention_mask(inputs) - if mask is not None: - padding_mask = ops.cast(mask[:, None, :], dtype="int32") - padding_mask = ops.minimum(padding_mask, causal_mask) + + if mask is None: + inputs_padding_mask, encoder_outputs_padding_mask = None, None else: - padding_mask = None + inputs_padding_mask, encoder_outputs_padding_mask = mask attention_output_1 = self.attention_1( - query=inputs, value=inputs, key=inputs, attention_mask=causal_mask + query=inputs, + value=inputs, + key=inputs, + attention_mask=causal_mask, + query_mask=inputs_padding_mask, ) out_1 = self.layernorm_1(inputs + attention_output_1) @@ -359,7 +361,8 @@ def call(self, inputs, encoder_outputs, mask=None): query=out_1, value=encoder_outputs, key=encoder_outputs, - attention_mask=padding_mask, + query_mask=inputs_padding_mask, + key_mask=encoder_outputs_padding_mask, ) out_2 = self.layernorm_2(out_1 + attention_output_2) @@ -407,14 +410,15 @@ def get_config(self): decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs") encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs") x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs) -x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs) +x = TransformerDecoder(embed_dim, latent_dim, num_heads)([x, encoder_outputs]) x = layers.Dropout(0.5)(x) decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x) decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs) -decoder_outputs = decoder([decoder_inputs, encoder_outputs]) transformer = keras.Model( - [encoder_inputs, decoder_inputs], decoder_outputs, name="transformer" + {"encoder_inputs": encoder_inputs, "decoder_inputs": decoder_inputs}, + decoder_outputs, + name="transformer", ) """ @@ -431,7 +435,9 @@ def get_config(self): transformer.summary() transformer.compile( - "rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"] + "rmsprop", + loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=0), + metrics=["accuracy"], ) transformer.fit(train_ds, epochs=epochs, validation_data=val_ds) @@ -454,7 +460,12 @@ def decode_sequence(input_sentence): decoded_sentence = "[start]" for i in range(max_decoded_sentence_length): tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1] - predictions = transformer([tokenized_input_sentence, tokenized_target_sentence]) + predictions = transformer( + { + "encoder_inputs": tokenized_input_sentence, + "decoder_inputs": tokenized_target_sentence, + } + ) # ops.argmax(predictions[0, i, :]) is not a concrete value for jax here sampled_token_index = ops.convert_to_numpy(```