diff --git a/galore_torch/galore_projector.py b/galore_torch/galore_projector.py index 586e17f..242d74d 100755 --- a/galore_torch/galore_projector.py +++ b/galore_torch/galore_projector.py @@ -8,7 +8,7 @@ def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_typ self.scale = scale self.ortho_matrix = None self.proj_type = proj_type - + def project(self, full_rank_grad, iter): if self.proj_type == 'std': if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: @@ -40,7 +40,7 @@ def project(self, full_rank_grad, iter): if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full') low_rank_grad = torch.matmul(self.ortho_matrix[0].t().to(full_rank_grad.device.type), full_rank_grad) @ self.ortho_matrix[1].t().to(full_rank_grad.device.type) - + return low_rank_grad def project_back(self, low_rank_grad): @@ -60,11 +60,11 @@ def project_back(self, low_rank_grad): full_rank_grad = torch.matmul(self.ortho_matrix.to(low_rank_grad.device.type), low_rank_grad) elif self.proj_type == 'full': full_rank_grad = torch.matmul(self.ortho_matrix[0].to(low_rank_grad.device.type), low_rank_grad) @ self.ortho_matrix[1].to(low_rank_grad.device.type) - - + + return full_rank_grad * self.scale - - + + # svd decomposition def get_orthogonal_matrix(self, weights, rank, type): module_params = weights @@ -77,20 +77,17 @@ def get_orthogonal_matrix(self, weights, rank, type): else: float_data = True matrix = module_params.data - + U, s, Vh = torch.linalg.svd(matrix, full_matrices = False) - + #make the smaller matrix always to be orthogonal matrix if type=='right': - A = U[:, :rank] @ torch.diag(s[:rank]) B = Vh[:rank, :] - if not float_data: B = B.to(original_device).type(original_type) return B elif type=='left': A = U[:, :rank] - B = torch.diag(s[:rank]) @ Vh[:rank, :] if not float_data: A = A.to(original_device).type(original_type) return A