Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about the framwork #5

Open
nehdiii opened this issue Sep 1, 2024 · 5 comments
Open

Question about the framwork #5

nehdiii opened this issue Sep 1, 2024 · 5 comments

Comments

@nehdiii
Copy link

nehdiii commented Sep 1, 2024

I want to understand the overall framework technically. From my understanding and from reading the paper, you are training vectors of shape [input_hidden, 1]—for example, [384, 1] in the case of ViT Small—for each dataset over a few iterations. Then, at inference time, you stack them into a vector of size [384, N], where N is the number of datasets or domains. In this setup, sparse MoE is used, and self.w_gate is replaced with this vector. If I'm wrong in my understanding, please help me correct it.

@muqeeth
Copy link
Collaborator

muqeeth commented Sep 7, 2024

Hey! that's correct. Additionally, when training the gate vectors, you need to freeze entire model with adapter included. During inference, once you stack the gate vectors to make into a linear layer, make sure to do normalization to account for the fact that these gates are independent and can have varying norms. In papers, we normalized to mean zero and standard deviation of 1.

@nehdiii
Copy link
Author

nehdiii commented Sep 7, 2024

Thank you for your response. There’s one thing left: is there any ablation on training the vectors you mentioned in the paper? You stated that these vectors only need a few iterations (around 100), and then you use them. When I try this in my case, specifically in a ReID vision task, I see that, even with the same hyperparameters used during training, the accuracy increases slowly. How can I determine if my vectors are correctly learning the routing paths? Also, if you don’t mind, could you provide me with your contact information? I’d love to discuss ideas further. Thank you once again!

@muqeeth
Copy link
Collaborator

muqeeth commented Sep 7, 2024

Hey, for training gates, there's no concrete objective we used to measure if the gates are trained properly. In our paper, we did 10% of training steps for experts (which is 100 steps) with all hyper parameters same as expert's training. The output of sigmoid gates start from 0.5, so initial loss should be around same value as the end of expert's training and make sure to double check that loss doesn't go much higher during gate training (if so, try lower lrs). As long as it is around same value, training for a fixed number of steps should give us reasonable gates to use post-hoc.

@nehdiii
Copy link
Author

nehdiii commented Sep 7, 2024

Thx a lot For explanation Mohammed

@nehdiii
Copy link
Author

nehdiii commented Sep 7, 2024

this is a simplifed implmentation in case of Vit i use block in routeing vector training and PHATGOOSE block in case of inference im wondring if im missing sth in my code

`class Block(nn.Module):

def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
             drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,seq_len=129):
    super().__init__()
    self.norm1 = norm_layer(dim)
    self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
    # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    self.norm2 = norm_layer(dim)
    mlp_hidden_dim = int(dim * mlp_ratio)
    self.mlp = MlpLoRA(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    # learn PHATGOOSE gate 

    self.learn_input_gate = True
    if self.learn_input_gate:
        self.expert_input_gate = nn.Parameter(torch.zeros(dim))

def forward(self, x, register_hook=False):
    x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
    
    if self.learn_input_gate:
        xskip = x 
        x = self.norm2(x)
        
        input_gate_scores = torch.sum(x * self.expert_input_gate,dim=-1)
        input_gate_probs = torch.sigmoid(input_gate_scores)
        input_gate = input_gate_probs
        x = x * input_gate.unsqueeze(-1)

        x = xskip + self.drop_path(self.mlp(x))
    else:
       
        x = x + self.drop_path(self.mlp(self.norm2(x)))
    
    return x

class PHATGOOSEBlock(nn.Module):

def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
             drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,seq_len=129):
    super().__init__()
    self.norm1 = norm_layer(dim)
    self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
    # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    self.norm2 = norm_layer(dim)
    mlp_hidden_dim = int(dim * mlp_ratio)

    # lets create a simple moe with only one expert 
    self.num_experts = [2,2]
    self.mlp_experts = nn.ModuleList([MlpLoRA(in_features=dim, 
                                        hidden_features=mlp_hidden_dim, 
                                        act_layer=act_layer, 
                                        drop=drop,lora_rank=_)
                                    for _ in self.num_experts])
    
    self.output_size = dim
    self.input_size = dim
    self.hidden_size = seq_len
    self.k = 1

    # init gateing 
    self.w_gate = nn.Parameter(torch.zeros(self.input_size, len(self.num_experts)),requires_grad=False)
    self.softmax = nn.Softmax(1)

    


def top_k_gating(self, x):
    
    
    # x = x /  torch.norm(x, dim=-1, keepdim=True) + 1e-6
    # w_gate = self.w_gate /  torch.norm(self.w_gate, dim=-1, keepdim=True) + 1e-6
    
    clean_logits = x @ self.w_gate
    logits = clean_logits


    # calculate topk + 1 that will be needed for the noisy gates
    logits = self.softmax(logits)

    top_logits, top_indices = logits.topk(min(self.k , len(self.num_experts)), dim=1)
    top_k_logits = top_logits[:, :self.k]
    top_k_indices = top_indices[:, :self.k]
    top_k_gates = top_k_logits / (top_k_logits.sum(1, keepdim=True) + 1e-6)  # normalization
    
    zeros = torch.zeros_like(logits, requires_grad=True)
    gates = zeros.scatter(1, top_k_indices, top_k_gates)


    return gates

def forward(self, x, register_hook=False):

    
    # attn first 
    x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
    # experts
    x_skip = x 
    x = self.norm2(x) 
    
    bsz, length, emb_size = x.size()
    x = x.reshape(-1, emb_size)
    gates = self.top_k_gating(x)
    dispatcher = SparseDispatcher(len(self.num_experts), gates)
    expert_inputs = dispatcher.dispatch(x)
    gates = dispatcher.expert_to_gates()
    expert_outputs = [self.mlp_experts[i](expert_inputs[i]) for i in range(len(self.num_experts))]
    x = dispatcher.combine(expert_outputs)
    x = x.view(bsz, length, self.input_size)


    
    
    return x

`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants