Skip to content

Commit

Permalink
Fixes to the "English-to-Spanish Translation with a Sequence-to-Seque…
Browse files Browse the repository at this point in the history
…nce Transformer" Code Example (#1997)

* bugfix: Encoder and decoder inputs were flipped.

Given 30 epochs of training, the model never ended producing sensible output. These are examples:
1) Tom didn't like Mary. → [start] ha estoy qué
2) Tom called Mary and canceled their date. → [start] sola qué yo pasatiempo visto campo

When fitting the model the following relevant warning was emitted:
```
UserWarning: The structure of `inputs` doesn't match the expected structure: ['encoder_inputs', 'decoder_inputs']. Received: the structure of inputs={'encoder_inputs': '*', 'decoder_inputs': '*'}
```

After the fix the model now outputs sentences that are close to proper Spanish:
1)That's what Tom told me. → [start] eso es lo que tom me dijo [end]
2) Does Tom like cheeseburgers? → [start] a tom le gustan las queso de queso [end]

* Fix compute_mask in PostionalEmbedding

The check essentially disables the mask calculation, as the layer is the first one to receive the input, and thus never has a previous.

With this change mask is now passed on to the encoder.

Looks like a regression error. The initial commit looks very similar to this.

* Propagate both encoder/decoder-sequence masks to the decoder

 As per https://github.com/tensorflow/tensorflow/blob/6550e4bd80223cdb8be6c3afd1f81e86a4d433c3/tensorflow/python/keras/engine/base_layer.py#L965 the inputs should be a list, and not kwargs. When this is done, both the masks are received as a tuple in the mask argument.

* Apply both the padding masks in the attention layers and during loss computation

* Regenerate ipynb/md-files for NMT example
  • Loading branch information
tomjelen authored Nov 29, 2024
1 parent 5f16942 commit 2158c91
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [fchollet](https://twitter.com/fchollet)<br>\n",
"**Date created:** 2021/05/26<br>\n",
"**Last modified:** 2023/02/25<br>\n",
"**Last modified:** 2024/11/18<br>\n",
"**Description:** Implementing a sequence-to-sequence Transformer and training it on a machine translation task."
]
},
Expand Down Expand Up @@ -84,7 +84,7 @@
"import keras\n",
"from keras import layers\n",
"from keras import ops\n",
"from keras.layers import TextVectorization\n"
"from keras.layers import TextVectorization"
]
},
{
Expand Down Expand Up @@ -213,7 +213,7 @@
"The English layer will use the default string standardization (strip punctuation characters)\n",
"and splitting scheme (split on whitespace), while\n",
"the Spanish layer will use a custom standardization, where we add the character\n",
"`\"\u00bf\"` to the set of punctuation characters to be stripped.\n",
"`\"¿\"` to the set of punctuation characters to be stripped.\n",
"\n",
"Note: in a production-grade machine translation model, I would not recommend\n",
"stripping the punctuation characters in either language. Instead, I would recommend turning\n",
Expand All @@ -229,7 +229,7 @@
},
"outputs": [],
"source": [
"strip_chars = string.punctuation + \"\u00bf\"\n",
"strip_chars = string.punctuation + \"¿\"\n",
"strip_chars = strip_chars.replace(\"[\", \"\")\n",
"strip_chars = strip_chars.replace(\"]\", \"\")\n",
"\n",
Expand Down Expand Up @@ -441,10 +441,7 @@
" return embedded_tokens + embedded_positions\n",
"\n",
" def compute_mask(self, inputs, mask=None):\n",
" if mask is None:\n",
" return None\n",
" else:\n",
" return ops.not_equal(inputs, 0)\n",
" return ops.not_equal(inputs, 0)\n",
"\n",
" def get_config(self):\n",
" config = super().get_config()\n",
Expand Down Expand Up @@ -481,24 +478,30 @@
" self.layernorm_3 = layers.LayerNormalization()\n",
" self.supports_masking = True\n",
"\n",
" def call(self, inputs, encoder_outputs, mask=None):\n",
" def call(self, inputs, mask=None):\n",
" inputs, encoder_outputs = inputs\n",
" causal_mask = self.get_causal_attention_mask(inputs)\n",
" if mask is not None:\n",
" padding_mask = ops.cast(mask[:, None, :], dtype=\"int32\")\n",
" padding_mask = ops.minimum(padding_mask, causal_mask)\n",
"\n",
" if mask is None:\n",
" inputs_padding_mask, encoder_outputs_padding_mask = None, None\n",
" else:\n",
" padding_mask = None\n",
" inputs_padding_mask, encoder_outputs_padding_mask = mask\n",
"\n",
" attention_output_1 = self.attention_1(\n",
" query=inputs, value=inputs, key=inputs, attention_mask=causal_mask\n",
" query=inputs,\n",
" value=inputs,\n",
" key=inputs,\n",
" attention_mask=causal_mask,\n",
" query_mask=inputs_padding_mask,\n",
" )\n",
" out_1 = self.layernorm_1(inputs + attention_output_1)\n",
"\n",
" attention_output_2 = self.attention_2(\n",
" query=out_1,\n",
" value=encoder_outputs,\n",
" key=encoder_outputs,\n",
" attention_mask=padding_mask,\n",
" query_mask=inputs_padding_mask,\n",
" key_mask=encoder_outputs_padding_mask,\n",
" )\n",
" out_2 = self.layernorm_2(out_1 + attention_output_2)\n",
"\n",
Expand Down Expand Up @@ -527,8 +530,7 @@
" \"num_heads\": self.num_heads,\n",
" }\n",
" )\n",
" return config\n",
""
" return config\n"
]
},
{
Expand Down Expand Up @@ -560,14 +562,15 @@
"decoder_inputs = keras.Input(shape=(None,), dtype=\"int64\", name=\"decoder_inputs\")\n",
"encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name=\"decoder_state_inputs\")\n",
"x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)\n",
"x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs)\n",
"x = TransformerDecoder(embed_dim, latent_dim, num_heads)([x, encoder_outputs])\n",
"x = layers.Dropout(0.5)(x)\n",
"decoder_outputs = layers.Dense(vocab_size, activation=\"softmax\")(x)\n",
"decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)\n",
"\n",
"decoder_outputs = decoder([decoder_inputs, encoder_outputs])\n",
"transformer = keras.Model(\n",
" [encoder_inputs, decoder_inputs], decoder_outputs, name=\"transformer\"\n",
" {\"encoder_inputs\": encoder_inputs, \"decoder_inputs\": decoder_inputs},\n",
" decoder_outputs,\n",
" name=\"transformer\",\n",
")"
]
},
Expand Down Expand Up @@ -598,7 +601,9 @@
"\n",
"transformer.summary()\n",
"transformer.compile(\n",
" \"rmsprop\", loss=\"sparse_categorical_crossentropy\", metrics=[\"accuracy\"]\n",
" \"rmsprop\",\n",
" loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=0),\n",
" metrics=[\"accuracy\"],\n",
")\n",
"transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)"
]
Expand Down Expand Up @@ -635,7 +640,12 @@
" decoded_sentence = \"[start]\"\n",
" for i in range(max_decoded_sentence_length):\n",
" tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]\n",
" predictions = transformer([tokenized_input_sentence, tokenized_target_sentence])\n",
" predictions = transformer(\n",
" {\n",
" \"encoder_inputs\": tokenized_input_sentence,\n",
" \"decoder_inputs\": tokenized_target_sentence,\n",
" }\n",
" )\n",
"\n",
" # ops.argmax(predictions[0, i, :]) is not a concrete value for jax here\n",
" sampled_token_index = ops.convert_to_numpy(\n",
Expand Down Expand Up @@ -664,19 +674,19 @@
"After 30 epochs, we get results such as:\n",
"\n",
"> She handed him the money.\n",
"> [start] ella le pas\u00f3 el dinero [end]\n",
"> [start] ella le pasó el dinero [end]\n",
"\n",
"> Tom has never heard Mary sing.\n",
"> [start] tom nunca ha o\u00eddo cantar a mary [end]\n",
"> [start] tom nunca ha oído cantar a mary [end]\n",
"\n",
"> Perhaps she will come tomorrow.\n",
"> [start] tal vez ella vendr\u00e1 ma\u00f1ana [end]\n",
"> [start] tal vez ella vendrá mañana [end]\n",
"\n",
"> I love to write.\n",
"> [start] me encanta escribir [end]\n",
"\n",
"> His French is improving little by little.\n",
"> [start] su franc\u00e9s va a [UNK] s\u00f3lo un poco [end]\n",
"> [start] su francés va a [UNK] sólo un poco [end]\n",
"\n",
"> My hotel told me to call you.\n",
"> [start] mi hotel me dijo que te [UNK] [end]"
Expand All @@ -693,7 +703,7 @@
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "venv",
"language": "python",
"name": "python3"
},
Expand All @@ -707,9 +717,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading

0 comments on commit 2158c91

Please sign in to comment.