diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 7009406d..dc6c719b 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -329,7 +329,7 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: except RuntimeError as e: print(f"RuntimeError at forward: {e}") print(f'data[features]: {data["features"]}') - raise Exception + raise e inp = self.word_dropout(inp) electra = self.electra(inputs_embeds=inp, **kwargs) d = electra.last_hidden_state[:, 0, :] diff --git a/chebai/models/ffn.py b/chebai/models/ffn.py new file mode 100644 index 00000000..ca1f6f22 --- /dev/null +++ b/chebai/models/ffn.py @@ -0,0 +1,61 @@ +from typing import Dict, Any, Tuple + +from chebai.models import ChebaiBaseNet +import torch +from torch import Tensor + + +class FFN(ChebaiBaseNet): + + NAME = "FFN" + + def __init__( + self, + input_size: int = 1000, + num_hidden_layers: int = 3, + hidden_size: int = 128, + **kwargs + ): + super().__init__(**kwargs) + + self.layers = torch.nn.ModuleList() + self.layers.append(torch.nn.Linear(input_size, hidden_size)) + for _ in range(num_hidden_layers): + self.layers.append(torch.nn.Linear(hidden_size, hidden_size)) + self.layers.append(torch.nn.Linear(hidden_size, self.out_dim)) + + def _get_prediction_and_labels(self, data, labels, model_output): + d = model_output["logits"] + loss_kwargs = data.get("loss_kwargs", dict()) + if "non_null_labels" in loss_kwargs: + n = loss_kwargs["non_null_labels"] + d = data[n] + return torch.sigmoid(d), labels.int() if labels is not None else None + + def _process_for_loss( + self, + model_output: Dict[str, Tensor], + labels: Tensor, + loss_kwargs: Dict[str, Any], + ) -> Tuple[Tensor, Tensor, Dict[str, Any]]: + """ + Process the model output for calculating the loss. + + Args: + model_output (Dict[str, Tensor]): The output of the model. + labels (Tensor): The target labels. + loss_kwargs (Dict[str, Any]): Additional loss arguments. + + Returns: + tuple: A tuple containing the processed model output, labels, and loss arguments. + """ + kwargs_copy = dict(loss_kwargs) + if labels is not None: + labels = labels.float() + return model_output["logits"], labels, kwargs_copy + + def forward(self, data, **kwargs): + x = data["features"] + for layer in self.layers: + x = torch.relu(layer(x)) + return {"logits": x} diff --git a/chebai/preprocessing/bin/protein_token/tokens.txt b/chebai/preprocessing/bin/protein_token/tokens.txt index 72ad1b6d..c31c5b72 100644 --- a/chebai/preprocessing/bin/protein_token/tokens.txt +++ b/chebai/preprocessing/bin/protein_token/tokens.txt @@ -18,3 +18,4 @@ W E V H +X diff --git a/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt b/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt index 69dca126..534e5db1 100644 --- a/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt +++ b/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt @@ -7998,3 +7998,362 @@ WWC WCC WCH WWM +TAX +AXD +XDR +IEX +EXV +QAX +AXX +XXE +XES +MXN +XNF +NRX +RXX +XXX +XXR +XRI +SAX +AXG +XGG +PRX +RXR +XRX +RXE +XEF +QEX +EXQ +XQR +REX +EXR +RXQ +XQQ +DRX +RXP +XPG +QMX +MXT +XTX +TXR +XRM +APX +PXX +XXG +XGI +NLX +LXX +XXM +XMA +LNX +NXE +XEA +GTX +TXN +XND +LIX +IXI +XIM +MVX +VXX +XXK +XKT +GLX +LXP +XPP +QGX +GXD +XDL +XAP +QNX +NXM +XMN +VAX +XGV +IKX +KXY +KEX +EXL +XLY +GQX +QXE +XEP +PLX +XKC +PVX +XKE +RXI +XIR +AXL +XLN +LLX +LXD +XDA +AXE +XEL +GGX +GXG +KAX +XXA +XAG +XWS +SPX +PXC +XCD +GWX +WXH +XHF +MPX +ESX +SXN +XNK +DLX +LXN +XNS +QXG +XGD +ITX +XRG +NEX +EXA +XAL +LDX +DXI +XII +TPX +PXM +XMR +NXG +XGY +ASX +SXV +XVE +TKX +KXA +KRX +XXT +XTL +IDX +DXX +XXL +XLV +AKX +KXX +QHX +HXV +XVN +NSX +SXX +XKX +XDP +DAX +AXK +XKQ +PIX +IXX +XXF +VLX +XDI +DIX +IXL +XLK +LKX +KXV +XVA +DNX +NXD +ILX +LXK +XKV +VYX +YXE +XEI +RXS +XSH +KGX +XGF +AVX +VXY +XYG +HVX +XXI +XID +TVX +XXS +XSA +ENX +NXX +XMD +IIX +XMQ +AEX +EXX +XME +PGX +GXP +XPR +SKX +KXF +XFT +HRX +XSW +PQX +XGR +QQX +VTX +XRP +PSX +SXP +XPL +VGX +GXY +RSX +SXS +XSL +VSX +XST +AXV +XVL +AGX +GXX +XTK +KLX +LXR +XRV +AHX +HXC +XCS +LVX +VXN +XNR +NGX +GXL +TSX +SXQ +XQN +KXL +XLL +VIX +IXG +XGA +GFX +FXG +XGL +PTX +TXT +XTS +EMX +MXQ +SXY +XYA +IQX +QXY +XYR +TXK +IGX +XPS +PXT +XTG +NXQ +VKX +KXS +XSN +GVX +VXE +GRX +XRE +YKX +KXE +XEE +EEX +EXT +XTI +EHX +HXN +XNL +NDX +DXD +IAX +KSX +SXL +RRX +XRK +DDX +DXE +RXG +VXL +XLS +DTX +TXG +VXF +XFA +XIG +VXT +XTA +ISX +SXR +XRY +VQX +QXP +XPC +LGX +GXS +HGX +XGH +XXD +XDD +KKX +XXV +PKX +XLT +XSP +XLD +RAX +AXS +XSI +IYX +YXX +XXP +XPI +MSX +SXT +GEX +XHP +LFX +FXX +VXI +XIW +QTX +TXX +XXQ +XQA +FLX +DXN +XNC +MXS +XSR +YLX +EQX +QXS +TMX +MXC +XCY +NXA +XAV +EXE +XEQ +HPX +PXP +LMX +MXX +KTX +XKK +XXH +XHS +MKX +XIH +WRX +XKS +EXY +XYQ +QKX diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index dfa0f999..6158b9dc 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -155,8 +155,19 @@ def fold_dir(self) -> str: return f"cv_{self.inner_k_folds}_fold" @property + @abstractmethod def _name(self) -> str: - raise NotImplementedError + """ + Abstract property representing the name of the data module. + + This property should be implemented in subclasses to provide a unique name for the data module. + The name is used to create subdirectories within the base directory or `processed_dir_main` + for storing relevant data associated with this module. + + Returns: + str: The name of the data module. + """ + pass def _filter_labels(self, row: dict) -> dict: """ @@ -728,7 +739,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: processed_name = self.processed_main_file_names_dict["data"] if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): - print("Missing processed data file (`data.pkl` file)") + print(f"Missing processed data file (`{processed_name}` file)") os.makedirs(self.processed_dir_main, exist_ok=True) data_path = self._download_required_data() g = self._extract_class_hierarchy(data_path) @@ -812,7 +823,10 @@ def setup_processed(self) -> None: None """ os.makedirs(self.processed_dir, exist_ok=True) - print("Missing transformed data (`data.pt` file). Transforming data.... ") + transformed_file_name = self.processed_file_names_dict["data"] + print( + f"Missing transformed data (`{transformed_file_name}` file). Transforming data.... " + ) torch.save( self._load_data_from_file( os.path.join( @@ -820,7 +834,7 @@ def setup_processed(self) -> None: self.processed_main_file_names_dict["data"], ) ), - os.path.join(self.processed_dir, self.processed_file_names_dict["data"]), + os.path.join(self.processed_dir, transformed_file_name), ) @staticmethod diff --git a/chebai/preprocessing/datasets/deepGO/__init__.py b/chebai/preprocessing/datasets/deepGO/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/deepGO/go_uniprot.py similarity index 69% rename from chebai/preprocessing/datasets/go_uniprot.py rename to chebai/preprocessing/datasets/deepGO/go_uniprot.py index a2c4ae54..3c957e6c 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/deepGO/go_uniprot.py @@ -1,18 +1,29 @@ -# Reference for this file : +# References for this file : +# Reference 1: # Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf; # DeepGO: Predicting protein functions from sequence and interactions # using a deep ontology-aware classifier, Bioinformatics, 2017. # https://doi.org/10.1093/bioinformatics/btx624 # Github: https://github.com/bio-ontology-research-group/deepgo + +# Reference 2: # https://www.ebi.ac.uk/GOA/downloads # https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt # https://www.uniprot.org/uniprotkb +# Reference 3: +# Kulmanov, M., Guzmán-Vega, F.J., Duek Roggli, +# P. et al. Protein function prediction as approximate semantic entailment. Nat Mach Intell 6, 220–228 (2024). +# https://doi.org/10.1038/s42256-024-00795-w +# https://github.com/bio-ontology-research-group/deepgo2 + __all__ = [ "GOUniProtOver250", "GOUniProtOver50", "EXPERIMENTAL_EVIDENCE_CODES", "AMBIGUOUS_AMINO_ACIDS", + "DeepGO1MigratedData", + "DeepGO2MigratedData", ] import gzip @@ -29,11 +40,13 @@ import pandas as pd import requests import torch +import tqdm from Bio import SwissProt from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import _DynamicDataset +# https://github.com/bio-ontology-research-group/deepgo/blob/master/utils.py#L15 EXPERIMENTAL_EVIDENCE_CODES = { "EXP", "IDA", @@ -43,10 +56,19 @@ "IEP", "TAS", "IC", + # New evidence codes added in latest paper year 2024 Reference number 3 + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L24-L26 + "HTP", + "HDA", + "HMP", + "HGI", + "HEP", } # https://github.com/bio-ontology-research-group/deepgo/blob/d97447a05c108127fee97982fd2c57929b2cf7eb/aaindex.py#L8 -AMBIGUOUS_AMINO_ACIDS = {"B", "O", "J", "U", "X", "Z", "*"} +# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L10 +# `X` is now considered as valid amino acid, as per latest paper year 2024 Refernce number 3 +AMBIGUOUS_AMINO_ACIDS = {"B", "O", "J", "U", "Z", "*"} class _GOUniProtDataExtractor(_DynamicDataset, ABC): @@ -56,10 +78,16 @@ class _GOUniProtDataExtractor(_DynamicDataset, ABC): Args: dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. - **kwargs: Additional keyword arguments passed to XYBaseDataModule. + max_sequence_length (int, optional): Specifies the maximum allowed sequence length for a protein, with a + default of 1002. During data preprocessing, any proteins exceeding this length will be excluded from further + processing. + **kwargs: Additional keyword arguments passed to DynamicDataset and XYBaseDataModule. Attributes: dynamic_data_split_seed (int): The seed for random data splitting, default is 42. + max_sequence_length (int, optional): Specifies the maximum allowed sequence length for a protein, with a + default of 1002. During data preprocessing, any proteins exceeding this length will be excluded from further + processing. splits_file_path (Optional[str]): Path to the CSV file containing split assignments. """ @@ -405,12 +433,9 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame: Note: This mapping is necessary because the GO data does not include the protein sequence representation. - - Quote from the DeepGo Paper: - `We select proteins with annotations having experimental evidence codes - (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC) and filter the proteins by a - maximum length of 1002, ignoring proteins with ambiguous amino acid codes - (B, O, J, U, X, Z) in their sequence.` + We select proteins with annotations having experimental evidence codes, as specified in + `EXPERIMENTAL_EVIDENCE_CODES` and filter the proteins by a maximum length of 1002, ignoring proteins with + ambiguous amino acid codes specified in `AMBIGUOUS_AMINO_ACIDS` in their sequence. Check the link below for keyword details: https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt @@ -709,3 +734,277 @@ class GOUniProtOver50(_GOUniProtOverX): """ THRESHOLD: int = 50 + + +class _DeepGOMigratedData(_GOUniProtDataExtractor, ABC): + """ + Base class for use of the migrated DeepGO data with common properties, name formatting, and file paths. + + Attributes: + READER (dr.ProteinDataReader): Protein data reader class. + THRESHOLD (Optional[int]): Threshold value for GO class selection, + determined by the GO branch type in derived classes. + """ + + READER: dr.ProteinDataReader = dr.ProteinDataReader + THRESHOLD: Optional[int] = None + + # Mapping from GO branch conventions used in DeepGO to our conventions + GO_BRANCH_MAPPING: dict = { + "cc": "CC", + "mf": "MF", + "bp": "BP", + } + + @property + def _name(self) -> str: + """ + Generates a unique identifier for the migrated data based on the GO + branch and max sequence length, optionally including a threshold. + + Returns: + str: A formatted name string for the data. + """ + threshold_part = f"GO{self.THRESHOLD}_" if self.THRESHOLD is not None else "GO_" + + if self.go_branch != self._ALL_GO_BRANCHES: + return f"{threshold_part}{self.go_branch}_{self.max_sequence_length}" + + return f"{threshold_part}{self.max_sequence_length}" + + # ------------------------------ Phase: Prepare data ----------------------------------- + def prepare_data(self, *args: Any, **kwargs: Any) -> None: + """ + Checks for the existence of migrated DeepGO data in the specified directory. + Raises an error if the required data file is not found, prompting + migration from DeepGO to this data structure. + + Args: + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. + + Raises: + FileNotFoundError: If the processed data file does not exist. + """ + print("Checking for processed data in", self.processed_dir_main) + + processed_name = self.processed_main_file_names_dict["data"] + if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): + raise FileNotFoundError( + f"File {processed_name} not found.\n" + f"You must run the appropriate DeepGO migration script " + f"(chebai/preprocessing/migration/deep_go) before executing this configuration " + f"to migrate data from DeepGO to this data structure." + ) + + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + # Selection of GO classes not needed for migrated data + pass + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + @abstractmethod + def processed_main_file_names_dict(self) -> Dict[str, str]: + """ + Abstract property for defining main processed file names. + These files are stored in the same directory as the generated data files + but have distinct names to differentiate them during training. + + Returns: + dict: A dictionary with key-value pairs for main processed file names. + """ + pass + + @property + @abstractmethod + def processed_file_names_dict(self) -> Dict[str, str]: + """ + Abstract property for defining additional processed file names. + These files are stored in the same directory as the generated data files + but have distinct names to differentiate them during training. + + Returns: + dict: A dictionary with key-value pairs for processed file names. + """ + pass + + +class DeepGO1MigratedData(_DeepGOMigratedData): + """ + Migrated data class specific to DeepGO1. Sets threshold values according + to the research paper based on the GO branch. + + Note: + Refer reference number 1 at the top of this file for the corresponding research paper. + + Args: + **kwargs: Arbitrary keyword arguments passed to the superclass. + + Raises: + ValueError: If an unsupported GO branch is provided. + """ + + def __init__(self, **kwargs): + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + assert int(kwargs.get("max_sequence_length")) == 1002 + + # Set threshold based on GO branch, as per DeepGO1 paper and its data. + if kwargs.get("go_branch") in ["CC", "MF"]: + self.THRESHOLD = 50 + elif kwargs.get("go_branch") == "BP": + self.THRESHOLD = 250 + else: + raise ValueError( + f"DeepGO1 paper has no defined threshold for branch {self.go_branch}" + ) + + super(_DeepGOMigratedData, self).__init__(**kwargs) + + @property + def processed_main_file_names_dict(self) -> Dict[str, str]: + """ + Returns main processed file names specific to DeepGO1. + + Returns: + dict: Dictionary with the main data file name for DeepGO1. + """ + return {"data": "data_deep_go1.pkl"} + + @property + def processed_file_names_dict(self) -> Dict[str, str]: + """ + Returns processed file names specific to DeepGO1. + + Returns: + dict: Dictionary with data file name for DeepGO1. + """ + return {"data": "data_deep_go1.pt"} + + +class DeepGO2MigratedData(_DeepGOMigratedData): + """ + Migrated data class specific to DeepGO2, inheriting from DeepGO1MigratedData + with different processed file names. + + Note: + Refer reference number 3 at the top of this file for the corresponding research paper. + + Returns: + dict: Dictionary with file names specific to DeepGO2. + """ + + _LABELS_START_IDX: int = 5 # additional esm2_embeddings column in the dataframe + _ESM_EMBEDDINGS_COL_IDX: int = 4 + + def __init__(self, use_esm2_embeddings=False, **kwargs): + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + assert int(kwargs.get("max_sequence_length")) == 1000 + self.use_esm2_embeddings: bool = use_esm2_embeddings + super(_DeepGOMigratedData, self).__init__(**kwargs) + + # ------------------------------ Phase: Setup data ----------------------------------- + def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: + """ + Load and process data from a file into a list of dictionaries containing features and labels. + + This method processes data differently based on the `use_esm2_embeddings` flag: + - If `use_esm2_embeddings` is True, raw dictionaries from `_load_dict` are returned, _load_dict already returns + the numerical features (esm2 embeddings) from the data file, hence no reader is required. + - Otherwise, a reader is used to process the data (generate numerical features). + + Args: + path (str): The path to the input file. + + Returns: + List[Dict[str, Any]]: A list of dictionaries with the following keys: + - `features`: Sequence or embedding data, depending on the context. + - `labels`: A boolean array of labels. + - `ident`: The identifier for the sequence. + """ + lines = self._get_data_size(path) + print(f"Processing {lines} lines...") + + if self.use_esm2_embeddings: + data = [ + d + for d in tqdm.tqdm(self._load_dict(path), total=lines) + if d["features"] is not None + ] + else: + data = [ + self.reader.to_data(d) + for d in tqdm.tqdm(self._load_dict(path), total=lines) + if d["features"] is not None + ] + + # filter for missing features in resulting data + data = [val for val in data if val["features"] is not None] + + return data + + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + """ + Loads data from a pickled file and yields individual dictionaries for each row. + + The pickled file is expected to contain rows with the following structure: + - Data at row index `self._ID_IDX`: ID of go data instance + - Data at row index `self._DATA_REPRESENTATION_IDX`: Sequence representation of protein + - Data at row index `self._ESM2_EMBEDDINGS_COL_IDX`: ESM2 embeddings of the protein + - Data from row index `self._LABELS_START_IDX` onwards: Labels + + The method adapts based on the `use_esm2_embeddings` flag: + - If `use_esm2_embeddings` is True, features are loaded from the column specified by `self._ESM_EMBEDDINGS_COL_IDX`. + - Otherwise, features are loaded from the column specified by `self._DATA_REPRESENTATION_IDX`. + + Args: + input_file_path (str): The path to the pickled input file. + + Yields: + Dict[str, Any]: A dictionary containing: + - `features` (Any): Sequence or embedding data for the instance. + - `labels` (np.ndarray): A boolean array of labels starting from row index 4. + - `ident` (Any): The identifier from row index 0. + """ + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + + if self.use_esm2_embeddings: + features_idx = self._ESM_EMBEDDINGS_COL_IDX + else: + features_idx = self._DATA_REPRESENTATION_IDX + + for row in df.values: + labels = row[self._LABELS_START_IDX :].astype(bool) + yield dict( + features=row[features_idx], + labels=labels, + ident=row[self._ID_IDX], + ) + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + def processed_main_file_names_dict(self) -> Dict[str, str]: + """ + Returns main processed file names specific to DeepGO2. + + Returns: + dict: Dictionary with the main data file name for DeepGO2. + """ + return {"data": "data_deep_go2.pkl"} + + @property + def processed_file_names_dict(self) -> Dict[str, str]: + """ + Returns processed file names specific to DeepGO2. + + Returns: + dict: Dictionary with data file name for DeepGO2. + """ + return {"data": "data_deep_go2.pt"} + + @property + def identifier(self) -> tuple: + """Identifier for the dataset.""" + if self.use_esm2_embeddings: + return (dr.ESM2EmbeddingReader.name(),) + return (self.reader.name(),) diff --git a/chebai/preprocessing/datasets/protein_pretraining.py b/chebai/preprocessing/datasets/deepGO/protein_pretraining.py similarity index 97% rename from chebai/preprocessing/datasets/protein_pretraining.py rename to chebai/preprocessing/datasets/deepGO/protein_pretraining.py index 6b5d1df0..8f7e9c4d 100644 --- a/chebai/preprocessing/datasets/protein_pretraining.py +++ b/chebai/preprocessing/datasets/deepGO/protein_pretraining.py @@ -12,7 +12,7 @@ from sklearn.model_selection import train_test_split from chebai.preprocessing.datasets.base import _DynamicDataset -from chebai.preprocessing.datasets.go_uniprot import ( +from chebai.preprocessing.datasets.deepGO.go_uniprot import ( AMBIGUOUS_AMINO_ACIDS, EXPERIMENTAL_EVIDENCE_CODES, GOUniProtOver250, @@ -96,15 +96,15 @@ def _download_required_data(self) -> str: def _parse_protein_data_for_pretraining(self) -> pd.DataFrame: """ Parses the Swiss-Prot data and returns a DataFrame containing Swiss-Prot proteins which does not have any valid - Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence code - (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC). + Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence codes, as specified in + `EXPERIMENTAL_EVIDENCE_CODES`. The DataFrame includes the following columns: - "swiss_id": The unique identifier for each Swiss-Prot record. - "sequence": The protein sequence. Note: - We ignore proteins with ambiguous amino acid codes (B, O, J, U, X, Z) in their sequence.` + We ignore proteins with ambiguous amino acid specified in `AMBIGUOUS_AMINO_ACIDS` in their sequence.` Returns: pd.DataFrame: A DataFrame where each row corresponds to a Swiss-Prot record with not associated valid GO. diff --git a/chebai/preprocessing/datasets/scope/__init__.py b/chebai/preprocessing/datasets/scope/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py new file mode 100644 index 00000000..a987f53d --- /dev/null +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -0,0 +1,381 @@ +import gzip +import itertools +import os +import pickle +import shutil +from abc import ABC +from collections import OrderedDict +from tempfile import NamedTemporaryFile +from typing import Any, Dict, Generator, List, Optional, Tuple, Union + +import fastobo +import networkx as nx +import pandas as pd +import requests +import torch +from Bio import SeqIO +from Bio.Seq import Seq + +from chebai.preprocessing.datasets.base import _DynamicDataset +from chebai.preprocessing.reader import ProteinDataReader + + +class _SCOPeDataExtractor(_DynamicDataset, ABC): + """ + A class for extracting and processing data from the Gene Ontology (GO) dataset and the Swiss UniProt dataset. + + Args: + dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. + splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. + max_sequence_length (int, optional): Specifies the maximum allowed sequence length for a protein, with a + default of 1002. During data preprocessing, any proteins exceeding this length will be excluded from further + processing. + **kwargs: Additional keyword arguments passed to DynamicDataset and XYBaseDataModule. + """ + + _GO_DATA_INIT = "GO" + _SWISS_DATA_INIT = "SWISS" + + # -- Index for columns of processed `data.pkl` (derived from `_get_swiss_to_go_mapping` & `_graph_to_raw_dataset` + # "swiss_id" at row index 0 + # "accession" at row index 1 + # "go_ids" at row index 2 + # "sequence" at row index 3 + # labels starting from row index 4 + _ID_IDX: int = 0 + _DATA_REPRESENTATION_IDX: int = 3 # here `sequence` column + _LABELS_START_IDX: int = 4 + + _SCOPE_GENERAL_URL = "https://scop.berkeley.edu/downloads/parse/dir.{data_type}.scope.{version_number}-stable.txt" + _PDB_SEQUENCE_DATA_URL = ( + "https://files.rcsb.org/pub/pdb/derived_data/pdb_seqres.txt.gz" + ) + + def __init__( + self, + scope_version: float, + scope_version_train: Optional[float] = None, + **kwargs, + ): + + self.scope_version: float = scope_version + self.scope_version_train: float = scope_version_train + + super(_SCOPeDataExtractor, self).__init__(**kwargs) + + if self.scope_version_train is not None: + # Instantiate another same class with "scope_version" as "scope_version_train", if train_version is given + # This is to get the data from respective directory related to "scope_version_train" + _init_kwargs = kwargs + _init_kwargs["chebi_version"] = self.scope_version_train + self._scope_version_train_obj = self.__class__( + **_init_kwargs, + ) + + @staticmethod + def _get_scope_url(data_type: str, version_number: float) -> str: + """ + Generates the URL for downloading SCOPe files. + + Args: + data_type (str): The type of data (e.g., 'cla', 'hie', 'des'). + version_number (str): The version of the SCOPe file. + + Returns: + str: The formatted SCOPe file URL. + """ + return _SCOPeDataExtractor._SCOPE_GENERAL_URL.format( + data_type=data_type, version_number=version_number + ) + + # ------------------------------ Phase: Prepare data ----------------------------------- + def _download_required_data(self) -> str: + """ + Downloads the required raw data related to Gene Ontology (GO) and Swiss-UniProt dataset. + + Returns: + str: Path to the downloaded data. + """ + self._download_pdb_sequence_data() + return self._download_scope_raw_data() + + def _download_pdb_sequence_data(self) -> None: + pdb_seq_file_path = os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]) + os.makedirs(os.path.dirname(pdb_seq_file_path), exist_ok=True) + + if not os.path.isfile(pdb_seq_file_path): + print(f"Downloading PDB sequence data....") + + # Create a temporary file + with NamedTemporaryFile(delete=False) as tf: + temp_filename = tf.name + print(f"Downloading to temporary file {temp_filename}") + + # Download the file + response = requests.get(self._PDB_SEQUENCE_DATA_URL, stream=True) + with open(temp_filename, "wb") as temp_file: + shutil.copyfileobj(response.raw, temp_file) + + print(f"Downloaded to {temp_filename}") + + # Unpack the gzipped file + try: + print(f"Unzipping the file....") + with gzip.open(temp_filename, "rb") as f_in: + output_file_path = pdb_seq_file_path + with open(output_file_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + print(f"Unpacked and saved to {output_file_path}") + + except Exception as e: + print(f"Failed to unpack the file: {e}") + finally: + # Clean up the temporary file + os.remove(temp_filename) + print(f"Removed temporary file {temp_filename}") + + def _download_scope_raw_data(self) -> str: + os.makedirs(self.raw_dir, exist_ok=True) + for data_type in ["CLA", "COM", "HIE", "DES"]: + data_file_name = self.raw_file_names_dict[data_type] + scope_path = os.path.join(self.raw_dir, data_file_name) + if not os.path.isfile(scope_path): + print(f"Missing Scope: {data_file_name} raw data, Downloading...") + r = requests.get( + self._get_scope_url(data_type.lower(), self.scope_version), + allow_redirects=False, + verify=False, # Disable SSL verification + ) + r.raise_for_status() # Check if the request was successful + open(scope_path, "wb").write(r.content) + return "dummy/path" + + def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]: + pdb_chain_seq_mapping: Dict[str, Dict[str, str]] = {} + for record in SeqIO.parse( + os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]), "fasta" + ): + pdb_id, chain = record.id.split("_") + pdb_chain_seq_mapping.setdefault(pdb_id, {})[chain] = str(record.seq) + return pdb_chain_seq_mapping + + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + print("Extracting class hierarchy...") + + # Load and preprocess CLA file + df_cla = pd.read_csv( + os.path.join(self.raw_dir, self.raw_file_names_dict["CLA"]), + sep="\t", + header=None, + comment="#", + ) + df_cla.columns = [ + "sid", + "PDB_ID", + "description", + "sccs", + "sunid", + "ancestor_nodes", + ] + df_cla["sunid"] = pd.to_numeric( + df_cla["sunid"], errors="coerce", downcast="integer" + ) + df_cla["ancestor_nodes"] = df_cla["ancestor_nodes"].apply( + lambda x: {k: int(v) for k, v in (item.split("=") for item in x.split(","))} + ) + df_cla.set_index("sunid", inplace=True) + + # Load and preprocess HIE file + df_hie = pd.read_csv( + os.path.join(self.raw_dir, self.raw_file_names_dict["HIE"]), + sep="\t", + header=None, + comment="#", + ) + df_hie.columns = ["sunid", "parent_sunid", "children_sunids"] + df_hie["sunid"] = pd.to_numeric( + df_hie["sunid"], errors="coerce", downcast="integer" + ) + df_hie["parent_sunid"] = df_hie["parent_sunid"].replace("-", -1).astype(int) + df_hie["children_sunids"] = df_hie["children_sunids"].apply( + lambda x: list(map(int, x.split(","))) if x != "-" else [] + ) + + # Initialize directed graph + g = nx.DiGraph() + + # Add nodes and edges efficiently + g.add_edges_from( + df_hie[df_hie["parent_sunid"] != -1].apply( + lambda row: (row["parent_sunid"], row["sunid"]), axis=1 + ) + ) + g.add_edges_from( + df_hie.explode("children_sunids") + .dropna() + .apply(lambda row: (row["sunid"], row["children_sunids"]), axis=1) + ) + + pdb_chain_seq_mapping = self._parse_pdb_sequence_file() + + node_to_pdb_id = df_cla["PDB_ID"].to_dict() + + for node in g.nodes(): + pdb_id = node_to_pdb_id[node] + chain_mapping = pdb_chain_seq_mapping.get(pdb_id, {}) + + # Add nodes and edges for chains in the mapping + for chain, sequence in chain_mapping.items(): + chain_node = f"{pdb_id}_{chain}" + g.add_node(chain_node, sequence=sequence) + g.add_edge(node, chain_node) + + print("Compute transitive closure...") + return nx.transitive_closure_dag(g) + + def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: + """ + Processes a directed acyclic graph (DAG) to create a raw dataset in DataFrame format. The dataset includes + Swiss-Prot protein data and their associations with Gene Ontology (GO) terms. + + Note: + - GO classes are used as labels in the dataset. Each GO term is represented as a column, and its value + indicates whether a Swiss-Prot protein is associated with that GO term. + - Swiss-Prot proteins serve as samples. There is no 1-to-1 correspondence between Swiss-Prot proteins + and GO terms. + + Data Format: pd.DataFrame + - Column 0 : swiss_id (Identifier for SwissProt protein) + - Column 1 : Accession of the protein + - Column 2 : GO IDs (associated GO terms) + - Column 3 : Sequence of the protein + - Column 4 to Column "n": Each column corresponding to a class with value True/False indicating whether the + protein is associated with this GO term. + + Args: + g (nx.DiGraph): The class hierarchy graph. + + Returns: + pd.DataFrame: The raw dataset created from the graph. + """ + print(f"Processing graph") + + data_df = self._get_swiss_to_go_mapping() + # add ancestors to go ids + data_df["go_ids"] = data_df["go_ids"].apply( + lambda go_ids: sorted( + set( + itertools.chain.from_iterable( + [ + [go_id] + list(g.predecessors(go_id)) + for go_id in go_ids + if go_id in g.nodes + ] + ) + ) + ) + ) + # Initialize the GO term labels/columns to False + selected_classes = self.select_classes(g, data_df=data_df) + new_label_columns = pd.DataFrame( + False, index=data_df.index, columns=selected_classes + ) + data_df = pd.concat([data_df, new_label_columns], axis=1) + + # Set True for the corresponding GO IDs in the DataFrame go labels/columns + for index, row in data_df.iterrows(): + for go_id in row["go_ids"]: + if go_id in data_df.columns: + data_df.at[index, go_id] = True + + # This filters the DataFrame to include only the rows where at least one value in the row from 5th column + # onwards is True/non-zero. + # Quote from DeepGo Paper: `For training and testing, we use proteins which have been annotated with at least + # one GO term from the set of the GO terms for the model` + data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] + return data_df + + # ------------------------------ Phase: Setup data ----------------------------------- + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + for row in df.values: + labels = row[self._LABELS_START_IDX :].astype(bool) + # chebai.preprocessing.reader.DataReader only needs features, labels, ident, group + # "group" set to None, by default as no such entity for this data + yield dict( + features=row[self._DATA_REPRESENTATION_IDX], + labels=labels, + ident=row[self._ID_IDX], + ) + + # ------------------------------ Phase: Dynamic Splits ----------------------------------- + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + try: + filename = self.processed_file_names_dict["data"] + data_go = torch.load( + os.path.join(self.processed_dir, filename), weights_only=False + ) + except FileNotFoundError: + raise FileNotFoundError( + f"File data.pt doesn't exists. " + f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" + ) + + df_go_data = pd.DataFrame(data_go) + train_df_go, df_test = self.get_test_split( + df_go_data, seed=self.dynamic_data_split_seed + ) + + # Get all splits + df_train, df_val = self.get_train_val_splits_given_test( + train_df_go, + df_test, + seed=self.dynamic_data_split_seed, + ) + + return df_train, df_val, df_test + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + def base_dir(self) -> str: + """ + Returns the base directory path for storing GO-Uniprot data. + + Returns: + str: The path to the base directory, which is "data/GO_UniProt". + """ + return os.path.join("data", "SCOPe", f"version_{self.scope_version}") + + @property + def raw_file_names_dict(self) -> dict: + """ + Returns a dictionary of raw file names used in data processing. + + Returns: + dict: A dictionary mapping dataset names to their respective file names. + For example, {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"}. + """ + return { + "CLA": "cla.txt", + "DES": "des.txt", + "HIE": "hie.txt", + "COM": "com.txt", + "PDB": "pdb_sequences.txt", + } + + +class SCOPE(_SCOPeDataExtractor): + READER = ProteinDataReader + + @property + def _name(self) -> str: + return "test" + + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + pass + + +if __name__ == "__main__": + scope = SCOPE(scope_version=2.08) + scope._parse_pdb_sequence_file() diff --git a/chebai/preprocessing/migration/deep_go/__init__.py b/chebai/preprocessing/migration/deep_go/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py new file mode 100644 index 00000000..7d59c699 --- /dev/null +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -0,0 +1,316 @@ +import os +from collections import OrderedDict +from typing import List, Literal, Optional, Tuple + +import pandas as pd +from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit +from jsonargparse import CLI + +from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO1MigratedData + + +class DeepGo1DataMigration: + """ + A class to handle data migration and processing for the DeepGO project. + It migrates the DeepGO data to our data structure followed for GO-UniProt data. + + This class handles migration of data from the DeepGO paper below: + Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, + DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, + Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668 + (https://doi.org/10.1093/bioinformatics/btx624). + """ + + # Max sequence length as per DeepGO1 + _MAXLEN = 1002 + _LABELS_START_IDX = DeepGO1MigratedData._LABELS_START_IDX + + def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + """ + Initializes the data migration object with a data directory and GO branch. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use. + """ + valid_go_branches = list(DeepGO1MigratedData.GO_BRANCH_MAPPING.keys()) + if go_branch not in valid_go_branches: + raise ValueError(f"go_branch must be one of {valid_go_branches}") + self._go_branch = go_branch + + self._data_dir: str = rf"{data_dir}" + self._train_df: Optional[pd.DataFrame] = None + self._test_df: Optional[pd.DataFrame] = None + self._validation_df: Optional[pd.DataFrame] = None + self._terms_df: Optional[pd.DataFrame] = None + self._classes: Optional[List[str]] = None + + def migrate(self) -> None: + """ + Executes the data migration by loading, processing, and saving the data. + """ + print("Starting the migration process...") + self._load_data() + if not all( + df is not None + for df in [ + self._train_df, + self._validation_df, + self._test_df, + self._terms_df, + ] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + splits_df = self._record_splits() + data_with_labels_df = self._extract_required_data_from_splits() + + if not all( + var is not None for var in [data_with_labels_df, splits_df, self._classes] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + + self.save_migrated_data(data_with_labels_df, splits_df) + + def _load_data(self) -> None: + """ + Loads the test, train, validation, and terms data from the pickled files + in the data directory. + """ + try: + print(f"Loading data files from directory: {self._data_dir}") + self._test_df = pd.DataFrame( + pd.read_pickle( + os.path.join(self._data_dir, f"test-{self._go_branch}.pkl") + ) + ) + + # DeepGO 1 lacks a validation split, so we will create one by further splitting the training set. + # Although this reduces the training data slightly compared to the original DeepGO setup, + # given the data size, the impact should be minimal. + train_df = pd.DataFrame( + pd.read_pickle( + os.path.join(self._data_dir, f"train-{self._go_branch}.pkl") + ) + ) + + self._train_df, self._validation_df = self._get_train_val_split(train_df) + + self._terms_df = pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, f"{self._go_branch}.pkl")) + ) + + except FileNotFoundError as e: + raise FileNotFoundError( + f"Data file not found in directory: {e}. " + "Please ensure all required files are available in the specified directory." + ) + + @staticmethod + def _get_train_val_split( + train_df: pd.DataFrame, + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + Splits the training data into a smaller training set and a validation set. + + Args: + train_df (pd.DataFrame): Original training DataFrame. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame]: Training and validation DataFrames. + """ + labels_list_train = train_df["labels"].tolist() + train_split = 0.85 + test_size = ((1 - train_split) ** 2) / train_split + + splitter = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=test_size, random_state=42 + ) + + train_indices, validation_indices = next( + splitter.split(labels_list_train, labels_list_train) + ) + + df_validation = train_df.iloc[validation_indices] + df_train = train_df.iloc[train_indices] + return df_train, df_validation + + def _record_splits(self) -> pd.DataFrame: + """ + Creates a DataFrame that stores the IDs and their corresponding data splits. + + Returns: + pd.DataFrame: A combined DataFrame containing split assignments. + """ + print("Recording data splits for train, validation, and test sets.") + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), + pd.DataFrame( + {"id": self._validation_df["proteins"], "split": "validation"} + ), + pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}), + ] + + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + return combined_split_assignment + + def _extract_required_data_from_splits(self) -> pd.DataFrame: + """ + Extracts required columns from the combined data splits. + + Returns: + pd.DataFrame: A DataFrame containing the essential columns for processing. + """ + print("Combining data splits into a single DataFrame with required columns.") + required_columns = [ + "proteins", + "accessions", + "sequences", + "gos", + "labels", + ] + + new_df = pd.concat( + [ + self._train_df[required_columns], + self._validation_df[required_columns], + self._test_df[required_columns], + ], + ignore_index=True, + ) + new_df["go_ids"] = new_df.apply( + lambda row: self.extract_go_id(row["gos"]), axis=1 + ) + + labels_df = self._get_labels_columns(new_df) + + data_df = pd.DataFrame( + OrderedDict( + swiss_id=new_df["proteins"], + accession=new_df["accessions"], + go_ids=new_df["go_ids"], + sequence=new_df["sequences"], + ) + ) + + df = pd.concat([data_df, labels_df], axis=1) + + return df + + @staticmethod + def extract_go_id(go_list: List[str]) -> List[int]: + """ + Extracts and parses GO IDs from a list of GO annotations. + + Args: + go_list (List[str]): List of GO annotation strings. + + Returns: + List[int]: List of parsed GO IDs. + """ + return [DeepGO1MigratedData._parse_go_id(go_id_str) for go_id_str in go_list] + + def _get_labels_columns(self, data_df: pd.DataFrame) -> pd.DataFrame: + """ + Generates columns for labels based on provided selected terms. + + Args: + data_df (pd.DataFrame): DataFrame with GO annotations and labels. + + Returns: + pd.DataFrame: DataFrame with label columns. + """ + print("Generating label columns from provided selected terms.") + parsed_go_ids: pd.Series = self._terms_df["functions"].apply( + lambda gos: DeepGO1MigratedData._parse_go_id(gos) + ) + all_go_ids_list = parsed_go_ids.values.tolist() + self._classes = all_go_ids_list + + new_label_columns = pd.DataFrame( + data_df["labels"].tolist(), index=data_df.index, columns=all_go_ids_list + ) + + return new_label_columns + + def save_migrated_data( + self, data_df: pd.DataFrame, splits_df: pd.DataFrame + ) -> None: + """ + Saves the processed data and split information. + + Args: + data_df (pd.DataFrame): Data with GO labels. + splits_df (pd.DataFrame): Split assignment DataFrame. + """ + print("Saving transformed data files.") + + deepgo_migr_inst: DeepGO1MigratedData = DeepGO1MigratedData( + go_branch=DeepGO1MigratedData.GO_BRANCH_MAPPING[self._go_branch], + max_sequence_length=self._MAXLEN, + ) + + # Save data file + deepgo_migr_inst.save_processed( + data_df, deepgo_migr_inst.processed_main_file_names_dict["data"] + ) + print( + f"{deepgo_migr_inst.processed_main_file_names_dict['data']} saved to {deepgo_migr_inst.processed_dir_main}" + ) + + # Save splits file + splits_df.to_csv( + os.path.join(deepgo_migr_inst.processed_dir_main, "splits_deep_go1.csv"), + index=False, + ) + print(f"splits_deep_go1.csv saved to {deepgo_migr_inst.processed_dir_main}") + + # Save classes file + classes = sorted(self._classes) + with open( + os.path.join(deepgo_migr_inst.processed_dir_main, "classes_deep_go1.txt"), + "wt", + ) as fout: + fout.writelines(str(node) + "\n" for node in classes) + print(f"classes_deep_go1.txt saved to {deepgo_migr_inst.processed_dir_main}") + + print("Migration process completed!") + + +class Main: + """ + Main class to handle the migration process for DeepGo1DataMigration. + + Methods: + migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + Initiates the migration process for the specified data directory and GO branch. + """ + + @staticmethod + def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None: + """ + Initiates the migration process by creating a DeepGoDataMigration instance + and invoking its migrate method. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use + ("cc" for cellular_component, + "mf" for molecular_function, + or "bp" for biological_process). + """ + DeepGo1DataMigration(data_dir, go_branch).migrate() + + +if __name__ == "__main__": + # Example: python script_name.py migrate --data_dir="data/deep_go1" --go_branch="mf" + # --data_dir specifies the directory containing the data files. + # --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration. + CLI( + Main, + description="DeepGo1DataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).", + as_positional=False, + ) diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py new file mode 100644 index 00000000..68d7dc78 --- /dev/null +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -0,0 +1,370 @@ +import os +import re +from collections import OrderedDict +from typing import List, Literal, Optional + +import pandas as pd +from jsonargparse import CLI + +from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO2MigratedData +from chebai.preprocessing.reader import ProteinDataReader + + +class DeepGo2DataMigration: + """ + A class to handle data migration and processing for the DeepGO project. It migrates the data from the DeepGO-SE + data structure to our data structure followed for GO-UniProt data. + + This class handles migration of data from the DeepGO paper below: + Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, + DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, + Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668 + (https://doi.org/10.1093/bioinformatics/btx624) + """ + + _LABELS_START_IDX = DeepGO2MigratedData._LABELS_START_IDX + + def __init__( + self, data_dir: str, go_branch: Literal["cc", "mf", "bp"], max_len: int = 1000 + ): + """ + Initializes the data migration object with a data directory and GO branch. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use. + max_len (int): Used to truncate the sequence to this length. Default is 1000. + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + """ + valid_go_branches = list(DeepGO2MigratedData.GO_BRANCH_MAPPING.keys()) + if go_branch not in valid_go_branches: + raise ValueError(f"go_branch must be one of {valid_go_branches}") + self._go_branch = go_branch + + self._data_dir: str = os.path.join(rf"{data_dir}", go_branch) + self._max_len: int = max_len + + self._train_df: Optional[pd.DataFrame] = None + self._test_df: Optional[pd.DataFrame] = None + self._validation_df: Optional[pd.DataFrame] = None + self._terms_df: Optional[pd.DataFrame] = None + self._classes: Optional[List[str]] = None + + def migrate(self) -> None: + """ + Executes the data migration by loading, processing, and saving the data. + """ + print("Starting the migration process...") + self._load_data() + if not all( + df is not None + for df in [ + self._train_df, + self._validation_df, + self._test_df, + self._terms_df, + ] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + splits_df = self._record_splits() + + data_df = self._extract_required_data_from_splits() + data_with_labels_df = self._generate_labels(data_df) + + if not all( + var is not None for var in [data_with_labels_df, splits_df, self._classes] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + + self.save_migrated_data(data_with_labels_df, splits_df) + + def _load_data(self) -> None: + """ + Loads the test, train, validation, and terms data from the pickled files + in the data directory. + """ + + try: + print(f"Loading data from directory: {self._data_dir}......") + + print( + "Pre-processing the data before loading them into instance variables\n" + f"2-Steps preprocessing: \n" + f"\t 1: Truncating every sequence to {self._max_len}\n" + f"\t 2: Replacing every amino acid which is not in {ProteinDataReader.AA_LETTER}" + ) + + self._test_df = self._pre_process_data( + pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl")) + ) + ) + self._train_df = self._pre_process_data( + pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl")) + ) + ) + self._validation_df = self._pre_process_data( + pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl")) + ) + ) + + self._terms_df = pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "terms.pkl")) + ) + + except FileNotFoundError as e: + raise FileNotFoundError( + f"Data file not found in directory: {e}. " + "Please ensure all required files are available in the specified directory." + ) + + def _pre_process_data(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Pre-processes the input dataframe by truncating sequences to the maximum + length and replacing invalid amino acids with 'X'. + + Args: + df (pd.DataFrame): The dataframe to preprocess. + + Returns: + pd.DataFrame: The processed dataframe. + """ + df = self._truncate_sequences(df) + df = self._replace_invalid_amino_acids(df) + return df + + def _truncate_sequences( + self, df: pd.DataFrame, column: str = "sequences" + ) -> pd.DataFrame: + """ + Truncate sequences in a specified column of a dataframe to the maximum length. + + https://github.com/bio-ontology-research-group/deepgo2/blob/main/train_cnn.py#L206-L217 + + Args: + df (pd.DataFrame): The input dataframe containing the data to be processed. + column (str, optional): The column containing sequences to truncate. + Defaults to "sequences". + + Returns: + pd.DataFrame: The dataframe with sequences truncated to `self._max_len`. + """ + df[column] = df[column].apply(lambda x: x[: self._max_len]) + return df + + @staticmethod + def _replace_invalid_amino_acids( + df: pd.DataFrame, column: str = "sequences" + ) -> pd.DataFrame: + """ + Replaces invalid amino acids in a sequence with 'X' using regex. + + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L26-L33 + https://github.com/ChEB-AI/python-chebai/pull/64#issuecomment-2517067073 + + Args: + df (pd.DataFrame): The dataframe containing the sequences to be processed. + column (str, optional): The column containing the sequences. Defaults to "sequences". + + Returns: + pd.DataFrame: The dataframe with invalid amino acids replaced by 'X'. + """ + valid_amino_acids = "".join(ProteinDataReader.AA_LETTER) + # Replace any character not in the valid set with 'X' + df[column] = df[column].apply( + lambda x: re.sub(f"[^{valid_amino_acids}]", "X", x) + ) + return df + + def _record_splits(self) -> pd.DataFrame: + """ + Creates a DataFrame that stores the IDs and their corresponding data splits. + + Returns: + pd.DataFrame: A combined DataFrame containing split assignments. + """ + print("Recording data splits for train, validation, and test sets.") + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), + pd.DataFrame( + {"id": self._validation_df["proteins"], "split": "validation"} + ), + pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}), + ] + + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + return combined_split_assignment + + def _extract_required_data_from_splits(self) -> pd.DataFrame: + """ + Extracts required columns from the combined data splits. + + Returns: + pd.DataFrame: A DataFrame containing the essential columns for processing. + """ + print("Combining the data splits with required data..... ") + required_columns = [ + "proteins", + "accessions", + "sequences", + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L45-L58 + "exp_annotations", # Directly associated GO ids + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69 + "prop_annotations", # Transitively associated GO ids + "esm2", + ] + + new_df = pd.concat( + [ + self._train_df[required_columns], + self._validation_df[required_columns], + self._test_df[required_columns], + ], + ignore_index=True, + ) + new_df["go_ids"] = new_df.apply( + lambda row: self.extract_go_id(row["exp_annotations"]) + + self.extract_go_id(row["prop_annotations"]), + axis=1, + ) + + data_df = pd.DataFrame( + OrderedDict( + swiss_id=new_df["proteins"], + accession=new_df["accessions"], + go_ids=new_df["go_ids"], + sequence=new_df["sequences"], + esm2_embeddings=new_df["esm2"], + ) + ) + return data_df + + @staticmethod + def extract_go_id(go_list: List[str]) -> List[int]: + """ + Extracts and parses GO IDs from a list of GO annotations. + + Args: + go_list (List[str]): List of GO annotation strings. + + Returns: + List[str]: List of parsed GO IDs. + """ + return [DeepGO2MigratedData._parse_go_id(go_id_str) for go_id_str in go_list] + + def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: + """ + Generates label columns for each GO term in the dataset. + + Args: + data_df (pd.DataFrame): DataFrame containing data with GO IDs. + + Returns: + pd.DataFrame: DataFrame with new label columns. + """ + print("Generating labels based on terms.pkl file.......") + parsed_go_ids: pd.Series = self._terms_df["gos"].apply( + lambda gos: DeepGO2MigratedData._parse_go_id(gos) + ) + all_go_ids_list = parsed_go_ids.values.tolist() + self._classes = all_go_ids_list + new_label_columns = pd.DataFrame( + False, index=data_df.index, columns=all_go_ids_list + ) + data_df = pd.concat([data_df, new_label_columns], axis=1) + + for index, row in data_df.iterrows(): + for go_id in row["go_ids"]: + if go_id in data_df.columns: + data_df.at[index, go_id] = True + + data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] + return data_df + + def save_migrated_data( + self, data_df: pd.DataFrame, splits_df: pd.DataFrame + ) -> None: + """ + Saves the processed data and split information. + + Args: + data_df (pd.DataFrame): Data with GO labels. + splits_df (pd.DataFrame): Split assignment DataFrame. + """ + print("Saving transformed data......") + deepgo_migr_inst: DeepGO2MigratedData = DeepGO2MigratedData( + go_branch=DeepGO2MigratedData.GO_BRANCH_MAPPING[self._go_branch], + max_sequence_length=self._max_len, + ) + + # Save data file + deepgo_migr_inst.save_processed( + data_df, deepgo_migr_inst.processed_main_file_names_dict["data"] + ) + print( + f"{deepgo_migr_inst.processed_main_file_names_dict['data']} saved to {deepgo_migr_inst.processed_dir_main}" + ) + + # Save split file + splits_df.to_csv( + os.path.join(deepgo_migr_inst.processed_dir_main, "splits_deep_go2.csv"), + index=False, + ) + print(f"splits_deep_go2.csv saved to {deepgo_migr_inst.processed_dir_main}") + + # Save classes.txt file + classes = sorted(self._classes) + with open( + os.path.join(deepgo_migr_inst.processed_dir_main, "classes_deep_go2.txt"), + "wt", + ) as fout: + fout.writelines(str(node) + "\n" for node in classes) + print(f"classes_deep_go2.txt saved to {deepgo_migr_inst.processed_dir_main}") + + print("Migration completed!") + + +class Main: + """ + Main class to handle the migration process for DeepGoDataMigration. + + Methods: + migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + Initiates the migration process for the specified data directory and GO branch. + """ + + @staticmethod + def migrate( + data_dir: str, go_branch: Literal["cc", "mf", "bp"], max_len: int = 1000 + ) -> None: + """ + Initiates the migration process by creating a DeepGoDataMigration instance + and invoking its migrate method. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use + ("cc" for cellular_component, + "mf" for molecular_function, + or "bp" for biological_process). + max_len (int): Used to truncate the sequence to this length. Default is 1000. + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + """ + DeepGo2DataMigration(data_dir, go_branch, max_len).migrate() + + +if __name__ == "__main__": + # Example: python script_name.py migrate --data_dir="data/deep_go_se_training_data" --go_branch="bp" + # --data_dir specifies the directory containing the data files. + # --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration. + CLI( + Main, + description="DeepGoDataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).", + as_positional=False, + ) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index e220e1e4..88e4fedd 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -1,8 +1,18 @@ import os -from typing import Any, Dict, List, Optional +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +from urllib.error import HTTPError import deepsmiles import selfies as sf +import torch +from esm import Alphabet +from esm.model.esm2 import ESM2 +from esm.pretrained import ( + _has_regression_weights, + load_model_and_alphabet_core, + load_model_and_alphabet_local, +) from pysmiles.read_smiles import _tokenize from transformers import RobertaTokenizerFast @@ -348,7 +358,7 @@ class ProteinDataReader(DataReader): COLLATOR = RaggedCollator - # 20 natural amino acid notation + # 21 natural amino acid notation AA_LETTER = [ "A", "R", @@ -370,6 +380,8 @@ class ProteinDataReader(DataReader): "W", "Y", "V", + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L3-L5 + "X", # Consider valid in latest paper year 2024 Reference number 3 in go_uniprot.py ] def name(self) -> str: @@ -469,3 +481,249 @@ def on_finish(self) -> None: print(f"Saving {len(self.cache)} tokens to {self.token_path}...") print(f"First 10 tokens: {self.cache[:10]}") pk.writelines([f"{c}\n" for c in self.cache]) + + +class ESM2EmbeddingReader(DataReader): + """ + A data reader to process protein sequences using the ESM2 model for embeddings. + + References: + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/extract_esm.py + + Note: + For layer availability by model, Please check below link: + https://github.com/facebookresearch/esm?tab=readme-ov-file#pre-trained-models- + + To test this reader, try lighter models: + esm2_t6_8M_UR50D: 6 layers (valid layers: 1–6), (~28 Mb) - A tiny 8M parameter model. + esm2_t12_35M_UR50D: 12 layers (valid layers: 1–12), (~128 Mb) - A slightly larger, 35M parameter model. + These smaller models are good for testing and debugging purposes. + + """ + + # https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L53 + _MODELS_URL = "https://dl.fbaipublicfiles.com/fair-esm/models/{}.pt" + _REGRESSION_URL = ( + "https://dl.fbaipublicfiles.com/fair-esm/regression/{}-contact-regression.pt" + ) + + def __init__( + self, + save_model_dir: str, + model_name: str = "esm2_t36_3B_UR50D", + device: Optional[torch.device] = None, + truncation_length: int = 1022, + toks_per_batch: int = 4096, + return_contacts: bool = False, + repr_layer: int = 36, + *args, + **kwargs, + ): + """ + Initialize the ESM2EmbeddingReader class. + + Args: + save_model_dir (str): Directory to save/load the pretrained ESM model. + model_name (str): Name of the pretrained model. Defaults to "esm2_t36_3B_UR50D". + device (torch.device or str, optional): Device for computation (e.g., 'cpu', 'cuda'). + truncation_length (int): Maximum sequence length for truncation. Defaults to 1022. + toks_per_batch (int): Tokens per batch for data processing. Defaults to 4096. + return_contacts (bool): Whether to return contact maps. Defaults to False. + repr_layers (int): Layer number to extract representations from. Defaults to 36. + """ + self.save_model_dir = save_model_dir + if not os.path.exists(self.save_model_dir): + os.makedirs((os.path.dirname(self.save_model_dir)), exist_ok=True) + self.model_name = model_name + self.device = device + self.truncation_length = truncation_length + self.toks_per_batch = toks_per_batch + self.return_contacts = return_contacts + self.repr_layer = repr_layer + + self._model: Optional[ESM2] = None + self._alphabet: Optional[Alphabet] = None + + self._model, self._alphabet = self.load_model_and_alphabet() + self._model.eval() + + if self.device: + self._model = self._model.to(device) + + super().__init__(*args, **kwargs) + + def load_model_and_alphabet(self) -> Tuple[ESM2, Alphabet]: + """ + Load the ESM2 model and its alphabet. + + References: + https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L24-L28 + + Returns: + Tuple[ESM2, Alphabet]: Loaded model and alphabet. + """ + model_location = os.path.join(self.save_model_dir, f"{self.model_name}.pt") + if os.path.exists(model_location): + return load_model_and_alphabet_local(model_location) + else: + return self.load_model_and_alphabet_hub() + + def load_model_and_alphabet_hub(self) -> Tuple[ESM2, Alphabet]: + """ + Load the model and alphabet from the hub URL. + + References: + https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L62-L64 + + Returns: + Tuple[ESM2, Alphabet]: Loaded model and alphabet. + """ + model_url = self._MODELS_URL.format(self.model_name) + model_data = self.load_hub_workaround(model_url) + regression_data = None + if _has_regression_weights(self.model_name): + regression_url = self._REGRESSION_URL.format(self.model_name) + regression_data = self.load_hub_workaround(regression_url) + return load_model_and_alphabet_core( + self.model_name, model_data, regression_data + ) + + def load_hub_workaround(self, url) -> torch.Tensor: + """ + Workaround to load models from the PyTorch Hub. + + References: + https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L31-L43 + + Returns: + torch.Tensor: Loaded model state dictionary. + """ + try: + data = torch.hub.load_state_dict_from_url( + url, self.save_model_dir, progress=True, map_location=self.device + ) + + except RuntimeError: + # Handle PyTorch version issues + fn = Path(url).name + data = torch.load( + f"{torch.hub.get_dir()}/checkpoints/{fn}", + map_location="cpu", + ) + except HTTPError as e: + raise Exception( + f"Could not load {url}. Did you specify the correct model name?" + ) + return data + + @staticmethod + def name() -> None: + """ + Returns the name of the data reader. This method identifies the specific type of data reader. + + Returns: + str: The name of the data reader, which is "protein_token". + """ + return "esm2_embedding" + + @property + def token_path(self) -> None: + """ + Not used as no token file is not created for this reader. + + Returns: + str: Empty string since this method is not implemented. + """ + return + + def _read_data(self, raw_data: str) -> List[int]: + """ + Reads protein sequence data and generates embeddings. + + Args: + raw_data (str): The protein sequence. + + Returns: + List[int]: Embeddings generated for the sequence. + """ + alp_tokens_idx = self._sequence_to_alphabet_tokens_idx(raw_data) + return self._alphabet_tokens_to_esm_embedding(alp_tokens_idx).tolist() + + def _sequence_to_alphabet_tokens_idx(self, sequence: str) -> torch.Tensor: + """ + Converts a protein sequence into ESM alphabet token indices. + + Args: + sequence (str): Protein sequence. + + References: + https://github.com/facebookresearch/esm/blob/2b369911bb5b4b0dda914521b9475cad1656b2ac/esm/data.py#L249-L250 + https://github.com/facebookresearch/esm/blob/2b369911bb5b4b0dda914521b9475cad1656b2ac/esm/data.py#L262-L297 + + Returns: + torch.Tensor: Tokenized sequence with special tokens (BOS/EOS) included. + """ + seq_encoded = self._alphabet.encode(sequence) + tokens = [] + + # Add BOS token if configured + if self._alphabet.prepend_bos: + tokens.append(self._alphabet.cls_idx) + + # Add the main sequence + tokens.extend(seq_encoded) + + # Add EOS token if configured + if self._alphabet.append_eos: + tokens.append(self._alphabet.eos_idx) + + # Convert to PyTorch tensor and return + return torch.tensor([tokens], dtype=torch.int64) + + def _alphabet_tokens_to_esm_embedding(self, tokens: torch.Tensor) -> torch.Tensor: + """ + Converts alphabet tokens into ESM embeddings. + + Args: + tokens (torch.Tensor): Tokenized protein sequences. + + References: + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/extract_esm.py#L82-L107 + + Returns: + torch.Tensor: Protein embedding from the specified representation layer. + """ + if self.device: + tokens = tokens.to(self.device, non_blocking=True) + + with torch.no_grad(): + out = self._model( + tokens, + repr_layers=[ + self.repr_layer, + ], + return_contacts=self.return_contacts, + ) + + # Extract representations and compute the mean embedding for each layer + representations = { + layer: t.to(self.device) for layer, t in out["representations"].items() + } + truncate_len = min(self.truncation_length, tokens.size(1)) + + result = { + "mean_representations": { + layer: t[0, 1 : truncate_len + 1].mean(0).clone() + for layer, t in representations.items() + } + } + return result["mean_representations"][self.repr_layer] + + def on_finish(self) -> None: + """ + Not used here as no token file exists for this reader. + + Returns: + None + """ + pass diff --git a/chebai/result/evaluate_predictions.py b/chebai/result/evaluate_predictions.py new file mode 100644 index 00000000..355c07c0 --- /dev/null +++ b/chebai/result/evaluate_predictions.py @@ -0,0 +1,108 @@ +from typing import Tuple + +import numpy as np +import torch +from jsonargparse import CLI +from torchmetrics.functional.classification import multilabel_auroc + +from chebai.callbacks.epoch_metrics import MacroF1 +from chebai.result.utils import load_results_from_buffer + + +class EvaluatePredictions: + def __init__(self, eval_dir: str): + """ + Initializes the EvaluatePredictions class. + + Args: + eval_dir (str): Path to the directory containing evaluation files. + """ + self.eval_dir = eval_dir + self.metrics = [] + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.num_labels = None + + @staticmethod + def validate_eval_dir(label_files: torch.Tensor, pred_files: torch.Tensor) -> None: + """ + Validates that the number of labels matches the number of predictions, + ensuring that they have the same shape. + + Args: + label_files (torch.Tensor): Tensor containing label data. + pred_files (torch.Tensor): Tensor containing prediction data. + + Raises: + ValueError: If label and prediction tensors are mismatched in shape. + """ + if label_files is None or pred_files is None: + raise ValueError("Both label and prediction tensors must be provided.") + + # Check if the number of labels matches the number of predictions + if label_files.shape[0] != pred_files.shape[0]: + raise ValueError( + "Number of label tensors does not match the number of prediction tensors." + ) + + # Validate that the last dimension matches the expected number of classes + if label_files.shape[1] != pred_files.shape[1]: + raise ValueError( + "Label and prediction tensors must have the same shape in terms of class outputs." + ) + + def evaluate(self) -> None: + """ + Loads predictions and labels, validates file correspondence, and calculates Multilabel AUROC and Fmax. + """ + test_preds, test_labels = load_results_from_buffer(self.eval_dir, self.device) + self.validate_eval_dir(test_labels, test_preds) + self.num_labels = test_preds.shape[1] + + ml_auroc = multilabel_auroc( + test_preds, test_labels, num_labels=self.num_labels + ).item() + + print("Multilabel AUC-ROC:", ml_auroc) + + fmax, threshold = self.calculate_fmax(test_preds, test_labels) + print(f"F-max : {fmax}, threshold: {threshold}") + + def calculate_fmax( + self, test_preds: torch.Tensor, test_labels: torch.Tensor + ) -> Tuple[float, float]: + """ + Calculates the Fmax metric using the F1 score at various thresholds. + + Args: + test_preds (torch.Tensor): Predicted scores for the labels. + test_labels (torch.Tensor): True labels for the evaluation. + + Returns: + Tuple[float, float]: The maximum F1 score and the corresponding threshold. + """ + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/metrics.py#L51-L52 + thresholds = np.linspace(0, 1, 101) + fmax = 0.0 + best_threshold = 0.0 + + for t in thresholds: + custom_f1_metric = MacroF1(num_labels=self.num_labels, threshold=t) + custom_f1_metric.update(test_preds, test_labels) + custom_f1_metric_score = custom_f1_metric.compute().item() + + # Check if the current score is the best we've seen + if custom_f1_metric_score > fmax: + fmax = custom_f1_metric_score + best_threshold = t + + return fmax, best_threshold + + +class Main: + def evaluate(self, eval_dir: str): + EvaluatePredictions(eval_dir).evaluate() + + +if __name__ == "__main__": + # evaluate_predictions.py evaluate + CLI(Main) diff --git a/configs/data/deepGO/deepgo2_esm2.yml b/configs/data/deepGO/deepgo2_esm2.yml new file mode 100644 index 00000000..4b3ae3b1 --- /dev/null +++ b/configs/data/deepGO/deepgo2_esm2.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData +init_args: + go_branch: "MF" + max_sequence_length: 1000 + use_esm2_embeddings: True \ No newline at end of file diff --git a/configs/data/deepGO/deepgo_1_migrated_data.yml b/configs/data/deepGO/deepgo_1_migrated_data.yml new file mode 100644 index 00000000..0924e023 --- /dev/null +++ b/configs/data/deepGO/deepgo_1_migrated_data.yml @@ -0,0 +1,4 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO1MigratedData +init_args: + go_branch: "MF" + max_sequence_length: 1002 diff --git a/configs/data/deepGO/deepgo_2_migrated_data.yml b/configs/data/deepGO/deepgo_2_migrated_data.yml new file mode 100644 index 00000000..1ed2ad09 --- /dev/null +++ b/configs/data/deepGO/deepgo_2_migrated_data.yml @@ -0,0 +1,4 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData +init_args: + go_branch: "MF" + max_sequence_length: 1000 diff --git a/configs/data/deepGO/go250.yml b/configs/data/deepGO/go250.yml new file mode 100644 index 00000000..01e34aa4 --- /dev/null +++ b/configs/data/deepGO/go250.yml @@ -0,0 +1,3 @@ +class_path: chebai.preprocessing.datasets.go_uniprot.deepGO.GOUniProtOver250 +init_args: + go_branch: "BP" diff --git a/configs/data/deepGO/go50.yml b/configs/data/deepGO/go50.yml new file mode 100644 index 00000000..bee43773 --- /dev/null +++ b/configs/data/deepGO/go50.yml @@ -0,0 +1 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.GOUniProtOver50 diff --git a/configs/data/go250.yml b/configs/data/go250.yml deleted file mode 100644 index 5598495c..00000000 --- a/configs/data/go250.yml +++ /dev/null @@ -1,3 +0,0 @@ -class_path: chebai.preprocessing.datasets.go_uniprot.GOUniProtOver250 -init_args: - go_branch: "BP" diff --git a/configs/data/go50.yml b/configs/data/go50.yml deleted file mode 100644 index 2ed4d14c..00000000 --- a/configs/data/go50.yml +++ /dev/null @@ -1 +0,0 @@ -class_path: chebai.preprocessing.datasets.go_uniprot.GOUniProtOver50 diff --git a/configs/model/electra.yml b/configs/model/electra.yml index c3cf2fdf..ade89acd 100644 --- a/configs/model/electra.yml +++ b/configs/model/electra.yml @@ -3,7 +3,7 @@ init_args: optimizer_kwargs: lr: 1e-3 config: - vocab_size: 1400 + vocab_size: 8500 max_position_embeddings: 1800 num_attention_heads: 8 num_hidden_layers: 6 diff --git a/configs/model/ffn.yml b/configs/model/ffn.yml new file mode 100644 index 00000000..193c6f64 --- /dev/null +++ b/configs/model/ffn.yml @@ -0,0 +1,7 @@ +class_path: chebai.models.ffn.FFN +init_args: + optimizer_kwargs: + lr: 1e-3 + hidden_size: 128 + num_hidden_layers: 3 + input_size: 2560 diff --git a/setup.py b/setup.py index 58bfc75b..ba134e41 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ "pyyaml", "torchmetrics", "biopython", + "fair-esm", ], extras_require={"dev": ["black", "isort", "pre-commit"]}, ) diff --git a/tests/unit/dataset_classes/testGOUniProDataExtractor.py b/tests/unit/dataset_classes/testGOUniProDataExtractor.py index 9da48bee..96ff9a3a 100644 --- a/tests/unit/dataset_classes/testGOUniProDataExtractor.py +++ b/tests/unit/dataset_classes/testGOUniProDataExtractor.py @@ -6,7 +6,7 @@ import networkx as nx import pandas as pd -from chebai.preprocessing.datasets.go_uniprot import _GOUniProtDataExtractor +from chebai.preprocessing.datasets.deepGO.go_uniprot import _GOUniProtDataExtractor from chebai.preprocessing.reader import ProteinDataReader from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData diff --git a/tests/unit/dataset_classes/testGoUniProtOverX.py b/tests/unit/dataset_classes/testGoUniProtOverX.py index d4157770..3f329c56 100644 --- a/tests/unit/dataset_classes/testGoUniProtOverX.py +++ b/tests/unit/dataset_classes/testGoUniProtOverX.py @@ -5,7 +5,7 @@ import networkx as nx import pandas as pd -from chebai.preprocessing.datasets.go_uniprot import _GOUniProtOverX +from chebai.preprocessing.datasets.deepGO.go_uniprot import _GOUniProtOverX from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData diff --git a/tests/unit/dataset_classes/testProteinPretrainingData.py b/tests/unit/dataset_classes/testProteinPretrainingData.py index cb6b0688..caac3eac 100644 --- a/tests/unit/dataset_classes/testProteinPretrainingData.py +++ b/tests/unit/dataset_classes/testProteinPretrainingData.py @@ -1,7 +1,9 @@ import unittest from unittest.mock import PropertyMock, mock_open, patch -from chebai.preprocessing.datasets.protein_pretraining import _ProteinPretrainingData +from chebai.preprocessing.datasets.deepGO.protein_pretraining import ( + _ProteinPretrainingData, +) from chebai.preprocessing.reader import ProteinDataReader from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index a05b89f1..552d2918 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -658,18 +658,19 @@ def get_UniProt_raw_data() -> str: - **Swiss_Prot_1**: A valid protein with three valid GO classes and one invalid GO class. - **Swiss_Prot_2**: Another valid protein with two valid GO classes and one invalid. - **Swiss_Prot_3**: Contains valid GO classes but has a sequence length > 1002. - - **Swiss_Prot_4**: Has valid GO classes but contains an invalid amino acid, 'X'. + - **Swiss_Prot_4**: Has valid GO classes but contains an invalid amino acid, 'B'. - **Swiss_Prot_5**: Has a sequence but no GO classes associated. - **Swiss_Prot_6**: Has GO classes without any associated evidence codes. - **Swiss_Prot_7**: Has a GO class with an invalid evidence code. - **Swiss_Prot_8**: Has a sequence length > 1002 and has only invalid GO class. - - **Swiss_Prot_9**: Has no GO classes but contains an invalid amino acid, 'X', in its sequence. + - **Swiss_Prot_9**: Has no GO classes but contains an invalid amino acid, 'B', in its sequence. - **Swiss_Prot_10**: Has a valid GO class but lacks a sequence. - **Swiss_Prot_11**: Has only Invalid GO class but lacks a sequence. Note: - A valid GO label is the one which has one of the following evidence code - (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC). + A valid GO label is the one which has one of the following evidence code specified in + go_uniprot.py->`EXPERIMENTAL_EVIDENCE_CODES`. + Invalid amino acids are specified in go_uniprot.py->`AMBIGUOUS_AMINO_ACIDS`. Returns: str: The raw UniProt data in string format. @@ -715,7 +716,7 @@ def get_UniProt_raw_data() -> str: "DR GO; GO:0000005; P:regulation of viral transcription; IEA:InterPro.\n" "DR GO; GO:0000006; P:regulation of viral transcription; EXP:PomBase.\n" "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" - " XAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + " BAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" "//\n" # Below protein with sequence string but has no GO class "ID Swiss_Prot_5 Reviewed; 60 AA.\n" @@ -749,7 +750,7 @@ def get_UniProt_raw_data() -> str: "ID Swiss_Prot_9 Reviewed; 60 AA.\n" "AC Q6GZX4;\n" "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" - " XAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + " BAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" "//\n" # Below protein with a `valid` associated GO class but without sequence string "ID Swiss_Prot_10 Reviewed; 60 AA.\n" diff --git a/tutorials/data_exploration_go.ipynb b/tutorials/data_exploration_go.ipynb index 6f67c82b..1a205e37 100644 --- a/tutorials/data_exploration_go.ipynb +++ b/tutorials/data_exploration_go.ipynb @@ -70,7 +70,7 @@ } }, "outputs": [], - "source": "from chebai.preprocessing.datasets.go_uniprot import GOUniProtOver250" + "source": "from chebai.preprocessing.datasets.deepGO.go_uniprot import GOUniProtOver250" }, { "cell_type": "code",