From 275300acffba72b8f1c67b9ceb327794502383f4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Joaqu=C3=ADn=20Jim=C3=A9nez?=
<37051995+johacks@users.noreply.github.com>
Date: Wed, 27 Nov 2024 11:09:55 +0100
Subject: [PATCH] Add Keras 3 example for "Transformer model for MIDI music
generation" (#1992)
* Pre-PR review: base script file working in tf
* Fix too slow inference on jax
* Refactor MIDI generation script: PR feedback.
- Update import statement for keras_hub
- Fix documentation for audio dependencies
- Remove unnecessary markdown annotation
Other changes:
- Update last modified date in MIDI generation example
- Fix: Increase training epochs from 1 to 20 in MIDI generation example
- Fix: use flip for unsupported negative step in torch
- Fix: Remove unused learning rate
* Torch backend fixes: add cast on generate, change learning rate scale
* Compact script
* Add auto generated files
---
.../midi_generation_with_transformer.ipynb | 1016 ++++++++++++++++
.../md/midi_generation_with_transformer.md | 1059 +++++++++++++++++
.../midi_generation_with_transformer.py | 722 +++++++++++
3 files changed, 2797 insertions(+)
create mode 100644 examples/generative/ipynb/midi_generation_with_transformer.ipynb
create mode 100644 examples/generative/md/midi_generation_with_transformer.md
create mode 100644 examples/generative/midi_generation_with_transformer.py
diff --git a/examples/generative/ipynb/midi_generation_with_transformer.ipynb b/examples/generative/ipynb/midi_generation_with_transformer.ipynb
new file mode 100644
index 0000000000..e034e37edc
--- /dev/null
+++ b/examples/generative/ipynb/midi_generation_with_transformer.ipynb
@@ -0,0 +1,1016 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# Music Generation with Transformer Models\n",
+ "\n",
+ "**Author:** [Joaquin Jimenez](https://github.com/johacks/)
\n",
+ "**Date created:** 2024/11/22
\n",
+ "**Last modified:** 2024/11/26
\n",
+ "**Description:** Use a Transformer model to train on MIDI data and generate music sequences."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Introduction\n",
+ "\n",
+ "In this tutorial, we learn how to build a music generation model using a\n",
+ "Transformer decode-only architecture.\n",
+ "The model is trained on the [Maestro dataset](https://magenta.tensorflow.org/datasets/maestro)\n",
+ "and implemented using keras 3.\n",
+ "In the process, we explore MIDI tokenization, and relative global attention mechanisms.\n",
+ "\n",
+ "This example is based on the paper \"Music Transformer\" by Huang et al. (2018).\n",
+ "Check out the original [paper](https://arxiv.org/abs/1809.04281) and\n",
+ "[code](https://github.com/jason9693/MusicTransformer-tensorflow2.0)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Setup\n",
+ "\n",
+ "Before we start, let's import and install all the libraries we need."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -qq midi_neural_processor\n",
+ "!pip install -qq keras_hub\n",
+ "!pip install -qq \"keras>=3.6.0\" # Allows use of keras.utils.Config."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Optional dependencies\n",
+ "\n",
+ "To hear the audio, install the following additional dependencies:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "!sudo apt-get -qq install -y fluidsynth 2> /dev/null\n",
+ "!pip install -qq pyfluidsynth scipy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import random\n",
+ "import tempfile\n",
+ "\n",
+ "import keras\n",
+ "import midi_neural_processor.processor as midi_tokenizer\n",
+ "import numpy as np\n",
+ "from keras import callbacks, layers, ops, optimizers, utils\n",
+ "from keras_hub import layers as hub_layers\n",
+ "from os import path"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Configuration\n",
+ "\n",
+ "Lets define the configuration for the model and the dataset to be used in this example."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "event_range = midi_tokenizer.RANGE_NOTE_ON\n",
+ "event_range += midi_tokenizer.RANGE_NOTE_OFF\n",
+ "event_range += midi_tokenizer.RANGE_TIME_SHIFT\n",
+ "event_range += midi_tokenizer.RANGE_VEL\n",
+ "CONFIG = utils.Config(\n",
+ " max_sequence_len=2048,\n",
+ " embedding_dim=256,\n",
+ " num_transformer_blocks=6,\n",
+ " batch_size=6,\n",
+ " token_pad=event_range,\n",
+ " token_start_of_sentence=event_range + 1,\n",
+ " token_end_of_sentence=event_range + 2,\n",
+ " vocabulary_size=event_range + 3,\n",
+ " model_out=\"tmp/music_transformer.keras\",\n",
+ " seed=42,\n",
+ ")\n",
+ "utils.set_random_seed(CONFIG.seed)\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Maestro dataset\n",
+ "\n",
+ "The Maestro dataset contains MIDI files for piano performances.\n",
+ "\n",
+ "### Download the dataset\n",
+ "\n",
+ "We now download and extract the dataset, then move the MIDI files to a new directory."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def download_maestro(output_dir=None):\n",
+ " \"\"\"Download the Maestro MIDI dataset.\n",
+ " Extracted from: https://magenta.tensorflow.org/datasets/maestro\n",
+ " \"\"\"\n",
+ " # Ensure the output directory exists\n",
+ " output_dir = tempfile.mkdtemp() if output_dir is None else output_dir\n",
+ " os.makedirs(output_dir, exist_ok=True)\n",
+ "\n",
+ " # Download and extract zip file\n",
+ " dir = utils.get_file(\n",
+ " origin=\"https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip\",\n",
+ " extract=True,\n",
+ " )\n",
+ "\n",
+ " # Gather all MIDI files\n",
+ " midi_files, file_paths = set(), list()\n",
+ " for root, _, files in os.walk(dir):\n",
+ " for file in files:\n",
+ " if file.lower().endswith(\".midi\") or file.lower().endswith(\".mid\"):\n",
+ " midi_files.add(path.join(root, file))\n",
+ "\n",
+ " # Move the files to the output directory\n",
+ " for file in sorted(midi_files):\n",
+ " file_paths.append(new_path := path.join(output_dir, path.basename(file)))\n",
+ " os.rename(file, new_path)\n",
+ " return file_paths\n",
+ "\n",
+ "\n",
+ "paths = list(sorted(download_maestro(output_dir=\"datasets/maestro\")))\n",
+ "output_dir = path.dirname(paths[0])\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Split the dataset\n",
+ "\n",
+ "We can now split the dataset into training and validation sets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "indices = np.random.permutation(len(paths))\n",
+ "split = int(len(paths) * 0.1)\n",
+ "train_paths = [paths[i] for i in indices[split:]]\n",
+ "val_paths = [paths[i] for i in indices[:split]]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Hear a MIDI file\n",
+ "\n",
+ "We use the pretty_midi library and fluidsynth to convert MIDI files into waveform audio.\n",
+ "This allows us to listen to the data samples before and after processing.\n",
+ "\n",
+ "The following dependencies are required to play the audio:\n",
+ "- fluidsynth: `sudo apt install -y fluidsynth`\n",
+ "- pyfluidsynth, scipy: `pip install pyfluidsynth scipy`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def visualize_midi(midi_path, sampling_rate=16000, seconds=15, out_dir=None):\n",
+ " import pretty_midi\n",
+ " from scipy.io.wavfile import write as write_wav\n",
+ " from IPython.display import Audio\n",
+ "\n",
+ " # Create the audio waveform\n",
+ " pretty_midi_file = pretty_midi.PrettyMIDI(midi_path)\n",
+ " waveform = pretty_midi_file.fluidsynth(fs=sampling_rate)[: seconds * sampling_rate]\n",
+ "\n",
+ " # Display the audio if no path is provided\n",
+ " if out_dir is None:\n",
+ " # IPython display\n",
+ " return Audio(waveform, rate=sampling_rate)\n",
+ "\n",
+ " # Save the audio to a file\n",
+ " os.makedirs(out_dir, exist_ok=True)\n",
+ " audio_path = path.join(out_dir, path.basename(midi_path).split(\".\")[0] + \".wav\")\n",
+ " write_wav(audio_path, sampling_rate, (waveform * 32767).astype(np.int16))\n",
+ " return audio_path\n",
+ "\n",
+ "\n",
+ "print(visualize_midi(train_paths[0], out_dir=\"tmp/\")) # Saved audio path\n",
+ "visualize_midi(train_paths[0]) # Display the audio if in a Jupyter notebook\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Tokenize the data\n",
+ "\n",
+ "We now preprocess the MIDI files into a tokenized format for training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def encode_midi_task(midi_path):\n",
+ " \"\"\"Define a task that tokenizes a MIDI file.\"\"\"\n",
+ " import midi_neural_processor.processor as midi_tokenizer\n",
+ "\n",
+ " return midi_tokenizer.encode_midi(midi_path)\n",
+ "\n",
+ "\n",
+ "def preprocess_midi_files(file_paths, save_dir=None):\n",
+ " \"\"\"Preprocess a list of MIDI files and save the notes to a file.\"\"\"\n",
+ " from multiprocessing import Pool, cpu_count\n",
+ "\n",
+ " # Assume all files are in the same directory and save to the same directory\n",
+ " save_dir = path.dirname(file_paths[0]) if save_dir is None else save_dir\n",
+ " os.makedirs(save_dir, exist_ok=True)\n",
+ "\n",
+ " # Check if the notes have already been preprocessed\n",
+ " output_file = path.join(save_dir, \"notes.npz\")\n",
+ " if path.exists(output_file):\n",
+ " npz_file = np.load(output_file)\n",
+ " return [npz_file[key] for key in npz_file.keys()]\n",
+ "\n",
+ " # Preprocess the MIDI files in parallel\n",
+ " progbar = utils.Progbar(len(file_paths), unit_name=\"MIDI_file\", interval=5)\n",
+ " pool = Pool(cpu_count() - 1)\n",
+ " all_notes = []\n",
+ " for notes in pool.imap_unordered(encode_midi_task, file_paths):\n",
+ " progbar.add(1)\n",
+ " all_notes.append(np.array(notes))\n",
+ "\n",
+ " # Save the notes to a file\n",
+ " np.savez(output_file, *all_notes)\n",
+ " return all_notes\n",
+ "\n",
+ "\n",
+ "train_midis = preprocess_midi_files(train_paths, path.join(output_dir, \"train\"))\n",
+ "val_midis = preprocess_midi_files(val_paths, path.join(output_dir, \"val\"))\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Dataset objects\n",
+ "\n",
+ "We now define a dataset class that yields batches of input sequences and target sequences."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "class MidiDataset(utils.PyDataset):\n",
+ " \"\"\"A dataset for MIDI files that yields batches of input sequences and target sequences.\"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " encoded_midis,\n",
+ " batch_size=CONFIG.batch_size,\n",
+ " max_sequence_len=CONFIG.max_sequence_len,\n",
+ " ):\n",
+ " super(MidiDataset, self).__init__()\n",
+ " self.batch_size = batch_size\n",
+ " self.max_sequence_len = max_sequence_len\n",
+ " self.encoded_midis = encoded_midis\n",
+ " batches, last_batch_size = divmod(len(encoded_midis), batch_size)\n",
+ " self._num_batches = batches + int(last_batch_size > 0)\n",
+ "\n",
+ " def __len__(self):\n",
+ " \"\"\"Get the number of batches.\"\"\"\n",
+ " return self._num_batches\n",
+ "\n",
+ " def __getitem__(self, idx):\n",
+ " \"\"\"Generate random inputs and corresponding targets for the model.\"\"\"\n",
+ " # Same as in the original paper, we always get a random batch.\n",
+ " # See: https://github.com/jason9693/MusicTransformer-tensorflow2.0/blob/f7c06c0cb2e9cdddcbf6db779cb39cd650282778/data.py\n",
+ " batch = random.sample(self.encoded_midis, k=self.batch_size)\n",
+ "\n",
+ " # Convert the batch to sequences\n",
+ " batch_data = [\n",
+ " self._get_sequence(midi, self.max_sequence_len + 1) for midi in batch\n",
+ " ]\n",
+ " batch_data = np.array(batch_data)\n",
+ "\n",
+ " # Split the data into input and target sequences\n",
+ " return batch_data[:, :-1], batch_data[:, 1:]\n",
+ "\n",
+ " def _get_sequence(self, data, max_length):\n",
+ " \"\"\"Get a random sequence of notes from a file.\"\"\"\n",
+ " # Truncate or pad the sequence\n",
+ " if len(data) > max_length:\n",
+ " start = random.randrange(0, len(data) - max_length)\n",
+ " data = data[start : start + max_length]\n",
+ " elif len(data) < max_length:\n",
+ " data = np.append(data, CONFIG.token_end_of_sentence)\n",
+ "\n",
+ " # Pad the sequence if necessary\n",
+ " if len(data) < max_length:\n",
+ " data = np.concatenate(\n",
+ " (data, np.full(max_length - len(data), CONFIG.token_pad))\n",
+ " )\n",
+ " return np.asanyarray(data, dtype=\"int32\")\n",
+ "\n",
+ "\n",
+ "train_dataset, val_dataset = MidiDataset(train_midis), MidiDataset(val_midis)\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Model definition\n",
+ "\n",
+ "It is time to define the model architecture. We use a Transformer decoder\n",
+ "architecture with a custom attention mechanism, relative global attention.\n",
+ "\n",
+ "### Relative Global Attention\n",
+ "\n",
+ "The following code implements the Relative Global Attention layer. It is used\n",
+ "in place of the standard multi-head attention layer in the Transformer decoder.\n",
+ "The main difference is that it includes a relative positional encoding that\n",
+ "allows the model to learn relative positional information between tokens."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "@keras.utils.register_keras_serializable()\n",
+ "class RelativeGlobalAttention(layers.Layer):\n",
+ " \"\"\"\n",
+ " From Music Transformer (Huang et al., 2018)\n",
+ " https://arxiv.org/abs/1809.04281\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, num_heads, embedding_dim, max_sequence_len, **kwargs):\n",
+ " super().__init__(**kwargs)\n",
+ " self.key_length = None\n",
+ " self.max_sequence_len = max_sequence_len\n",
+ " self.relative_embedding = None\n",
+ " self.num_heads = num_heads\n",
+ " self.embedding_dim = embedding_dim\n",
+ " self.head_dim = embedding_dim // num_heads\n",
+ " self.query_dense = layers.Dense(int(self.embedding_dim))\n",
+ " self.key_dense = layers.Dense(int(self.embedding_dim))\n",
+ " self.value_dense = layers.Dense(int(self.embedding_dim))\n",
+ " self.output_dense = layers.Dense(embedding_dim, name=\"output\")\n",
+ "\n",
+ " def build(self, input_shape):\n",
+ " self.query_length = input_shape[0][1]\n",
+ " self.key_length = input_shape[1][1]\n",
+ " self.relative_embedding = self.add_weight(\n",
+ " (self.max_sequence_len, int(self.head_dim)), name=\"relative_embedding\"\n",
+ " )\n",
+ "\n",
+ " def _apply_dense_layer_and_split_heads(self, inputs, dense_layer):\n",
+ " # Apply linear transformation\n",
+ " inputs = dense_layer(inputs)\n",
+ " new_shape = ops.shape(inputs)\n",
+ " # Reshape to split by attention heads\n",
+ " reshaped = ops.reshape(inputs, (new_shape[0], new_shape[1], self.num_heads, -1))\n",
+ " # Transpose for head-first format\n",
+ " return ops.transpose(reshaped, (0, 2, 1, 3))\n",
+ "\n",
+ " def call(self, inputs, mask=None):\n",
+ " # Compute Q, K, V: Batch, head, sequence, features\n",
+ " query = self._apply_dense_layer_and_split_heads(inputs[0], self.query_dense)\n",
+ " key = self._apply_dense_layer_and_split_heads(inputs[1], self.key_dense)\n",
+ " value = self._apply_dense_layer_and_split_heads(inputs[2], self.value_dense)\n",
+ "\n",
+ " # Compute scaled dot-product attention scores\n",
+ " attention_scores = ops.matmul(query, ops.transpose(key, [0, 1, 3, 2]))\n",
+ "\n",
+ " # Compute relative positional encoding and combine with attention scores\n",
+ " start_idx = max(0, self.max_sequence_len - ops.shape(query)[2])\n",
+ " relative_embedding = self.relative_embedding[start_idx:, :]\n",
+ " attention_scores += self._compute_attention_scores(query, relative_embedding)\n",
+ " logits = attention_scores / ops.sqrt(self.head_dim)\n",
+ "\n",
+ " # Apply mask if provided\n",
+ " if mask is not None:\n",
+ " logits += ops.cast(mask, \"float32\") * -1e9\n",
+ "\n",
+ " # Compute attention weights\n",
+ " attention_weights = ops.nn.softmax(logits, axis=-1)\n",
+ " attention_output = ops.matmul(attention_weights, value)\n",
+ "\n",
+ " # Merge heads and apply final linear transformation\n",
+ " merged_attention = ops.transpose(attention_output, (0, 2, 1, 3))\n",
+ " merged_attention = ops.reshape(\n",
+ " merged_attention, (ops.shape(merged_attention)[0], -1, self.embedding_dim)\n",
+ " )\n",
+ " output = self.output_dense(merged_attention)\n",
+ "\n",
+ " return output, attention_weights\n",
+ "\n",
+ " def _compute_attention_scores(self, query, relative_embedding):\n",
+ " \"\"\"\n",
+ " Compute relative attention scores using positional encodings.\n",
+ " \"\"\"\n",
+ " relative_scores = ops.einsum(\"bhld, md->bhlm\", query, relative_embedding)\n",
+ " relative_scores = self._apply_mask_to_relative_scores(relative_scores)\n",
+ " return self._skew_attention_scores(relative_scores)\n",
+ "\n",
+ " def _apply_mask_to_relative_scores(self, scores):\n",
+ " \"\"\"\n",
+ " Apply masking to relative positional scores to ignore future positions.\n",
+ " \"\"\"\n",
+ " mask = ops.flip(\n",
+ " ops.tri(scores.shape[-2], scores.shape[-1], dtype=\"float32\"), axis=1\n",
+ " )\n",
+ " return mask * scores\n",
+ "\n",
+ " def _skew_attention_scores(self, scores):\n",
+ " \"\"\"\n",
+ " Perform skewing operation to align relative attention scores with the sequence.\n",
+ " \"\"\"\n",
+ " padded_scores = ops.pad(scores, ((0, 0), (0, 0), (0, 0), (1, 0)))\n",
+ " padded_shape = ops.shape(padded_scores)\n",
+ " reshaped_scores = ops.reshape(\n",
+ " padded_scores, (-1, padded_shape[1], padded_shape[-1], padded_shape[-2])\n",
+ " )\n",
+ " skewed_scores = reshaped_scores[:, :, 1:, :]\n",
+ "\n",
+ " if self.key_length > self.query_length:\n",
+ " size_diff = self.key_length - self.query_length\n",
+ " return ops.pad(skewed_scores, [[0, 0], [0, 0], [0, 0], [0, size_diff]])\n",
+ " else:\n",
+ " return skewed_scores[:, :, :, : self.key_length]\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Decoder Layer\n",
+ "\n",
+ "Using the RelativeGlobalAttention layer, we can define the DecoderLayer. It is mostly like\n",
+ "the standard Transformer decoder layer but with the custom attention mechanism."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "@keras.utils.register_keras_serializable()\n",
+ "class DecoderLayer(layers.Layer):\n",
+ " def __init__(self, embedding_dim, num_heads, max_sequence_len, dropout=0.1):\n",
+ " super(DecoderLayer, self).__init__()\n",
+ "\n",
+ " # Initialize attributes\n",
+ " self.embedding_dim = embedding_dim\n",
+ " self.num_heads = num_heads\n",
+ " self.max_sequence_len = max_sequence_len\n",
+ "\n",
+ " # Initialize layers\n",
+ " self.relative_global_attention_1 = RelativeGlobalAttention(\n",
+ " num_heads, embedding_dim, max_sequence_len\n",
+ " )\n",
+ "\n",
+ " self.feed_forward_network_pre = layers.Dense(self.embedding_dim // 2, \"relu\")\n",
+ " self.feed_forward_network_pos = layers.Dense(self.embedding_dim)\n",
+ "\n",
+ " self.layer_normalization_1 = layers.LayerNormalization(epsilon=1e-6)\n",
+ " self.layer_normalization_2 = layers.LayerNormalization(epsilon=1e-6)\n",
+ "\n",
+ " self.dropout_1 = layers.Dropout(dropout)\n",
+ " self.dropout_2 = layers.Dropout(dropout)\n",
+ "\n",
+ " def call(self, inputs, mask=None, training=False):\n",
+ " # Attention block. Inputs are (query, key, value)\n",
+ " attention_out, attention_weights = self.relative_global_attention_1(\n",
+ " (inputs, inputs, inputs), mask=mask\n",
+ " )\n",
+ " attention_out = self.dropout_1(attention_out, training=training)\n",
+ " attention_out_normalized = self.layer_normalization_1(attention_out + inputs)\n",
+ "\n",
+ " ffn_out = self.feed_forward_network_pre(attention_out)\n",
+ " ffn_out = self.feed_forward_network_pos(ffn_out)\n",
+ " ffn_out = self.dropout_2(ffn_out, training=training)\n",
+ " out = self.layer_normalization_2(attention_out_normalized + ffn_out)\n",
+ "\n",
+ " return out, attention_weights\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Decoder\n",
+ "\n",
+ "The Decoder layer is composed of multiple DecoderLayer blocks. It also includes\n",
+ "an embedding layer that converts our tokenized input into an embedding representation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "@keras.utils.register_keras_serializable()\n",
+ "class Decoder(layers.Layer):\n",
+ " def __init__(\n",
+ " self, embedding_dim, vocabulary_size, max_sequence_len, num_blocks, dropout\n",
+ " ):\n",
+ " super(Decoder, self).__init__()\n",
+ "\n",
+ " self.embedding_dim = embedding_dim\n",
+ " self.num_blocks = num_blocks\n",
+ "\n",
+ " self.embedding = layers.Embedding(vocabulary_size, self.embedding_dim)\n",
+ " self.positional_encoding = hub_layers.SinePositionEncoding()\n",
+ "\n",
+ " self.decode_layers = [\n",
+ " DecoderLayer(\n",
+ " embedding_dim, embedding_dim // 64, max_sequence_len, dropout=dropout\n",
+ " )\n",
+ " for _ in range(num_blocks)\n",
+ " ]\n",
+ " self.dropout = layers.Dropout(dropout)\n",
+ "\n",
+ " def call(self, inputs, mask=None, training=False, return_attention_weights=False):\n",
+ " weights = []\n",
+ "\n",
+ " # Adding embedding and position encoding.\n",
+ " x = self.embedding(inputs)\n",
+ " x = x * ops.sqrt(ops.cast(self.embedding_dim, \"float32\"))\n",
+ " x = x + self.positional_encoding(x)\n",
+ " x = self.dropout(x, training=training)\n",
+ "\n",
+ " # Passing through the transformer blocks.\n",
+ " for i in range(self.num_blocks):\n",
+ " x, w = self.decode_layers[i](x, mask=mask, training=training)\n",
+ " weights.append(w)\n",
+ " if return_attention_weights:\n",
+ " return x, weights\n",
+ " return x\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Music Transformer Decoder\n",
+ "\n",
+ "With the above layers defined, we can now define the MusicTransformerDecoder model. It applies\n",
+ "a linear transformation to the output of the decoder to get the logits for each token."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "@keras.utils.register_keras_serializable()\n",
+ "class MusicTransformerDecoder(keras.Model):\n",
+ " def __init__(\n",
+ " self,\n",
+ " embedding_dim=CONFIG.embedding_dim,\n",
+ " vocabulary_size=CONFIG.vocabulary_size,\n",
+ " num_blocks=CONFIG.num_transformer_blocks,\n",
+ " max_sequence_len=CONFIG.max_sequence_len,\n",
+ " dropout=0.2,\n",
+ " ):\n",
+ " # Initialize attributes\n",
+ " super(MusicTransformerDecoder, self).__init__()\n",
+ " self.embedding_dim = embedding_dim\n",
+ " self.vocabulary_size = vocabulary_size\n",
+ " self.num_blocks = num_blocks\n",
+ " self.max_sequence_len = max_sequence_len\n",
+ "\n",
+ " # Initialize layers\n",
+ " # Transformer decoder\n",
+ " self.decoder = Decoder(\n",
+ " embedding_dim, vocabulary_size, max_sequence_len, num_blocks, dropout\n",
+ " )\n",
+ " # Output layer\n",
+ " self.fc = layers.Dense(self.vocabulary_size, activation=None, name=\"output\")\n",
+ "\n",
+ " @staticmethod\n",
+ " def get_look_ahead_mask(max_sequence_len, inputs):\n",
+ " sequence_length = min(max_sequence_len, inputs.shape[1])\n",
+ " sequence_mask = ops.logical_not(\n",
+ " ops.tri(sequence_length, sequence_length, dtype=\"bool\")\n",
+ " )\n",
+ "\n",
+ " inputs = ops.cast(inputs[:, None, None, :], \"int32\")\n",
+ " output_pad_tensor = ops.ones_like(inputs) * CONFIG.token_pad\n",
+ " decoder_output_mask = ops.equal(inputs, output_pad_tensor)\n",
+ " return ops.cast(ops.logical_or(decoder_output_mask, sequence_mask), \"int32\")\n",
+ "\n",
+ " def call(self, inputs, training=False):\n",
+ " mask = self.get_look_ahead_mask(self.max_sequence_len, inputs)\n",
+ " decoding = self.decoder(\n",
+ " inputs, mask=mask, training=training, return_attention_weights=False\n",
+ " )\n",
+ " return self.fc(decoding)\n",
+ "\n",
+ " # --- Sequence generation methods\n",
+ "\n",
+ " def generate(self, inputs: list, length=CONFIG.max_sequence_len, top_k=5):\n",
+ " inputs = ops.convert_to_tensor([inputs])\n",
+ "\n",
+ " # Generate a new token using output distribution at given index\n",
+ " def generate_token(inputs, end_idx):\n",
+ " distribution = ops.stop_gradient(self.call(inputs)[0, end_idx])\n",
+ "\n",
+ " # Select the top-k tokens and their probabilities\n",
+ " top_k_distribution, top_k_indices = ops.top_k(distribution, k=top_k)\n",
+ "\n",
+ " # Sample from the top-k probabilities\n",
+ " new_token_idx = keras.random.categorical(top_k_distribution[None, :], 1)\n",
+ " return ops.take(top_k_indices, new_token_idx[0])\n",
+ "\n",
+ " # Compute the number of tokens to add\n",
+ " added_tokens = min(length, self.max_sequence_len - inputs.shape[1])\n",
+ " progbar = utils.Progbar(added_tokens, unit_name=\"token\", interval=5)\n",
+ "\n",
+ " # Pad the input sequence that will be filled with generated tokens\n",
+ " out = ops.pad(inputs, ((0, 0), (0, added_tokens)), \"constant\", CONFIG.token_pad)\n",
+ "\n",
+ " # Generate tokens using top-k sampling\n",
+ " for token_idx in range(inputs.shape[1] - 1, inputs.shape[1] - 1 + added_tokens):\n",
+ " token = ops.cast(generate_token(out, end_idx=token_idx), out.dtype)\n",
+ " out = ops.scatter_update(out, ((0, token_idx + 1),), token)\n",
+ " progbar.add(1)\n",
+ "\n",
+ " return ops.convert_to_numpy(out[0])\n",
+ "\n",
+ " # --- Serialization methods\n",
+ "\n",
+ " def get_config(self):\n",
+ " atts = [\"embedding_dim\", \"vocabulary_size\", \"num_blocks\", \"max_sequence_len\"]\n",
+ " return {a: getattr(self, a) for a in atts}\n",
+ "\n",
+ " @classmethod\n",
+ " def from_config(cls, config):\n",
+ " return cls(**config)\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Loss function\n",
+ "\n",
+ "We define a custom loss function that computes the categorical cross-entropy\n",
+ "loss for the model. It is computed only for non-padding tokens and uses\n",
+ "`from_logits=True` since the model outputs logits."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "@keras.utils.register_keras_serializable()\n",
+ "def train_loss(y_true, y_pred):\n",
+ " mask = ops.cast(ops.logical_not(ops.equal(y_true, CONFIG.token_pad)), \"float32\")\n",
+ " y_true = ops.one_hot(ops.cast(y_true, \"int32\"), CONFIG.vocabulary_size)\n",
+ " return ops.categorical_crossentropy(y_true, y_pred, from_logits=True) * mask\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Learning rate schedule\n",
+ "\n",
+ "Following the Music Transformer paper, we define an adapted exponential decay\n",
+ "learning rate schedule that takes into account the embedding dimension."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "@keras.utils.register_keras_serializable()\n",
+ "class CustomSchedule(optimizers.schedules.LearningRateSchedule):\n",
+ " def __init__(self, embedding_dim, warmup_steps=4000):\n",
+ " super(CustomSchedule, self).__init__()\n",
+ "\n",
+ " self.embedding_dim = embedding_dim\n",
+ " self.warmup_steps = warmup_steps\n",
+ "\n",
+ " self._embedding_dim = ops.cast(self.embedding_dim, \"float32\")\n",
+ " # Numerical stability adjustment on torch, which is less precise\n",
+ " self._lr_adjust = 0.1 if keras.backend.backend() == \"torch\" else 1.0\n",
+ "\n",
+ " def get_config(self):\n",
+ " return {\"embedding_dim\": self.embedding_dim, \"warmup_steps\": self.warmup_steps}\n",
+ "\n",
+ " def __call__(self, step):\n",
+ " step_rsqrt = ops.rsqrt(ops.cast(step, \"float32\"))\n",
+ " warmup_adjust = step * (self.warmup_steps**-1.5)\n",
+ " output = ops.rsqrt(self._embedding_dim) * ops.minimum(step_rsqrt, warmup_adjust)\n",
+ " return self._lr_adjust * output\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Training the model\n",
+ "\n",
+ "We can now train the model on the Maestro dataset. First, we define a training\n",
+ "function. This function compiles the model, trains it, and saves the best model\n",
+ "checkpoint. This way, we can continue training from the best model checkpoint\n",
+ "if needed."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def train_model(model, train_ds, val_ds, epochs=15):\n",
+ " # Configure optimizer\n",
+ " learning_rate = CustomSchedule(CONFIG.embedding_dim)\n",
+ " optimizer = optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)\n",
+ "\n",
+ " # Compile the model\n",
+ " model.compile(optimizer=optimizer, loss=train_loss)\n",
+ "\n",
+ " # Train the model\n",
+ " save_cb = callbacks.ModelCheckpoint(CONFIG.model_out, save_best_only=True)\n",
+ " model.fit(\n",
+ " train_ds, validation_data=val_ds, epochs=epochs, callbacks=[save_cb], verbose=2\n",
+ " )\n",
+ " return model\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "We can now train the model on the Maestro dataset. If a model checkpoint exists,\n",
+ "we can load it and continue training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "if path.exists(CONFIG.model_out):\n",
+ " model = keras.models.load_model(CONFIG.model_out)\n",
+ " # Comment out to continue model training from the checkpoint\n",
+ " # train_model(model, train_dataset, val_dataset, epochs=10)\n",
+ "else:\n",
+ " # Train the model\n",
+ " model = train_model(MusicTransformerDecoder(), train_dataset, val_dataset)\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Generate music\n",
+ "\n",
+ "We can now generate music using the trained model. We use an existing MIDI file\n",
+ "as a seed and generate a new sequence."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def generate_music(model, seed_path, length=1024, out_dir=None, top_k=None):\n",
+ " # Ensure the output directory exists\n",
+ " out_dir = out_dir if out_dir is not None else tempfile.mkdtemp()\n",
+ " os.makedirs(out_dir, exist_ok=True)\n",
+ "\n",
+ " # Get some tokens from the MIDI file\n",
+ " inputs = midi_tokenizer.encode_midi(seed_path)[100:125]\n",
+ " print(f\"Seed tokens: {inputs}\")\n",
+ "\n",
+ " # Generate music that follows the input tokens until the maximum length\n",
+ " result = model.generate(inputs, length=length, top_k=top_k)\n",
+ "\n",
+ " output_path = path.join(out_dir, path.basename(seed_path).split(\".\")[0] + \".mid\")\n",
+ " midi_tokenizer.decode_midi(result, output_path)\n",
+ " return output_path\n",
+ "\n",
+ "\n",
+ "output_file = generate_music(model, val_paths[-1], out_dir=\"tmp/\", top_k=15)\n",
+ "print(visualize_midi(output_file, out_dir=\"tmp/\")) # Saved audio path\n",
+ "visualize_midi(output_file) # Display the audio if in a Jupyter notebook"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Conclusion\n",
+ "\n",
+ "In this example, we learned how to build a music generation model using a custom\n",
+ "Transformer decoder architecture.\n",
+ "\n",
+ "We did it following the Music Transformer paper by Huang et al. (2018).\n",
+ "To do so we had to:\n",
+ "\n",
+ "- Define a custom loss function and learning rate schedule.\n",
+ "- Define a custom attention mechanism.\n",
+ "- Preprocess MIDI files into a tokenized format.\n",
+ "\n",
+ "After training the model on the Maestro dataset, we generated music sequences\n",
+ "using a seed MIDI file.\n",
+ "\n",
+ "### Next steps\n",
+ "\n",
+ "We could further improve inference times by caching attention weights during the\n",
+ "forward pass, in a similar way as `keras_hub` `CausalLM` models, which use the\n",
+ "`CachedMultiHeadAttention` layer."
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "midi_generation_with_transformer",
+ "private_outputs": false,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/examples/generative/md/midi_generation_with_transformer.md b/examples/generative/md/midi_generation_with_transformer.md
new file mode 100644
index 0000000000..a0280e50c3
--- /dev/null
+++ b/examples/generative/md/midi_generation_with_transformer.md
@@ -0,0 +1,1059 @@
+# Music Generation with Transformer Models
+
+**Author:** [Joaquin Jimenez](https://github.com/johacks/)
+**Date created:** 2024/11/22
+**Last modified:** 2024/11/26
+**Description:** Use a Transformer model to train on MIDI data and generate music sequences.
+
+
+ [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/generative/ipynb/midi_generation_with_transformer.ipynb) • [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/generative/midi_generation_with_transformer.py)
+
+
+
+---
+## Introduction
+
+In this tutorial, we learn how to build a music generation model using a
+Transformer decode-only architecture.
+The model is trained on the [Maestro dataset](https://magenta.tensorflow.org/datasets/maestro)
+and implemented using keras 3.
+In the process, we explore MIDI tokenization, and relative global attention mechanisms.
+
+This example is based on the paper "Music Transformer" by Huang et al. (2018).
+Check out the original [paper](https://arxiv.org/abs/1809.04281) and
+[code](https://github.com/jason9693/MusicTransformer-tensorflow2.0).
+
+---
+## Setup
+
+Before we start, let's import and install all the libraries we need.
+
+
+```python
+!pip install -qq midi_neural_processor
+!pip install -qq keras_hub
+!pip install -qq "keras>=3.6.0" # Allows use of keras.utils.Config.
+```
+
+### Optional dependencies
+
+To hear the audio, install the following additional dependencies:
+
+
+```python
+!sudo apt-get -qq install -y fluidsynth 2> /dev/null
+!pip install -qq pyfluidsynth scipy
+```
+
+
+```python
+import os
+import random
+import tempfile
+
+import keras
+import midi_neural_processor.processor as midi_tokenizer
+import numpy as np
+from keras import callbacks, layers, ops, optimizers, utils
+from keras_hub import layers as hub_layers
+from os import path
+```
+
+---
+## Configuration
+
+Lets define the configuration for the model and the dataset to be used in this example.
+
+
+```python
+event_range = midi_tokenizer.RANGE_NOTE_ON
+event_range += midi_tokenizer.RANGE_NOTE_OFF
+event_range += midi_tokenizer.RANGE_TIME_SHIFT
+event_range += midi_tokenizer.RANGE_VEL
+CONFIG = utils.Config(
+ max_sequence_len=2048,
+ embedding_dim=256,
+ num_transformer_blocks=6,
+ batch_size=6,
+ token_pad=event_range,
+ token_start_of_sentence=event_range + 1,
+ token_end_of_sentence=event_range + 2,
+ vocabulary_size=event_range + 3,
+ model_out="tmp/music_transformer.keras",
+ seed=42,
+)
+utils.set_random_seed(CONFIG.seed)
+
+```
+
+---
+## Maestro dataset
+
+The Maestro dataset contains MIDI files for piano performances.
+
+### Download the dataset
+
+We now download and extract the dataset, then move the MIDI files to a new directory.
+
+
+```python
+
+def download_maestro(output_dir=None):
+ """Download the Maestro MIDI dataset.
+ Extracted from: https://magenta.tensorflow.org/datasets/maestro
+ """
+ # Ensure the output directory exists
+ output_dir = tempfile.mkdtemp() if output_dir is None else output_dir
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Download and extract zip file
+ dir = utils.get_file(
+ origin="https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip",
+ extract=True,
+ )
+
+ # Gather all MIDI files
+ midi_files, file_paths = set(), list()
+ for root, _, files in os.walk(dir):
+ for file in files:
+ if file.lower().endswith(".midi") or file.lower().endswith(".mid"):
+ midi_files.add(path.join(root, file))
+
+ # Move the files to the output directory
+ for file in sorted(midi_files):
+ file_paths.append(new_path := path.join(output_dir, path.basename(file)))
+ os.rename(file, new_path)
+ return file_paths
+
+
+paths = list(sorted(download_maestro(output_dir="datasets/maestro")))
+output_dir = path.dirname(paths[0])
+
+```
+
+### Split the dataset
+
+We can now split the dataset into training and validation sets.
+
+
+```python
+indices = np.random.permutation(len(paths))
+split = int(len(paths) * 0.1)
+train_paths = [paths[i] for i in indices[split:]]
+val_paths = [paths[i] for i in indices[:split]]
+```
+
+### Hear a MIDI file
+
+We use the pretty_midi library and fluidsynth to convert MIDI files into waveform audio.
+This allows us to listen to the data samples before and after processing.
+
+The following dependencies are required to play the audio:
+- fluidsynth: `sudo apt install -y fluidsynth`
+- pyfluidsynth, scipy: `pip install pyfluidsynth scipy`
+
+
+```python
+
+def visualize_midi(midi_path, sampling_rate=16000, seconds=15, out_dir=None):
+ import pretty_midi
+ from scipy.io.wavfile import write as write_wav
+ from IPython.display import Audio
+
+ # Create the audio waveform
+ pretty_midi_file = pretty_midi.PrettyMIDI(midi_path)
+ waveform = pretty_midi_file.fluidsynth(fs=sampling_rate)[: seconds * sampling_rate]
+
+ # Display the audio if no path is provided
+ if out_dir is None:
+ # IPython display
+ return Audio(waveform, rate=sampling_rate)
+
+ # Save the audio to a file
+ os.makedirs(out_dir, exist_ok=True)
+ audio_path = path.join(out_dir, path.basename(midi_path).split(".")[0] + ".wav")
+ write_wav(audio_path, sampling_rate, (waveform * 32767).astype(np.int16))
+ return audio_path
+
+
+print(visualize_midi(train_paths[0], out_dir="tmp/")) # Saved audio path
+visualize_midi(train_paths[0]) # Display the audio if in a Jupyter notebook
+
+```
+
+