Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
jiaweizzhao committed Sep 10, 2024
1 parent 8691652 commit 55b8419
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions galore_torch/galore_projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_typ
self.scale = scale
self.ortho_matrix = None
self.proj_type = proj_type

def project(self, full_rank_grad, iter):
if self.proj_type == 'std':
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
Expand Down Expand Up @@ -40,7 +40,7 @@ def project(self, full_rank_grad, iter):
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full')
low_rank_grad = torch.matmul(self.ortho_matrix[0].t().to(full_rank_grad.device.type), full_rank_grad) @ self.ortho_matrix[1].t().to(full_rank_grad.device.type)

return low_rank_grad

def project_back(self, low_rank_grad):
Expand All @@ -60,11 +60,11 @@ def project_back(self, low_rank_grad):
full_rank_grad = torch.matmul(self.ortho_matrix.to(low_rank_grad.device.type), low_rank_grad)
elif self.proj_type == 'full':
full_rank_grad = torch.matmul(self.ortho_matrix[0].to(low_rank_grad.device.type), low_rank_grad) @ self.ortho_matrix[1].to(low_rank_grad.device.type)


return full_rank_grad * self.scale


# svd decomposition
def get_orthogonal_matrix(self, weights, rank, type):
module_params = weights
Expand All @@ -77,20 +77,17 @@ def get_orthogonal_matrix(self, weights, rank, type):
else:
float_data = True
matrix = module_params.data

U, s, Vh = torch.linalg.svd(matrix, full_matrices = False)

#make the smaller matrix always to be orthogonal matrix
if type=='right':
A = U[:, :rank] @ torch.diag(s[:rank])
B = Vh[:rank, :]

if not float_data:
B = B.to(original_device).type(original_type)
return B
elif type=='left':
A = U[:, :rank]
B = torch.diag(s[:rank]) @ Vh[:rank, :]
if not float_data:
A = A.to(original_device).type(original_type)
return A
Expand Down

0 comments on commit 55b8419

Please sign in to comment.