From 275300acffba72b8f1c67b9ceb327794502383f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joaqu=C3=ADn=20Jim=C3=A9nez?= <37051995+johacks@users.noreply.github.com> Date: Wed, 27 Nov 2024 11:09:55 +0100 Subject: [PATCH] Add Keras 3 example for "Transformer model for MIDI music generation" (#1992) * Pre-PR review: base script file working in tf * Fix too slow inference on jax * Refactor MIDI generation script: PR feedback. - Update import statement for keras_hub - Fix documentation for audio dependencies - Remove unnecessary markdown annotation Other changes: - Update last modified date in MIDI generation example - Fix: Increase training epochs from 1 to 20 in MIDI generation example - Fix: use flip for unsupported negative step in torch - Fix: Remove unused learning rate * Torch backend fixes: add cast on generate, change learning rate scale * Compact script * Add auto generated files --- .../midi_generation_with_transformer.ipynb | 1016 ++++++++++++++++ .../md/midi_generation_with_transformer.md | 1059 +++++++++++++++++ .../midi_generation_with_transformer.py | 722 +++++++++++ 3 files changed, 2797 insertions(+) create mode 100644 examples/generative/ipynb/midi_generation_with_transformer.ipynb create mode 100644 examples/generative/md/midi_generation_with_transformer.md create mode 100644 examples/generative/midi_generation_with_transformer.py diff --git a/examples/generative/ipynb/midi_generation_with_transformer.ipynb b/examples/generative/ipynb/midi_generation_with_transformer.ipynb new file mode 100644 index 0000000000..e034e37edc --- /dev/null +++ b/examples/generative/ipynb/midi_generation_with_transformer.ipynb @@ -0,0 +1,1016 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "# Music Generation with Transformer Models\n", + "\n", + "**Author:** [Joaquin Jimenez](https://github.com/johacks/)
\n", + "**Date created:** 2024/11/22
\n", + "**Last modified:** 2024/11/26
\n", + "**Description:** Use a Transformer model to train on MIDI data and generate music sequences." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Introduction\n", + "\n", + "In this tutorial, we learn how to build a music generation model using a\n", + "Transformer decode-only architecture.\n", + "The model is trained on the [Maestro dataset](https://magenta.tensorflow.org/datasets/maestro)\n", + "and implemented using keras 3.\n", + "In the process, we explore MIDI tokenization, and relative global attention mechanisms.\n", + "\n", + "This example is based on the paper \"Music Transformer\" by Huang et al. (2018).\n", + "Check out the original [paper](https://arxiv.org/abs/1809.04281) and\n", + "[code](https://github.com/jason9693/MusicTransformer-tensorflow2.0)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Setup\n", + "\n", + "Before we start, let's import and install all the libraries we need." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "!pip install -qq midi_neural_processor\n", + "!pip install -qq keras_hub\n", + "!pip install -qq \"keras>=3.6.0\" # Allows use of keras.utils.Config." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Optional dependencies\n", + "\n", + "To hear the audio, install the following additional dependencies:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "!sudo apt-get -qq install -y fluidsynth 2> /dev/null\n", + "!pip install -qq pyfluidsynth scipy" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "import os\n", + "import random\n", + "import tempfile\n", + "\n", + "import keras\n", + "import midi_neural_processor.processor as midi_tokenizer\n", + "import numpy as np\n", + "from keras import callbacks, layers, ops, optimizers, utils\n", + "from keras_hub import layers as hub_layers\n", + "from os import path" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Configuration\n", + "\n", + "Lets define the configuration for the model and the dataset to be used in this example." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "event_range = midi_tokenizer.RANGE_NOTE_ON\n", + "event_range += midi_tokenizer.RANGE_NOTE_OFF\n", + "event_range += midi_tokenizer.RANGE_TIME_SHIFT\n", + "event_range += midi_tokenizer.RANGE_VEL\n", + "CONFIG = utils.Config(\n", + " max_sequence_len=2048,\n", + " embedding_dim=256,\n", + " num_transformer_blocks=6,\n", + " batch_size=6,\n", + " token_pad=event_range,\n", + " token_start_of_sentence=event_range + 1,\n", + " token_end_of_sentence=event_range + 2,\n", + " vocabulary_size=event_range + 3,\n", + " model_out=\"tmp/music_transformer.keras\",\n", + " seed=42,\n", + ")\n", + "utils.set_random_seed(CONFIG.seed)\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Maestro dataset\n", + "\n", + "The Maestro dataset contains MIDI files for piano performances.\n", + "\n", + "### Download the dataset\n", + "\n", + "We now download and extract the dataset, then move the MIDI files to a new directory." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "def download_maestro(output_dir=None):\n", + " \"\"\"Download the Maestro MIDI dataset.\n", + " Extracted from: https://magenta.tensorflow.org/datasets/maestro\n", + " \"\"\"\n", + " # Ensure the output directory exists\n", + " output_dir = tempfile.mkdtemp() if output_dir is None else output_dir\n", + " os.makedirs(output_dir, exist_ok=True)\n", + "\n", + " # Download and extract zip file\n", + " dir = utils.get_file(\n", + " origin=\"https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip\",\n", + " extract=True,\n", + " )\n", + "\n", + " # Gather all MIDI files\n", + " midi_files, file_paths = set(), list()\n", + " for root, _, files in os.walk(dir):\n", + " for file in files:\n", + " if file.lower().endswith(\".midi\") or file.lower().endswith(\".mid\"):\n", + " midi_files.add(path.join(root, file))\n", + "\n", + " # Move the files to the output directory\n", + " for file in sorted(midi_files):\n", + " file_paths.append(new_path := path.join(output_dir, path.basename(file)))\n", + " os.rename(file, new_path)\n", + " return file_paths\n", + "\n", + "\n", + "paths = list(sorted(download_maestro(output_dir=\"datasets/maestro\")))\n", + "output_dir = path.dirname(paths[0])\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Split the dataset\n", + "\n", + "We can now split the dataset into training and validation sets." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "indices = np.random.permutation(len(paths))\n", + "split = int(len(paths) * 0.1)\n", + "train_paths = [paths[i] for i in indices[split:]]\n", + "val_paths = [paths[i] for i in indices[:split]]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Hear a MIDI file\n", + "\n", + "We use the pretty_midi library and fluidsynth to convert MIDI files into waveform audio.\n", + "This allows us to listen to the data samples before and after processing.\n", + "\n", + "The following dependencies are required to play the audio:\n", + "- fluidsynth: `sudo apt install -y fluidsynth`\n", + "- pyfluidsynth, scipy: `pip install pyfluidsynth scipy`" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "def visualize_midi(midi_path, sampling_rate=16000, seconds=15, out_dir=None):\n", + " import pretty_midi\n", + " from scipy.io.wavfile import write as write_wav\n", + " from IPython.display import Audio\n", + "\n", + " # Create the audio waveform\n", + " pretty_midi_file = pretty_midi.PrettyMIDI(midi_path)\n", + " waveform = pretty_midi_file.fluidsynth(fs=sampling_rate)[: seconds * sampling_rate]\n", + "\n", + " # Display the audio if no path is provided\n", + " if out_dir is None:\n", + " # IPython display\n", + " return Audio(waveform, rate=sampling_rate)\n", + "\n", + " # Save the audio to a file\n", + " os.makedirs(out_dir, exist_ok=True)\n", + " audio_path = path.join(out_dir, path.basename(midi_path).split(\".\")[0] + \".wav\")\n", + " write_wav(audio_path, sampling_rate, (waveform * 32767).astype(np.int16))\n", + " return audio_path\n", + "\n", + "\n", + "print(visualize_midi(train_paths[0], out_dir=\"tmp/\")) # Saved audio path\n", + "visualize_midi(train_paths[0]) # Display the audio if in a Jupyter notebook\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Tokenize the data\n", + "\n", + "We now preprocess the MIDI files into a tokenized format for training." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "def encode_midi_task(midi_path):\n", + " \"\"\"Define a task that tokenizes a MIDI file.\"\"\"\n", + " import midi_neural_processor.processor as midi_tokenizer\n", + "\n", + " return midi_tokenizer.encode_midi(midi_path)\n", + "\n", + "\n", + "def preprocess_midi_files(file_paths, save_dir=None):\n", + " \"\"\"Preprocess a list of MIDI files and save the notes to a file.\"\"\"\n", + " from multiprocessing import Pool, cpu_count\n", + "\n", + " # Assume all files are in the same directory and save to the same directory\n", + " save_dir = path.dirname(file_paths[0]) if save_dir is None else save_dir\n", + " os.makedirs(save_dir, exist_ok=True)\n", + "\n", + " # Check if the notes have already been preprocessed\n", + " output_file = path.join(save_dir, \"notes.npz\")\n", + " if path.exists(output_file):\n", + " npz_file = np.load(output_file)\n", + " return [npz_file[key] for key in npz_file.keys()]\n", + "\n", + " # Preprocess the MIDI files in parallel\n", + " progbar = utils.Progbar(len(file_paths), unit_name=\"MIDI_file\", interval=5)\n", + " pool = Pool(cpu_count() - 1)\n", + " all_notes = []\n", + " for notes in pool.imap_unordered(encode_midi_task, file_paths):\n", + " progbar.add(1)\n", + " all_notes.append(np.array(notes))\n", + "\n", + " # Save the notes to a file\n", + " np.savez(output_file, *all_notes)\n", + " return all_notes\n", + "\n", + "\n", + "train_midis = preprocess_midi_files(train_paths, path.join(output_dir, \"train\"))\n", + "val_midis = preprocess_midi_files(val_paths, path.join(output_dir, \"val\"))\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Dataset objects\n", + "\n", + "We now define a dataset class that yields batches of input sequences and target sequences." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "class MidiDataset(utils.PyDataset):\n", + " \"\"\"A dataset for MIDI files that yields batches of input sequences and target sequences.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " encoded_midis,\n", + " batch_size=CONFIG.batch_size,\n", + " max_sequence_len=CONFIG.max_sequence_len,\n", + " ):\n", + " super(MidiDataset, self).__init__()\n", + " self.batch_size = batch_size\n", + " self.max_sequence_len = max_sequence_len\n", + " self.encoded_midis = encoded_midis\n", + " batches, last_batch_size = divmod(len(encoded_midis), batch_size)\n", + " self._num_batches = batches + int(last_batch_size > 0)\n", + "\n", + " def __len__(self):\n", + " \"\"\"Get the number of batches.\"\"\"\n", + " return self._num_batches\n", + "\n", + " def __getitem__(self, idx):\n", + " \"\"\"Generate random inputs and corresponding targets for the model.\"\"\"\n", + " # Same as in the original paper, we always get a random batch.\n", + " # See: https://github.com/jason9693/MusicTransformer-tensorflow2.0/blob/f7c06c0cb2e9cdddcbf6db779cb39cd650282778/data.py\n", + " batch = random.sample(self.encoded_midis, k=self.batch_size)\n", + "\n", + " # Convert the batch to sequences\n", + " batch_data = [\n", + " self._get_sequence(midi, self.max_sequence_len + 1) for midi in batch\n", + " ]\n", + " batch_data = np.array(batch_data)\n", + "\n", + " # Split the data into input and target sequences\n", + " return batch_data[:, :-1], batch_data[:, 1:]\n", + "\n", + " def _get_sequence(self, data, max_length):\n", + " \"\"\"Get a random sequence of notes from a file.\"\"\"\n", + " # Truncate or pad the sequence\n", + " if len(data) > max_length:\n", + " start = random.randrange(0, len(data) - max_length)\n", + " data = data[start : start + max_length]\n", + " elif len(data) < max_length:\n", + " data = np.append(data, CONFIG.token_end_of_sentence)\n", + "\n", + " # Pad the sequence if necessary\n", + " if len(data) < max_length:\n", + " data = np.concatenate(\n", + " (data, np.full(max_length - len(data), CONFIG.token_pad))\n", + " )\n", + " return np.asanyarray(data, dtype=\"int32\")\n", + "\n", + "\n", + "train_dataset, val_dataset = MidiDataset(train_midis), MidiDataset(val_midis)\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Model definition\n", + "\n", + "It is time to define the model architecture. We use a Transformer decoder\n", + "architecture with a custom attention mechanism, relative global attention.\n", + "\n", + "### Relative Global Attention\n", + "\n", + "The following code implements the Relative Global Attention layer. It is used\n", + "in place of the standard multi-head attention layer in the Transformer decoder.\n", + "The main difference is that it includes a relative positional encoding that\n", + "allows the model to learn relative positional information between tokens." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "@keras.utils.register_keras_serializable()\n", + "class RelativeGlobalAttention(layers.Layer):\n", + " \"\"\"\n", + " From Music Transformer (Huang et al., 2018)\n", + " https://arxiv.org/abs/1809.04281\n", + " \"\"\"\n", + "\n", + " def __init__(self, num_heads, embedding_dim, max_sequence_len, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.key_length = None\n", + " self.max_sequence_len = max_sequence_len\n", + " self.relative_embedding = None\n", + " self.num_heads = num_heads\n", + " self.embedding_dim = embedding_dim\n", + " self.head_dim = embedding_dim // num_heads\n", + " self.query_dense = layers.Dense(int(self.embedding_dim))\n", + " self.key_dense = layers.Dense(int(self.embedding_dim))\n", + " self.value_dense = layers.Dense(int(self.embedding_dim))\n", + " self.output_dense = layers.Dense(embedding_dim, name=\"output\")\n", + "\n", + " def build(self, input_shape):\n", + " self.query_length = input_shape[0][1]\n", + " self.key_length = input_shape[1][1]\n", + " self.relative_embedding = self.add_weight(\n", + " (self.max_sequence_len, int(self.head_dim)), name=\"relative_embedding\"\n", + " )\n", + "\n", + " def _apply_dense_layer_and_split_heads(self, inputs, dense_layer):\n", + " # Apply linear transformation\n", + " inputs = dense_layer(inputs)\n", + " new_shape = ops.shape(inputs)\n", + " # Reshape to split by attention heads\n", + " reshaped = ops.reshape(inputs, (new_shape[0], new_shape[1], self.num_heads, -1))\n", + " # Transpose for head-first format\n", + " return ops.transpose(reshaped, (0, 2, 1, 3))\n", + "\n", + " def call(self, inputs, mask=None):\n", + " # Compute Q, K, V: Batch, head, sequence, features\n", + " query = self._apply_dense_layer_and_split_heads(inputs[0], self.query_dense)\n", + " key = self._apply_dense_layer_and_split_heads(inputs[1], self.key_dense)\n", + " value = self._apply_dense_layer_and_split_heads(inputs[2], self.value_dense)\n", + "\n", + " # Compute scaled dot-product attention scores\n", + " attention_scores = ops.matmul(query, ops.transpose(key, [0, 1, 3, 2]))\n", + "\n", + " # Compute relative positional encoding and combine with attention scores\n", + " start_idx = max(0, self.max_sequence_len - ops.shape(query)[2])\n", + " relative_embedding = self.relative_embedding[start_idx:, :]\n", + " attention_scores += self._compute_attention_scores(query, relative_embedding)\n", + " logits = attention_scores / ops.sqrt(self.head_dim)\n", + "\n", + " # Apply mask if provided\n", + " if mask is not None:\n", + " logits += ops.cast(mask, \"float32\") * -1e9\n", + "\n", + " # Compute attention weights\n", + " attention_weights = ops.nn.softmax(logits, axis=-1)\n", + " attention_output = ops.matmul(attention_weights, value)\n", + "\n", + " # Merge heads and apply final linear transformation\n", + " merged_attention = ops.transpose(attention_output, (0, 2, 1, 3))\n", + " merged_attention = ops.reshape(\n", + " merged_attention, (ops.shape(merged_attention)[0], -1, self.embedding_dim)\n", + " )\n", + " output = self.output_dense(merged_attention)\n", + "\n", + " return output, attention_weights\n", + "\n", + " def _compute_attention_scores(self, query, relative_embedding):\n", + " \"\"\"\n", + " Compute relative attention scores using positional encodings.\n", + " \"\"\"\n", + " relative_scores = ops.einsum(\"bhld, md->bhlm\", query, relative_embedding)\n", + " relative_scores = self._apply_mask_to_relative_scores(relative_scores)\n", + " return self._skew_attention_scores(relative_scores)\n", + "\n", + " def _apply_mask_to_relative_scores(self, scores):\n", + " \"\"\"\n", + " Apply masking to relative positional scores to ignore future positions.\n", + " \"\"\"\n", + " mask = ops.flip(\n", + " ops.tri(scores.shape[-2], scores.shape[-1], dtype=\"float32\"), axis=1\n", + " )\n", + " return mask * scores\n", + "\n", + " def _skew_attention_scores(self, scores):\n", + " \"\"\"\n", + " Perform skewing operation to align relative attention scores with the sequence.\n", + " \"\"\"\n", + " padded_scores = ops.pad(scores, ((0, 0), (0, 0), (0, 0), (1, 0)))\n", + " padded_shape = ops.shape(padded_scores)\n", + " reshaped_scores = ops.reshape(\n", + " padded_scores, (-1, padded_shape[1], padded_shape[-1], padded_shape[-2])\n", + " )\n", + " skewed_scores = reshaped_scores[:, :, 1:, :]\n", + "\n", + " if self.key_length > self.query_length:\n", + " size_diff = self.key_length - self.query_length\n", + " return ops.pad(skewed_scores, [[0, 0], [0, 0], [0, 0], [0, size_diff]])\n", + " else:\n", + " return skewed_scores[:, :, :, : self.key_length]\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Decoder Layer\n", + "\n", + "Using the RelativeGlobalAttention layer, we can define the DecoderLayer. It is mostly like\n", + "the standard Transformer decoder layer but with the custom attention mechanism." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "@keras.utils.register_keras_serializable()\n", + "class DecoderLayer(layers.Layer):\n", + " def __init__(self, embedding_dim, num_heads, max_sequence_len, dropout=0.1):\n", + " super(DecoderLayer, self).__init__()\n", + "\n", + " # Initialize attributes\n", + " self.embedding_dim = embedding_dim\n", + " self.num_heads = num_heads\n", + " self.max_sequence_len = max_sequence_len\n", + "\n", + " # Initialize layers\n", + " self.relative_global_attention_1 = RelativeGlobalAttention(\n", + " num_heads, embedding_dim, max_sequence_len\n", + " )\n", + "\n", + " self.feed_forward_network_pre = layers.Dense(self.embedding_dim // 2, \"relu\")\n", + " self.feed_forward_network_pos = layers.Dense(self.embedding_dim)\n", + "\n", + " self.layer_normalization_1 = layers.LayerNormalization(epsilon=1e-6)\n", + " self.layer_normalization_2 = layers.LayerNormalization(epsilon=1e-6)\n", + "\n", + " self.dropout_1 = layers.Dropout(dropout)\n", + " self.dropout_2 = layers.Dropout(dropout)\n", + "\n", + " def call(self, inputs, mask=None, training=False):\n", + " # Attention block. Inputs are (query, key, value)\n", + " attention_out, attention_weights = self.relative_global_attention_1(\n", + " (inputs, inputs, inputs), mask=mask\n", + " )\n", + " attention_out = self.dropout_1(attention_out, training=training)\n", + " attention_out_normalized = self.layer_normalization_1(attention_out + inputs)\n", + "\n", + " ffn_out = self.feed_forward_network_pre(attention_out)\n", + " ffn_out = self.feed_forward_network_pos(ffn_out)\n", + " ffn_out = self.dropout_2(ffn_out, training=training)\n", + " out = self.layer_normalization_2(attention_out_normalized + ffn_out)\n", + "\n", + " return out, attention_weights\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Decoder\n", + "\n", + "The Decoder layer is composed of multiple DecoderLayer blocks. It also includes\n", + "an embedding layer that converts our tokenized input into an embedding representation." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "@keras.utils.register_keras_serializable()\n", + "class Decoder(layers.Layer):\n", + " def __init__(\n", + " self, embedding_dim, vocabulary_size, max_sequence_len, num_blocks, dropout\n", + " ):\n", + " super(Decoder, self).__init__()\n", + "\n", + " self.embedding_dim = embedding_dim\n", + " self.num_blocks = num_blocks\n", + "\n", + " self.embedding = layers.Embedding(vocabulary_size, self.embedding_dim)\n", + " self.positional_encoding = hub_layers.SinePositionEncoding()\n", + "\n", + " self.decode_layers = [\n", + " DecoderLayer(\n", + " embedding_dim, embedding_dim // 64, max_sequence_len, dropout=dropout\n", + " )\n", + " for _ in range(num_blocks)\n", + " ]\n", + " self.dropout = layers.Dropout(dropout)\n", + "\n", + " def call(self, inputs, mask=None, training=False, return_attention_weights=False):\n", + " weights = []\n", + "\n", + " # Adding embedding and position encoding.\n", + " x = self.embedding(inputs)\n", + " x = x * ops.sqrt(ops.cast(self.embedding_dim, \"float32\"))\n", + " x = x + self.positional_encoding(x)\n", + " x = self.dropout(x, training=training)\n", + "\n", + " # Passing through the transformer blocks.\n", + " for i in range(self.num_blocks):\n", + " x, w = self.decode_layers[i](x, mask=mask, training=training)\n", + " weights.append(w)\n", + " if return_attention_weights:\n", + " return x, weights\n", + " return x\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Music Transformer Decoder\n", + "\n", + "With the above layers defined, we can now define the MusicTransformerDecoder model. It applies\n", + "a linear transformation to the output of the decoder to get the logits for each token." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "@keras.utils.register_keras_serializable()\n", + "class MusicTransformerDecoder(keras.Model):\n", + " def __init__(\n", + " self,\n", + " embedding_dim=CONFIG.embedding_dim,\n", + " vocabulary_size=CONFIG.vocabulary_size,\n", + " num_blocks=CONFIG.num_transformer_blocks,\n", + " max_sequence_len=CONFIG.max_sequence_len,\n", + " dropout=0.2,\n", + " ):\n", + " # Initialize attributes\n", + " super(MusicTransformerDecoder, self).__init__()\n", + " self.embedding_dim = embedding_dim\n", + " self.vocabulary_size = vocabulary_size\n", + " self.num_blocks = num_blocks\n", + " self.max_sequence_len = max_sequence_len\n", + "\n", + " # Initialize layers\n", + " # Transformer decoder\n", + " self.decoder = Decoder(\n", + " embedding_dim, vocabulary_size, max_sequence_len, num_blocks, dropout\n", + " )\n", + " # Output layer\n", + " self.fc = layers.Dense(self.vocabulary_size, activation=None, name=\"output\")\n", + "\n", + " @staticmethod\n", + " def get_look_ahead_mask(max_sequence_len, inputs):\n", + " sequence_length = min(max_sequence_len, inputs.shape[1])\n", + " sequence_mask = ops.logical_not(\n", + " ops.tri(sequence_length, sequence_length, dtype=\"bool\")\n", + " )\n", + "\n", + " inputs = ops.cast(inputs[:, None, None, :], \"int32\")\n", + " output_pad_tensor = ops.ones_like(inputs) * CONFIG.token_pad\n", + " decoder_output_mask = ops.equal(inputs, output_pad_tensor)\n", + " return ops.cast(ops.logical_or(decoder_output_mask, sequence_mask), \"int32\")\n", + "\n", + " def call(self, inputs, training=False):\n", + " mask = self.get_look_ahead_mask(self.max_sequence_len, inputs)\n", + " decoding = self.decoder(\n", + " inputs, mask=mask, training=training, return_attention_weights=False\n", + " )\n", + " return self.fc(decoding)\n", + "\n", + " # --- Sequence generation methods\n", + "\n", + " def generate(self, inputs: list, length=CONFIG.max_sequence_len, top_k=5):\n", + " inputs = ops.convert_to_tensor([inputs])\n", + "\n", + " # Generate a new token using output distribution at given index\n", + " def generate_token(inputs, end_idx):\n", + " distribution = ops.stop_gradient(self.call(inputs)[0, end_idx])\n", + "\n", + " # Select the top-k tokens and their probabilities\n", + " top_k_distribution, top_k_indices = ops.top_k(distribution, k=top_k)\n", + "\n", + " # Sample from the top-k probabilities\n", + " new_token_idx = keras.random.categorical(top_k_distribution[None, :], 1)\n", + " return ops.take(top_k_indices, new_token_idx[0])\n", + "\n", + " # Compute the number of tokens to add\n", + " added_tokens = min(length, self.max_sequence_len - inputs.shape[1])\n", + " progbar = utils.Progbar(added_tokens, unit_name=\"token\", interval=5)\n", + "\n", + " # Pad the input sequence that will be filled with generated tokens\n", + " out = ops.pad(inputs, ((0, 0), (0, added_tokens)), \"constant\", CONFIG.token_pad)\n", + "\n", + " # Generate tokens using top-k sampling\n", + " for token_idx in range(inputs.shape[1] - 1, inputs.shape[1] - 1 + added_tokens):\n", + " token = ops.cast(generate_token(out, end_idx=token_idx), out.dtype)\n", + " out = ops.scatter_update(out, ((0, token_idx + 1),), token)\n", + " progbar.add(1)\n", + "\n", + " return ops.convert_to_numpy(out[0])\n", + "\n", + " # --- Serialization methods\n", + "\n", + " def get_config(self):\n", + " atts = [\"embedding_dim\", \"vocabulary_size\", \"num_blocks\", \"max_sequence_len\"]\n", + " return {a: getattr(self, a) for a in atts}\n", + "\n", + " @classmethod\n", + " def from_config(cls, config):\n", + " return cls(**config)\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Loss function\n", + "\n", + "We define a custom loss function that computes the categorical cross-entropy\n", + "loss for the model. It is computed only for non-padding tokens and uses\n", + "`from_logits=True` since the model outputs logits." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "@keras.utils.register_keras_serializable()\n", + "def train_loss(y_true, y_pred):\n", + " mask = ops.cast(ops.logical_not(ops.equal(y_true, CONFIG.token_pad)), \"float32\")\n", + " y_true = ops.one_hot(ops.cast(y_true, \"int32\"), CONFIG.vocabulary_size)\n", + " return ops.categorical_crossentropy(y_true, y_pred, from_logits=True) * mask\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Learning rate schedule\n", + "\n", + "Following the Music Transformer paper, we define an adapted exponential decay\n", + "learning rate schedule that takes into account the embedding dimension." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "@keras.utils.register_keras_serializable()\n", + "class CustomSchedule(optimizers.schedules.LearningRateSchedule):\n", + " def __init__(self, embedding_dim, warmup_steps=4000):\n", + " super(CustomSchedule, self).__init__()\n", + "\n", + " self.embedding_dim = embedding_dim\n", + " self.warmup_steps = warmup_steps\n", + "\n", + " self._embedding_dim = ops.cast(self.embedding_dim, \"float32\")\n", + " # Numerical stability adjustment on torch, which is less precise\n", + " self._lr_adjust = 0.1 if keras.backend.backend() == \"torch\" else 1.0\n", + "\n", + " def get_config(self):\n", + " return {\"embedding_dim\": self.embedding_dim, \"warmup_steps\": self.warmup_steps}\n", + "\n", + " def __call__(self, step):\n", + " step_rsqrt = ops.rsqrt(ops.cast(step, \"float32\"))\n", + " warmup_adjust = step * (self.warmup_steps**-1.5)\n", + " output = ops.rsqrt(self._embedding_dim) * ops.minimum(step_rsqrt, warmup_adjust)\n", + " return self._lr_adjust * output\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Training the model\n", + "\n", + "We can now train the model on the Maestro dataset. First, we define a training\n", + "function. This function compiles the model, trains it, and saves the best model\n", + "checkpoint. This way, we can continue training from the best model checkpoint\n", + "if needed." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "def train_model(model, train_ds, val_ds, epochs=15):\n", + " # Configure optimizer\n", + " learning_rate = CustomSchedule(CONFIG.embedding_dim)\n", + " optimizer = optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)\n", + "\n", + " # Compile the model\n", + " model.compile(optimizer=optimizer, loss=train_loss)\n", + "\n", + " # Train the model\n", + " save_cb = callbacks.ModelCheckpoint(CONFIG.model_out, save_best_only=True)\n", + " model.fit(\n", + " train_ds, validation_data=val_ds, epochs=epochs, callbacks=[save_cb], verbose=2\n", + " )\n", + " return model\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "We can now train the model on the Maestro dataset. If a model checkpoint exists,\n", + "we can load it and continue training." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "if path.exists(CONFIG.model_out):\n", + " model = keras.models.load_model(CONFIG.model_out)\n", + " # Comment out to continue model training from the checkpoint\n", + " # train_model(model, train_dataset, val_dataset, epochs=10)\n", + "else:\n", + " # Train the model\n", + " model = train_model(MusicTransformerDecoder(), train_dataset, val_dataset)\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Generate music\n", + "\n", + "We can now generate music using the trained model. We use an existing MIDI file\n", + "as a seed and generate a new sequence." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "def generate_music(model, seed_path, length=1024, out_dir=None, top_k=None):\n", + " # Ensure the output directory exists\n", + " out_dir = out_dir if out_dir is not None else tempfile.mkdtemp()\n", + " os.makedirs(out_dir, exist_ok=True)\n", + "\n", + " # Get some tokens from the MIDI file\n", + " inputs = midi_tokenizer.encode_midi(seed_path)[100:125]\n", + " print(f\"Seed tokens: {inputs}\")\n", + "\n", + " # Generate music that follows the input tokens until the maximum length\n", + " result = model.generate(inputs, length=length, top_k=top_k)\n", + "\n", + " output_path = path.join(out_dir, path.basename(seed_path).split(\".\")[0] + \".mid\")\n", + " midi_tokenizer.decode_midi(result, output_path)\n", + " return output_path\n", + "\n", + "\n", + "output_file = generate_music(model, val_paths[-1], out_dir=\"tmp/\", top_k=15)\n", + "print(visualize_midi(output_file, out_dir=\"tmp/\")) # Saved audio path\n", + "visualize_midi(output_file) # Display the audio if in a Jupyter notebook" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Conclusion\n", + "\n", + "In this example, we learned how to build a music generation model using a custom\n", + "Transformer decoder architecture.\n", + "\n", + "We did it following the Music Transformer paper by Huang et al. (2018).\n", + "To do so we had to:\n", + "\n", + "- Define a custom loss function and learning rate schedule.\n", + "- Define a custom attention mechanism.\n", + "- Preprocess MIDI files into a tokenized format.\n", + "\n", + "After training the model on the Maestro dataset, we generated music sequences\n", + "using a seed MIDI file.\n", + "\n", + "### Next steps\n", + "\n", + "We could further improve inference times by caching attention weights during the\n", + "forward pass, in a similar way as `keras_hub` `CausalLM` models, which use the\n", + "`CachedMultiHeadAttention` layer." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "midi_generation_with_transformer", + "private_outputs": false, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/examples/generative/md/midi_generation_with_transformer.md b/examples/generative/md/midi_generation_with_transformer.md new file mode 100644 index 0000000000..a0280e50c3 --- /dev/null +++ b/examples/generative/md/midi_generation_with_transformer.md @@ -0,0 +1,1059 @@ +# Music Generation with Transformer Models + +**Author:** [Joaquin Jimenez](https://github.com/johacks/)
+**Date created:** 2024/11/22
+**Last modified:** 2024/11/26
+**Description:** Use a Transformer model to train on MIDI data and generate music sequences. + + + [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/generative/ipynb/midi_generation_with_transformer.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/generative/midi_generation_with_transformer.py) + + + +--- +## Introduction + +In this tutorial, we learn how to build a music generation model using a +Transformer decode-only architecture. +The model is trained on the [Maestro dataset](https://magenta.tensorflow.org/datasets/maestro) +and implemented using keras 3. +In the process, we explore MIDI tokenization, and relative global attention mechanisms. + +This example is based on the paper "Music Transformer" by Huang et al. (2018). +Check out the original [paper](https://arxiv.org/abs/1809.04281) and +[code](https://github.com/jason9693/MusicTransformer-tensorflow2.0). + +--- +## Setup + +Before we start, let's import and install all the libraries we need. + + +```python +!pip install -qq midi_neural_processor +!pip install -qq keras_hub +!pip install -qq "keras>=3.6.0" # Allows use of keras.utils.Config. +``` + +### Optional dependencies + +To hear the audio, install the following additional dependencies: + + +```python +!sudo apt-get -qq install -y fluidsynth 2> /dev/null +!pip install -qq pyfluidsynth scipy +``` + + +```python +import os +import random +import tempfile + +import keras +import midi_neural_processor.processor as midi_tokenizer +import numpy as np +from keras import callbacks, layers, ops, optimizers, utils +from keras_hub import layers as hub_layers +from os import path +``` + +--- +## Configuration + +Lets define the configuration for the model and the dataset to be used in this example. + + +```python +event_range = midi_tokenizer.RANGE_NOTE_ON +event_range += midi_tokenizer.RANGE_NOTE_OFF +event_range += midi_tokenizer.RANGE_TIME_SHIFT +event_range += midi_tokenizer.RANGE_VEL +CONFIG = utils.Config( + max_sequence_len=2048, + embedding_dim=256, + num_transformer_blocks=6, + batch_size=6, + token_pad=event_range, + token_start_of_sentence=event_range + 1, + token_end_of_sentence=event_range + 2, + vocabulary_size=event_range + 3, + model_out="tmp/music_transformer.keras", + seed=42, +) +utils.set_random_seed(CONFIG.seed) + +``` + +--- +## Maestro dataset + +The Maestro dataset contains MIDI files for piano performances. + +### Download the dataset + +We now download and extract the dataset, then move the MIDI files to a new directory. + + +```python + +def download_maestro(output_dir=None): + """Download the Maestro MIDI dataset. + Extracted from: https://magenta.tensorflow.org/datasets/maestro + """ + # Ensure the output directory exists + output_dir = tempfile.mkdtemp() if output_dir is None else output_dir + os.makedirs(output_dir, exist_ok=True) + + # Download and extract zip file + dir = utils.get_file( + origin="https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip", + extract=True, + ) + + # Gather all MIDI files + midi_files, file_paths = set(), list() + for root, _, files in os.walk(dir): + for file in files: + if file.lower().endswith(".midi") or file.lower().endswith(".mid"): + midi_files.add(path.join(root, file)) + + # Move the files to the output directory + for file in sorted(midi_files): + file_paths.append(new_path := path.join(output_dir, path.basename(file))) + os.rename(file, new_path) + return file_paths + + +paths = list(sorted(download_maestro(output_dir="datasets/maestro"))) +output_dir = path.dirname(paths[0]) + +``` + +### Split the dataset + +We can now split the dataset into training and validation sets. + + +```python +indices = np.random.permutation(len(paths)) +split = int(len(paths) * 0.1) +train_paths = [paths[i] for i in indices[split:]] +val_paths = [paths[i] for i in indices[:split]] +``` + +### Hear a MIDI file + +We use the pretty_midi library and fluidsynth to convert MIDI files into waveform audio. +This allows us to listen to the data samples before and after processing. + +The following dependencies are required to play the audio: +- fluidsynth: `sudo apt install -y fluidsynth` +- pyfluidsynth, scipy: `pip install pyfluidsynth scipy` + + +```python + +def visualize_midi(midi_path, sampling_rate=16000, seconds=15, out_dir=None): + import pretty_midi + from scipy.io.wavfile import write as write_wav + from IPython.display import Audio + + # Create the audio waveform + pretty_midi_file = pretty_midi.PrettyMIDI(midi_path) + waveform = pretty_midi_file.fluidsynth(fs=sampling_rate)[: seconds * sampling_rate] + + # Display the audio if no path is provided + if out_dir is None: + # IPython display + return Audio(waveform, rate=sampling_rate) + + # Save the audio to a file + os.makedirs(out_dir, exist_ok=True) + audio_path = path.join(out_dir, path.basename(midi_path).split(".")[0] + ".wav") + write_wav(audio_path, sampling_rate, (waveform * 32767).astype(np.int16)) + return audio_path + + +print(visualize_midi(train_paths[0], out_dir="tmp/")) # Saved audio path +visualize_midi(train_paths[0]) # Display the audio if in a Jupyter notebook + +``` + +
+``` +tmp/MIDI-Unprocessed_03_R2_2008_01-03_ORIG_MID--AUDIO_03_R2_2008_wav--2.wav + +``` +
+ + + + + +### Tokenize the data + +We now preprocess the MIDI files into a tokenized format for training. + + +```python + +def encode_midi_task(midi_path): + """Define a task that tokenizes a MIDI file.""" + import midi_neural_processor.processor as midi_tokenizer + + return midi_tokenizer.encode_midi(midi_path) + + +def preprocess_midi_files(file_paths, save_dir=None): + """Preprocess a list of MIDI files and save the notes to a file.""" + from multiprocessing import Pool, cpu_count + + # Assume all files are in the same directory and save to the same directory + save_dir = path.dirname(file_paths[0]) if save_dir is None else save_dir + os.makedirs(save_dir, exist_ok=True) + + # Check if the notes have already been preprocessed + output_file = path.join(save_dir, "notes.npz") + if path.exists(output_file): + npz_file = np.load(output_file) + return [npz_file[key] for key in npz_file.keys()] + + # Preprocess the MIDI files in parallel + progbar = utils.Progbar(len(file_paths), unit_name="MIDI_file", interval=5) + pool = Pool(cpu_count() - 1) + all_notes = [] + for notes in pool.imap_unordered(encode_midi_task, file_paths): + progbar.add(1) + all_notes.append(np.array(notes)) + + # Save the notes to a file + np.savez(output_file, *all_notes) + return all_notes + + +train_midis = preprocess_midi_files(train_paths, path.join(output_dir, "train")) +val_midis = preprocess_midi_files(val_paths, path.join(output_dir, "val")) + +``` + + + 1/1149 ━━━━━━━━━━━━━━━━━━━━ 4:26 232ms/MIDI_file + +
+``` + +``` +
+ 197/1149 ━━━━━━━━━━━━━━━━━━━━ 24s 26ms/MIDI_file + +
+``` + +``` +
+ 380/1149 ━━━━━━━━━━━━━━━━━━━━ 20s 26ms/MIDI_file + +
+``` + +``` +
+ 560/1149 ━━━━━━━━━━━━━━━━━━━━ 15s 27ms/MIDI_file + +
+``` + +``` +
+ 755/1149 ━━━━━━━━━━━━━━━━━━━━ 10s 27ms/MIDI_file + +
+``` + +``` +
+ 953/1149 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/MIDI_file + +
+``` + +``` +
+ 1146/1149 ━━━━━━━━━━━━━━━━━━━━ 0s 26ms/MIDI_file + +
+``` + +``` +
+ 1149/1149 ━━━━━━━━━━━━━━━━━━━━ 31s 26ms/MIDI_file + + + + 1/127 ━━━━━━━━━━━━━━━━━━━━ 20s 166ms/MIDI_file + +
+``` + +``` +
+ 127/127 ━━━━━━━━━━━━━━━━━━━━ 4s 34ms/MIDI_file + + +### Dataset objects + +We now define a dataset class that yields batches of input sequences and target sequences. + + +```python + +class MidiDataset(utils.PyDataset): + """A dataset for MIDI files that yields batches of input sequences and target sequences.""" + + def __init__( + self, + encoded_midis, + batch_size=CONFIG.batch_size, + max_sequence_len=CONFIG.max_sequence_len, + ): + super(MidiDataset, self).__init__() + self.batch_size = batch_size + self.max_sequence_len = max_sequence_len + self.encoded_midis = encoded_midis + batches, last_batch_size = divmod(len(encoded_midis), batch_size) + self._num_batches = batches + int(last_batch_size > 0) + + def __len__(self): + """Get the number of batches.""" + return self._num_batches + + def __getitem__(self, idx): + """Generate random inputs and corresponding targets for the model.""" + # Same as in the original paper, we always get a random batch. + # See: https://github.com/jason9693/MusicTransformer-tensorflow2.0/blob/f7c06c0cb2e9cdddcbf6db779cb39cd650282778/data.py + batch = random.sample(self.encoded_midis, k=self.batch_size) + + # Convert the batch to sequences + batch_data = [ + self._get_sequence(midi, self.max_sequence_len + 1) for midi in batch + ] + batch_data = np.array(batch_data) + + # Split the data into input and target sequences + return batch_data[:, :-1], batch_data[:, 1:] + + def _get_sequence(self, data, max_length): + """Get a random sequence of notes from a file.""" + # Truncate or pad the sequence + if len(data) > max_length: + start = random.randrange(0, len(data) - max_length) + data = data[start : start + max_length] + elif len(data) < max_length: + data = np.append(data, CONFIG.token_end_of_sentence) + + # Pad the sequence if necessary + if len(data) < max_length: + data = np.concatenate( + (data, np.full(max_length - len(data), CONFIG.token_pad)) + ) + return np.asanyarray(data, dtype="int32") + + +train_dataset, val_dataset = MidiDataset(train_midis), MidiDataset(val_midis) + +``` + +--- +## Model definition + +It is time to define the model architecture. We use a Transformer decoder +architecture with a custom attention mechanism, relative global attention. + +### Relative Global Attention + +The following code implements the Relative Global Attention layer. It is used +in place of the standard multi-head attention layer in the Transformer decoder. +The main difference is that it includes a relative positional encoding that +allows the model to learn relative positional information between tokens. + + +```python + +@keras.utils.register_keras_serializable() +class RelativeGlobalAttention(layers.Layer): + """ + From Music Transformer (Huang et al., 2018) + https://arxiv.org/abs/1809.04281 + """ + + def __init__(self, num_heads, embedding_dim, max_sequence_len, **kwargs): + super().__init__(**kwargs) + self.key_length = None + self.max_sequence_len = max_sequence_len + self.relative_embedding = None + self.num_heads = num_heads + self.embedding_dim = embedding_dim + self.head_dim = embedding_dim // num_heads + self.query_dense = layers.Dense(int(self.embedding_dim)) + self.key_dense = layers.Dense(int(self.embedding_dim)) + self.value_dense = layers.Dense(int(self.embedding_dim)) + self.output_dense = layers.Dense(embedding_dim, name="output") + + def build(self, input_shape): + self.query_length = input_shape[0][1] + self.key_length = input_shape[1][1] + self.relative_embedding = self.add_weight( + (self.max_sequence_len, int(self.head_dim)), name="relative_embedding" + ) + + def _apply_dense_layer_and_split_heads(self, inputs, dense_layer): + # Apply linear transformation + inputs = dense_layer(inputs) + new_shape = ops.shape(inputs) + # Reshape to split by attention heads + reshaped = ops.reshape(inputs, (new_shape[0], new_shape[1], self.num_heads, -1)) + # Transpose for head-first format + return ops.transpose(reshaped, (0, 2, 1, 3)) + + def call(self, inputs, mask=None): + # Compute Q, K, V: Batch, head, sequence, features + query = self._apply_dense_layer_and_split_heads(inputs[0], self.query_dense) + key = self._apply_dense_layer_and_split_heads(inputs[1], self.key_dense) + value = self._apply_dense_layer_and_split_heads(inputs[2], self.value_dense) + + # Compute scaled dot-product attention scores + attention_scores = ops.matmul(query, ops.transpose(key, [0, 1, 3, 2])) + + # Compute relative positional encoding and combine with attention scores + start_idx = max(0, self.max_sequence_len - ops.shape(query)[2]) + relative_embedding = self.relative_embedding[start_idx:, :] + attention_scores += self._compute_attention_scores(query, relative_embedding) + logits = attention_scores / ops.sqrt(self.head_dim) + + # Apply mask if provided + if mask is not None: + logits += ops.cast(mask, "float32") * -1e9 + + # Compute attention weights + attention_weights = ops.nn.softmax(logits, axis=-1) + attention_output = ops.matmul(attention_weights, value) + + # Merge heads and apply final linear transformation + merged_attention = ops.transpose(attention_output, (0, 2, 1, 3)) + merged_attention = ops.reshape( + merged_attention, (ops.shape(merged_attention)[0], -1, self.embedding_dim) + ) + output = self.output_dense(merged_attention) + + return output, attention_weights + + def _compute_attention_scores(self, query, relative_embedding): + """ + Compute relative attention scores using positional encodings. + """ + relative_scores = ops.einsum("bhld, md->bhlm", query, relative_embedding) + relative_scores = self._apply_mask_to_relative_scores(relative_scores) + return self._skew_attention_scores(relative_scores) + + def _apply_mask_to_relative_scores(self, scores): + """ + Apply masking to relative positional scores to ignore future positions. + """ + mask = ops.flip( + ops.tri(scores.shape[-2], scores.shape[-1], dtype="float32"), axis=1 + ) + return mask * scores + + def _skew_attention_scores(self, scores): + """ + Perform skewing operation to align relative attention scores with the sequence. + """ + padded_scores = ops.pad(scores, ((0, 0), (0, 0), (0, 0), (1, 0))) + padded_shape = ops.shape(padded_scores) + reshaped_scores = ops.reshape( + padded_scores, (-1, padded_shape[1], padded_shape[-1], padded_shape[-2]) + ) + skewed_scores = reshaped_scores[:, :, 1:, :] + + if self.key_length > self.query_length: + size_diff = self.key_length - self.query_length + return ops.pad(skewed_scores, [[0, 0], [0, 0], [0, 0], [0, size_diff]]) + else: + return skewed_scores[:, :, :, : self.key_length] + +``` + +### Decoder Layer + +Using the RelativeGlobalAttention layer, we can define the DecoderLayer. It is mostly like +the standard Transformer decoder layer but with the custom attention mechanism. + + +```python + +@keras.utils.register_keras_serializable() +class DecoderLayer(layers.Layer): + def __init__(self, embedding_dim, num_heads, max_sequence_len, dropout=0.1): + super(DecoderLayer, self).__init__() + + # Initialize attributes + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.max_sequence_len = max_sequence_len + + # Initialize layers + self.relative_global_attention_1 = RelativeGlobalAttention( + num_heads, embedding_dim, max_sequence_len + ) + + self.feed_forward_network_pre = layers.Dense(self.embedding_dim // 2, "relu") + self.feed_forward_network_pos = layers.Dense(self.embedding_dim) + + self.layer_normalization_1 = layers.LayerNormalization(epsilon=1e-6) + self.layer_normalization_2 = layers.LayerNormalization(epsilon=1e-6) + + self.dropout_1 = layers.Dropout(dropout) + self.dropout_2 = layers.Dropout(dropout) + + def call(self, inputs, mask=None, training=False): + # Attention block. Inputs are (query, key, value) + attention_out, attention_weights = self.relative_global_attention_1( + (inputs, inputs, inputs), mask=mask + ) + attention_out = self.dropout_1(attention_out, training=training) + attention_out_normalized = self.layer_normalization_1(attention_out + inputs) + + ffn_out = self.feed_forward_network_pre(attention_out) + ffn_out = self.feed_forward_network_pos(ffn_out) + ffn_out = self.dropout_2(ffn_out, training=training) + out = self.layer_normalization_2(attention_out_normalized + ffn_out) + + return out, attention_weights + +``` + +### Decoder + +The Decoder layer is composed of multiple DecoderLayer blocks. It also includes +an embedding layer that converts our tokenized input into an embedding representation. + + +```python + +@keras.utils.register_keras_serializable() +class Decoder(layers.Layer): + def __init__( + self, embedding_dim, vocabulary_size, max_sequence_len, num_blocks, dropout + ): + super(Decoder, self).__init__() + + self.embedding_dim = embedding_dim + self.num_blocks = num_blocks + + self.embedding = layers.Embedding(vocabulary_size, self.embedding_dim) + self.positional_encoding = hub_layers.SinePositionEncoding() + + self.decode_layers = [ + DecoderLayer( + embedding_dim, embedding_dim // 64, max_sequence_len, dropout=dropout + ) + for _ in range(num_blocks) + ] + self.dropout = layers.Dropout(dropout) + + def call(self, inputs, mask=None, training=False, return_attention_weights=False): + weights = [] + + # Adding embedding and position encoding. + x = self.embedding(inputs) + x = x * ops.sqrt(ops.cast(self.embedding_dim, "float32")) + x = x + self.positional_encoding(x) + x = self.dropout(x, training=training) + + # Passing through the transformer blocks. + for i in range(self.num_blocks): + x, w = self.decode_layers[i](x, mask=mask, training=training) + weights.append(w) + if return_attention_weights: + return x, weights + return x + +``` + +### Music Transformer Decoder + +With the above layers defined, we can now define the MusicTransformerDecoder model. It applies +a linear transformation to the output of the decoder to get the logits for each token. + + +```python + +@keras.utils.register_keras_serializable() +class MusicTransformerDecoder(keras.Model): + def __init__( + self, + embedding_dim=CONFIG.embedding_dim, + vocabulary_size=CONFIG.vocabulary_size, + num_blocks=CONFIG.num_transformer_blocks, + max_sequence_len=CONFIG.max_sequence_len, + dropout=0.2, + ): + # Initialize attributes + super(MusicTransformerDecoder, self).__init__() + self.embedding_dim = embedding_dim + self.vocabulary_size = vocabulary_size + self.num_blocks = num_blocks + self.max_sequence_len = max_sequence_len + + # Initialize layers + # Transformer decoder + self.decoder = Decoder( + embedding_dim, vocabulary_size, max_sequence_len, num_blocks, dropout + ) + # Output layer + self.fc = layers.Dense(self.vocabulary_size, activation=None, name="output") + + @staticmethod + def get_look_ahead_mask(max_sequence_len, inputs): + sequence_length = min(max_sequence_len, inputs.shape[1]) + sequence_mask = ops.logical_not( + ops.tri(sequence_length, sequence_length, dtype="bool") + ) + + inputs = ops.cast(inputs[:, None, None, :], "int32") + output_pad_tensor = ops.ones_like(inputs) * CONFIG.token_pad + decoder_output_mask = ops.equal(inputs, output_pad_tensor) + return ops.cast(ops.logical_or(decoder_output_mask, sequence_mask), "int32") + + def call(self, inputs, training=False): + mask = self.get_look_ahead_mask(self.max_sequence_len, inputs) + decoding = self.decoder( + inputs, mask=mask, training=training, return_attention_weights=False + ) + return self.fc(decoding) + + # --- Sequence generation methods + + def generate(self, inputs: list, length=CONFIG.max_sequence_len, top_k=5): + inputs = ops.convert_to_tensor([inputs]) + + # Generate a new token using output distribution at given index + def generate_token(inputs, end_idx): + distribution = ops.stop_gradient(self.call(inputs)[0, end_idx]) + + # Select the top-k tokens and their probabilities + top_k_distribution, top_k_indices = ops.top_k(distribution, k=top_k) + + # Sample from the top-k probabilities + new_token_idx = keras.random.categorical(top_k_distribution[None, :], 1) + return ops.take(top_k_indices, new_token_idx[0]) + + # Compute the number of tokens to add + added_tokens = min(length, self.max_sequence_len - inputs.shape[1]) + progbar = utils.Progbar(added_tokens, unit_name="token", interval=5) + + # Pad the input sequence that will be filled with generated tokens + out = ops.pad(inputs, ((0, 0), (0, added_tokens)), "constant", CONFIG.token_pad) + + # Generate tokens using top-k sampling + for token_idx in range(inputs.shape[1] - 1, inputs.shape[1] - 1 + added_tokens): + token = ops.cast(generate_token(out, end_idx=token_idx), out.dtype) + out = ops.scatter_update(out, ((0, token_idx + 1),), token) + progbar.add(1) + + return ops.convert_to_numpy(out[0]) + + # --- Serialization methods + + def get_config(self): + atts = ["embedding_dim", "vocabulary_size", "num_blocks", "max_sequence_len"] + return {a: getattr(self, a) for a in atts} + + @classmethod + def from_config(cls, config): + return cls(**config) + +``` + +### Loss function + +We define a custom loss function that computes the categorical cross-entropy +loss for the model. It is computed only for non-padding tokens and uses +`from_logits=True` since the model outputs logits. + + +```python + +@keras.utils.register_keras_serializable() +def train_loss(y_true, y_pred): + mask = ops.cast(ops.logical_not(ops.equal(y_true, CONFIG.token_pad)), "float32") + y_true = ops.one_hot(ops.cast(y_true, "int32"), CONFIG.vocabulary_size) + return ops.categorical_crossentropy(y_true, y_pred, from_logits=True) * mask + +``` + +### Learning rate schedule + +Following the Music Transformer paper, we define an adapted exponential decay +learning rate schedule that takes into account the embedding dimension. + + +```python + +@keras.utils.register_keras_serializable() +class CustomSchedule(optimizers.schedules.LearningRateSchedule): + def __init__(self, embedding_dim, warmup_steps=4000): + super(CustomSchedule, self).__init__() + + self.embedding_dim = embedding_dim + self.warmup_steps = warmup_steps + + self._embedding_dim = ops.cast(self.embedding_dim, "float32") + # Numerical stability adjustment on torch, which is less precise + self._lr_adjust = 0.1 if keras.backend.backend() == "torch" else 1.0 + + def get_config(self): + return {"embedding_dim": self.embedding_dim, "warmup_steps": self.warmup_steps} + + def __call__(self, step): + step_rsqrt = ops.rsqrt(ops.cast(step, "float32")) + warmup_adjust = step * (self.warmup_steps**-1.5) + output = ops.rsqrt(self._embedding_dim) * ops.minimum(step_rsqrt, warmup_adjust) + return self._lr_adjust * output + +``` + +--- +## Training the model + +We can now train the model on the Maestro dataset. First, we define a training +function. This function compiles the model, trains it, and saves the best model +checkpoint. This way, we can continue training from the best model checkpoint +if needed. + + +```python + +def train_model(model, train_ds, val_ds, epochs=15): + # Configure optimizer + learning_rate = CustomSchedule(CONFIG.embedding_dim) + optimizer = optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9) + + # Compile the model + model.compile(optimizer=optimizer, loss=train_loss) + + # Train the model + save_cb = callbacks.ModelCheckpoint(CONFIG.model_out, save_best_only=True) + model.fit( + train_ds, validation_data=val_ds, epochs=epochs, callbacks=[save_cb], verbose=2 + ) + return model + +``` + +We can now train the model on the Maestro dataset. If a model checkpoint exists, +we can load it and continue training. + + +```python +if path.exists(CONFIG.model_out): + model = keras.models.load_model(CONFIG.model_out) + # Comment out to continue model training from the checkpoint + # train_model(model, train_dataset, val_dataset, epochs=10) +else: + # Train the model + model = train_model(MusicTransformerDecoder(), train_dataset, val_dataset) + +``` + +
+``` +Epoch 1/15 + +WARNING: All log messages before absl::InitializeLog() is called are written to STDERR +I0000 00:00:1732641133.718156 302469 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node +Your kernel may have been built without NUMA support. +I0000 00:00:1732641133.736834 302469 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node +Your kernel may have been built without NUMA support. +I0000 00:00:1732641133.736873 302469 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node +Your kernel may have been built without NUMA support. +I0000 00:00:1732641133.738653 302469 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node +Your kernel may have been built without NUMA support. +I0000 00:00:1732641133.738685 302469 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node +Your kernel may have been built without NUMA support. +I0000 00:00:1732641133.738700 302469 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node +Your kernel may have been built without NUMA support. +I0000 00:00:1732641133.881476 302469 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node +Your kernel may have been built without NUMA support. +I0000 00:00:1732641133.881527 302469 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node +Your kernel may have been built without NUMA support. +I0000 00:00:1732641133.881559 302469 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node +Your kernel may have been built without NUMA support. + +WARNING: All log messages before absl::InitializeLog() is called are written to STDERR +I0000 00:00:1732641149.500770 303551 service.cc:146] XLA service 0x7f35ec010020 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices: +I0000 00:00:1732641149.500802 303551 service.cc:154] StreamExecutor device (0): NVIDIA GeForce RTX 3070 Ti, Compute Capability 8.6 + +I0000 00:00:1732641168.138225 303551 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. + +192/192 - 65s - 341ms/step - loss: 5.5919 - val_loss: 5.0251 + +Epoch 2/15 + +192/192 - 27s - 140ms/step - loss: 4.9749 - val_loss: 4.8658 + +Epoch 3/15 + +192/192 - 27s - 141ms/step - loss: 4.6788 - val_loss: 4.1796 + +Epoch 4/15 + +192/192 - 27s - 140ms/step - loss: 4.1006 - val_loss: 4.0220 + +Epoch 5/15 + +192/192 - 27s - 140ms/step - loss: 3.9812 - val_loss: 3.9015 + +Epoch 6/15 + +192/192 - 27s - 140ms/step - loss: 3.8634 - val_loss: 3.8328 + +Epoch 7/15 + +192/192 - 27s - 140ms/step - loss: 3.7634 - val_loss: 3.6601 + +Epoch 8/15 + +192/192 - 27s - 140ms/step - loss: 3.6034 - val_loss: 3.4094 + +Epoch 9/15 + +192/192 - 27s - 139ms/step - loss: 3.3404 - val_loss: 3.2729 + +Epoch 10/15 + +192/192 - 27s - 140ms/step - loss: 3.2182 - val_loss: 3.1253 + +Epoch 11/15 + +192/192 - 27s - 140ms/step - loss: 3.1626 - val_loss: 3.0725 + +Epoch 12/15 + +192/192 - 27s - 140ms/step - loss: 3.0909 - val_loss: 3.0714 + +Epoch 13/15 + +192/192 - 27s - 140ms/step - loss: 3.0565 - val_loss: 2.9813 + +Epoch 14/15 + +192/192 - 27s - 140ms/step - loss: 2.9938 - val_loss: 2.9099 + +Epoch 15/15 + +192/192 - 27s - 140ms/step - loss: 2.9512 - val_loss: 2.9054 + +``` +
+--- +## Generate music + +We can now generate music using the trained model. We use an existing MIDI file +as a seed and generate a new sequence. + + +```python + +def generate_music(model, seed_path, length=1024, out_dir=None, top_k=None): + # Ensure the output directory exists + out_dir = out_dir if out_dir is not None else tempfile.mkdtemp() + os.makedirs(out_dir, exist_ok=True) + + # Get some tokens from the MIDI file + inputs = midi_tokenizer.encode_midi(seed_path)[100:125] + print(f"Seed tokens: {inputs}") + + # Generate music that follows the input tokens until the maximum length + result = model.generate(inputs, length=length, top_k=top_k) + + output_path = path.join(out_dir, path.basename(seed_path).split(".")[0] + ".mid") + midi_tokenizer.decode_midi(result, output_path) + return output_path + + +output_file = generate_music(model, val_paths[-1], out_dir="tmp/", top_k=15) +print(visualize_midi(output_file, out_dir="tmp/")) # Saved audio path +visualize_midi(output_file) # Display the audio if in a Jupyter notebook +``` + +
+``` +Seed tokens: [348, 367, 70, 259, 364, 63, 256, 361, 51, 363, 43, 257, 176, 264, 196, 297, 179, 257, 191, 333, 367, 72, 257, 198, 365] + +``` +
+ + 1/1024 ━━━━━━━━━━━━━━━━━━━━ 22:58 1s/token + +
+``` + +``` +
+ 67/1024 ━━━━━━━━━━━━━━━━━━━━ 1:13 76ms/token + +
+``` + +``` +
+ 133/1024 ━━━━━━━━━━━━━━━━━━━━ 1:08 76ms/token + +
+``` + +``` +
+ 199/1024 ━━━━━━━━━━━━━━━━━━━━ 1:02 76ms/token + +
+``` + +``` +
+ 266/1024 ━━━━━━━━━━━━━━━━━━━━ 57s 76ms/token + +
+``` + +``` +
+ 331/1024 ━━━━━━━━━━━━━━━━━━━━ 52s 76ms/token + +
+``` + +``` +
+ 396/1024 ━━━━━━━━━━━━━━━━━━━━ 47s 76ms/token + +
+``` + +``` +
+ 461/1024 ━━━━━━━━━━━━━━━━━━━━ 43s 76ms/token + +
+``` + +``` +
+ 528/1024 ━━━━━━━━━━━━━━━━━━━━ 37s 76ms/token + +
+``` + +``` +
+ 594/1024 ━━━━━━━━━━━━━━━━━━━━ 32s 76ms/token + +
+``` + +``` +
+ 660/1024 ━━━━━━━━━━━━━━━━━━━━ 27s 76ms/token + +
+``` + +``` +
+ 726/1024 ━━━━━━━━━━━━━━━━━━━━ 22s 76ms/token + +
+``` + +``` +
+ 793/1024 ━━━━━━━━━━━━━━━━━━━━ 17s 76ms/token + +
+``` + +``` +
+ 859/1024 ━━━━━━━━━━━━━━━━━━━━ 12s 76ms/token + +
+``` + +``` +
+ 925/1024 ━━━━━━━━━━━━━━━━━━━━ 7s 76ms/token + +
+``` + +``` +
+ 991/1024 ━━━━━━━━━━━━━━━━━━━━ 2s 76ms/token + +
+``` + +``` +
+ 1024/1024 ━━━━━━━━━━━━━━━━━━━━ 79s 76ms/token + + +
+``` +info removed pitch: 48 +info removed pitch: 68 +info removed pitch: 39 +info removed pitch: 24 +info removed pitch: 24 +info removed pitch: 30 +info removed pitch: 24 + +tmp/MIDI-Unprocessed_12_R2_2009_01_ORIG_MID--AUDIO_12_R2_2009_12_R2_2009_02_WAV.wav + +``` +
+ + + + + +--- +## Conclusion + +In this example, we learned how to build a music generation model using a custom +Transformer decoder architecture. + +We did it following the Music Transformer paper by Huang et al. (2018). +To do so we had to: + +- Define a custom loss function and learning rate schedule. +- Define a custom attention mechanism. +- Preprocess MIDI files into a tokenized format. + +After training the model on the Maestro dataset, we generated music sequences +using a seed MIDI file. + +### Next steps + +We could further improve inference times by caching attention weights during the +forward pass, in a similar way as `keras_hub` `CausalLM` models, which use the +`CachedMultiHeadAttention` layer. diff --git a/examples/generative/midi_generation_with_transformer.py b/examples/generative/midi_generation_with_transformer.py new file mode 100644 index 0000000000..1527dc86da --- /dev/null +++ b/examples/generative/midi_generation_with_transformer.py @@ -0,0 +1,722 @@ +""" +Title: Music Generation with Transformer Models +Author: [Joaquin Jimenez](https://github.com/johacks/) +Date created: 2024/11/22 +Last modified: 2024/11/26 +Description: Use a Transformer model to train on MIDI data and generate music sequences. +Accelerator: GPU +""" + +""" +## Introduction + +In this tutorial, we learn how to build a music generation model using a +Transformer decode-only architecture. +The model is trained on the [Maestro dataset](https://magenta.tensorflow.org/datasets/maestro) +and implemented using keras 3. +In the process, we explore MIDI tokenization, and relative global attention mechanisms. + +This example is based on the paper "Music Transformer" by Huang et al. (2018). +Check out the original [paper](https://arxiv.org/abs/1809.04281) and +[code](https://github.com/jason9693/MusicTransformer-tensorflow2.0). +""" + +""" +## Setup + +Before we start, let's import and install all the libraries we need. +""" + +"""shell +pip install -qq midi_neural_processor +pip install -qq keras_hub +pip install -qq "keras>=3.6.0" # Allows use of keras.utils.Config. +""" + +""" +### Optional dependencies + +To hear the audio, install the following additional dependencies: +""" + +"""shell +sudo apt-get -qq install -y fluidsynth 2> /dev/null +pip install -qq pyfluidsynth scipy +""" + +import os +import random +import tempfile + +import keras +import midi_neural_processor.processor as midi_tokenizer +import numpy as np +from keras import callbacks, layers, ops, optimizers, utils +from keras_hub import layers as hub_layers +from os import path + +""" +## Configuration + +Lets define the configuration for the model and the dataset to be used in this example. +""" +event_range = midi_tokenizer.RANGE_NOTE_ON +event_range += midi_tokenizer.RANGE_NOTE_OFF +event_range += midi_tokenizer.RANGE_TIME_SHIFT +event_range += midi_tokenizer.RANGE_VEL +CONFIG = utils.Config( + max_sequence_len=2048, + embedding_dim=256, + num_transformer_blocks=6, + batch_size=6, + token_pad=event_range, + token_start_of_sentence=event_range + 1, + token_end_of_sentence=event_range + 2, + vocabulary_size=event_range + 3, + model_out="tmp/music_transformer.keras", + seed=42, +) +utils.set_random_seed(CONFIG.seed) + + +""" +## Maestro dataset + +The Maestro dataset contains MIDI files for piano performances. + +### Download the dataset + +We now download and extract the dataset, then move the MIDI files to a new directory. +""" + + +def download_maestro(output_dir=None): + """Download the Maestro MIDI dataset. + Extracted from: https://magenta.tensorflow.org/datasets/maestro + """ + # Ensure the output directory exists + output_dir = tempfile.mkdtemp() if output_dir is None else output_dir + os.makedirs(output_dir, exist_ok=True) + + # Download and extract zip file + dir = utils.get_file( + origin="https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip", + extract=True, + ) + + # Gather all MIDI files + midi_files, file_paths = set(), list() + for root, _, files in os.walk(dir): + for file in files: + if file.lower().endswith(".midi") or file.lower().endswith(".mid"): + midi_files.add(path.join(root, file)) + + # Move the files to the output directory + for file in sorted(midi_files): + file_paths.append(new_path := path.join(output_dir, path.basename(file))) + os.rename(file, new_path) + return file_paths + + +paths = list(sorted(download_maestro(output_dir="datasets/maestro"))) +output_dir = path.dirname(paths[0]) + + +""" +### Split the dataset + +We can now split the dataset into training and validation sets. +""" + +indices = np.random.permutation(len(paths)) +split = int(len(paths) * 0.1) +train_paths = [paths[i] for i in indices[split:]] +val_paths = [paths[i] for i in indices[:split]] + +""" +### Hear a MIDI file + +We use the pretty_midi library and fluidsynth to convert MIDI files into waveform audio. +This allows us to listen to the data samples before and after processing. + +The following dependencies are required to play the audio: +- fluidsynth: `sudo apt install -y fluidsynth` +- pyfluidsynth, scipy: `pip install pyfluidsynth scipy` +""" + + +def visualize_midi(midi_path, sampling_rate=16000, seconds=15, out_dir=None): + import pretty_midi + from scipy.io.wavfile import write as write_wav + from IPython.display import Audio + + # Create the audio waveform + pretty_midi_file = pretty_midi.PrettyMIDI(midi_path) + waveform = pretty_midi_file.fluidsynth(fs=sampling_rate)[: seconds * sampling_rate] + + # Display the audio if no path is provided + if out_dir is None: + # IPython display + return Audio(waveform, rate=sampling_rate) + + # Save the audio to a file + os.makedirs(out_dir, exist_ok=True) + audio_path = path.join(out_dir, path.basename(midi_path).split(".")[0] + ".wav") + write_wav(audio_path, sampling_rate, (waveform * 32767).astype(np.int16)) + return audio_path + + +print(visualize_midi(train_paths[0], out_dir="tmp/")) # Saved audio path +visualize_midi(train_paths[0]) # Display the audio if in a Jupyter notebook + + +""" +### Tokenize the data + +We now preprocess the MIDI files into a tokenized format for training. +""" + + +def encode_midi_task(midi_path): + """Define a task that tokenizes a MIDI file.""" + import midi_neural_processor.processor as midi_tokenizer + + return midi_tokenizer.encode_midi(midi_path) + + +def preprocess_midi_files(file_paths, save_dir=None): + """Preprocess a list of MIDI files and save the notes to a file.""" + from multiprocessing import Pool, cpu_count + + # Assume all files are in the same directory and save to the same directory + save_dir = path.dirname(file_paths[0]) if save_dir is None else save_dir + os.makedirs(save_dir, exist_ok=True) + + # Check if the notes have already been preprocessed + output_file = path.join(save_dir, "notes.npz") + if path.exists(output_file): + npz_file = np.load(output_file) + return [npz_file[key] for key in npz_file.keys()] + + # Preprocess the MIDI files in parallel + progbar = utils.Progbar(len(file_paths), unit_name="MIDI_file", interval=5) + pool = Pool(cpu_count() - 1) + all_notes = [] + for notes in pool.imap_unordered(encode_midi_task, file_paths): + progbar.add(1) + all_notes.append(np.array(notes)) + + # Save the notes to a file + np.savez(output_file, *all_notes) + return all_notes + + +train_midis = preprocess_midi_files(train_paths, path.join(output_dir, "train")) +val_midis = preprocess_midi_files(val_paths, path.join(output_dir, "val")) + + +""" +### Dataset objects + +We now define a dataset class that yields batches of input sequences and target sequences. +""" + + +class MidiDataset(utils.PyDataset): + """A dataset for MIDI files that yields batches of input sequences and target sequences.""" + + def __init__( + self, + encoded_midis, + batch_size=CONFIG.batch_size, + max_sequence_len=CONFIG.max_sequence_len, + ): + super(MidiDataset, self).__init__() + self.batch_size = batch_size + self.max_sequence_len = max_sequence_len + self.encoded_midis = encoded_midis + batches, last_batch_size = divmod(len(encoded_midis), batch_size) + self._num_batches = batches + int(last_batch_size > 0) + + def __len__(self): + """Get the number of batches.""" + return self._num_batches + + def __getitem__(self, idx): + """Generate random inputs and corresponding targets for the model.""" + # Same as in the original paper, we always get a random batch. + # See: https://github.com/jason9693/MusicTransformer-tensorflow2.0/blob/f7c06c0cb2e9cdddcbf6db779cb39cd650282778/data.py + batch = random.sample(self.encoded_midis, k=self.batch_size) + + # Convert the batch to sequences + batch_data = [ + self._get_sequence(midi, self.max_sequence_len + 1) for midi in batch + ] + batch_data = np.array(batch_data) + + # Split the data into input and target sequences + return batch_data[:, :-1], batch_data[:, 1:] + + def _get_sequence(self, data, max_length): + """Get a random sequence of notes from a file.""" + # Truncate or pad the sequence + if len(data) > max_length: + start = random.randrange(0, len(data) - max_length) + data = data[start : start + max_length] + elif len(data) < max_length: + data = np.append(data, CONFIG.token_end_of_sentence) + + # Pad the sequence if necessary + if len(data) < max_length: + data = np.concatenate( + (data, np.full(max_length - len(data), CONFIG.token_pad)) + ) + return np.asanyarray(data, dtype="int32") + + +train_dataset, val_dataset = MidiDataset(train_midis), MidiDataset(val_midis) + + +""" +## Model definition + +It is time to define the model architecture. We use a Transformer decoder +architecture with a custom attention mechanism, relative global attention. + +### Relative Global Attention + +The following code implements the Relative Global Attention layer. It is used +in place of the standard multi-head attention layer in the Transformer decoder. +The main difference is that it includes a relative positional encoding that +allows the model to learn relative positional information between tokens. +""" + + +@keras.utils.register_keras_serializable() +class RelativeGlobalAttention(layers.Layer): + """ + From Music Transformer (Huang et al., 2018) + https://arxiv.org/abs/1809.04281 + """ + + def __init__(self, num_heads, embedding_dim, max_sequence_len, **kwargs): + super().__init__(**kwargs) + self.key_length = None + self.max_sequence_len = max_sequence_len + self.relative_embedding = None + self.num_heads = num_heads + self.embedding_dim = embedding_dim + self.head_dim = embedding_dim // num_heads + self.query_dense = layers.Dense(int(self.embedding_dim)) + self.key_dense = layers.Dense(int(self.embedding_dim)) + self.value_dense = layers.Dense(int(self.embedding_dim)) + self.output_dense = layers.Dense(embedding_dim, name="output") + + def build(self, input_shape): + self.query_length = input_shape[0][1] + self.key_length = input_shape[1][1] + self.relative_embedding = self.add_weight( + (self.max_sequence_len, int(self.head_dim)), name="relative_embedding" + ) + + def _apply_dense_layer_and_split_heads(self, inputs, dense_layer): + # Apply linear transformation + inputs = dense_layer(inputs) + new_shape = ops.shape(inputs) + # Reshape to split by attention heads + reshaped = ops.reshape(inputs, (new_shape[0], new_shape[1], self.num_heads, -1)) + # Transpose for head-first format + return ops.transpose(reshaped, (0, 2, 1, 3)) + + def call(self, inputs, mask=None): + # Compute Q, K, V: Batch, head, sequence, features + query = self._apply_dense_layer_and_split_heads(inputs[0], self.query_dense) + key = self._apply_dense_layer_and_split_heads(inputs[1], self.key_dense) + value = self._apply_dense_layer_and_split_heads(inputs[2], self.value_dense) + + # Compute scaled dot-product attention scores + attention_scores = ops.matmul(query, ops.transpose(key, [0, 1, 3, 2])) + + # Compute relative positional encoding and combine with attention scores + start_idx = max(0, self.max_sequence_len - ops.shape(query)[2]) + relative_embedding = self.relative_embedding[start_idx:, :] + attention_scores += self._compute_attention_scores(query, relative_embedding) + logits = attention_scores / ops.sqrt(self.head_dim) + + # Apply mask if provided + if mask is not None: + logits += ops.cast(mask, "float32") * -1e9 + + # Compute attention weights + attention_weights = ops.nn.softmax(logits, axis=-1) + attention_output = ops.matmul(attention_weights, value) + + # Merge heads and apply final linear transformation + merged_attention = ops.transpose(attention_output, (0, 2, 1, 3)) + merged_attention = ops.reshape( + merged_attention, (ops.shape(merged_attention)[0], -1, self.embedding_dim) + ) + output = self.output_dense(merged_attention) + + return output, attention_weights + + def _compute_attention_scores(self, query, relative_embedding): + """ + Compute relative attention scores using positional encodings. + """ + relative_scores = ops.einsum("bhld, md->bhlm", query, relative_embedding) + relative_scores = self._apply_mask_to_relative_scores(relative_scores) + return self._skew_attention_scores(relative_scores) + + def _apply_mask_to_relative_scores(self, scores): + """ + Apply masking to relative positional scores to ignore future positions. + """ + mask = ops.flip( + ops.tri(scores.shape[-2], scores.shape[-1], dtype="float32"), axis=1 + ) + return mask * scores + + def _skew_attention_scores(self, scores): + """ + Perform skewing operation to align relative attention scores with the sequence. + """ + padded_scores = ops.pad(scores, ((0, 0), (0, 0), (0, 0), (1, 0))) + padded_shape = ops.shape(padded_scores) + reshaped_scores = ops.reshape( + padded_scores, (-1, padded_shape[1], padded_shape[-1], padded_shape[-2]) + ) + skewed_scores = reshaped_scores[:, :, 1:, :] + + if self.key_length > self.query_length: + size_diff = self.key_length - self.query_length + return ops.pad(skewed_scores, [[0, 0], [0, 0], [0, 0], [0, size_diff]]) + else: + return skewed_scores[:, :, :, : self.key_length] + + +""" +### Decoder Layer + +Using the RelativeGlobalAttention layer, we can define the DecoderLayer. It is mostly like +the standard Transformer decoder layer but with the custom attention mechanism. +""" + + +@keras.utils.register_keras_serializable() +class DecoderLayer(layers.Layer): + def __init__(self, embedding_dim, num_heads, max_sequence_len, dropout=0.1): + super(DecoderLayer, self).__init__() + + # Initialize attributes + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.max_sequence_len = max_sequence_len + + # Initialize layers + self.relative_global_attention_1 = RelativeGlobalAttention( + num_heads, embedding_dim, max_sequence_len + ) + + self.feed_forward_network_pre = layers.Dense(self.embedding_dim // 2, "relu") + self.feed_forward_network_pos = layers.Dense(self.embedding_dim) + + self.layer_normalization_1 = layers.LayerNormalization(epsilon=1e-6) + self.layer_normalization_2 = layers.LayerNormalization(epsilon=1e-6) + + self.dropout_1 = layers.Dropout(dropout) + self.dropout_2 = layers.Dropout(dropout) + + def call(self, inputs, mask=None, training=False): + # Attention block. Inputs are (query, key, value) + attention_out, attention_weights = self.relative_global_attention_1( + (inputs, inputs, inputs), mask=mask + ) + attention_out = self.dropout_1(attention_out, training=training) + attention_out_normalized = self.layer_normalization_1(attention_out + inputs) + + ffn_out = self.feed_forward_network_pre(attention_out) + ffn_out = self.feed_forward_network_pos(ffn_out) + ffn_out = self.dropout_2(ffn_out, training=training) + out = self.layer_normalization_2(attention_out_normalized + ffn_out) + + return out, attention_weights + + +""" +### Decoder + +The Decoder layer is composed of multiple DecoderLayer blocks. It also includes +an embedding layer that converts our tokenized input into an embedding representation. +""" + + +@keras.utils.register_keras_serializable() +class Decoder(layers.Layer): + def __init__( + self, embedding_dim, vocabulary_size, max_sequence_len, num_blocks, dropout + ): + super(Decoder, self).__init__() + + self.embedding_dim = embedding_dim + self.num_blocks = num_blocks + + self.embedding = layers.Embedding(vocabulary_size, self.embedding_dim) + self.positional_encoding = hub_layers.SinePositionEncoding() + + self.decode_layers = [ + DecoderLayer( + embedding_dim, embedding_dim // 64, max_sequence_len, dropout=dropout + ) + for _ in range(num_blocks) + ] + self.dropout = layers.Dropout(dropout) + + def call(self, inputs, mask=None, training=False, return_attention_weights=False): + weights = [] + + # Adding embedding and position encoding. + x = self.embedding(inputs) + x = x * ops.sqrt(ops.cast(self.embedding_dim, "float32")) + x = x + self.positional_encoding(x) + x = self.dropout(x, training=training) + + # Passing through the transformer blocks. + for i in range(self.num_blocks): + x, w = self.decode_layers[i](x, mask=mask, training=training) + weights.append(w) + if return_attention_weights: + return x, weights + return x + + +""" +### Music Transformer Decoder + +With the above layers defined, we can now define the MusicTransformerDecoder model. It applies +a linear transformation to the output of the decoder to get the logits for each token. +""" + + +@keras.utils.register_keras_serializable() +class MusicTransformerDecoder(keras.Model): + def __init__( + self, + embedding_dim=CONFIG.embedding_dim, + vocabulary_size=CONFIG.vocabulary_size, + num_blocks=CONFIG.num_transformer_blocks, + max_sequence_len=CONFIG.max_sequence_len, + dropout=0.2, + ): + # Initialize attributes + super(MusicTransformerDecoder, self).__init__() + self.embedding_dim = embedding_dim + self.vocabulary_size = vocabulary_size + self.num_blocks = num_blocks + self.max_sequence_len = max_sequence_len + + # Initialize layers + # Transformer decoder + self.decoder = Decoder( + embedding_dim, vocabulary_size, max_sequence_len, num_blocks, dropout + ) + # Output layer + self.fc = layers.Dense(self.vocabulary_size, activation=None, name="output") + + @staticmethod + def get_look_ahead_mask(max_sequence_len, inputs): + sequence_length = min(max_sequence_len, inputs.shape[1]) + sequence_mask = ops.logical_not( + ops.tri(sequence_length, sequence_length, dtype="bool") + ) + + inputs = ops.cast(inputs[:, None, None, :], "int32") + output_pad_tensor = ops.ones_like(inputs) * CONFIG.token_pad + decoder_output_mask = ops.equal(inputs, output_pad_tensor) + return ops.cast(ops.logical_or(decoder_output_mask, sequence_mask), "int32") + + def call(self, inputs, training=False): + mask = self.get_look_ahead_mask(self.max_sequence_len, inputs) + decoding = self.decoder( + inputs, mask=mask, training=training, return_attention_weights=False + ) + return self.fc(decoding) + + # --- Sequence generation methods + + def generate(self, inputs: list, length=CONFIG.max_sequence_len, top_k=5): + inputs = ops.convert_to_tensor([inputs]) + + # Generate a new token using output distribution at given index + def generate_token(inputs, end_idx): + distribution = ops.stop_gradient(self.call(inputs)[0, end_idx]) + + # Select the top-k tokens and their probabilities + top_k_distribution, top_k_indices = ops.top_k(distribution, k=top_k) + + # Sample from the top-k probabilities + new_token_idx = keras.random.categorical(top_k_distribution[None, :], 1) + return ops.take(top_k_indices, new_token_idx[0]) + + # Compute the number of tokens to add + added_tokens = min(length, self.max_sequence_len - inputs.shape[1]) + progbar = utils.Progbar(added_tokens, unit_name="token", interval=5) + + # Pad the input sequence that will be filled with generated tokens + out = ops.pad(inputs, ((0, 0), (0, added_tokens)), "constant", CONFIG.token_pad) + + # Generate tokens using top-k sampling + for token_idx in range(inputs.shape[1] - 1, inputs.shape[1] - 1 + added_tokens): + token = ops.cast(generate_token(out, end_idx=token_idx), out.dtype) + out = ops.scatter_update(out, ((0, token_idx + 1),), token) + progbar.add(1) + + return ops.convert_to_numpy(out[0]) + + # --- Serialization methods + + def get_config(self): + atts = ["embedding_dim", "vocabulary_size", "num_blocks", "max_sequence_len"] + return {a: getattr(self, a) for a in atts} + + @classmethod + def from_config(cls, config): + return cls(**config) + + +""" +### Loss function + +We define a custom loss function that computes the categorical cross-entropy +loss for the model. It is computed only for non-padding tokens and uses +`from_logits=True` since the model outputs logits. +""" + + +@keras.utils.register_keras_serializable() +def train_loss(y_true, y_pred): + mask = ops.cast(ops.logical_not(ops.equal(y_true, CONFIG.token_pad)), "float32") + y_true = ops.one_hot(ops.cast(y_true, "int32"), CONFIG.vocabulary_size) + return ops.categorical_crossentropy(y_true, y_pred, from_logits=True) * mask + + +""" +### Learning rate schedule + +Following the Music Transformer paper, we define an adapted exponential decay +learning rate schedule that takes into account the embedding dimension. +""" + + +@keras.utils.register_keras_serializable() +class CustomSchedule(optimizers.schedules.LearningRateSchedule): + def __init__(self, embedding_dim, warmup_steps=4000): + super(CustomSchedule, self).__init__() + + self.embedding_dim = embedding_dim + self.warmup_steps = warmup_steps + + self._embedding_dim = ops.cast(self.embedding_dim, "float32") + # Numerical stability adjustment on torch, which is less precise + self._lr_adjust = 0.1 if keras.backend.backend() == "torch" else 1.0 + + def get_config(self): + return {"embedding_dim": self.embedding_dim, "warmup_steps": self.warmup_steps} + + def __call__(self, step): + step_rsqrt = ops.rsqrt(ops.cast(step, "float32")) + warmup_adjust = step * (self.warmup_steps**-1.5) + output = ops.rsqrt(self._embedding_dim) * ops.minimum(step_rsqrt, warmup_adjust) + return self._lr_adjust * output + + +""" +## Training the model + +We can now train the model on the Maestro dataset. First, we define a training +function. This function compiles the model, trains it, and saves the best model +checkpoint. This way, we can continue training from the best model checkpoint +if needed. +""" + + +def train_model(model, train_ds, val_ds, epochs=15): + # Configure optimizer + learning_rate = CustomSchedule(CONFIG.embedding_dim) + optimizer = optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9) + + # Compile the model + model.compile(optimizer=optimizer, loss=train_loss) + + # Train the model + save_cb = callbacks.ModelCheckpoint(CONFIG.model_out, save_best_only=True) + model.fit( + train_ds, validation_data=val_ds, epochs=epochs, callbacks=[save_cb], verbose=2 + ) + return model + + +""" +We can now train the model on the Maestro dataset. If a model checkpoint exists, +we can load it and continue training. +""" +if path.exists(CONFIG.model_out): + model = keras.models.load_model(CONFIG.model_out) + # Comment out to continue model training from the checkpoint + # train_model(model, train_dataset, val_dataset, epochs=10) +else: + # Train the model + model = train_model(MusicTransformerDecoder(), train_dataset, val_dataset) + + +""" +## Generate music + +We can now generate music using the trained model. We use an existing MIDI file +as a seed and generate a new sequence. +""" + + +def generate_music(model, seed_path, length=1024, out_dir=None, top_k=None): + # Ensure the output directory exists + out_dir = out_dir if out_dir is not None else tempfile.mkdtemp() + os.makedirs(out_dir, exist_ok=True) + + # Get some tokens from the MIDI file + inputs = midi_tokenizer.encode_midi(seed_path)[100:125] + print(f"Seed tokens: {inputs}") + + # Generate music that follows the input tokens until the maximum length + result = model.generate(inputs, length=length, top_k=top_k) + + output_path = path.join(out_dir, path.basename(seed_path).split(".")[0] + ".mid") + midi_tokenizer.decode_midi(result, output_path) + return output_path + + +output_file = generate_music(model, val_paths[-1], out_dir="tmp/", top_k=15) +print(visualize_midi(output_file, out_dir="tmp/")) # Saved audio path +visualize_midi(output_file) # Display the audio if in a Jupyter notebook + +""" +## Conclusion + +In this example, we learned how to build a music generation model using a custom +Transformer decoder architecture. + +We did it following the Music Transformer paper by Huang et al. (2018). +To do so we had to: + +- Define a custom loss function and learning rate schedule. +- Define a custom attention mechanism. +- Preprocess MIDI files into a tokenized format. + +After training the model on the Maestro dataset, we generated music sequences +using a seed MIDI file. + +### Next steps + +We could further improve inference times by caching attention weights during the +forward pass, in a similar way as `keras_hub` `CausalLM` models, which use the +`CachedMultiHeadAttention` layer. +"""