From 54bca3fdb27b763e49f4f5b7017e666be37d2dcf Mon Sep 17 00:00:00 2001 From: Robert Date: Fri, 26 Apr 2024 13:22:38 -0700 Subject: [PATCH] Final changes for tensor decomposition. --- galore_torch/adamw.py | 8 +- galore_torch/galore_projector_sketching.py | 126 --------------- galore_torch/galore_projector_tensor.py | 102 ++++++++++++ scripts/tensor_test/neural_operator.py | 173 +++++++++++++++++++++ 4 files changed, 280 insertions(+), 129 deletions(-) delete mode 100755 galore_torch/galore_projector_sketching.py create mode 100644 galore_torch/galore_projector_tensor.py create mode 100644 scripts/tensor_test/neural_operator.py diff --git a/galore_torch/adamw.py b/galore_torch/adamw.py index 49be772..a677985 100755 --- a/galore_torch/adamw.py +++ b/galore_torch/adamw.py @@ -10,7 +10,7 @@ from transformers.utils.versions import require_version from .galore_projector import GaLoreProjector -from .galore_projector_sketching import GaLoreProjectorSketching +from .galore_projector_tensor import GaLoreProjectorTensor class AdamW(Optimizer): @@ -92,8 +92,10 @@ def step(self, closure: Callable = None): # GaLore Projection if "rank" in group: if "projector" not in state: - state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"]) - + if group['dim'] <=2: + state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"]) + else: + state["projector"] = GaLoreProjectorTensor(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"]) grad = state["projector"].project(grad, state["step"]) # State initialization diff --git a/galore_torch/galore_projector_sketching.py b/galore_torch/galore_projector_sketching.py deleted file mode 100755 index 79f346b..0000000 --- a/galore_torch/galore_projector_sketching.py +++ /dev/null @@ -1,126 +0,0 @@ -import torch - -def sketch_for_low_rank_approx_left(A, k, l, old_Q = False, Omega_old=None, old_Y = None): - """ - Implement Algorithm 1: Sketch for Low-Rank Approximation - - Args: - A (torch.Tensor): Input matrix of size (m, n) - k (int): Sketch size parameter for range sketch - l (int): Sketch size parameter for co-range sketch - - Returns: - Omega (torch.Tensor): Random test matrix of size (n, k) - Psi (torch.Tensor): Random test matrix of size (l, m) - Y (torch.Tensor): Range sketch Y = AΩ of size (m, k) - W (torch.Tensor): Co-range sketch W = ΨA of size (l, n) - """ - m, n = A.size() - original_type = A.data.dtype - original_device = A.data.device - theta = 1.0 - zeta = 1.0 - # Generate random test matrices - if old_Q: - Omega = Omega_old - Y = theta * old_Y + zeta * A @ Omega - else: - Omega = torch.randn(n, k).to(original_device).type(original_type) - Y = A @ Omega - Psi = None - W = None - - Q, _ = torch.linalg.qr(Y.type(torch.float32)) - Q = Q.type(torch.float32) - return Omega, Psi, Y, W, Q - -def sketch_for_low_rank_approx_right(A, k, l, old_Q = False, Psi_old=None, old_W = None): - m, n = A.size() - original_type = A.data.dtype - original_device = A.data.device - theta = 1.0 - zeta = 1.0 - # Generate random test matrices - if old_Q: - Psi = Psi_old - W = theta * old_W + zeta * Psi @ A - else: - Psi = torch.randn(l, m).to(original_device).type(original_type) - W = Psi @ A - Omega = None - Y = None - - Q, _ = torch.linalg.qr(W.T.type(torch.float32)) - Q = Q.type(torch.float32) - return Omega, Psi, Y, W, Q - -def low_rank_approx(Y, W, Psi): - """ - Implement Algorithm 4: Low-Rank Approximation - - Args: - Y (torch.Tensor): Range sketch Y = AΩ of size (m, k) - W (torch.Tensor): Co-range sketch W = ΨA of size (l, n) - Psi (torch.Tensor): Random test matrix Psi of size (l, m) - - Returns: - Q (torch.Tensor): Orthonormal basis for range of Y of size (m, k) - X (torch.Tensor): Factor matrix of size (k, n) - A_approx (torch.Tensor): Low-rank approximation QX of size (m, n) - """ - # Step 1: Form an orthogonal basis for the range of Y - Q, _ = torch.linalg.qr(Y.type(torch.float32)) - Q = Q.type(torch.float32) - Psi = Psi.type(torch.float32) - W = W.type(torch.float32) - # Step 2: Orthogonal-triangular factorization of ΨQ - PsiQ = Psi @ Q - U, T = torch.linalg.qr(PsiQ) - - # Step 3: Solve the least-squares problem to obtain X - X = torch.linalg.lstsq(T, U.T @ W).solution - - # Step 4: Construct the rank-k approximation - A_approx = Q @ X - - return Q, X, A_approx - -class GaLoreProjectorSketching: - def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'): - self.rank = rank - self.verbose = verbose - self.update_proj_gap = update_proj_gap - self.scale = scale - self.proj_type = proj_type - self.Omega = None - self.Psi = None - self.Y = None - self.W = None - - def project(self, full_rank_grad, iter): - if self.Omega is None or self.Psi is None or iter % self.update_proj_gap == 0: - if self.Omega is None or self.Psi is None: - if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: - self.Omega, self.Psi, self.Y, self.W, self.Q = sketch_for_low_rank_approx_left(full_rank_grad, self.rank, self.rank, False, None, None) - else: - self.Omega, self.Psi, self.Y, self.W, self.Q = sketch_for_low_rank_approx_right(full_rank_grad, self.rank, self.rank, False, None, None) - else: - if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: - self.Omega, self.Psi, self.Y, self.W, self.Q = sketch_for_low_rank_approx_left(full_rank_grad, self.rank, self.rank, True, self.Psi, self.Y) - else: - self.Omega, self.Psi, self.Y, self.W, self.Q = sketch_for_low_rank_approx_right(full_rank_grad, self.rank, self.rank, True, self.Omega, self.W) - original_device = full_rank_grad.device - original_type = full_rank_grad.dtype - self.Q = self.Q.to(original_device).type(original_type) - if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: - low_rank_grad = torch.matmul(self.Q.T, full_rank_grad) - else: - low_rank_grad = torch.matmul(full_rank_grad, self.Q) - return low_rank_grad - - def project_back(self, low_rank_grad): - if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: - full_rank_grad = torch.matmul(self.Q, low_rank_grad.T).T - else: - full_rank_grad = torch.matmul(low_rank_grad.T, self.Q.T).T - return full_rank_grad * self.scale \ No newline at end of file diff --git a/galore_torch/galore_projector_tensor.py b/galore_torch/galore_projector_tensor.py new file mode 100644 index 0000000..ae8e104 --- /dev/null +++ b/galore_torch/galore_projector_tensor.py @@ -0,0 +1,102 @@ +import torch +from tensorly.decomposition import tucker +from tensorly import tenalg + +# The GaLoreProjector class in Python implements a projection method using orthogonal matrix +# decomposition for low-rank approximation of gradients for general tensors of dimension >2. +# We use tensor decomposition using tensorly library: https://tensorly.org/stable/index.html +class GaLoreProjectorTensor: + """ + A class that represents a projector for the GaLore algorithm. + + Args: + rank (int): The rank of the projector. + verbose (bool, optional): Whether to print verbose output. Defaults to False. + update_proj_gap (int, optional): The number of iterations between updating the orthogonal matrix. Defaults to 200. + scale (float, optional): The scaling factor for the projected gradients. Defaults to 1.0. + """ + + def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0): + self.rank = rank + self.verbose = verbose + self.update_proj_gap = update_proj_gap + self.scale = scale + self.ortho_matrix = None + self.transformed_low_rank = None + + def project(self, full_rank_grad, iter): + """ + Projects the full-rank gradients onto the low-rank subspace. + + Args: + full_rank_grad (torch.Tensor): The full-rank gradients. + iter (int): The current iteration. + + Returns: + torch.Tensor: The transformed low-rank gradients. + """ + if self.ortho_matrix is None and iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank) + self.transformed_low_rank = self.transform(self.ortho_matrix, full_rank_grad) + return self.transformed_low_rank + + def project_back(self, low_rank_grad): + """ + Projects the low-rank gradients back to the full-rank space. + + Args: + low_rank_grad (torch.Tensor): The low-rank gradients. + + Returns: + torch.Tensor: The full-rank gradients. + """ + full_rank_grad = self.inverse_transform(self.ortho_matrix, self.transformed_low_rank) + return full_rank_grad * self.scale + + # svd decomposition + def get_orthogonal_matrix(self, weights, rank_all): + """ + Computes the orthogonal matrix using SVD decomposition. + + Args: + weights (torch.Tensor): The weights to decompose. + rank_all (int): The desired rank of the decomposition. + + Returns: + tuple: A tuple containing the core and factors of the orthogonal matrix. + """ + module_params = weights + if module_params.data.dtype != torch.float: + matrix = module_params.data.float() + else: + matrix = module_params.data + tucker_tensor = tucker(matrix, rank=rank_all) + return tucker_tensor + + def transform(self, tensor, x): + """ + Transforms the input tensor using the factors of the orthogonal matrix. + + Args: + tensor (tuple): A tuple containing the core and factors of the orthogonal matrix. + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The transformed tensor. + """ + _, factors = tensor + return tenalg.multi_mode_dot(x, factors, transpose=True) + + def inverse_transform(self, tensor, x): + """ + Inverse transforms the input tensor using the factors of the orthogonal matrix. + + Args: + tensor (tuple): A tuple containing the core and factors of the orthogonal matrix. + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The inverse transformed tensor. + """ + _, factors = tensor + return tenalg.multi_mode_dot(x, factors) diff --git a/scripts/tensor_test/neural_operator.py b/scripts/tensor_test/neural_operator.py new file mode 100644 index 0000000..a328b16 --- /dev/null +++ b/scripts/tensor_test/neural_operator.py @@ -0,0 +1,173 @@ +""" +Training a neural operator on Darcy-Flow - Author Robert Joseph +======================================== +In this example, we demonstrate how to use the small Darcy-Flow example we ship with the package on Incremental FNO and Incremental Resolution as well as using Galore tensor decomposition. + +Assuming one installs the neuraloperator library: Instructions can be found here: https://github.com/NeuralOperator/neuraloperator +""" + +# %% +# +import torch +import matplotlib.pyplot as plt +import sys +from neuralop.training.callbacks import BasicLoggerCallback +from neuralop.models import FNO +from neuralop import Trainer +from neuralop.datasets import load_darcy_flow_small +from neuralop.utils import count_model_params +from neuralop.training.callbacks import IncrementalCallback +from neuralop.datasets import data_transforms +from neuralop import LpLoss, H1Loss +from neuralop.training import AdamW +from neuralop.utils import count_model_params + + +# %% +# Loading the Darcy flow dataset +train_loader, test_loaders, data_processor = load_darcy_flow_small( + n_train=1000, batch_size=32, + test_resolutions=[16, 32], n_tests=[100, 50], + test_batch_sizes=[32, 32], + positional_encoding=True +) +# %% +# Choose device +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# %% +# Set up the incremental FNO model +# We start with 2 modes in each dimension +# We choose to update the modes by the incremental gradient explained algorithm + +starting_modes = (10, 10) +incremental = False + +model = FNO( + max_n_modes=(20, 20), + n_modes=starting_modes, + hidden_channels=64, + in_channels=1, + out_channels=1, + n_layers=4 +) +callbacks = [ + IncrementalCallback( + incremental_loss_gap=True, + incremental_grad=False, + incremental_grad_eps=0.9999, + incremental_buffer=5, + incremental_max_iter=1, + incremental_grad_max_iter=2, + ) +] +model = model.to(device) +n_params = count_model_params(model) +galore_params = [] +galore_params.extend(list(model.fno_blocks.convs.parameters())) +print(galore_params[0].shape, galore_params[1].shape, galore_params[2].shape, galore_params[3].shape) +galore_params.pop(0) +id_galore_params = [id(p) for p in galore_params] +# make parameters without "rank" to another group +regular_params = [p for p in model.parameters() if id(p) not in id_galore_params] +# then call galore_adamw +# In this case we have a 5d tensor representing the weights in the spectral layers of the FNO +# A good rule of thumb for tensor decomposition is that we should limit the rank to atmost 0.75, and increase the epochs and tune the lr accordingly compared to the baseline. +# Low rank decomposition takes longer to converge, but it is more memory efficient. +param_groups = [{'params': regular_params}, + {'params': galore_params, 'rank': 0.2 , 'update_proj_gap': 10, 'scale': 0.25, 'proj_type': "std", 'dim': 5}] +optimizer = AdamW(param_groups, lr=0.01) +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30) +data_transform = data_transforms.IncrementalDataProcessor( + in_normalizer=None, + out_normalizer=None, + positional_encoding=None, + device=device, + dataset_sublist=[2, 1], + dataset_resolution=16, + dataset_indices=[2, 3], + epoch_gap=10, + verbose=True, +) + +data_transform = data_transform.to(device) +# %% +# Set up the losses +l2loss = LpLoss(d=2, p=2) +h1loss = H1Loss(d=2) +train_loss = h1loss +eval_losses = {"h1": h1loss, "l2": l2loss} +print("\n### OPTIMIZER rank ###\n", i, optimizer) +sys.stdout.flush() + +# Finally pass all of these to the Trainer +trainer = Trainer( + model=model, + n_epochs=100, + data_processor=data_transform, + callbacks=callbacks, + device=device, + verbose=True, +) + +# %% +# Train the model +trainer.train( + train_loader, + test_loaders, + optimizer, + scheduler, + regularizer=False, + training_loss=train_loss, + eval_losses=eval_losses, +) + +# %% +# Plot the prediction, and compare with the ground-truth +# Note that we trained on a very small resolution for +# a very small number of epochs +# In practice, we would train at larger resolution, on many more samples. +# +# However, for practicity, we created a minimal example that +# i) fits in just a few Mb of memory +# ii) can be trained quickly on CPU +# +# In practice we would train a Neural Operator on one or multiple GPUs + +test_samples = test_loaders[32].dataset + +fig = plt.figure(figsize=(7, 7)) +for index in range(3): + data = test_samples[index] + # Input x + x = data["x"].to(device) + # Ground-truth + y = data["y"].to(device) + # Model prediction + out = model(x.unsqueeze(0)) + ax = fig.add_subplot(3, 3, index * 3 + 1) + x = x.cpu().squeeze().detach().numpy() + y = y.cpu().squeeze().detach().numpy() + ax.imshow(x, cmap="gray") + if index == 0: + ax.set_title("Input x") + plt.xticks([], []) + plt.yticks([], []) + + ax = fig.add_subplot(3, 3, index * 3 + 2) + ax.imshow(y.squeeze()) + if index == 0: + ax.set_title("Ground-truth y") + plt.xticks([], []) + plt.yticks([], []) + + ax = fig.add_subplot(3, 3, index * 3 + 3) + ax.imshow(out.cpu().squeeze().detach().numpy()) + if index == 0: + ax.set_title("Model prediction") + plt.xticks([], []) + plt.yticks([], []) + +fig.suptitle("Inputs, ground-truth output and prediction.", y=0.98) +plt.tight_layout() +fig.show() \ No newline at end of file