Skip to content

Commit

Permalink
add simple Feed-forward network (for ESM2->chebi task)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel05 committed Jan 10, 2025
1 parent 7da8963 commit 976f2b8
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
55 changes: 55 additions & 0 deletions chebai/models/ffn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Dict, Any, Tuple

from chebai.models import ChebaiBaseNet
import torch
from torch import Tensor

class FFN(ChebaiBaseNet):

NAME = "FFN"

def __init__(self, input_size: int = 1000, num_hidden_layers: int = 3, hidden_size: int = 128, **kwargs):
super().__init__(**kwargs)

self.layers = torch.nn.ModuleList()
self.layers.append(torch.nn.Linear(input_size, hidden_size))
for _ in range(num_hidden_layers):
self.layers.append(torch.nn.Linear(hidden_size, hidden_size))
self.layers.append(torch.nn.Linear(hidden_size, self.out_dim))

def _get_prediction_and_labels(self, data, labels, model_output):
d = model_output["logits"]
loss_kwargs = data.get("loss_kwargs", dict())
if "non_null_labels" in loss_kwargs:
n = loss_kwargs["non_null_labels"]
d = data[n]
return torch.sigmoid(d), labels.int() if labels is not None else None

def _process_for_loss(
self,
model_output: Dict[str, Tensor],
labels: Tensor,
loss_kwargs: Dict[str, Any],
) -> Tuple[Tensor, Tensor, Dict[str, Any]]:
"""
Process the model output for calculating the loss.
Args:
model_output (Dict[str, Tensor]): The output of the model.
labels (Tensor): The target labels.
loss_kwargs (Dict[str, Any]): Additional loss arguments.
Returns:
tuple: A tuple containing the processed model output, labels, and loss arguments.
"""
kwargs_copy = dict(loss_kwargs)
if labels is not None:
labels = labels.float()
return model_output["logits"], labels, kwargs_copy

def forward(self, data, **kwargs):
x = data["features"]
for layer in self.layers:
x = torch.relu(layer(x))
return {"logits": x}

5 changes: 5 additions & 0 deletions configs/data/deepGO/deepgo2_esm2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData
init_args:
go_branch: "MF"
max_sequence_length: 1000
use_esm2_embeddings: True
7 changes: 7 additions & 0 deletions configs/model/ffn.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class_path: chebai.models.ffn.FFN
init_args:
optimizer_kwargs:
lr: 1e-3
hidden_size: 128
num_hidden_layers: 3
input_size: 2560

0 comments on commit 976f2b8

Please sign in to comment.