diff --git a/chebai/models/electra.py b/chebai/models/electra.py index abbd1406..6a234d14 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -483,17 +483,17 @@ def __init__(self, hidden_size=None, embeddings_to_points_hidden_size=None, embe super().__init__(**kwargs) self.config = ElectraConfig(**kwargs["config"], output_attentions=True) - in_dim = self.config.hidden_size - hidden_dim = self.config.embeddings_to_points_hidden_size - out_dim = self.config.embeddings_dimensions + self.in_dim = self.config.hidden_size + self.hidden_dim = self.config.embeddings_to_points_hidden_size + self.out_dim = self.config.embeddings_dimensions #self.boxes = nn.Parameter(torch.rand((self.config.num_labels, out_dim, 2))) - self.boxes = nn.Parameter(torch.rand((self.config.num_labels, out_dim, 2)) * 3 ) + self.boxes = nn.Parameter(torch.rand((self.config.num_labels, self.out_dim, 2)) * 3 ) self.embeddings_to_points = nn.Sequential( - nn.Linear(in_dim, hidden_dim), + nn.Linear(self.in_dim, self.hidden_dim), nn.ReLU(), - nn.Linear(hidden_dim, out_dim) + nn.Linear(self.hidden_dim, self.out_dim) ) def forward(self, data, **kwargs): @@ -530,16 +530,16 @@ def __init__(self, hidden_size=None, embeddings_to_points_hidden_size=None, embe super().__init__(**kwargs) self.config = ElectraConfig(**kwargs["config"], output_attentions=True) - in_dim = self.config.hidden_size - hidden_dim = self.config.embeddings_to_points_hidden_size - out_dim = self.config.embeddings_dimensions + self.in_dim = self.config.hidden_size + self.hidden_dim = self.config.embeddings_to_points_hidden_size + self.out_dim = self.config.embeddings_dimensions - self.boxes = nn.Parameter(torch.rand((self.config.num_labels, out_dim, 2))) + self.boxes = nn.Parameter(torch.rand((self.config.num_labels, self.out_dim, 2))) self.embeddings_to_points = nn.Sequential( - nn.Linear(in_dim, hidden_dim), + nn.Linear(self.in_dim, self.hidden_dim), nn.ReLU(), - nn.Linear(hidden_dim, out_dim) + nn.Linear(self.hidden_dim, self.out_dim) ) def forward(self, data, **kwargs):