From 976f2b895e3ee8fce4a9bcbde6ace30539e7845a Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 10 Jan 2025 13:55:50 +0100 Subject: [PATCH] add simple Feed-forward network (for ESM2->chebi task) --- chebai/models/ffn.py | 55 ++++++++++++++++++++++++++++ configs/data/deepGO/deepgo2_esm2.yml | 5 +++ configs/model/ffn.yml | 7 ++++ 3 files changed, 67 insertions(+) create mode 100644 chebai/models/ffn.py create mode 100644 configs/data/deepGO/deepgo2_esm2.yml create mode 100644 configs/model/ffn.yml diff --git a/chebai/models/ffn.py b/chebai/models/ffn.py new file mode 100644 index 00000000..77046ae6 --- /dev/null +++ b/chebai/models/ffn.py @@ -0,0 +1,55 @@ +from typing import Dict, Any, Tuple + +from chebai.models import ChebaiBaseNet +import torch +from torch import Tensor + +class FFN(ChebaiBaseNet): + + NAME = "FFN" + + def __init__(self, input_size: int = 1000, num_hidden_layers: int = 3, hidden_size: int = 128, **kwargs): + super().__init__(**kwargs) + + self.layers = torch.nn.ModuleList() + self.layers.append(torch.nn.Linear(input_size, hidden_size)) + for _ in range(num_hidden_layers): + self.layers.append(torch.nn.Linear(hidden_size, hidden_size)) + self.layers.append(torch.nn.Linear(hidden_size, self.out_dim)) + + def _get_prediction_and_labels(self, data, labels, model_output): + d = model_output["logits"] + loss_kwargs = data.get("loss_kwargs", dict()) + if "non_null_labels" in loss_kwargs: + n = loss_kwargs["non_null_labels"] + d = data[n] + return torch.sigmoid(d), labels.int() if labels is not None else None + + def _process_for_loss( + self, + model_output: Dict[str, Tensor], + labels: Tensor, + loss_kwargs: Dict[str, Any], + ) -> Tuple[Tensor, Tensor, Dict[str, Any]]: + """ + Process the model output for calculating the loss. + + Args: + model_output (Dict[str, Tensor]): The output of the model. + labels (Tensor): The target labels. + loss_kwargs (Dict[str, Any]): Additional loss arguments. + + Returns: + tuple: A tuple containing the processed model output, labels, and loss arguments. + """ + kwargs_copy = dict(loss_kwargs) + if labels is not None: + labels = labels.float() + return model_output["logits"], labels, kwargs_copy + + def forward(self, data, **kwargs): + x = data["features"] + for layer in self.layers: + x = torch.relu(layer(x)) + return {"logits": x} + diff --git a/configs/data/deepGO/deepgo2_esm2.yml b/configs/data/deepGO/deepgo2_esm2.yml new file mode 100644 index 00000000..4b3ae3b1 --- /dev/null +++ b/configs/data/deepGO/deepgo2_esm2.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData +init_args: + go_branch: "MF" + max_sequence_length: 1000 + use_esm2_embeddings: True \ No newline at end of file diff --git a/configs/model/ffn.yml b/configs/model/ffn.yml new file mode 100644 index 00000000..193c6f64 --- /dev/null +++ b/configs/model/ffn.yml @@ -0,0 +1,7 @@ +class_path: chebai.models.ffn.FFN +init_args: + optimizer_kwargs: + lr: 1e-3 + hidden_size: 128 + num_hidden_layers: 3 + input_size: 2560