Skip to content

Commit

Permalink
fix bug in model's parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
adelmemariani committed Dec 12, 2023
1 parent 29fc9d4 commit ffaa2c6
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,17 +483,17 @@ def __init__(self, hidden_size=None, embeddings_to_points_hidden_size=None, embe
super().__init__(**kwargs)

self.config = ElectraConfig(**kwargs["config"], output_attentions=True)
in_dim = self.config.hidden_size
hidden_dim = self.config.embeddings_to_points_hidden_size
out_dim = self.config.embeddings_dimensions
self.in_dim = self.config.hidden_size
self.hidden_dim = self.config.embeddings_to_points_hidden_size
self.out_dim = self.config.embeddings_dimensions

#self.boxes = nn.Parameter(torch.rand((self.config.num_labels, out_dim, 2)))
self.boxes = nn.Parameter(torch.rand((self.config.num_labels, out_dim, 2)) * 3 )
self.boxes = nn.Parameter(torch.rand((self.config.num_labels, self.out_dim, 2)) * 3 )

self.embeddings_to_points = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.Linear(self.in_dim, self.hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, out_dim)
nn.Linear(self.hidden_dim, self.out_dim)
)

def forward(self, data, **kwargs):
Expand Down Expand Up @@ -530,16 +530,16 @@ def __init__(self, hidden_size=None, embeddings_to_points_hidden_size=None, embe
super().__init__(**kwargs)

self.config = ElectraConfig(**kwargs["config"], output_attentions=True)
in_dim = self.config.hidden_size
hidden_dim = self.config.embeddings_to_points_hidden_size
out_dim = self.config.embeddings_dimensions
self.in_dim = self.config.hidden_size
self.hidden_dim = self.config.embeddings_to_points_hidden_size
self.out_dim = self.config.embeddings_dimensions

self.boxes = nn.Parameter(torch.rand((self.config.num_labels, out_dim, 2)))
self.boxes = nn.Parameter(torch.rand((self.config.num_labels, self.out_dim, 2)))

self.embeddings_to_points = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.Linear(self.in_dim, self.hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, out_dim)
nn.Linear(self.hidden_dim, self.out_dim)
)

def forward(self, data, **kwargs):
Expand Down

0 comments on commit ffaa2c6

Please sign in to comment.