From bdba4428ce26ca2fb8285a87880714329610fec5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 4 Nov 2024 13:02:22 +0100 Subject: [PATCH 01/33] script to evaluate go predictions --- chebai/result/evaluate_predictions.py | 70 +++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 chebai/result/evaluate_predictions.py diff --git a/chebai/result/evaluate_predictions.py b/chebai/result/evaluate_predictions.py new file mode 100644 index 00000000..e25c2130 --- /dev/null +++ b/chebai/result/evaluate_predictions.py @@ -0,0 +1,70 @@ +import torch +from jsonargparse import CLI +from torchmetrics.functional.classification import multilabel_auroc + +from chebai.result.utils import load_results_from_buffer + + +class EvaluatePredictions: + def __init__(self, eval_dir: str): + """ + Initializes the EvaluatePredictions class. + + Args: + eval_dir (str): Path to the directory containing evaluation files. + """ + self.eval_dir = eval_dir + self.metrics = [] + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.num_labels = None + + @staticmethod + def validate_eval_dir(label_files: torch.Tensor, pred_files: torch.Tensor) -> None: + """ + Validates that the number of labels matches the number of predictions, + ensuring that they have the same shape. + + Args: + label_files (torch.Tensor): Tensor containing label data. + pred_files (torch.Tensor): Tensor containing prediction data. + + Raises: + ValueError: If label and prediction tensors are mismatched in shape. + """ + if label_files is None or pred_files is None: + raise ValueError("Both label and prediction tensors must be provided.") + + # Check if the number of labels matches the number of predictions + if label_files.shape[0] != pred_files.shape[0]: + raise ValueError( + "Number of label tensors does not match the number of prediction tensors." + ) + + # Validate that the last dimension matches the expected number of classes + if label_files.shape[1] != pred_files.shape[1]: + raise ValueError( + "Label and prediction tensors must have the same shape in terms of class outputs." + ) + + def evaluate(self) -> None: + """ + Loads predictions and labels, validates file correspondence, and calculates Multilabel AUROC. + """ + test_preds, test_labels = load_results_from_buffer(self.eval_dir, self.device) + self.validate_eval_dir(test_labels, test_preds) + self.num_labels = test_preds.shape[1] + + ml_auroc = multilabel_auroc( + test_preds, test_labels, num_labels=self.num_labels + ).item() + + print("Multilabel AUC-ROC:", ml_auroc) + + +class Main: + def evaluate(self, eval_dir: str): + EvaluatePredictions(eval_dir).evaluate() + + +if __name__ == "__main__": + CLI(Main) From 6c0fce185ef8fc754a05c2923dd1bbb9382d2f06 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 4 Nov 2024 15:41:16 +0100 Subject: [PATCH 02/33] add fmax to evaluation script --- chebai/result/evaluate_predictions.py | 39 ++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/chebai/result/evaluate_predictions.py b/chebai/result/evaluate_predictions.py index e25c2130..48ddef83 100644 --- a/chebai/result/evaluate_predictions.py +++ b/chebai/result/evaluate_predictions.py @@ -1,7 +1,11 @@ +from typing import Tuple + +import numpy as np import torch from jsonargparse import CLI from torchmetrics.functional.classification import multilabel_auroc +from chebai.callbacks.epoch_metrics import MacroF1 from chebai.result.utils import load_results_from_buffer @@ -48,7 +52,7 @@ def validate_eval_dir(label_files: torch.Tensor, pred_files: torch.Tensor) -> No def evaluate(self) -> None: """ - Loads predictions and labels, validates file correspondence, and calculates Multilabel AUROC. + Loads predictions and labels, validates file correspondence, and calculates Multilabel AUROC and Fmax. """ test_preds, test_labels = load_results_from_buffer(self.eval_dir, self.device) self.validate_eval_dir(test_labels, test_preds) @@ -60,6 +64,38 @@ def evaluate(self) -> None: print("Multilabel AUC-ROC:", ml_auroc) + fmax, threshold = self.calculate_fmax(test_preds, test_labels) + print(f"F-max : {fmax}, threshold: {threshold}") + + def calculate_fmax( + self, test_preds: torch.Tensor, test_labels: torch.Tensor + ) -> Tuple[float, float]: + """ + Calculates the Fmax metric using the F1 score at various thresholds. + + Args: + test_preds (torch.Tensor): Predicted scores for the labels. + test_labels (torch.Tensor): True labels for the evaluation. + + Returns: + Tuple[float, float]: The maximum F1 score and the corresponding threshold. + """ + thresholds = np.linspace(0, 1, 100) + fmax = 0.0 + best_threshold = 0.0 + + for t in thresholds: + custom_f1_metric = MacroF1(num_labels=self.num_labels, threshold=t) + custom_f1_metric.update(test_preds, test_labels) + custom_f1_metric_score = custom_f1_metric.compute().item() + + # Check if the current score is the best we've seen + if custom_f1_metric_score > fmax: + fmax = custom_f1_metric_score + best_threshold = t + + return fmax, best_threshold + class Main: def evaluate(self, eval_dir: str): @@ -67,4 +103,5 @@ def evaluate(self, eval_dir: str): if __name__ == "__main__": + # evaluate_predictions.py evaluate CLI(Main) From 58ae92d9a73b889256e02fe6a88d97ebbce5437f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 5 Nov 2024 23:37:29 +0100 Subject: [PATCH 03/33] add base code for deep_go data migration - migration from deep go format to chebai->go_uniprot format --- .../migration/deep_go_data_mirgration.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 chebai/preprocessing/migration/deep_go_data_mirgration.py diff --git a/chebai/preprocessing/migration/deep_go_data_mirgration.py b/chebai/preprocessing/migration/deep_go_data_mirgration.py new file mode 100644 index 00000000..ce35ff0b --- /dev/null +++ b/chebai/preprocessing/migration/deep_go_data_mirgration.py @@ -0,0 +1,54 @@ +from typing import List + +import pandas as pd + +# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L18-L22 +NAMESPACES = { + "cc": "cellular_component", + "mf": "molecular_function", + "bp": "biological_process", +} + +# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 +MAXLEN = 1000 + + +def load_data(data_dir): + test_df = pd.DataFrame(pd.read_pickle("test_data.pkl")) + train_df = pd.DataFrame(pd.read_pickle("train_data.pkl")) + validation_df = pd.DataFrame(pd.read_pickle("valid_data.pkl")) + + required_columns = [ + "proteins", + "accessions", + "sequences", + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L45-L58 + "exp_annotations", # Directly associated GO ids + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69 + "prop_annotations", # Transitively associated GO ids + ] + + new_df = pd.concat( + [ + train_df[required_columns], + validation_df[required_columns], + test_df[required_columns], + ], + ignore_index=True, + ) + # Generate splits.csv file to store ids of each corresponding split + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": train_df["proteins"], "split": "train"}), + pd.DataFrame({"id": validation_df["proteins"], "split": "validation"}), + pd.DataFrame({"id": test_df["proteins"], "split": "test"}), + ] + + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + + +def save_data(data_dir, data_df): + pass + + +if __name__ == "__main__": + pass From 78a38de062c603b9e6d193a1a3a2278a56f9da82 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 5 Nov 2024 23:38:01 +0100 Subject: [PATCH 04/33] varry fmax threshold as per paper --- chebai/result/evaluate_predictions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chebai/result/evaluate_predictions.py b/chebai/result/evaluate_predictions.py index 48ddef83..355c07c0 100644 --- a/chebai/result/evaluate_predictions.py +++ b/chebai/result/evaluate_predictions.py @@ -80,7 +80,8 @@ def calculate_fmax( Returns: Tuple[float, float]: The maximum F1 score and the corresponding threshold. """ - thresholds = np.linspace(0, 1, 100) + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/metrics.py#L51-L52 + thresholds = np.linspace(0, 1, 101) fmax = 0.0 best_threshold = 0.0 From 3a4e007fc0267ad72d4c2f43f7bfb99fdb1245ee Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 5 Nov 2024 23:38:40 +0100 Subject: [PATCH 05/33] go_uniprot: add sequence len to docstring --- chebai/preprocessing/datasets/go_uniprot.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index a2c4ae54..12bb0adc 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -56,10 +56,16 @@ class _GOUniProtDataExtractor(_DynamicDataset, ABC): Args: dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. - **kwargs: Additional keyword arguments passed to XYBaseDataModule. + max_sequence_length (int, optional): Specifies the maximum allowed sequence length for a protein, with a + default of 1002. During data preprocessing, any proteins exceeding this length will be excluded from further + processing. + **kwargs: Additional keyword arguments passed to DynamicDataset and XYBaseDataModule. Attributes: dynamic_data_split_seed (int): The seed for random data splitting, default is 42. + max_sequence_length (int, optional): Specifies the maximum allowed sequence length for a protein, with a + default of 1002. During data preprocessing, any proteins exceeding this length will be excluded from further + processing. splits_file_path (Optional[str]): Path to the CSV file containing split assignments. """ From 227a014af32479932208c2dcc565babc8b3fbbf8 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 6 Nov 2024 15:47:46 +0100 Subject: [PATCH 06/33] update experiment evidence codes as per DeepGo SE - https://github.com/ChEB-AI/python-chebai/issues/36#issuecomment-2458153758 --- chebai/preprocessing/datasets/go_uniprot.py | 7 ++++++- chebai/preprocessing/datasets/protein_pretraining.py | 4 ++-- tests/unit/mock_data/ontology_mock_data.py | 4 ++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 12bb0adc..73edc976 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -43,6 +43,11 @@ "IEP", "TAS", "IC", + "HTP", + "HDA", + "HMP", + "HGI", + "HEP", } # https://github.com/bio-ontology-research-group/deepgo/blob/d97447a05c108127fee97982fd2c57929b2cf7eb/aaindex.py#L8 @@ -414,7 +419,7 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame: Quote from the DeepGo Paper: `We select proteins with annotations having experimental evidence codes - (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC) and filter the proteins by a + `EXPERIMENTAL_EVIDENCE_CODES` and filter the proteins by a maximum length of 1002, ignoring proteins with ambiguous amino acid codes (B, O, J, U, X, Z) in their sequence.` diff --git a/chebai/preprocessing/datasets/protein_pretraining.py b/chebai/preprocessing/datasets/protein_pretraining.py index 8550db2b..f6e9d66d 100644 --- a/chebai/preprocessing/datasets/protein_pretraining.py +++ b/chebai/preprocessing/datasets/protein_pretraining.py @@ -96,8 +96,8 @@ def _download_required_data(self) -> str: def _parse_protein_data_for_pretraining(self) -> pd.DataFrame: """ Parses the Swiss-Prot data and returns a DataFrame containing Swiss-Prot proteins which does not have any valid - Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence code - (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC). + Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence code defined in + `EXPERIMENTAL_EVIDENCE_CODES`. The DataFrame includes the following columns: - "swiss_id": The unique identifier for each Swiss-Prot record. diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index a05b89f1..ca6148e7 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -668,8 +668,8 @@ def get_UniProt_raw_data() -> str: - **Swiss_Prot_11**: Has only Invalid GO class but lacks a sequence. Note: - A valid GO label is the one which has one of the following evidence code - (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC). + A valid GO label is the one which has one of the following evidence code defined in + `EXPERIMENTAL_EVIDENCE_CODES`. Returns: str: The raw UniProt data in string format. From c6d60cddd23e1e8d137ec4d02285061baa987d31 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 6 Nov 2024 16:40:53 +0100 Subject: [PATCH 07/33] consIder `X` as a valid amino acid as per DeepGO-SE - https://github.com/ChEB-AI/python-chebai/issues/36#issuecomment-2458153758 --- chebai/preprocessing/datasets/go_uniprot.py | 27 +++++++++++++------ .../datasets/protein_pretraining.py | 4 +-- chebai/preprocessing/reader.py | 4 ++- tests/unit/mock_data/ontology_mock_data.py | 13 ++++----- 4 files changed, 31 insertions(+), 17 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 73edc976..7b1c16e3 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -1,13 +1,22 @@ -# Reference for this file : +# References for this file : +# Reference 1: # Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf; # DeepGO: Predicting protein functions from sequence and interactions # using a deep ontology-aware classifier, Bioinformatics, 2017. # https://doi.org/10.1093/bioinformatics/btx624 # Github: https://github.com/bio-ontology-research-group/deepgo + +# Reference 2: # https://www.ebi.ac.uk/GOA/downloads # https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt # https://www.uniprot.org/uniprotkb +# Reference 3: +# Kulmanov, M., Guzmán-Vega, F.J., Duek Roggli, +# P. et al. Protein function prediction as approximate semantic entailment. Nat Mach Intell 6, 220–228 (2024). +# https://doi.org/10.1038/s42256-024-00795-w +# https://github.com/bio-ontology-research-group/deepgo2 + __all__ = [ "GOUniProtOver250", "GOUniProtOver50", @@ -34,6 +43,7 @@ from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import _DynamicDataset +# https://github.com/bio-ontology-research-group/deepgo/blob/master/utils.py#L15 EXPERIMENTAL_EVIDENCE_CODES = { "EXP", "IDA", @@ -43,6 +53,8 @@ "IEP", "TAS", "IC", + # New evidence codes added in latest paper year 2024 Reference number 3 + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L24-L26 "HTP", "HDA", "HMP", @@ -51,7 +63,9 @@ } # https://github.com/bio-ontology-research-group/deepgo/blob/d97447a05c108127fee97982fd2c57929b2cf7eb/aaindex.py#L8 -AMBIGUOUS_AMINO_ACIDS = {"B", "O", "J", "U", "X", "Z", "*"} +# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L10 +# `X` is now considered as valid amino acid, as per latest paper year 2024 Refernce number 3 +AMBIGUOUS_AMINO_ACIDS = {"B", "O", "J", "U", "Z", "*"} class _GOUniProtDataExtractor(_DynamicDataset, ABC): @@ -416,12 +430,9 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame: Note: This mapping is necessary because the GO data does not include the protein sequence representation. - - Quote from the DeepGo Paper: - `We select proteins with annotations having experimental evidence codes - `EXPERIMENTAL_EVIDENCE_CODES` and filter the proteins by a - maximum length of 1002, ignoring proteins with ambiguous amino acid codes - (B, O, J, U, X, Z) in their sequence.` + We select proteins with annotations having experimental evidence codes, as specified in + `EXPERIMENTAL_EVIDENCE_CODES` and filter the proteins by a maximum length of 1002, ignoring proteins with + ambiguous amino acid codes specified in `AMBIGUOUS_AMINO_ACIDS` in their sequence. Check the link below for keyword details: https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt diff --git a/chebai/preprocessing/datasets/protein_pretraining.py b/chebai/preprocessing/datasets/protein_pretraining.py index f6e9d66d..63d53144 100644 --- a/chebai/preprocessing/datasets/protein_pretraining.py +++ b/chebai/preprocessing/datasets/protein_pretraining.py @@ -96,7 +96,7 @@ def _download_required_data(self) -> str: def _parse_protein_data_for_pretraining(self) -> pd.DataFrame: """ Parses the Swiss-Prot data and returns a DataFrame containing Swiss-Prot proteins which does not have any valid - Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence code defined in + Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence codes, as specified in `EXPERIMENTAL_EVIDENCE_CODES`. The DataFrame includes the following columns: @@ -104,7 +104,7 @@ def _parse_protein_data_for_pretraining(self) -> pd.DataFrame: - "sequence": The protein sequence. Note: - We ignore proteins with ambiguous amino acid codes (B, O, J, U, X, Z) in their sequence.` + We ignore proteins with ambiguous amino acid specified in `AMBIGUOUS_AMINO_ACIDS` in their sequence.` Returns: pd.DataFrame: A DataFrame where each row corresponds to a Swiss-Prot record with not associated valid GO. diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index e220e1e4..a08a3f91 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -348,7 +348,7 @@ class ProteinDataReader(DataReader): COLLATOR = RaggedCollator - # 20 natural amino acid notation + # 21 natural amino acid notation AA_LETTER = [ "A", "R", @@ -370,6 +370,8 @@ class ProteinDataReader(DataReader): "W", "Y", "V", + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L3-L5 + "X", # Consider valid in latest paper year 2024 Reference number 3 in go_uniprot.py ] def name(self) -> str: diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index ca6148e7..552d2918 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -658,18 +658,19 @@ def get_UniProt_raw_data() -> str: - **Swiss_Prot_1**: A valid protein with three valid GO classes and one invalid GO class. - **Swiss_Prot_2**: Another valid protein with two valid GO classes and one invalid. - **Swiss_Prot_3**: Contains valid GO classes but has a sequence length > 1002. - - **Swiss_Prot_4**: Has valid GO classes but contains an invalid amino acid, 'X'. + - **Swiss_Prot_4**: Has valid GO classes but contains an invalid amino acid, 'B'. - **Swiss_Prot_5**: Has a sequence but no GO classes associated. - **Swiss_Prot_6**: Has GO classes without any associated evidence codes. - **Swiss_Prot_7**: Has a GO class with an invalid evidence code. - **Swiss_Prot_8**: Has a sequence length > 1002 and has only invalid GO class. - - **Swiss_Prot_9**: Has no GO classes but contains an invalid amino acid, 'X', in its sequence. + - **Swiss_Prot_9**: Has no GO classes but contains an invalid amino acid, 'B', in its sequence. - **Swiss_Prot_10**: Has a valid GO class but lacks a sequence. - **Swiss_Prot_11**: Has only Invalid GO class but lacks a sequence. Note: - A valid GO label is the one which has one of the following evidence code defined in - `EXPERIMENTAL_EVIDENCE_CODES`. + A valid GO label is the one which has one of the following evidence code specified in + go_uniprot.py->`EXPERIMENTAL_EVIDENCE_CODES`. + Invalid amino acids are specified in go_uniprot.py->`AMBIGUOUS_AMINO_ACIDS`. Returns: str: The raw UniProt data in string format. @@ -715,7 +716,7 @@ def get_UniProt_raw_data() -> str: "DR GO; GO:0000005; P:regulation of viral transcription; IEA:InterPro.\n" "DR GO; GO:0000006; P:regulation of viral transcription; EXP:PomBase.\n" "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" - " XAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + " BAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" "//\n" # Below protein with sequence string but has no GO class "ID Swiss_Prot_5 Reviewed; 60 AA.\n" @@ -749,7 +750,7 @@ def get_UniProt_raw_data() -> str: "ID Swiss_Prot_9 Reviewed; 60 AA.\n" "AC Q6GZX4;\n" "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" - " XAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + " BAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" "//\n" # Below protein with a `valid` associated GO class but without sequence string "ID Swiss_Prot_10 Reviewed; 60 AA.\n" From ca5461fce0bf4a431f620af0e7ad3df81c61b1b5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 6 Nov 2024 20:40:09 +0100 Subject: [PATCH 08/33] deepgo se mirgration : add class to migrate --- .../migration/deep_go_data_mirgration.py | 342 +++++++++++++++--- 1 file changed, 297 insertions(+), 45 deletions(-) diff --git a/chebai/preprocessing/migration/deep_go_data_mirgration.py b/chebai/preprocessing/migration/deep_go_data_mirgration.py index ce35ff0b..a33e407b 100644 --- a/chebai/preprocessing/migration/deep_go_data_mirgration.py +++ b/chebai/preprocessing/migration/deep_go_data_mirgration.py @@ -1,54 +1,306 @@ -from typing import List +import os +from collections import OrderedDict +from random import randint +from typing import List, Literal import pandas as pd +from jsonargparse import CLI -# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L18-L22 -NAMESPACES = { - "cc": "cellular_component", - "mf": "molecular_function", - "bp": "biological_process", -} - -# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 -MAXLEN = 1000 - - -def load_data(data_dir): - test_df = pd.DataFrame(pd.read_pickle("test_data.pkl")) - train_df = pd.DataFrame(pd.read_pickle("train_data.pkl")) - validation_df = pd.DataFrame(pd.read_pickle("valid_data.pkl")) - - required_columns = [ - "proteins", - "accessions", - "sequences", - # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L45-L58 - "exp_annotations", # Directly associated GO ids - # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69 - "prop_annotations", # Transitively associated GO ids - ] - - new_df = pd.concat( - [ - train_df[required_columns], - validation_df[required_columns], - test_df[required_columns], - ], - ignore_index=True, - ) - # Generate splits.csv file to store ids of each corresponding split - split_assignment_list: List[pd.DataFrame] = [ - pd.DataFrame({"id": train_df["proteins"], "split": "train"}), - pd.DataFrame({"id": validation_df["proteins"], "split": "validation"}), - pd.DataFrame({"id": test_df["proteins"], "split": "test"}), - ] +from chebai.preprocessing.datasets.go_uniprot import ( + GOUniProtOver50, + GOUniProtOver250, + _GOUniProtDataExtractor, +) + + +class DeepGoDataMigration: + """ + A class to handle data migration and processing for the DeepGO project. It migrates the data from the DeepGO-SE + data structure to our data structure followed for GO-UniProt data. + + Attributes: + _CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes. + _MAXLEN (int): Maximum sequence length for sequences. + _LABELS_START_IDX (int): Starting index for labels in the dataset. + + Methods: + __init__(data_dir, go_branch): Initializes the data directory and GO branch. + _load_data(): Loads train, validation, test, and terms data from the specified directory. + _record_splits(): Creates a DataFrame with IDs and their corresponding split. + migrate(): Executes the migration process including data loading, processing, and saving. + _extract_required_data_from_splits(): Extracts required columns from the splits data. + _generate_labels(data_df): Generates label columns for the data based on GO terms. + extract_go_id(go_list): Extracts GO IDs from a list. + save_migrated_data(data_df, splits_df): Saves the processed data and splits. + """ + + # Link for the namespaces convention used for GO branch + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L18-L22 + _CORRESPONDING_GO_CLASSES = { + "cc": GOUniProtOver50, + "mf": GOUniProtOver50, + "bp": GOUniProtOver250, + } + + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + _MAXLEN = 1000 + _LABELS_START_IDX = _GOUniProtDataExtractor._LABELS_START_IDX + + def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + """ + Initializes the data migration object with a data directory and GO branch. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process). + """ + valid_go_branches = list(self._CORRESPONDING_GO_CLASSES.keys()) + if go_branch not in valid_go_branches: + raise ValueError(f"go_branch must be one of {valid_go_branches}") + self._go_branch = go_branch + + self._data_dir = os.path.join(data_dir, go_branch) + self._train_df: pd.DataFrame = None + self._test_df: pd.DataFrame = None + self._validation_df: pd.DataFrame = None + self._terms_df: pd.DataFrame = None + self._classes: List[str] = None + + def _load_data(self) -> None: + """ + Loads the test, train, validation, and terms data from the pickled files + in the data directory. + """ + try: + print(f"Loading data from {self._data_dir}......") + self._test_df = pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl")) + ) + self._train_df = pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl")) + ) + self._validation_df = pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl")) + ) + self._terms_df = pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "terms.pkl")) + ) + except FileNotFoundError as e: + print(f"Error loading data: {e}") + + def _record_splits(self) -> pd.DataFrame: + """ + Creates a DataFrame that stores the IDs and their corresponding data splits. + + Returns: + pd.DataFrame: A combined DataFrame containing split assignments. + """ + print("Recording splits...") + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), + pd.DataFrame( + {"id": self._validation_df["proteins"], "split": "validation"} + ), + pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}), + ] + + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + return combined_split_assignment + + def migrate(self) -> None: + """ + Executes the data migration by loading, processing, and saving the data. + """ + print("Migration started......") + self._load_data() + if not all( + [self._train_df, self._validation_df, self._test_df, self._terms_df] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + splits_df = self._record_splits() + + data_df = self._extract_required_data_from_splits() + data_with_labels_df = self._generate_labels(data_df) + + if not all([data_with_labels_df, splits_df, self._classes]): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + + self.save_migrated_data(data_df, splits_df) + + def _extract_required_data_from_splits(self) -> pd.DataFrame: + """ + Extracts required columns from the combined data splits. + + Returns: + pd.DataFrame: A DataFrame containing the essential columns for processing. + """ + print("Combining the data splits with required data..... ") + required_columns = [ + "proteins", + "accessions", + "sequences", + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L45-L58 + "exp_annotations", # Directly associated GO ids + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69 + "prop_annotations", # Transitively associated GO ids + ] + + new_df = pd.concat( + [ + self._train_df[required_columns], + self._validation_df[required_columns], + self._test_df[required_columns], + ], + ignore_index=True, + ) + new_df["go_ids"] = new_df.apply( + lambda row: self.extract_go_id(row["exp_annotations"]) + + self.extract_go_id(row["prop_annotations"]), + axis=1, + ) - combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + data_df = pd.DataFrame( + OrderedDict( + swiss_id=new_df["proteins"], + accession=new_df["accessions"], + go_ids=new_df["go_ids"], + sequence=new_df["sequences"], + ) + ) + return data_df + def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: + """ + Generates label columns for each GO term in the dataset. -def save_data(data_dir, data_df): - pass + Args: + data_df (pd.DataFrame): DataFrame containing data with GO IDs. + + Returns: + pd.DataFrame: DataFrame with new label columns. + """ + print("Generating labels based on terms.pkl file.......") + parsed_go_ids: pd.Series = self._terms_df.apply( + lambda row: self.extract_go_id(row["gos"]) + ) + all_go_ids_list = parsed_go_ids.values.tolist() + self._classes = all_go_ids_list + new_label_columns = pd.DataFrame( + False, index=data_df.index, columns=all_go_ids_list + ) + data_df = pd.concat([data_df, new_label_columns], axis=1) + + for index, row in data_df.iterrows(): + for go_id in row["go_ids"]: + if go_id in data_df.columns: + data_df.at[index, go_id] = True + + data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] + return data_df + + @staticmethod + def extract_go_id(go_list: List[str]) -> List[str]: + """ + Extracts and parses GO IDs from a list of GO annotations. + + Args: + go_list (List[str]): List of GO annotation strings. + + Returns: + List[str]: List of parsed GO IDs. + """ + return [ + _GOUniProtDataExtractor._parse_go_id(go_id_str) for go_id_str in go_list + ] + + def save_migrated_data( + self, data_df: pd.DataFrame, splits_df: pd.DataFrame + ) -> None: + """ + Saves the processed data and split information. + + Args: + data_df (pd.DataFrame): Data with GO labels. + splits_df (pd.DataFrame): Split assignment DataFrame. + """ + print("Saving transformed data......") + go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[ + self._go_branch + ](go_branch=self._go_branch, max_sequence_length=self._MAXLEN) + + go_class_instance.save_processed( + data_df, go_class_instance.processed_file_names_dict["data"] + ) + print( + f"{go_class_instance.processed_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}" + ) + + splits_df.to_csv( + os.path.join(go_class_instance.processed_dir_main, "splits.csv"), + index=False, + ) + print(f"splits.csv saved to {go_class_instance.processed_dir_main}") + + classes = sorted(self._classes) + with open( + os.path.join(go_class_instance.processed_dir_main, "classes.txt"), "wt" + ) as fout: + fout.writelines(str(node) + "\n" for node in classes) + print(f"classes.txt saved to {go_class_instance.processed_dir_main}") + print("Migration completed!") + + +class Main: + """ + Main class to handle the migration process for DeepGoDataMigration. + + Methods: + migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + Initiates the migration process for the specified data directory and GO branch. + """ + + def migrate(self, data_dir: str, go_branch: str) -> None: + """ + Initiates the migration process by creating a DeepGoDataMigration instance + and invoking its migrate method. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use + ("cc" for cellular_component, + "mf" for molecular_function, + or "bp" for biological_process). + """ + DeepGoDataMigration(data_dir, go_branch).migrate() + + +class Main1: + def __init__(self, max_prize: int = 100): + """ + Args: + max_prize: Maximum prize that can be awarded. + """ + self.max_prize = max_prize + + def person(self, name: str, additional_prize: int = 0): + """ + Args: + name: Name of the winner. + additional_prize: Additional prize that can be added to the prize amount. + """ + prize = randint(0, self.max_prize) + additional_prize + return f"{name} won {prize}€!" if __name__ == "__main__": - pass + # Example: python script_name.py migrate data_dir="data/deep_go_se_training_data" go_branch="bp" + # --data_dir specifies the directory containing the data files. + # --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration. + CLI( + Main1, + description="DeepGoDataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).", + ) From dfb9430795a7a45826eb350d3068074e2b567a83 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 7 Nov 2024 11:12:15 +0100 Subject: [PATCH 09/33] migration: rectify errors --- .../migration/deep_go_data_mirgration.py | 65 ++++++++----------- 1 file changed, 28 insertions(+), 37 deletions(-) diff --git a/chebai/preprocessing/migration/deep_go_data_mirgration.py b/chebai/preprocessing/migration/deep_go_data_mirgration.py index a33e407b..5c22b389 100644 --- a/chebai/preprocessing/migration/deep_go_data_mirgration.py +++ b/chebai/preprocessing/migration/deep_go_data_mirgration.py @@ -1,7 +1,6 @@ import os from collections import OrderedDict -from random import randint -from typing import List, Literal +from typing import List, Literal, Optional import pandas as pd from jsonargparse import CLI @@ -59,12 +58,12 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): raise ValueError(f"go_branch must be one of {valid_go_branches}") self._go_branch = go_branch - self._data_dir = os.path.join(data_dir, go_branch) - self._train_df: pd.DataFrame = None - self._test_df: pd.DataFrame = None - self._validation_df: pd.DataFrame = None - self._terms_df: pd.DataFrame = None - self._classes: List[str] = None + self._data_dir: str = os.path.join(rf"{data_dir}", go_branch) + self._train_df: Optional[pd.DataFrame] = None + self._test_df: Optional[pd.DataFrame] = None + self._validation_df: Optional[pd.DataFrame] = None + self._terms_df: Optional[pd.DataFrame] = None + self._classes: Optional[List[str]] = None def _load_data(self) -> None: """ @@ -114,7 +113,13 @@ def migrate(self) -> None: print("Migration started......") self._load_data() if not all( - [self._train_df, self._validation_df, self._test_df, self._terms_df] + df is not None + for df in [ + self._train_df, + self._validation_df, + self._test_df, + self._terms_df, + ] ): raise Exception( "Data splits or terms data is not available in instance variables." @@ -124,7 +129,9 @@ def migrate(self) -> None: data_df = self._extract_required_data_from_splits() data_with_labels_df = self._generate_labels(data_df) - if not all([data_with_labels_df, splits_df, self._classes]): + if not all( + var is not None for var in [data_with_labels_df, splits_df, self._classes] + ): raise Exception( "Data splits or terms data is not available in instance variables." ) @@ -184,8 +191,8 @@ def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: pd.DataFrame: DataFrame with new label columns. """ print("Generating labels based on terms.pkl file.......") - parsed_go_ids: pd.Series = self._terms_df.apply( - lambda row: self.extract_go_id(row["gos"]) + parsed_go_ids: pd.Series = self._terms_df["gos"].apply( + lambda gos: _GOUniProtDataExtractor._parse_go_id(gos) ) all_go_ids_list = parsed_go_ids.values.tolist() self._classes = all_go_ids_list @@ -203,7 +210,7 @@ def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: return data_df @staticmethod - def extract_go_id(go_list: List[str]) -> List[str]: + def extract_go_id(go_list: List[str]) -> List[int]: """ Extracts and parses GO IDs from a list of GO annotations. @@ -230,13 +237,13 @@ def save_migrated_data( print("Saving transformed data......") go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[ self._go_branch - ](go_branch=self._go_branch, max_sequence_length=self._MAXLEN) + ](go_branch=self._go_branch.upper(), max_sequence_length=self._MAXLEN) go_class_instance.save_processed( - data_df, go_class_instance.processed_file_names_dict["data"] + data_df, go_class_instance.processed_main_file_names_dict["data"] ) print( - f"{go_class_instance.processed_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}" + f"{go_class_instance.processed_main_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}" ) splits_df.to_csv( @@ -263,7 +270,8 @@ class Main: Initiates the migration process for the specified data directory and GO branch. """ - def migrate(self, data_dir: str, go_branch: str) -> None: + @staticmethod + def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None: """ Initiates the migration process by creating a DeepGoDataMigration instance and invoking its migrate method. @@ -278,29 +286,12 @@ def migrate(self, data_dir: str, go_branch: str) -> None: DeepGoDataMigration(data_dir, go_branch).migrate() -class Main1: - def __init__(self, max_prize: int = 100): - """ - Args: - max_prize: Maximum prize that can be awarded. - """ - self.max_prize = max_prize - - def person(self, name: str, additional_prize: int = 0): - """ - Args: - name: Name of the winner. - additional_prize: Additional prize that can be added to the prize amount. - """ - prize = randint(0, self.max_prize) + additional_prize - return f"{name} won {prize}€!" - - if __name__ == "__main__": - # Example: python script_name.py migrate data_dir="data/deep_go_se_training_data" go_branch="bp" + # Example: python script_name.py migrate --data_dir="data/deep_go_se_training_data" --go_branch="bp" # --data_dir specifies the directory containing the data files. # --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration. CLI( - Main1, + Main, description="DeepGoDataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).", + as_positional=False, ) From 085b13b5798398d4dca9477ed8ad80ecf50d2e0b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 7 Nov 2024 13:25:21 +0100 Subject: [PATCH 10/33] protein trigram containing tokenS with `X` - https://github.com/ChEB-AI/python-chebai/issues/36#issuecomment-2458153758 --- .../bin/protein_token_3_gram/tokens.txt | 359 ++++++++++++++++++ 1 file changed, 359 insertions(+) diff --git a/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt b/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt index 69dca126..534e5db1 100644 --- a/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt +++ b/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt @@ -7998,3 +7998,362 @@ WWC WCC WCH WWM +TAX +AXD +XDR +IEX +EXV +QAX +AXX +XXE +XES +MXN +XNF +NRX +RXX +XXX +XXR +XRI +SAX +AXG +XGG +PRX +RXR +XRX +RXE +XEF +QEX +EXQ +XQR +REX +EXR +RXQ +XQQ +DRX +RXP +XPG +QMX +MXT +XTX +TXR +XRM +APX +PXX +XXG +XGI +NLX +LXX +XXM +XMA +LNX +NXE +XEA +GTX +TXN +XND +LIX +IXI +XIM +MVX +VXX +XXK +XKT +GLX +LXP +XPP +QGX +GXD +XDL +XAP +QNX +NXM +XMN +VAX +XGV +IKX +KXY +KEX +EXL +XLY +GQX +QXE +XEP +PLX +XKC +PVX +XKE +RXI +XIR +AXL +XLN +LLX +LXD +XDA +AXE +XEL +GGX +GXG +KAX +XXA +XAG +XWS +SPX +PXC +XCD +GWX +WXH +XHF +MPX +ESX +SXN +XNK +DLX +LXN +XNS +QXG +XGD +ITX +XRG +NEX +EXA +XAL +LDX +DXI +XII +TPX +PXM +XMR +NXG +XGY +ASX +SXV +XVE +TKX +KXA +KRX +XXT +XTL +IDX +DXX +XXL +XLV +AKX +KXX +QHX +HXV +XVN +NSX +SXX +XKX +XDP +DAX +AXK +XKQ +PIX +IXX +XXF +VLX +XDI +DIX +IXL +XLK +LKX +KXV +XVA +DNX +NXD +ILX +LXK +XKV +VYX +YXE +XEI +RXS +XSH +KGX +XGF +AVX +VXY +XYG +HVX +XXI +XID +TVX +XXS +XSA +ENX +NXX +XMD +IIX +XMQ +AEX +EXX +XME +PGX +GXP +XPR +SKX +KXF +XFT +HRX +XSW +PQX +XGR +QQX +VTX +XRP +PSX +SXP +XPL +VGX +GXY +RSX +SXS +XSL +VSX +XST +AXV +XVL +AGX +GXX +XTK +KLX +LXR +XRV +AHX +HXC +XCS +LVX +VXN +XNR +NGX +GXL +TSX +SXQ +XQN +KXL +XLL +VIX +IXG +XGA +GFX +FXG +XGL +PTX +TXT +XTS +EMX +MXQ +SXY +XYA +IQX +QXY +XYR +TXK +IGX +XPS +PXT +XTG +NXQ +VKX +KXS +XSN +GVX +VXE +GRX +XRE +YKX +KXE +XEE +EEX +EXT +XTI +EHX +HXN +XNL +NDX +DXD +IAX +KSX +SXL +RRX +XRK +DDX +DXE +RXG +VXL +XLS +DTX +TXG +VXF +XFA +XIG +VXT +XTA +ISX +SXR +XRY +VQX +QXP +XPC +LGX +GXS +HGX +XGH +XXD +XDD +KKX +XXV +PKX +XLT +XSP +XLD +RAX +AXS +XSI +IYX +YXX +XXP +XPI +MSX +SXT +GEX +XHP +LFX +FXX +VXI +XIW +QTX +TXX +XXQ +XQA +FLX +DXN +XNC +MXS +XSR +YLX +EQX +QXS +TMX +MXC +XCY +NXA +XAV +EXE +XEQ +HPX +PXP +LMX +MXX +KTX +XKK +XXH +XHS +MKX +XIH +WRX +XKS +EXY +XYQ +QKX From 3e0bae0d75c0d3330a75c3c72e6ffa023ae2b37b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 7 Nov 2024 13:28:59 +0100 Subject: [PATCH 11/33] protein token unigram contain `X` - https://github.com/ChEB-AI/python-chebai/issues/36#issuecomment-2458153758 --- chebai/preprocessing/bin/protein_token/tokens.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/chebai/preprocessing/bin/protein_token/tokens.txt b/chebai/preprocessing/bin/protein_token/tokens.txt index 72ad1b6d..c31c5b72 100644 --- a/chebai/preprocessing/bin/protein_token/tokens.txt +++ b/chebai/preprocessing/bin/protein_token/tokens.txt @@ -18,3 +18,4 @@ W E V H +X From 99b5af1e263aa86ccaf1f350fb8703da202e13ec Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 12 Nov 2024 00:21:29 +0100 Subject: [PATCH 12/33] add migration for deepgo1 - 2018 paper --- .../migration/deep_go/__init__.py | 0 .../deep_go/migrate_deep_go_1_data.py | 310 ++++++++++++++++++ .../migrate_deep_go_2_data.py} | 10 +- 3 files changed, 318 insertions(+), 2 deletions(-) create mode 100644 chebai/preprocessing/migration/deep_go/__init__.py create mode 100644 chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py rename chebai/preprocessing/migration/{deep_go_data_mirgration.py => deep_go/migrate_deep_go_2_data.py} (96%) diff --git a/chebai/preprocessing/migration/deep_go/__init__.py b/chebai/preprocessing/migration/deep_go/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py new file mode 100644 index 00000000..be709364 --- /dev/null +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -0,0 +1,310 @@ +import os +from collections import OrderedDict +from typing import List, Literal, Optional + +import pandas as pd +from jsonargparse import CLI + +from chebai.preprocessing.datasets.go_uniprot import ( + GOUniProtOver50, + GOUniProtOver250, + _GOUniProtDataExtractor, +) + + +class DeepGo1DataMigration: + """ + A class to handle data migration and processing for the DeepGO project. + It migrates the deepGO data to our data structure followed for GO-UniProt data. + + It migrates the data of DeepGO model of the below research paper: + Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, + DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, + Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668 + (https://doi.org/10.1093/bioinformatics/btx624), + + Attributes: + _CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes. + _MAXLEN (int): Maximum sequence length for sequences. + _LABELS_START_IDX (int): Starting index for labels in the dataset. + + Methods: + __init__(data_dir, go_branch): Initializes the data directory and GO branch. + _load_data(): Loads train, validation, test, and terms data from the specified directory. + _record_splits(): Creates a DataFrame with IDs and their corresponding split. + migrate(): Executes the migration process including data loading, processing, and saving. + _extract_required_data_from_splits(): Extracts required columns from the splits data. + _get_labels_columns(data_df): Generates label columns for the data based on GO terms. + extract_go_id(go_list): Extracts GO IDs from a list. + save_migrated_data(data_df, splits_df): Saves the processed data and splits. + """ + + # Number of annotations for each go_branch as per the research paper + _CORRESPONDING_GO_CLASSES = { + "cc": GOUniProtOver50, + "mf": GOUniProtOver50, + "bp": GOUniProtOver250, + } + + _MAXLEN = 1002 + _LABELS_START_IDX = _GOUniProtDataExtractor._LABELS_START_IDX + + def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + """ + Initializes the data migration object with a data directory and GO branch. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process). + """ + valid_go_branches = list(self._CORRESPONDING_GO_CLASSES.keys()) + if go_branch not in valid_go_branches: + raise ValueError(f"go_branch must be one of {valid_go_branches}") + self._go_branch = go_branch + + self._data_dir: str = rf"{data_dir}" + self._train_df: Optional[pd.DataFrame] = None + self._test_df: Optional[pd.DataFrame] = None + self._validation_df: Optional[pd.DataFrame] = None + self._terms_df: Optional[pd.DataFrame] = None + self._classes: Optional[List[str]] = None + + def _load_data(self) -> None: + """ + Loads the test, train, validation, and terms data from the pickled files + in the data directory. + """ + try: + print(f"Loading data from {self._data_dir}......") + self._test_df = pd.DataFrame( + pd.read_pickle( + os.path.join(self._data_dir, f"test-{self._go_branch}.pkl") + ) + ) + self._train_df = pd.DataFrame( + pd.read_pickle( + os.path.join(self._data_dir, f"train-{self._go_branch}.pkl") + ) + ) + # self._validation_df = pd.DataFrame( + # pd.read_pickle(os.path.join(self._data_dir, f"valid-{self._go_branch}.pkl")) + # ) + self._terms_df = pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, f"{self._go_branch}.pkl")) + ) + + except FileNotFoundError as e: + print(f"Error loading data: {e}") + + def _record_splits(self) -> pd.DataFrame: + """ + Creates a DataFrame that stores the IDs and their corresponding data splits. + + Returns: + pd.DataFrame: A combined DataFrame containing split assignments. + """ + print("Recording splits...") + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), + # pd.DataFrame( + # {"id": self._validation_df["proteins"], "split": "validation"} + # ), + pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}), + ] + + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + return combined_split_assignment + + def migrate(self) -> None: + """ + Executes the data migration by loading, processing, and saving the data. + """ + print("Migration started......") + self._load_data() + if not all( + df is not None + for df in [ + self._train_df, + # self._validation_df, + self._test_df, + self._terms_df, + ] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + splits_df = self._record_splits() + + data_with_labels_df = self._extract_required_data_from_splits() + + if not all( + var is not None for var in [data_with_labels_df, splits_df, self._classes] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + + self.save_migrated_data(data_with_labels_df, splits_df) + + def _extract_required_data_from_splits(self) -> pd.DataFrame: + """ + Extracts required columns from the combined data splits. + + Returns: + pd.DataFrame: A DataFrame containing the essential columns for processing. + """ + print("Combining the data splits with required data..... ") + required_columns = [ + "proteins", + "accessions", + "sequences", + # Note: The GO classes here only directly related one, and not transitive GO classes + "gos", + "labels", + ] + + new_df = pd.concat( + [ + self._train_df[required_columns], + # self._validation_df[required_columns], + self._test_df[required_columns], + ], + ignore_index=True, + ) + new_df["go_ids"] = new_df.apply( + lambda row: self.extract_go_id(row["gos"]), axis=1 + ) + + labels_df = self._get_labels_colums(new_df) + + data_df = pd.DataFrame( + OrderedDict( + swiss_id=new_df["proteins"], + accession=new_df["accessions"], + go_ids=new_df["go_ids"], + sequence=new_df["sequences"], + ) + ) + + df = pd.concat([data_df, labels_df], axis=1) + + return df + + def _get_labels_colums(self, data_df: pd.DataFrame) -> pd.DataFrame: + """ + Generates a DataFrame with one-hot encoded columns for each GO term label, + based on the terms provided in `self._terms_df` and the existing labels in `data_df`. + + This method extracts GO IDs from the `functions` column of `self._terms_df`, + creating a list of all unique GO IDs. It then uses this list to create new + columns in the returned DataFrame, where each row has binary values + (0 or 1) indicating the presence of each GO ID in the corresponding entry of + `data_df['labels']`. + + Args: + data_df (pd.DataFrame): DataFrame containing data with a 'labels' column, + which holds lists of GO ID labels for each row. + + Returns: + pd.DataFrame: A DataFrame with the same index as `data_df` and one column + per GO ID, containing binary values indicating label presence. + """ + print("Generating labels based on terms.pkl file.......") + parsed_go_ids: pd.Series = self._terms_df["functions"].apply( + lambda gos: _GOUniProtDataExtractor._parse_go_id(gos) + ) + all_go_ids_list = parsed_go_ids.values.tolist() + self._classes = all_go_ids_list + + new_label_columns = pd.DataFrame( + data_df["labels"].tolist(), index=data_df.index, columns=all_go_ids_list + ) + + return new_label_columns + + @staticmethod + def extract_go_id(go_list: List[str]) -> List[int]: + """ + Extracts and parses GO IDs from a list of GO annotations. + + Args: + go_list (List[str]): List of GO annotation strings. + + Returns: + List[str]: List of parsed GO IDs. + """ + return [ + _GOUniProtDataExtractor._parse_go_id(go_id_str) for go_id_str in go_list + ] + + def save_migrated_data( + self, data_df: pd.DataFrame, splits_df: pd.DataFrame + ) -> None: + """ + Saves the processed data and split information. + + Args: + data_df (pd.DataFrame): Data with GO labels. + splits_df (pd.DataFrame): Split assignment DataFrame. + """ + print("Saving transformed data......") + go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[ + self._go_branch + ](go_branch=self._go_branch.upper(), max_sequence_length=self._MAXLEN) + + go_class_instance.save_processed( + data_df, go_class_instance.processed_main_file_names_dict["data"] + ) + print( + f"{go_class_instance.processed_main_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}" + ) + + splits_df.to_csv( + os.path.join(go_class_instance.processed_dir_main, "splits.csv"), + index=False, + ) + print(f"splits.csv saved to {go_class_instance.processed_dir_main}") + + classes = sorted(self._classes) + with open( + os.path.join(go_class_instance.processed_dir_main, "classes.txt"), "wt" + ) as fout: + fout.writelines(str(node) + "\n" for node in classes) + print(f"classes.txt saved to {go_class_instance.processed_dir_main}") + print("Migration completed!") + + +class Main: + """ + Main class to handle the migration process for DeepGo1DataMigration. + + Methods: + migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + Initiates the migration process for the specified data directory and GO branch. + """ + + @staticmethod + def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None: + """ + Initiates the migration process by creating a DeepGoDataMigration instance + and invoking its migrate method. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use + ("cc" for cellular_component, + "mf" for molecular_function, + or "bp" for biological_process). + """ + DeepGo1DataMigration(data_dir, go_branch).migrate() + + +if __name__ == "__main__": + # Example: python script_name.py migrate --data_dir="data/deep_go1" --go_branch="mf" + # --data_dir specifies the directory containing the data files. + # --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration. + CLI( + Main, + description="DeepGo1DataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).", + as_positional=False, + ) diff --git a/chebai/preprocessing/migration/deep_go_data_mirgration.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py similarity index 96% rename from chebai/preprocessing/migration/deep_go_data_mirgration.py rename to chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index 5c22b389..0d5266ef 100644 --- a/chebai/preprocessing/migration/deep_go_data_mirgration.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -12,11 +12,17 @@ ) -class DeepGoDataMigration: +class DeepGo2DataMigration: """ A class to handle data migration and processing for the DeepGO project. It migrates the data from the DeepGO-SE data structure to our data structure followed for GO-UniProt data. + It migrates the data of DeepGO model of the below research paper: + Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, + DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, + Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668 + (https://doi.org/10.1093/bioinformatics/btx624), + Attributes: _CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes. _MAXLEN (int): Maximum sequence length for sequences. @@ -283,7 +289,7 @@ def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None: "mf" for molecular_function, or "bp" for biological_process). """ - DeepGoDataMigration(data_dir, go_branch).migrate() + DeepGo2DataMigration(data_dir, go_branch).migrate() if __name__ == "__main__": From a15d49254c1d5a378dc8ac64508392c55fcb3841 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 12 Nov 2024 17:40:39 +0100 Subject: [PATCH 13/33] deepgo1: create non-exclusive val set as a placeholder --- .../deep_go/migrate_deep_go_1_data.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py index be709364..f42b08c3 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -89,6 +89,14 @@ def _load_data(self) -> None: # self._validation_df = pd.DataFrame( # pd.read_pickle(os.path.join(self._data_dir, f"valid-{self._go_branch}.pkl")) # ) + + # DeepGO1 data does not include a separate validation split, but our data structure requires one. + # To accommodate this, we will create a placeholder validation split by duplicating a small subset of the + # training data. However, to ensure a fair comparison with DeepGO1, we will retain the full training set + # without creating an exclusive validation split from it. + # Therefore, any metrics calculated on this placeholder validation set should be disregarded, as they do not + # reflect true validation performance. + self._validation_df = self._train_df[len(self._train_df) - 5 :] self._terms_df = pd.DataFrame( pd.read_pickle(os.path.join(self._data_dir, f"{self._go_branch}.pkl")) ) @@ -106,9 +114,9 @@ def _record_splits(self) -> pd.DataFrame: print("Recording splits...") split_assignment_list: List[pd.DataFrame] = [ pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), - # pd.DataFrame( - # {"id": self._validation_df["proteins"], "split": "validation"} - # ), + pd.DataFrame( + {"id": self._validation_df["proteins"], "split": "validation"} + ), pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}), ] @@ -125,7 +133,7 @@ def migrate(self) -> None: df is not None for df in [ self._train_df, - # self._validation_df, + self._validation_df, self._test_df, self._terms_df, ] @@ -166,7 +174,7 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: new_df = pd.concat( [ self._train_df[required_columns], - # self._validation_df[required_columns], + self._validation_df[required_columns], self._test_df[required_columns], ], ignore_index=True, From e0a85247f2f7b561593d6cb9536e66aefb9ecebf Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 13 Nov 2024 22:45:32 +0100 Subject: [PATCH 14/33] deepgo1: further split train set into train and val for - +migration structure changes --- .../deep_go/migrate_deep_go_1_data.py | 241 +++++++++--------- 1 file changed, 118 insertions(+), 123 deletions(-) diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py index f42b08c3..48188cd7 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -1,53 +1,29 @@ import os from collections import OrderedDict -from typing import List, Literal, Optional +from typing import List, Literal, Optional, Tuple import pandas as pd +from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit from jsonargparse import CLI -from chebai.preprocessing.datasets.go_uniprot import ( - GOUniProtOver50, - GOUniProtOver250, - _GOUniProtDataExtractor, -) +from chebai.preprocessing.datasets.go_uniprot import DeepGO1MigratedData class DeepGo1DataMigration: """ A class to handle data migration and processing for the DeepGO project. - It migrates the deepGO data to our data structure followed for GO-UniProt data. + It migrates the DeepGO data to our data structure followed for GO-UniProt data. - It migrates the data of DeepGO model of the below research paper: + This class handles data from the DeepGO model as described in: Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668 - (https://doi.org/10.1093/bioinformatics/btx624), - - Attributes: - _CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes. - _MAXLEN (int): Maximum sequence length for sequences. - _LABELS_START_IDX (int): Starting index for labels in the dataset. - - Methods: - __init__(data_dir, go_branch): Initializes the data directory and GO branch. - _load_data(): Loads train, validation, test, and terms data from the specified directory. - _record_splits(): Creates a DataFrame with IDs and their corresponding split. - migrate(): Executes the migration process including data loading, processing, and saving. - _extract_required_data_from_splits(): Extracts required columns from the splits data. - _get_labels_columns(data_df): Generates label columns for the data based on GO terms. - extract_go_id(go_list): Extracts GO IDs from a list. - save_migrated_data(data_df, splits_df): Saves the processed data and splits. + (https://doi.org/10.1093/bioinformatics/btx624). """ - # Number of annotations for each go_branch as per the research paper - _CORRESPONDING_GO_CLASSES = { - "cc": GOUniProtOver50, - "mf": GOUniProtOver50, - "bp": GOUniProtOver250, - } - + # Max sequence length as per DeepGO1 _MAXLEN = 1002 - _LABELS_START_IDX = _GOUniProtDataExtractor._LABELS_START_IDX + _LABELS_START_IDX = DeepGO1MigratedData._LABELS_START_IDX def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): """ @@ -55,9 +31,9 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): Args: data_dir (str): Directory containing the data files. - go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process). + go_branch (Literal["cc", "mf", "bp"]): GO branch to use. """ - valid_go_branches = list(self._CORRESPONDING_GO_CLASSES.keys()) + valid_go_branches = list(DeepGO1MigratedData.GO_BRANCH_MAPPING.keys()) if go_branch not in valid_go_branches: raise ValueError(f"go_branch must be one of {valid_go_branches}") self._go_branch = go_branch @@ -69,34 +45,60 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): self._terms_df: Optional[pd.DataFrame] = None self._classes: Optional[List[str]] = None + def migrate(self) -> None: + """ + Executes the data migration by loading, processing, and saving the data. + """ + print("Starting the migration process...") + self._load_data() + if not all( + df is not None + for df in [ + self._train_df, + self._validation_df, + self._test_df, + self._terms_df, + ] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + splits_df = self._record_splits() + data_with_labels_df = self._extract_required_data_from_splits() + + if not all( + var is not None for var in [data_with_labels_df, splits_df, self._classes] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + + self.save_migrated_data(data_with_labels_df, splits_df) + def _load_data(self) -> None: """ Loads the test, train, validation, and terms data from the pickled files in the data directory. """ try: - print(f"Loading data from {self._data_dir}......") + print(f"Loading data files from directory: {self._data_dir}") self._test_df = pd.DataFrame( pd.read_pickle( os.path.join(self._data_dir, f"test-{self._go_branch}.pkl") ) ) - self._train_df = pd.DataFrame( + + # DeepGO 1 lacks a validation split, so we will create one by further splitting the training set. + # Although this reduces the training data slightly compared to the original DeepGO setup, + # given the data size, the impact should be minimal. + train_df = pd.DataFrame( pd.read_pickle( os.path.join(self._data_dir, f"train-{self._go_branch}.pkl") ) ) - # self._validation_df = pd.DataFrame( - # pd.read_pickle(os.path.join(self._data_dir, f"valid-{self._go_branch}.pkl")) - # ) - - # DeepGO1 data does not include a separate validation split, but our data structure requires one. - # To accommodate this, we will create a placeholder validation split by duplicating a small subset of the - # training data. However, to ensure a fair comparison with DeepGO1, we will retain the full training set - # without creating an exclusive validation split from it. - # Therefore, any metrics calculated on this placeholder validation set should be disregarded, as they do not - # reflect true validation performance. - self._validation_df = self._train_df[len(self._train_df) - 5 :] + + self._train_df, self._validation_df = self._get_train_val_split(train_df) + self._terms_df = pd.DataFrame( pd.read_pickle(os.path.join(self._data_dir, f"{self._go_branch}.pkl")) ) @@ -104,6 +106,35 @@ def _load_data(self) -> None: except FileNotFoundError as e: print(f"Error loading data: {e}") + @staticmethod + def _get_train_val_split( + train_df: pd.DataFrame, + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + Splits the training data into a smaller training set and a validation set. + + Args: + train_df (pd.DataFrame): Original training DataFrame. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame]: Training and validation DataFrames. + """ + labels_list_train = train_df["labels"].tolist() + train_split = 0.85 + test_size = ((1 - train_split) ** 2) / train_split + + splitter = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=test_size, random_state=42 + ) + + train_indices, validation_indices = next( + splitter.split(labels_list_train, labels_list_train) + ) + + df_validation = train_df.iloc[validation_indices] + df_train = train_df.iloc[train_indices] + return df_train, df_validation + def _record_splits(self) -> pd.DataFrame: """ Creates a DataFrame that stores the IDs and their corresponding data splits. @@ -111,7 +142,7 @@ def _record_splits(self) -> pd.DataFrame: Returns: pd.DataFrame: A combined DataFrame containing split assignments. """ - print("Recording splits...") + print("Recording data splits for train, validation, and test sets.") split_assignment_list: List[pd.DataFrame] = [ pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), pd.DataFrame( @@ -123,37 +154,6 @@ def _record_splits(self) -> pd.DataFrame: combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) return combined_split_assignment - def migrate(self) -> None: - """ - Executes the data migration by loading, processing, and saving the data. - """ - print("Migration started......") - self._load_data() - if not all( - df is not None - for df in [ - self._train_df, - self._validation_df, - self._test_df, - self._terms_df, - ] - ): - raise Exception( - "Data splits or terms data is not available in instance variables." - ) - splits_df = self._record_splits() - - data_with_labels_df = self._extract_required_data_from_splits() - - if not all( - var is not None for var in [data_with_labels_df, splits_df, self._classes] - ): - raise Exception( - "Data splits or terms data is not available in instance variables." - ) - - self.save_migrated_data(data_with_labels_df, splits_df) - def _extract_required_data_from_splits(self) -> pd.DataFrame: """ Extracts required columns from the combined data splits. @@ -161,12 +161,11 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: Returns: pd.DataFrame: A DataFrame containing the essential columns for processing. """ - print("Combining the data splits with required data..... ") + print("Combining data splits into a single DataFrame with required columns.") required_columns = [ "proteins", "accessions", "sequences", - # Note: The GO classes here only directly related one, and not transitive GO classes "gos", "labels", ] @@ -183,7 +182,7 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: lambda row: self.extract_go_id(row["gos"]), axis=1 ) - labels_df = self._get_labels_colums(new_df) + labels_df = self._get_labels_columns(new_df) data_df = pd.DataFrame( OrderedDict( @@ -198,28 +197,32 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: return df - def _get_labels_colums(self, data_df: pd.DataFrame) -> pd.DataFrame: + @staticmethod + def extract_go_id(go_list: List[str]) -> List[int]: """ - Generates a DataFrame with one-hot encoded columns for each GO term label, - based on the terms provided in `self._terms_df` and the existing labels in `data_df`. + Extracts and parses GO IDs from a list of GO annotations. - This method extracts GO IDs from the `functions` column of `self._terms_df`, - creating a list of all unique GO IDs. It then uses this list to create new - columns in the returned DataFrame, where each row has binary values - (0 or 1) indicating the presence of each GO ID in the corresponding entry of - `data_df['labels']`. + Args: + go_list (List[str]): List of GO annotation strings. + + Returns: + List[int]: List of parsed GO IDs. + """ + return [DeepGO1MigratedData._parse_go_id(go_id_str) for go_id_str in go_list] + + def _get_labels_columns(self, data_df: pd.DataFrame) -> pd.DataFrame: + """ + Generates columns for labels based on provided selected terms. Args: - data_df (pd.DataFrame): DataFrame containing data with a 'labels' column, - which holds lists of GO ID labels for each row. + data_df (pd.DataFrame): DataFrame with GO annotations and labels. Returns: - pd.DataFrame: A DataFrame with the same index as `data_df` and one column - per GO ID, containing binary values indicating label presence. + pd.DataFrame: DataFrame with label columns. """ - print("Generating labels based on terms.pkl file.......") + print("Generating label columns from provided selected terms.") parsed_go_ids: pd.Series = self._terms_df["functions"].apply( - lambda gos: _GOUniProtDataExtractor._parse_go_id(gos) + lambda gos: DeepGO1MigratedData._parse_go_id(gos) ) all_go_ids_list = parsed_go_ids.values.tolist() self._classes = all_go_ids_list @@ -230,21 +233,6 @@ def _get_labels_colums(self, data_df: pd.DataFrame) -> pd.DataFrame: return new_label_columns - @staticmethod - def extract_go_id(go_list: List[str]) -> List[int]: - """ - Extracts and parses GO IDs from a list of GO annotations. - - Args: - go_list (List[str]): List of GO annotation strings. - - Returns: - List[str]: List of parsed GO IDs. - """ - return [ - _GOUniProtDataExtractor._parse_go_id(go_id_str) for go_id_str in go_list - ] - def save_migrated_data( self, data_df: pd.DataFrame, splits_df: pd.DataFrame ) -> None: @@ -255,31 +243,38 @@ def save_migrated_data( data_df (pd.DataFrame): Data with GO labels. splits_df (pd.DataFrame): Split assignment DataFrame. """ - print("Saving transformed data......") - go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[ - self._go_branch - ](go_branch=self._go_branch.upper(), max_sequence_length=self._MAXLEN) + print("Saving transformed data files.") - go_class_instance.save_processed( - data_df, go_class_instance.processed_main_file_names_dict["data"] + deepgo_migr_inst: DeepGO1MigratedData = DeepGO1MigratedData( + go_branch=DeepGO1MigratedData.GO_BRANCH_MAPPING[self._go_branch], + max_sequence_length=self._MAXLEN, + ) + + # Save data file + deepgo_migr_inst.save_processed( + data_df, deepgo_migr_inst.processed_main_file_names_dict["data"] ) print( - f"{go_class_instance.processed_main_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}" + f"{deepgo_migr_inst.processed_main_file_names_dict['data']} saved to {deepgo_migr_inst.processed_dir_main}" ) + # Save splits file splits_df.to_csv( - os.path.join(go_class_instance.processed_dir_main, "splits.csv"), + os.path.join(deepgo_migr_inst.processed_dir_main, "splits_deep_go1.csv"), index=False, ) - print(f"splits.csv saved to {go_class_instance.processed_dir_main}") + print(f"splits_deep_go1.csv saved to {deepgo_migr_inst.processed_dir_main}") + # Save classes file classes = sorted(self._classes) with open( - os.path.join(go_class_instance.processed_dir_main, "classes.txt"), "wt" + os.path.join(deepgo_migr_inst.processed_dir_main, "classes_deep_go1.txt"), + "wt", ) as fout: fout.writelines(str(node) + "\n" for node in classes) - print(f"classes.txt saved to {go_class_instance.processed_dir_main}") - print("Migration completed!") + print(f"classes_deep_go1.txt saved to {deepgo_migr_inst.processed_dir_main}") + + print("Migration process completed!") class Main: From 093be281a3784972a80abd647dff7f79ceaca553 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 13 Nov 2024 22:56:44 +0100 Subject: [PATCH 15/33] migration script update --- .../deep_go/migrate_deep_go_1_data.py | 2 +- .../deep_go/migrate_deep_go_2_data.py | 163 ++++++++---------- 2 files changed, 71 insertions(+), 94 deletions(-) diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py index 48188cd7..ad8ae322 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -14,7 +14,7 @@ class DeepGo1DataMigration: A class to handle data migration and processing for the DeepGO project. It migrates the DeepGO data to our data structure followed for GO-UniProt data. - This class handles data from the DeepGO model as described in: + This class handles migration of data from the DeepGO paper below: Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668 diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index 0d5266ef..3d4109e1 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -5,11 +5,7 @@ import pandas as pd from jsonargparse import CLI -from chebai.preprocessing.datasets.go_uniprot import ( - GOUniProtOver50, - GOUniProtOver250, - _GOUniProtDataExtractor, -) +from chebai.preprocessing.datasets.go_uniprot import DeepGO2MigratedData class DeepGo2DataMigration: @@ -17,39 +13,16 @@ class DeepGo2DataMigration: A class to handle data migration and processing for the DeepGO project. It migrates the data from the DeepGO-SE data structure to our data structure followed for GO-UniProt data. - It migrates the data of DeepGO model of the below research paper: + This class handles migration of data from the DeepGO paper below: Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668 - (https://doi.org/10.1093/bioinformatics/btx624), - - Attributes: - _CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes. - _MAXLEN (int): Maximum sequence length for sequences. - _LABELS_START_IDX (int): Starting index for labels in the dataset. - - Methods: - __init__(data_dir, go_branch): Initializes the data directory and GO branch. - _load_data(): Loads train, validation, test, and terms data from the specified directory. - _record_splits(): Creates a DataFrame with IDs and their corresponding split. - migrate(): Executes the migration process including data loading, processing, and saving. - _extract_required_data_from_splits(): Extracts required columns from the splits data. - _generate_labels(data_df): Generates label columns for the data based on GO terms. - extract_go_id(go_list): Extracts GO IDs from a list. - save_migrated_data(data_df, splits_df): Saves the processed data and splits. + (https://doi.org/10.1093/bioinformatics/btx624) """ - # Link for the namespaces convention used for GO branch - # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L18-L22 - _CORRESPONDING_GO_CLASSES = { - "cc": GOUniProtOver50, - "mf": GOUniProtOver50, - "bp": GOUniProtOver250, - } - # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 _MAXLEN = 1000 - _LABELS_START_IDX = _GOUniProtDataExtractor._LABELS_START_IDX + _LABELS_START_IDX = DeepGO2MigratedData._LABELS_START_IDX def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): """ @@ -57,9 +30,9 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): Args: data_dir (str): Directory containing the data files. - go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process). + go_branch (Literal["cc", "mf", "bp"]): GO branch to use. """ - valid_go_branches = list(self._CORRESPONDING_GO_CLASSES.keys()) + valid_go_branches = list(DeepGO2MigratedData.GO_BRANCH_MAPPING.keys()) if go_branch not in valid_go_branches: raise ValueError(f"go_branch must be one of {valid_go_branches}") self._go_branch = go_branch @@ -71,13 +44,45 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): self._terms_df: Optional[pd.DataFrame] = None self._classes: Optional[List[str]] = None + def migrate(self) -> None: + """ + Executes the data migration by loading, processing, and saving the data. + """ + print("Starting the migration process...") + self._load_data() + if not all( + df is not None + for df in [ + self._train_df, + self._validation_df, + self._test_df, + self._terms_df, + ] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + splits_df = self._record_splits() + + data_df = self._extract_required_data_from_splits() + data_with_labels_df = self._generate_labels(data_df) + + if not all( + var is not None for var in [data_with_labels_df, splits_df, self._classes] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + + self.save_migrated_data(data_df, splits_df) + def _load_data(self) -> None: """ Loads the test, train, validation, and terms data from the pickled files in the data directory. """ try: - print(f"Loading data from {self._data_dir}......") + print(f"Loading data from directory: {self._data_dir}......") self._test_df = pd.DataFrame( pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl")) ) @@ -100,7 +105,7 @@ def _record_splits(self) -> pd.DataFrame: Returns: pd.DataFrame: A combined DataFrame containing split assignments. """ - print("Recording splits...") + print("Recording data splits for train, validation, and test sets.") split_assignment_list: List[pd.DataFrame] = [ pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), pd.DataFrame( @@ -112,38 +117,6 @@ def _record_splits(self) -> pd.DataFrame: combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) return combined_split_assignment - def migrate(self) -> None: - """ - Executes the data migration by loading, processing, and saving the data. - """ - print("Migration started......") - self._load_data() - if not all( - df is not None - for df in [ - self._train_df, - self._validation_df, - self._test_df, - self._terms_df, - ] - ): - raise Exception( - "Data splits or terms data is not available in instance variables." - ) - splits_df = self._record_splits() - - data_df = self._extract_required_data_from_splits() - data_with_labels_df = self._generate_labels(data_df) - - if not all( - var is not None for var in [data_with_labels_df, splits_df, self._classes] - ): - raise Exception( - "Data splits or terms data is not available in instance variables." - ) - - self.save_migrated_data(data_df, splits_df) - def _extract_required_data_from_splits(self) -> pd.DataFrame: """ Extracts required columns from the combined data splits. @@ -186,6 +159,19 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: ) return data_df + @staticmethod + def extract_go_id(go_list: List[str]) -> List[int]: + """ + Extracts and parses GO IDs from a list of GO annotations. + + Args: + go_list (List[str]): List of GO annotation strings. + + Returns: + List[str]: List of parsed GO IDs. + """ + return [DeepGO2MigratedData._parse_go_id(go_id_str) for go_id_str in go_list] + def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: """ Generates label columns for each GO term in the dataset. @@ -198,7 +184,7 @@ def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: """ print("Generating labels based on terms.pkl file.......") parsed_go_ids: pd.Series = self._terms_df["gos"].apply( - lambda gos: _GOUniProtDataExtractor._parse_go_id(gos) + lambda gos: DeepGO2MigratedData._parse_go_id(gos) ) all_go_ids_list = parsed_go_ids.values.tolist() self._classes = all_go_ids_list @@ -215,21 +201,6 @@ def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] return data_df - @staticmethod - def extract_go_id(go_list: List[str]) -> List[int]: - """ - Extracts and parses GO IDs from a list of GO annotations. - - Args: - go_list (List[str]): List of GO annotation strings. - - Returns: - List[str]: List of parsed GO IDs. - """ - return [ - _GOUniProtDataExtractor._parse_go_id(go_id_str) for go_id_str in go_list - ] - def save_migrated_data( self, data_df: pd.DataFrame, splits_df: pd.DataFrame ) -> None: @@ -241,29 +212,35 @@ def save_migrated_data( splits_df (pd.DataFrame): Split assignment DataFrame. """ print("Saving transformed data......") - go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[ - self._go_branch - ](go_branch=self._go_branch.upper(), max_sequence_length=self._MAXLEN) + deepgo_migr_inst: DeepGO2MigratedData = DeepGO2MigratedData( + go_branch=DeepGO2MigratedData.GO_BRANCH_MAPPING[self._go_branch], + max_sequence_length=self._MAXLEN, + ) - go_class_instance.save_processed( - data_df, go_class_instance.processed_main_file_names_dict["data"] + # Save data file + deepgo_migr_inst.save_processed( + data_df, deepgo_migr_inst.processed_main_file_names_dict["data"] ) print( - f"{go_class_instance.processed_main_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}" + f"{deepgo_migr_inst.processed_main_file_names_dict['data']} saved to {deepgo_migr_inst.processed_dir_main}" ) + # Save split file splits_df.to_csv( - os.path.join(go_class_instance.processed_dir_main, "splits.csv"), + os.path.join(deepgo_migr_inst.processed_dir_main, "splits_deep_go2.csv"), index=False, ) - print(f"splits.csv saved to {go_class_instance.processed_dir_main}") + print(f"splits_deep_go2.csv saved to {deepgo_migr_inst.processed_dir_main}") + # Save classes.txt file classes = sorted(self._classes) with open( - os.path.join(go_class_instance.processed_dir_main, "classes.txt"), "wt" + os.path.join(deepgo_migr_inst.processed_dir_main, "classes_deep_go2.txt"), + "wt", ) as fout: fout.writelines(str(node) + "\n" for node in classes) - print(f"classes.txt saved to {go_class_instance.processed_dir_main}") + print(f"classes_deep_go2.txt saved to {deepgo_migr_inst.processed_dir_main}") + print("Migration completed!") From 14db9d641a8b627dd2a878eee736158abdeddbcc Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 13 Nov 2024 23:00:18 +0100 Subject: [PATCH 16/33] add classes to use migrated deepgo data --- chebai/preprocessing/datasets/go_uniprot.py | 186 ++++++++++++++++++++ 1 file changed, 186 insertions(+) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 7b1c16e3..16bd6a31 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -22,6 +22,8 @@ "GOUniProtOver50", "EXPERIMENTAL_EVIDENCE_CODES", "AMBIGUOUS_AMINO_ACIDS", + "DeepGO1MigratedData", + "DeepGO2MigratedData", ] import gzip @@ -731,3 +733,187 @@ class GOUniProtOver50(_GOUniProtOverX): """ THRESHOLD: int = 50 + + +class _DeepGOMigratedData(_GOUniProtDataExtractor, ABC): + """ + Base class for use of the migrated DeepGO data with common properties, name formatting, and file paths. + + Attributes: + READER (dr.ProteinDataReader): Protein data reader class. + THRESHOLD (Optional[int]): Threshold value for GO class selection, + determined by the GO branch type in derived classes. + """ + + READER: dr.ProteinDataReader = dr.ProteinDataReader + THRESHOLD: Optional[int] = None + + # Mapping from GO branch conventions used in DeepGO to our conventions + GO_BRANCH_MAPPING: dict = { + "cc": "CC", + "mf": "MF", + "bp": "BP", + } + + @property + def _name(self) -> str: + """ + Generates a unique identifier for the migrated data based on the GO + branch and max sequence length, optionally including a threshold. + + Returns: + str: A formatted name string for the data. + """ + threshold_part = f"GO{self.THRESHOLD}_" if self.THRESHOLD is not None else "" + + if self.go_branch != self._ALL_GO_BRANCHES: + return f"{threshold_part}{self.go_branch}_{self.max_sequence_length}" + + return f"{threshold_part}{self.max_sequence_length}" + + # ------------------------------ Phase: Prepare data ----------------------------------- + def prepare_data(self, *args: Any, **kwargs: Any) -> None: + """ + Checks for the existence of migrated DeepGO data in the specified directory. + Raises an error if the required data file is not found, prompting + migration from DeepGO to this data structure. + + Args: + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. + + Raises: + FileNotFoundError: If the processed data file does not exist. + """ + print("Checking for processed data in", self.processed_dir_main) + + processed_name = self.processed_main_file_names_dict["data"] + if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): + raise FileNotFoundError( + f"File {processed_name} not found.\n" + f"You must run the appropriate DeepGO migration script " + f"(chebai/preprocessing/migration/deep_go) before executing this configuration " + f"to migrate data from DeepGO to this data structure." + ) + + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + # Selection of GO classes not needed for migrated data + pass + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + @abstractmethod + def processed_main_file_names_dict(self) -> Dict[str, str]: + """ + Abstract property for defining main processed file names. + These files are stored in the same directory as the generated data files + but have distinct names to differentiate them during training. + + Returns: + dict: A dictionary with key-value pairs for main processed file names. + """ + pass + + @property + @abstractmethod + def processed_file_names_dict(self) -> Dict[str, str]: + """ + Abstract property for defining additional processed file names. + These files are stored in the same directory as the generated data files + but have distinct names to differentiate them during training. + + Returns: + dict: A dictionary with key-value pairs for processed file names. + """ + pass + + +class DeepGO1MigratedData(_DeepGOMigratedData): + """ + Migrated data class specific to DeepGO1. Sets threshold values according + to the research paper based on the GO branch. + + Note: + Refer reference number 1 at the top of this file for the corresponding research paper. + + Args: + **kwargs: Arbitrary keyword arguments passed to the superclass. + + Raises: + ValueError: If an unsupported GO branch is provided. + """ + + def __init__(self, **kwargs): + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + assert int(kwargs.get("max_sequence_length")) == 1002 + + # Set threshold based on GO branch, as per DeepGO1 paper and its data. + if kwargs.get("go_branch") in ["CC", "MF"]: + self.THRESHOLD = 50 + elif kwargs.get("go_branch") == "BP": + self.THRESHOLD = 250 + else: + raise ValueError( + f"DeepGO1 paper has no defined threshold for branch {self.go_branch}" + ) + + super(_DeepGOMigratedData, self).__init__(**kwargs) + + @property + def processed_main_file_names_dict(self) -> Dict[str, str]: + """ + Returns main processed file names specific to DeepGO1. + + Returns: + dict: Dictionary with the main data file name for DeepGO1. + """ + return {"data": "data_deep_go1.pkl"} + + @property + def processed_file_names_dict(self) -> Dict[str, str]: + """ + Returns processed file names specific to DeepGO1. + + Returns: + dict: Dictionary with data file name for DeepGO1. + """ + return {"data": "data_deep_go1.pt"} + + +class DeepGO2MigratedData(_DeepGOMigratedData): + """ + Migrated data class specific to DeepGO2, inheriting from DeepGO1MigratedData + with different processed file names. + + Note: + Refer reference number 3 at the top of this file for the corresponding research paper. + + Returns: + dict: Dictionary with file names specific to DeepGO2. + """ + + def __init__(self, **kwargs): + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + assert int(kwargs.get("max_sequence_length")) == 1000 + + super(_DeepGOMigratedData, self).__init__(**kwargs) + + @property + def processed_main_file_names_dict(self) -> Dict[str, str]: + """ + Returns main processed file names specific to DeepGO2. + + Returns: + dict: Dictionary with the main data file name for DeepGO2. + """ + return {"data": "data_deep_go2.pkl"} + + @property + def processed_file_names_dict(self) -> Dict[str, str]: + """ + Returns processed file names specific to DeepGO2. + + Returns: + dict: Dictionary with data file name for DeepGO2. + """ + return {"data": "data_deep_go2.pt"} From 8922d4dc9c403648f6a039ac1144091383703f68 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 13 Nov 2024 23:24:07 +0100 Subject: [PATCH 17/33] deepgo: minor code change --- chebai/preprocessing/datasets/go_uniprot.py | 2 +- .../migration/deep_go/migrate_deep_go_1_data.py | 5 ++++- .../migration/deep_go/migrate_deep_go_2_data.py | 5 ++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 16bd6a31..22d13e3f 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -764,7 +764,7 @@ def _name(self) -> str: Returns: str: A formatted name string for the data. """ - threshold_part = f"GO{self.THRESHOLD}_" if self.THRESHOLD is not None else "" + threshold_part = f"GO{self.THRESHOLD}_" if self.THRESHOLD is not None else "GO_" if self.go_branch != self._ALL_GO_BRANCHES: return f"{threshold_part}{self.go_branch}_{self.max_sequence_length}" diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py index ad8ae322..d9122c75 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -104,7 +104,10 @@ def _load_data(self) -> None: ) except FileNotFoundError as e: - print(f"Error loading data: {e}") + raise FileNotFoundError( + f"Data file not found in directory: {e}. " + "Please ensure all required files are available in the specified directory." + ) @staticmethod def _get_train_val_split( diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index 3d4109e1..b24b3cfb 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -96,7 +96,10 @@ def _load_data(self) -> None: pd.read_pickle(os.path.join(self._data_dir, "terms.pkl")) ) except FileNotFoundError as e: - print(f"Error loading data: {e}") + raise FileNotFoundError( + f"Data file not found in directory: {e}. " + "Please ensure all required files are available in the specified directory." + ) def _record_splits(self) -> pd.DataFrame: """ From 796356cc3253e40eabfcc5a3d884c8bec089e086 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 13 Nov 2024 23:42:11 +0100 Subject: [PATCH 18/33] modify prints to display actual file name --- chebai/preprocessing/datasets/base.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index dfa0f999..f382f050 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -728,7 +728,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: processed_name = self.processed_main_file_names_dict["data"] if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): - print("Missing processed data file (`data.pkl` file)") + print(f"Missing processed data file (`{processed_name}` file)") os.makedirs(self.processed_dir_main, exist_ok=True) data_path = self._download_required_data() g = self._extract_class_hierarchy(data_path) @@ -812,12 +812,15 @@ def setup_processed(self) -> None: None """ os.makedirs(self.processed_dir, exist_ok=True) - print("Missing transformed data (`data.pt` file). Transforming data.... ") + processed_main_file_name = self.processed_main_file_names_dict["data"] + print( + f"Missing transformed data (`{processed_main_file_name}` file). Transforming data.... " + ) torch.save( self._load_data_from_file( os.path.join( self.processed_dir_main, - self.processed_main_file_names_dict["data"], + processed_main_file_name, ) ), os.path.join(self.processed_dir, self.processed_file_names_dict["data"]), From 3c11a690718ca743ac28d75438fa9bf9996adf84 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 17 Nov 2024 23:42:20 +0100 Subject: [PATCH 19/33] create sub dir for deego dataset and move rel files --- chebai/preprocessing/datasets/deepGO/__init__.py | 0 chebai/preprocessing/datasets/{ => deepGO}/go_uniprot.py | 0 chebai/preprocessing/datasets/{ => deepGO}/protein_pretraining.py | 0 .../preprocessing/datasets/deepGO/protein_protein_interactions.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 chebai/preprocessing/datasets/deepGO/__init__.py rename chebai/preprocessing/datasets/{ => deepGO}/go_uniprot.py (100%) rename chebai/preprocessing/datasets/{ => deepGO}/protein_pretraining.py (100%) create mode 100644 chebai/preprocessing/datasets/deepGO/protein_protein_interactions.py diff --git a/chebai/preprocessing/datasets/deepGO/__init__.py b/chebai/preprocessing/datasets/deepGO/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/deepGO/go_uniprot.py similarity index 100% rename from chebai/preprocessing/datasets/go_uniprot.py rename to chebai/preprocessing/datasets/deepGO/go_uniprot.py diff --git a/chebai/preprocessing/datasets/protein_pretraining.py b/chebai/preprocessing/datasets/deepGO/protein_pretraining.py similarity index 100% rename from chebai/preprocessing/datasets/protein_pretraining.py rename to chebai/preprocessing/datasets/deepGO/protein_pretraining.py diff --git a/chebai/preprocessing/datasets/deepGO/protein_protein_interactions.py b/chebai/preprocessing/datasets/deepGO/protein_protein_interactions.py new file mode 100644 index 00000000..e69de29b From 2b571c5f3b3d30fadc2ec77329ce1d16b70a99d1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 17 Nov 2024 23:51:14 +0100 Subject: [PATCH 20/33] update imports as per new deepGO dir --- chebai/preprocessing/datasets/deepGO/protein_pretraining.py | 2 +- .../preprocessing/migration/deep_go/migrate_deep_go_1_data.py | 2 +- .../preprocessing/migration/deep_go/migrate_deep_go_2_data.py | 2 +- tests/unit/dataset_classes/testGOUniProDataExtractor.py | 2 +- tests/unit/dataset_classes/testGoUniProtOverX.py | 2 +- tutorials/data_exploration_go.ipynb | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/chebai/preprocessing/datasets/deepGO/protein_pretraining.py b/chebai/preprocessing/datasets/deepGO/protein_pretraining.py index d2a2b6db..8f7e9c4d 100644 --- a/chebai/preprocessing/datasets/deepGO/protein_pretraining.py +++ b/chebai/preprocessing/datasets/deepGO/protein_pretraining.py @@ -12,7 +12,7 @@ from sklearn.model_selection import train_test_split from chebai.preprocessing.datasets.base import _DynamicDataset -from chebai.preprocessing.datasets.go_uniprot import ( +from chebai.preprocessing.datasets.deepGO.go_uniprot import ( AMBIGUOUS_AMINO_ACIDS, EXPERIMENTAL_EVIDENCE_CODES, GOUniProtOver250, diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py index d9122c75..7d59c699 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -6,7 +6,7 @@ from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit from jsonargparse import CLI -from chebai.preprocessing.datasets.go_uniprot import DeepGO1MigratedData +from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO1MigratedData class DeepGo1DataMigration: diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index b24b3cfb..d63bcad3 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -5,7 +5,7 @@ import pandas as pd from jsonargparse import CLI -from chebai.preprocessing.datasets.go_uniprot import DeepGO2MigratedData +from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO2MigratedData class DeepGo2DataMigration: diff --git a/tests/unit/dataset_classes/testGOUniProDataExtractor.py b/tests/unit/dataset_classes/testGOUniProDataExtractor.py index 9da48bee..96ff9a3a 100644 --- a/tests/unit/dataset_classes/testGOUniProDataExtractor.py +++ b/tests/unit/dataset_classes/testGOUniProDataExtractor.py @@ -6,7 +6,7 @@ import networkx as nx import pandas as pd -from chebai.preprocessing.datasets.go_uniprot import _GOUniProtDataExtractor +from chebai.preprocessing.datasets.deepGO.go_uniprot import _GOUniProtDataExtractor from chebai.preprocessing.reader import ProteinDataReader from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData diff --git a/tests/unit/dataset_classes/testGoUniProtOverX.py b/tests/unit/dataset_classes/testGoUniProtOverX.py index d4157770..3f329c56 100644 --- a/tests/unit/dataset_classes/testGoUniProtOverX.py +++ b/tests/unit/dataset_classes/testGoUniProtOverX.py @@ -5,7 +5,7 @@ import networkx as nx import pandas as pd -from chebai.preprocessing.datasets.go_uniprot import _GOUniProtOverX +from chebai.preprocessing.datasets.deepGO.go_uniprot import _GOUniProtOverX from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData diff --git a/tutorials/data_exploration_go.ipynb b/tutorials/data_exploration_go.ipynb index 6f67c82b..1a205e37 100644 --- a/tutorials/data_exploration_go.ipynb +++ b/tutorials/data_exploration_go.ipynb @@ -70,7 +70,7 @@ } }, "outputs": [], - "source": "from chebai.preprocessing.datasets.go_uniprot import GOUniProtOver250" + "source": "from chebai.preprocessing.datasets.deepGO.go_uniprot import GOUniProtOver250" }, { "cell_type": "code", From f75e30bcbbc3c3a7d916fa30ddee8fa34af1c486 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 17 Nov 2024 23:54:53 +0100 Subject: [PATCH 21/33] update import dir for pretrain test --- tests/unit/dataset_classes/testProteinPretrainingData.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/dataset_classes/testProteinPretrainingData.py b/tests/unit/dataset_classes/testProteinPretrainingData.py index cb6b0688..caac3eac 100644 --- a/tests/unit/dataset_classes/testProteinPretrainingData.py +++ b/tests/unit/dataset_classes/testProteinPretrainingData.py @@ -1,7 +1,9 @@ import unittest from unittest.mock import PropertyMock, mock_open, patch -from chebai.preprocessing.datasets.protein_pretraining import _ProteinPretrainingData +from chebai.preprocessing.datasets.deepGO.protein_pretraining import ( + _ProteinPretrainingData, +) from chebai.preprocessing.reader import ProteinDataReader from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData From 1b8b270c4b4ec99d81739c80ca658c9f7696da10 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 4 Dec 2024 12:06:11 +0100 Subject: [PATCH 22/33] migration fix : truncate seq and save data with labels --- .../deep_go/migrate_deep_go_2_data.py | 62 +++++++++++++++---- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index d63bcad3..1edec52b 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -20,17 +20,19 @@ class DeepGo2DataMigration: (https://doi.org/10.1093/bioinformatics/btx624) """ - # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 - _MAXLEN = 1000 _LABELS_START_IDX = DeepGO2MigratedData._LABELS_START_IDX - def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + def __init__( + self, data_dir: str, go_branch: Literal["cc", "mf", "bp"], max_len: int = 1000 + ): """ Initializes the data migration object with a data directory and GO branch. Args: data_dir (str): Directory containing the data files. go_branch (Literal["cc", "mf", "bp"]): GO branch to use. + max_len (int): Used to truncate the sequence to this length. Default is 1000. + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 """ valid_go_branches = list(DeepGO2MigratedData.GO_BRANCH_MAPPING.keys()) if go_branch not in valid_go_branches: @@ -38,6 +40,8 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): self._go_branch = go_branch self._data_dir: str = os.path.join(rf"{data_dir}", go_branch) + self._max_len: int = max_len + self._train_df: Optional[pd.DataFrame] = None self._test_df: Optional[pd.DataFrame] = None self._validation_df: Optional[pd.DataFrame] = None @@ -74,33 +78,61 @@ def migrate(self) -> None: "Data splits or terms data is not available in instance variables." ) - self.save_migrated_data(data_df, splits_df) + self.save_migrated_data(data_with_labels_df, splits_df) def _load_data(self) -> None: """ Loads the test, train, validation, and terms data from the pickled files in the data directory. """ + try: print(f"Loading data from directory: {self._data_dir}......") - self._test_df = pd.DataFrame( - pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl")) + self._test_df = self._truncate_sequences( + pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl")) + ) ) - self._train_df = pd.DataFrame( - pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl")) + self._train_df = self._truncate_sequences( + pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl")) + ) ) - self._validation_df = pd.DataFrame( - pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl")) + self._validation_df = self._truncate_sequences( + pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl")) + ) ) + self._terms_df = pd.DataFrame( pd.read_pickle(os.path.join(self._data_dir, "terms.pkl")) ) + except FileNotFoundError as e: raise FileNotFoundError( f"Data file not found in directory: {e}. " "Please ensure all required files are available in the specified directory." ) + def _truncate_sequences( + self, df: pd.DataFrame, column: str = "sequences" + ) -> pd.DataFrame: + """ + Truncate sequences in a specified column of a dataframe to the maximum length. + + https://github.com/bio-ontology-research-group/deepgo2/blob/main/train_cnn.py#L206-L217 + + Args: + df (pd.DataFrame): The input dataframe containing the data to be processed. + column (str, optional): The column containing sequences to truncate. + Defaults to "sequences". + + Returns: + pd.DataFrame: The dataframe with sequences truncated to `self._max_len`. + """ + df[column] = df[column].apply(lambda x: x[: self._max_len]) + return df + def _record_splits(self) -> pd.DataFrame: """ Creates a DataFrame that stores the IDs and their corresponding data splits. @@ -217,7 +249,7 @@ def save_migrated_data( print("Saving transformed data......") deepgo_migr_inst: DeepGO2MigratedData = DeepGO2MigratedData( go_branch=DeepGO2MigratedData.GO_BRANCH_MAPPING[self._go_branch], - max_sequence_length=self._MAXLEN, + max_sequence_length=self._max_len, ) # Save data file @@ -257,7 +289,9 @@ class Main: """ @staticmethod - def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None: + def migrate( + data_dir: str, go_branch: Literal["cc", "mf", "bp"], max_len: int = 1000 + ) -> None: """ Initiates the migration process by creating a DeepGoDataMigration instance and invoking its migrate method. @@ -268,8 +302,10 @@ def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None: ("cc" for cellular_component, "mf" for molecular_function, or "bp" for biological_process). + max_len (int): Used to truncate the sequence to this length. Default is 1000. + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 """ - DeepGo2DataMigration(data_dir, go_branch).migrate() + DeepGo2DataMigration(data_dir, go_branch, max_len).migrate() if __name__ == "__main__": From bcda11ca7517c4e60456303bc98418f345ca6f08 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 4 Dec 2024 12:37:56 +0100 Subject: [PATCH 23/33] Delete protein_protein_interactions.py --- .../preprocessing/datasets/deepGO/protein_protein_interactions.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 chebai/preprocessing/datasets/deepGO/protein_protein_interactions.py diff --git a/chebai/preprocessing/datasets/deepGO/protein_protein_interactions.py b/chebai/preprocessing/datasets/deepGO/protein_protein_interactions.py deleted file mode 100644 index e69de29b..00000000 From 85c47a05aa36a2bde9f07fca71f8838fe8fd5e96 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 4 Dec 2024 15:55:57 +0100 Subject: [PATCH 24/33] migration: replace invalid amino acid with "X" notation - https://github.com/ChEB-AI/python-chebai/pull/64#issuecomment-2517067073 --- .../deep_go/migrate_deep_go_2_data.py | 55 ++++++++++++++++++- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index 1edec52b..0bb07914 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -1,4 +1,5 @@ import os +import re from collections import OrderedDict from typing import List, Literal, Optional @@ -6,6 +7,7 @@ from jsonargparse import CLI from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO2MigratedData +from chebai.preprocessing.reader import ProteinDataReader class DeepGo2DataMigration: @@ -88,17 +90,25 @@ def _load_data(self) -> None: try: print(f"Loading data from directory: {self._data_dir}......") - self._test_df = self._truncate_sequences( + + print( + "Pre-processing the data before loading them into instance variables\n" + f"2-Steps preprocessing: \n" + f"\t 1: Truncating every sequence to {self._max_len}\n" + f"\t 2: Replacing every amino acid which is not in {ProteinDataReader.AA_LETTER}" + ) + + self._test_df = self._pre_process_data( pd.DataFrame( pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl")) ) ) - self._train_df = self._truncate_sequences( + self._train_df = self._pre_process_data( pd.DataFrame( pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl")) ) ) - self._validation_df = self._truncate_sequences( + self._validation_df = self._pre_process_data( pd.DataFrame( pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl")) ) @@ -114,6 +124,21 @@ def _load_data(self) -> None: "Please ensure all required files are available in the specified directory." ) + def _pre_process_data(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Pre-processes the input dataframe by truncating sequences to the maximum + length and replacing invalid amino acids with 'X'. + + Args: + df (pd.DataFrame): The dataframe to preprocess. + + Returns: + pd.DataFrame: The processed dataframe. + """ + df = self._truncate_sequences(df) + df = self._replace_invalid_amino_acids(df) + return df + def _truncate_sequences( self, df: pd.DataFrame, column: str = "sequences" ) -> pd.DataFrame: @@ -133,6 +158,30 @@ def _truncate_sequences( df[column] = df[column].apply(lambda x: x[: self._max_len]) return df + @staticmethod + def _replace_invalid_amino_acids( + df: pd.DataFrame, column: str = "sequences" + ) -> pd.DataFrame: + """ + Replaces invalid amino acids in a sequence with 'X' using regex. + + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L26-L33 + https://github.com/ChEB-AI/python-chebai/pull/64#issuecomment-2517067073 + + Args: + df (pd.DataFrame): The dataframe containing the sequences to be processed. + column (str, optional): The column containing the sequences. Defaults to "sequences". + + Returns: + pd.DataFrame: The dataframe with invalid amino acids replaced by 'X'. + """ + valid_amino_acids = "".join(ProteinDataReader.AA_LETTER) + # Replace any character not in the valid set with 'X' + df[column] = df[column].apply( + lambda x: re.sub(f"[^{valid_amino_acids}]", "X", x) + ) + return df + def _record_splits(self) -> pd.DataFrame: """ Creates a DataFrame that stores the IDs and their corresponding data splits. From fbb5c58064e2171964290d0a7d6d7f1d3da35173 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 4 Dec 2024 16:41:22 +0100 Subject: [PATCH 25/33] update deepgo configs --- configs/data/deepGO/deepgo_1_migrated_data.yml | 4 ++++ configs/data/deepGO/deepgo_2_migrated_data.yml | 4 ++++ configs/data/deepGO/go250.yml | 3 +++ configs/data/deepGO/go50.yml | 1 + configs/data/go250.yml | 3 --- configs/data/go50.yml | 1 - 6 files changed, 12 insertions(+), 4 deletions(-) create mode 100644 configs/data/deepGO/deepgo_1_migrated_data.yml create mode 100644 configs/data/deepGO/deepgo_2_migrated_data.yml create mode 100644 configs/data/deepGO/go250.yml create mode 100644 configs/data/deepGO/go50.yml delete mode 100644 configs/data/go250.yml delete mode 100644 configs/data/go50.yml diff --git a/configs/data/deepGO/deepgo_1_migrated_data.yml b/configs/data/deepGO/deepgo_1_migrated_data.yml new file mode 100644 index 00000000..0924e023 --- /dev/null +++ b/configs/data/deepGO/deepgo_1_migrated_data.yml @@ -0,0 +1,4 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO1MigratedData +init_args: + go_branch: "MF" + max_sequence_length: 1002 diff --git a/configs/data/deepGO/deepgo_2_migrated_data.yml b/configs/data/deepGO/deepgo_2_migrated_data.yml new file mode 100644 index 00000000..1ed2ad09 --- /dev/null +++ b/configs/data/deepGO/deepgo_2_migrated_data.yml @@ -0,0 +1,4 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData +init_args: + go_branch: "MF" + max_sequence_length: 1000 diff --git a/configs/data/deepGO/go250.yml b/configs/data/deepGO/go250.yml new file mode 100644 index 00000000..01e34aa4 --- /dev/null +++ b/configs/data/deepGO/go250.yml @@ -0,0 +1,3 @@ +class_path: chebai.preprocessing.datasets.go_uniprot.deepGO.GOUniProtOver250 +init_args: + go_branch: "BP" diff --git a/configs/data/deepGO/go50.yml b/configs/data/deepGO/go50.yml new file mode 100644 index 00000000..bee43773 --- /dev/null +++ b/configs/data/deepGO/go50.yml @@ -0,0 +1 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.GOUniProtOver50 diff --git a/configs/data/go250.yml b/configs/data/go250.yml deleted file mode 100644 index 5598495c..00000000 --- a/configs/data/go250.yml +++ /dev/null @@ -1,3 +0,0 @@ -class_path: chebai.preprocessing.datasets.go_uniprot.GOUniProtOver250 -init_args: - go_branch: "BP" diff --git a/configs/data/go50.yml b/configs/data/go50.yml deleted file mode 100644 index 2ed4d14c..00000000 --- a/configs/data/go50.yml +++ /dev/null @@ -1 +0,0 @@ -class_path: chebai.preprocessing.datasets.go_uniprot.GOUniProtOver50 From 272446db7a5dd0f2aa08de6e96fd9a6d11a0e3d2 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 9 Dec 2024 13:04:30 +0100 Subject: [PATCH 26/33] add esm2 reader for deepGO --- chebai/preprocessing/reader.py | 257 ++++++++++++++++++++++++++++++++- setup.py | 1 + 2 files changed, 257 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index a08a3f91..dff2ff51 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -1,8 +1,18 @@ import os -from typing import Any, Dict, List, Optional +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +from urllib.error import HTTPError import deepsmiles import selfies as sf +import torch +from esm import Alphabet +from esm.model.esm2 import ESM2 +from esm.pretrained import ( + _has_regression_weights, + load_model_and_alphabet_core, + load_model_and_alphabet_local, +) from pysmiles.read_smiles import _tokenize from transformers import RobertaTokenizerFast @@ -471,3 +481,248 @@ def on_finish(self) -> None: print(f"Saving {len(self.cache)} tokens to {self.token_path}...") print(f"First 10 tokens: {self.cache[:10]}") pk.writelines([f"{c}\n" for c in self.cache]) + + +class ESM2EmbeddingReader(DataReader): + """ + A data reader to process protein sequences using the ESM2 model for embeddings. + + References: + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/extract_esm.py + + Note: + For layer availability by model, Please check below link: + https://github.com/facebookresearch/esm?tab=readme-ov-file#pre-trained-models- + + To test this reader, try lighter models: + esm2_t6_8M_UR50D: 6 layers (valid layers: 1–6), (~28 Mb) - A tiny 8M parameter model. + esm2_t12_35M_UR50D: 12 layers (valid layers: 1–12), (~128 Mb) - A slightly larger, 35M parameter model. + These smaller models are good for testing and debugging purposes. + + """ + + # https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L53 + _MODELS_URL = "https://dl.fbaipublicfiles.com/fair-esm/models/{}.pt" + _REGRESSION_URL = ( + "https://dl.fbaipublicfiles.com/fair-esm/regression/{}-contact-regression.pt" + ) + + def __init__( + self, + save_model_dir: str, + model_name: str = "esm2_t36_3B_UR50D", + device: Optional[torch.device] = None, + truncation_length: int = 1022, + toks_per_batch: int = 4096, + return_contacts: bool = False, + repr_layer: int = 36, + *args, + **kwargs, + ): + """ + Initialize the ESM2EmbeddingReader class. + + Args: + save_model_dir (str): Directory to save/load the pretrained ESM model. + model_name (str): Name of the pretrained model. Defaults to "esm2_t36_3B_UR50D". + device (torch.device or str, optional): Device for computation (e.g., 'cpu', 'cuda'). + truncation_length (int): Maximum sequence length for truncation. Defaults to 1022. + toks_per_batch (int): Tokens per batch for data processing. Defaults to 4096. + return_contacts (bool): Whether to return contact maps. Defaults to False. + repr_layers (int): Layer number to extract representations from. Defaults to 36. + """ + self.save_model_dir = save_model_dir + if not os.path.exists(self.save_model_dir): + os.makedirs((os.path.dirname(self.save_model_dir)), exist_ok=True) + self.model_name = model_name + self.device = device + self.truncation_length = truncation_length + self.toks_per_batch = toks_per_batch + self.return_contacts = return_contacts + self.repr_layer = repr_layer + + self._model: Optional[ESM2] = None + self._alphabet: Optional[Alphabet] = None + + self._model, self._alphabet = self.load_model_and_alphabet() + self._model.eval() + + if self.device: + self._model = self._model.to(device) + + super().__init__(*args, **kwargs) + + def load_model_and_alphabet(self) -> Tuple[ESM2, Alphabet]: + """ + Load the ESM2 model and its alphabet. + + References: + https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L24-L28 + + Returns: + Tuple[ESM2, Alphabet]: Loaded model and alphabet. + """ + model_location = os.path.join(self.save_model_dir, f"{self.model_name}.pt") + if os.path.exists(model_location): + return load_model_and_alphabet_local(model_location) + else: + return self.load_model_and_alphabet_hub() + + def load_model_and_alphabet_hub(self) -> Tuple[ESM2, Alphabet]: + """ + Load the model and alphabet from the hub URL. + + References: + https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L62-L64 + + Returns: + Tuple[ESM2, Alphabet]: Loaded model and alphabet. + """ + model_url = self._MODELS_URL.format(self.model_name) + model_data = self.load_hub_workaround(model_url) + regression_data = None + if _has_regression_weights(self.model_name): + regression_url = self._REGRESSION_URL.format(self.model_name) + regression_data = self.load_hub_workaround(regression_url) + return load_model_and_alphabet_core( + self.model_name, model_data, regression_data + ) + + def load_hub_workaround(self, url) -> torch.Tensor: + """ + Workaround to load models from the PyTorch Hub. + + References: + https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L31-L43 + + Returns: + torch.Tensor: Loaded model state dictionary. + """ + try: + data = torch.hub.load_state_dict_from_url( + url, self.save_model_dir, progress=True, map_location=self.device + ) + + except RuntimeError: + # Handle PyTorch version issues + fn = Path(url).name + data = torch.load( + f"{torch.hub.get_dir()}/checkpoints/{fn}", + map_location="cpu", + ) + except HTTPError as e: + raise Exception( + f"Could not load {url}. Did you specify the correct model name?" + ) + return data + + def name(self) -> None: + """ + Returns the name of the data reader. This method identifies the specific type of data reader. + + Returns: + str: The name of the data reader, which is "protein_token". + """ + return "esm2_embedding" + + @property + def token_path(self) -> None: + """ + Not used as no token file is not created for this reader. + + Returns: + str: Empty string since this method is not implemented. + """ + return + + def _read_data(self, raw_data: str) -> List[int]: + """ + Reads protein sequence data and generates embeddings. + + Args: + raw_data (str): The protein sequence. + + Returns: + List[int]: Embeddings generated for the sequence. + """ + alp_tokens_idx = self._sequence_to_alphabet_tokens_idx(raw_data) + return self._alphabet_tokens_to_esm_embedding(alp_tokens_idx).tolist() + + def _sequence_to_alphabet_tokens_idx(self, sequence: str) -> torch.Tensor: + """ + Converts a protein sequence into ESM alphabet token indices. + + Args: + sequence (str): Protein sequence. + + References: + https://github.com/facebookresearch/esm/blob/2b369911bb5b4b0dda914521b9475cad1656b2ac/esm/data.py#L249-L250 + https://github.com/facebookresearch/esm/blob/2b369911bb5b4b0dda914521b9475cad1656b2ac/esm/data.py#L262-L297 + + Returns: + torch.Tensor: Tokenized sequence with special tokens (BOS/EOS) included. + """ + seq_encoded = self._alphabet.encode(sequence) + tokens = [] + + # Add BOS token if configured + if self._alphabet.prepend_bos: + tokens.append(self._alphabet.cls_idx) + + # Add the main sequence + tokens.extend(seq_encoded) + + # Add EOS token if configured + if self._alphabet.append_eos: + tokens.append(self._alphabet.eos_idx) + + # Convert to PyTorch tensor and return + return torch.tensor([tokens], dtype=torch.int64) + + def _alphabet_tokens_to_esm_embedding(self, tokens: torch.Tensor) -> torch.Tensor: + """ + Converts alphabet tokens into ESM embeddings. + + Args: + tokens (torch.Tensor): Tokenized protein sequences. + + References: + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/extract_esm.py#L82-L107 + + Returns: + torch.Tensor: Protein embedding from the specified representation layer. + """ + if self.device: + tokens = tokens.to(self.device, non_blocking=True) + + with torch.no_grad(): + out = self._model( + tokens, + repr_layers=[ + self.repr_layer, + ], + return_contacts=self.return_contacts, + ) + + # Extract representations and compute the mean embedding for each layer + representations = { + layer: t.to(self.device) for layer, t in out["representations"].items() + } + truncate_len = min(self.truncation_length, tokens.size(1)) + + result = { + "mean_representations": { + layer: t[0, 1 : truncate_len + 1].mean(0).clone() + for layer, t in representations.items() + } + } + return result["mean_representations"][self.repr_layer] + + def on_finish(self) -> None: + """ + Not used here as no token file exists for this reader. + + Returns: + None + """ + pass diff --git a/setup.py b/setup.py index 58bfc75b..ba134e41 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ "pyyaml", "torchmetrics", "biopython", + "fair-esm", ], extras_require={"dev": ["black", "isort", "pre-commit"]}, ) From a12354b527f670da28ac6b8f200b659d4d67ab43 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 9 Dec 2024 15:03:03 +0100 Subject: [PATCH 27/33] increase electra vocab size --- chebai/models/electra.py | 2 +- configs/model/electra.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 7009406d..dc6c719b 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -329,7 +329,7 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: except RuntimeError as e: print(f"RuntimeError at forward: {e}") print(f'data[features]: {data["features"]}') - raise Exception + raise e inp = self.word_dropout(inp) electra = self.electra(inputs_embeds=inp, **kwargs) d = electra.last_hidden_state[:, 0, :] diff --git a/configs/model/electra.yml b/configs/model/electra.yml index c3cf2fdf..ade89acd 100644 --- a/configs/model/electra.yml +++ b/configs/model/electra.yml @@ -3,7 +3,7 @@ init_args: optimizer_kwargs: lr: 1e-3 config: - vocab_size: 1400 + vocab_size: 8500 max_position_embeddings: 1800 num_attention_heads: 8 num_hidden_layers: 6 From 66732a7cf5e9e8f0f2f338848a00333bb0375ec4 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 9 Dec 2024 21:53:06 +0100 Subject: [PATCH 28/33] fix: print right name of missing file --- chebai/preprocessing/datasets/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index f382f050..fc64c808 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -812,18 +812,18 @@ def setup_processed(self) -> None: None """ os.makedirs(self.processed_dir, exist_ok=True) - processed_main_file_name = self.processed_main_file_names_dict["data"] + transformed_file_name = self.processed_file_names_dict["data"] print( - f"Missing transformed data (`{processed_main_file_name}` file). Transforming data.... " + f"Missing transformed data (`{transformed_file_name}` file). Transforming data.... " ) torch.save( self._load_data_from_file( os.path.join( self.processed_dir_main, - processed_main_file_name, + self.processed_main_file_names_dict["data"], ) ), - os.path.join(self.processed_dir, self.processed_file_names_dict["data"]), + os.path.join(self.processed_dir, transformed_file_name), ) @staticmethod From e7b3d800da1f3ae2aeb17e9202f1a2d45e1a5083 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 9 Dec 2024 21:56:35 +0100 Subject: [PATCH 29/33] migration : add esm2 embeddings - modify deepgo2 migration script to migrate the esm2 embeddings too - modify migration class to use esm2 embeddings or reader features, based on input --- .../datasets/deepGO/go_uniprot.py | 95 ++++++++++++++++++- .../deep_go/migrate_deep_go_2_data.py | 2 + chebai/preprocessing/reader.py | 3 +- 3 files changed, 97 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/datasets/deepGO/go_uniprot.py b/chebai/preprocessing/datasets/deepGO/go_uniprot.py index 22d13e3f..3c957e6c 100644 --- a/chebai/preprocessing/datasets/deepGO/go_uniprot.py +++ b/chebai/preprocessing/datasets/deepGO/go_uniprot.py @@ -40,6 +40,7 @@ import pandas as pd import requests import torch +import tqdm from Bio import SwissProt from chebai.preprocessing import reader as dr @@ -892,12 +893,95 @@ class DeepGO2MigratedData(_DeepGOMigratedData): dict: Dictionary with file names specific to DeepGO2. """ - def __init__(self, **kwargs): + _LABELS_START_IDX: int = 5 # additional esm2_embeddings column in the dataframe + _ESM_EMBEDDINGS_COL_IDX: int = 4 + + def __init__(self, use_esm2_embeddings=False, **kwargs): # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 assert int(kwargs.get("max_sequence_length")) == 1000 - + self.use_esm2_embeddings: bool = use_esm2_embeddings super(_DeepGOMigratedData, self).__init__(**kwargs) + # ------------------------------ Phase: Setup data ----------------------------------- + def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: + """ + Load and process data from a file into a list of dictionaries containing features and labels. + + This method processes data differently based on the `use_esm2_embeddings` flag: + - If `use_esm2_embeddings` is True, raw dictionaries from `_load_dict` are returned, _load_dict already returns + the numerical features (esm2 embeddings) from the data file, hence no reader is required. + - Otherwise, a reader is used to process the data (generate numerical features). + + Args: + path (str): The path to the input file. + + Returns: + List[Dict[str, Any]]: A list of dictionaries with the following keys: + - `features`: Sequence or embedding data, depending on the context. + - `labels`: A boolean array of labels. + - `ident`: The identifier for the sequence. + """ + lines = self._get_data_size(path) + print(f"Processing {lines} lines...") + + if self.use_esm2_embeddings: + data = [ + d + for d in tqdm.tqdm(self._load_dict(path), total=lines) + if d["features"] is not None + ] + else: + data = [ + self.reader.to_data(d) + for d in tqdm.tqdm(self._load_dict(path), total=lines) + if d["features"] is not None + ] + + # filter for missing features in resulting data + data = [val for val in data if val["features"] is not None] + + return data + + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + """ + Loads data from a pickled file and yields individual dictionaries for each row. + + The pickled file is expected to contain rows with the following structure: + - Data at row index `self._ID_IDX`: ID of go data instance + - Data at row index `self._DATA_REPRESENTATION_IDX`: Sequence representation of protein + - Data at row index `self._ESM2_EMBEDDINGS_COL_IDX`: ESM2 embeddings of the protein + - Data from row index `self._LABELS_START_IDX` onwards: Labels + + The method adapts based on the `use_esm2_embeddings` flag: + - If `use_esm2_embeddings` is True, features are loaded from the column specified by `self._ESM_EMBEDDINGS_COL_IDX`. + - Otherwise, features are loaded from the column specified by `self._DATA_REPRESENTATION_IDX`. + + Args: + input_file_path (str): The path to the pickled input file. + + Yields: + Dict[str, Any]: A dictionary containing: + - `features` (Any): Sequence or embedding data for the instance. + - `labels` (np.ndarray): A boolean array of labels starting from row index 4. + - `ident` (Any): The identifier from row index 0. + """ + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + + if self.use_esm2_embeddings: + features_idx = self._ESM_EMBEDDINGS_COL_IDX + else: + features_idx = self._DATA_REPRESENTATION_IDX + + for row in df.values: + labels = row[self._LABELS_START_IDX :].astype(bool) + yield dict( + features=row[features_idx], + labels=labels, + ident=row[self._ID_IDX], + ) + + # ------------------------------ Phase: Raw Properties ----------------------------------- @property def processed_main_file_names_dict(self) -> Dict[str, str]: """ @@ -917,3 +1001,10 @@ def processed_file_names_dict(self) -> Dict[str, str]: dict: Dictionary with data file name for DeepGO2. """ return {"data": "data_deep_go2.pt"} + + @property + def identifier(self) -> tuple: + """Identifier for the dataset.""" + if self.use_esm2_embeddings: + return (dr.ESM2EmbeddingReader.name(),) + return (self.reader.name(),) diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index 0bb07914..68d7dc78 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -217,6 +217,7 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: "exp_annotations", # Directly associated GO ids # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69 "prop_annotations", # Transitively associated GO ids + "esm2", ] new_df = pd.concat( @@ -239,6 +240,7 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: accession=new_df["accessions"], go_ids=new_df["go_ids"], sequence=new_df["sequences"], + esm2_embeddings=new_df["esm2"], ) ) return data_df diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index dff2ff51..88e4fedd 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -616,7 +616,8 @@ def load_hub_workaround(self, url) -> torch.Tensor: ) return data - def name(self) -> None: + @staticmethod + def name() -> None: """ Returns the name of the data reader. This method identifies the specific type of data reader. From 862c8ef5743f3711c92c9b922cd269203940936e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 5 Jan 2025 17:07:27 +0100 Subject: [PATCH 30/33] scope dataset: add scope abstract code --- .../preprocessing/datasets/scope/__init__.py | 0 chebai/preprocessing/datasets/scope/scope.py | 381 ++++++++++++++++++ 2 files changed, 381 insertions(+) create mode 100644 chebai/preprocessing/datasets/scope/__init__.py create mode 100644 chebai/preprocessing/datasets/scope/scope.py diff --git a/chebai/preprocessing/datasets/scope/__init__.py b/chebai/preprocessing/datasets/scope/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py new file mode 100644 index 00000000..a987f53d --- /dev/null +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -0,0 +1,381 @@ +import gzip +import itertools +import os +import pickle +import shutil +from abc import ABC +from collections import OrderedDict +from tempfile import NamedTemporaryFile +from typing import Any, Dict, Generator, List, Optional, Tuple, Union + +import fastobo +import networkx as nx +import pandas as pd +import requests +import torch +from Bio import SeqIO +from Bio.Seq import Seq + +from chebai.preprocessing.datasets.base import _DynamicDataset +from chebai.preprocessing.reader import ProteinDataReader + + +class _SCOPeDataExtractor(_DynamicDataset, ABC): + """ + A class for extracting and processing data from the Gene Ontology (GO) dataset and the Swiss UniProt dataset. + + Args: + dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. + splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. + max_sequence_length (int, optional): Specifies the maximum allowed sequence length for a protein, with a + default of 1002. During data preprocessing, any proteins exceeding this length will be excluded from further + processing. + **kwargs: Additional keyword arguments passed to DynamicDataset and XYBaseDataModule. + """ + + _GO_DATA_INIT = "GO" + _SWISS_DATA_INIT = "SWISS" + + # -- Index for columns of processed `data.pkl` (derived from `_get_swiss_to_go_mapping` & `_graph_to_raw_dataset` + # "swiss_id" at row index 0 + # "accession" at row index 1 + # "go_ids" at row index 2 + # "sequence" at row index 3 + # labels starting from row index 4 + _ID_IDX: int = 0 + _DATA_REPRESENTATION_IDX: int = 3 # here `sequence` column + _LABELS_START_IDX: int = 4 + + _SCOPE_GENERAL_URL = "https://scop.berkeley.edu/downloads/parse/dir.{data_type}.scope.{version_number}-stable.txt" + _PDB_SEQUENCE_DATA_URL = ( + "https://files.rcsb.org/pub/pdb/derived_data/pdb_seqres.txt.gz" + ) + + def __init__( + self, + scope_version: float, + scope_version_train: Optional[float] = None, + **kwargs, + ): + + self.scope_version: float = scope_version + self.scope_version_train: float = scope_version_train + + super(_SCOPeDataExtractor, self).__init__(**kwargs) + + if self.scope_version_train is not None: + # Instantiate another same class with "scope_version" as "scope_version_train", if train_version is given + # This is to get the data from respective directory related to "scope_version_train" + _init_kwargs = kwargs + _init_kwargs["chebi_version"] = self.scope_version_train + self._scope_version_train_obj = self.__class__( + **_init_kwargs, + ) + + @staticmethod + def _get_scope_url(data_type: str, version_number: float) -> str: + """ + Generates the URL for downloading SCOPe files. + + Args: + data_type (str): The type of data (e.g., 'cla', 'hie', 'des'). + version_number (str): The version of the SCOPe file. + + Returns: + str: The formatted SCOPe file URL. + """ + return _SCOPeDataExtractor._SCOPE_GENERAL_URL.format( + data_type=data_type, version_number=version_number + ) + + # ------------------------------ Phase: Prepare data ----------------------------------- + def _download_required_data(self) -> str: + """ + Downloads the required raw data related to Gene Ontology (GO) and Swiss-UniProt dataset. + + Returns: + str: Path to the downloaded data. + """ + self._download_pdb_sequence_data() + return self._download_scope_raw_data() + + def _download_pdb_sequence_data(self) -> None: + pdb_seq_file_path = os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]) + os.makedirs(os.path.dirname(pdb_seq_file_path), exist_ok=True) + + if not os.path.isfile(pdb_seq_file_path): + print(f"Downloading PDB sequence data....") + + # Create a temporary file + with NamedTemporaryFile(delete=False) as tf: + temp_filename = tf.name + print(f"Downloading to temporary file {temp_filename}") + + # Download the file + response = requests.get(self._PDB_SEQUENCE_DATA_URL, stream=True) + with open(temp_filename, "wb") as temp_file: + shutil.copyfileobj(response.raw, temp_file) + + print(f"Downloaded to {temp_filename}") + + # Unpack the gzipped file + try: + print(f"Unzipping the file....") + with gzip.open(temp_filename, "rb") as f_in: + output_file_path = pdb_seq_file_path + with open(output_file_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + print(f"Unpacked and saved to {output_file_path}") + + except Exception as e: + print(f"Failed to unpack the file: {e}") + finally: + # Clean up the temporary file + os.remove(temp_filename) + print(f"Removed temporary file {temp_filename}") + + def _download_scope_raw_data(self) -> str: + os.makedirs(self.raw_dir, exist_ok=True) + for data_type in ["CLA", "COM", "HIE", "DES"]: + data_file_name = self.raw_file_names_dict[data_type] + scope_path = os.path.join(self.raw_dir, data_file_name) + if not os.path.isfile(scope_path): + print(f"Missing Scope: {data_file_name} raw data, Downloading...") + r = requests.get( + self._get_scope_url(data_type.lower(), self.scope_version), + allow_redirects=False, + verify=False, # Disable SSL verification + ) + r.raise_for_status() # Check if the request was successful + open(scope_path, "wb").write(r.content) + return "dummy/path" + + def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]: + pdb_chain_seq_mapping: Dict[str, Dict[str, str]] = {} + for record in SeqIO.parse( + os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]), "fasta" + ): + pdb_id, chain = record.id.split("_") + pdb_chain_seq_mapping.setdefault(pdb_id, {})[chain] = str(record.seq) + return pdb_chain_seq_mapping + + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + print("Extracting class hierarchy...") + + # Load and preprocess CLA file + df_cla = pd.read_csv( + os.path.join(self.raw_dir, self.raw_file_names_dict["CLA"]), + sep="\t", + header=None, + comment="#", + ) + df_cla.columns = [ + "sid", + "PDB_ID", + "description", + "sccs", + "sunid", + "ancestor_nodes", + ] + df_cla["sunid"] = pd.to_numeric( + df_cla["sunid"], errors="coerce", downcast="integer" + ) + df_cla["ancestor_nodes"] = df_cla["ancestor_nodes"].apply( + lambda x: {k: int(v) for k, v in (item.split("=") for item in x.split(","))} + ) + df_cla.set_index("sunid", inplace=True) + + # Load and preprocess HIE file + df_hie = pd.read_csv( + os.path.join(self.raw_dir, self.raw_file_names_dict["HIE"]), + sep="\t", + header=None, + comment="#", + ) + df_hie.columns = ["sunid", "parent_sunid", "children_sunids"] + df_hie["sunid"] = pd.to_numeric( + df_hie["sunid"], errors="coerce", downcast="integer" + ) + df_hie["parent_sunid"] = df_hie["parent_sunid"].replace("-", -1).astype(int) + df_hie["children_sunids"] = df_hie["children_sunids"].apply( + lambda x: list(map(int, x.split(","))) if x != "-" else [] + ) + + # Initialize directed graph + g = nx.DiGraph() + + # Add nodes and edges efficiently + g.add_edges_from( + df_hie[df_hie["parent_sunid"] != -1].apply( + lambda row: (row["parent_sunid"], row["sunid"]), axis=1 + ) + ) + g.add_edges_from( + df_hie.explode("children_sunids") + .dropna() + .apply(lambda row: (row["sunid"], row["children_sunids"]), axis=1) + ) + + pdb_chain_seq_mapping = self._parse_pdb_sequence_file() + + node_to_pdb_id = df_cla["PDB_ID"].to_dict() + + for node in g.nodes(): + pdb_id = node_to_pdb_id[node] + chain_mapping = pdb_chain_seq_mapping.get(pdb_id, {}) + + # Add nodes and edges for chains in the mapping + for chain, sequence in chain_mapping.items(): + chain_node = f"{pdb_id}_{chain}" + g.add_node(chain_node, sequence=sequence) + g.add_edge(node, chain_node) + + print("Compute transitive closure...") + return nx.transitive_closure_dag(g) + + def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: + """ + Processes a directed acyclic graph (DAG) to create a raw dataset in DataFrame format. The dataset includes + Swiss-Prot protein data and their associations with Gene Ontology (GO) terms. + + Note: + - GO classes are used as labels in the dataset. Each GO term is represented as a column, and its value + indicates whether a Swiss-Prot protein is associated with that GO term. + - Swiss-Prot proteins serve as samples. There is no 1-to-1 correspondence between Swiss-Prot proteins + and GO terms. + + Data Format: pd.DataFrame + - Column 0 : swiss_id (Identifier for SwissProt protein) + - Column 1 : Accession of the protein + - Column 2 : GO IDs (associated GO terms) + - Column 3 : Sequence of the protein + - Column 4 to Column "n": Each column corresponding to a class with value True/False indicating whether the + protein is associated with this GO term. + + Args: + g (nx.DiGraph): The class hierarchy graph. + + Returns: + pd.DataFrame: The raw dataset created from the graph. + """ + print(f"Processing graph") + + data_df = self._get_swiss_to_go_mapping() + # add ancestors to go ids + data_df["go_ids"] = data_df["go_ids"].apply( + lambda go_ids: sorted( + set( + itertools.chain.from_iterable( + [ + [go_id] + list(g.predecessors(go_id)) + for go_id in go_ids + if go_id in g.nodes + ] + ) + ) + ) + ) + # Initialize the GO term labels/columns to False + selected_classes = self.select_classes(g, data_df=data_df) + new_label_columns = pd.DataFrame( + False, index=data_df.index, columns=selected_classes + ) + data_df = pd.concat([data_df, new_label_columns], axis=1) + + # Set True for the corresponding GO IDs in the DataFrame go labels/columns + for index, row in data_df.iterrows(): + for go_id in row["go_ids"]: + if go_id in data_df.columns: + data_df.at[index, go_id] = True + + # This filters the DataFrame to include only the rows where at least one value in the row from 5th column + # onwards is True/non-zero. + # Quote from DeepGo Paper: `For training and testing, we use proteins which have been annotated with at least + # one GO term from the set of the GO terms for the model` + data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] + return data_df + + # ------------------------------ Phase: Setup data ----------------------------------- + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + for row in df.values: + labels = row[self._LABELS_START_IDX :].astype(bool) + # chebai.preprocessing.reader.DataReader only needs features, labels, ident, group + # "group" set to None, by default as no such entity for this data + yield dict( + features=row[self._DATA_REPRESENTATION_IDX], + labels=labels, + ident=row[self._ID_IDX], + ) + + # ------------------------------ Phase: Dynamic Splits ----------------------------------- + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + try: + filename = self.processed_file_names_dict["data"] + data_go = torch.load( + os.path.join(self.processed_dir, filename), weights_only=False + ) + except FileNotFoundError: + raise FileNotFoundError( + f"File data.pt doesn't exists. " + f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" + ) + + df_go_data = pd.DataFrame(data_go) + train_df_go, df_test = self.get_test_split( + df_go_data, seed=self.dynamic_data_split_seed + ) + + # Get all splits + df_train, df_val = self.get_train_val_splits_given_test( + train_df_go, + df_test, + seed=self.dynamic_data_split_seed, + ) + + return df_train, df_val, df_test + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + def base_dir(self) -> str: + """ + Returns the base directory path for storing GO-Uniprot data. + + Returns: + str: The path to the base directory, which is "data/GO_UniProt". + """ + return os.path.join("data", "SCOPe", f"version_{self.scope_version}") + + @property + def raw_file_names_dict(self) -> dict: + """ + Returns a dictionary of raw file names used in data processing. + + Returns: + dict: A dictionary mapping dataset names to their respective file names. + For example, {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"}. + """ + return { + "CLA": "cla.txt", + "DES": "des.txt", + "HIE": "hie.txt", + "COM": "com.txt", + "PDB": "pdb_sequences.txt", + } + + +class SCOPE(_SCOPeDataExtractor): + READER = ProteinDataReader + + @property + def _name(self) -> str: + return "test" + + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + pass + + +if __name__ == "__main__": + scope = SCOPE(scope_version=2.08) + scope._parse_pdb_sequence_file() From 7da8963c169c3e59ac9eb512c65e73313f7370cd Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 5 Jan 2025 17:17:02 +0100 Subject: [PATCH 31/33] base: make _name property abstract method - this will help to identify methods that needs to be implemented during coding and not during runtime --- chebai/preprocessing/datasets/base.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index fc64c808..6158b9dc 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -155,8 +155,19 @@ def fold_dir(self) -> str: return f"cv_{self.inner_k_folds}_fold" @property + @abstractmethod def _name(self) -> str: - raise NotImplementedError + """ + Abstract property representing the name of the data module. + + This property should be implemented in subclasses to provide a unique name for the data module. + The name is used to create subdirectories within the base directory or `processed_dir_main` + for storing relevant data associated with this module. + + Returns: + str: The name of the data module. + """ + pass def _filter_labels(self, row: dict) -> dict: """ From 976f2b895e3ee8fce4a9bcbde6ace30539e7845a Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 10 Jan 2025 13:55:50 +0100 Subject: [PATCH 32/33] add simple Feed-forward network (for ESM2->chebi task) --- chebai/models/ffn.py | 55 ++++++++++++++++++++++++++++ configs/data/deepGO/deepgo2_esm2.yml | 5 +++ configs/model/ffn.yml | 7 ++++ 3 files changed, 67 insertions(+) create mode 100644 chebai/models/ffn.py create mode 100644 configs/data/deepGO/deepgo2_esm2.yml create mode 100644 configs/model/ffn.yml diff --git a/chebai/models/ffn.py b/chebai/models/ffn.py new file mode 100644 index 00000000..77046ae6 --- /dev/null +++ b/chebai/models/ffn.py @@ -0,0 +1,55 @@ +from typing import Dict, Any, Tuple + +from chebai.models import ChebaiBaseNet +import torch +from torch import Tensor + +class FFN(ChebaiBaseNet): + + NAME = "FFN" + + def __init__(self, input_size: int = 1000, num_hidden_layers: int = 3, hidden_size: int = 128, **kwargs): + super().__init__(**kwargs) + + self.layers = torch.nn.ModuleList() + self.layers.append(torch.nn.Linear(input_size, hidden_size)) + for _ in range(num_hidden_layers): + self.layers.append(torch.nn.Linear(hidden_size, hidden_size)) + self.layers.append(torch.nn.Linear(hidden_size, self.out_dim)) + + def _get_prediction_and_labels(self, data, labels, model_output): + d = model_output["logits"] + loss_kwargs = data.get("loss_kwargs", dict()) + if "non_null_labels" in loss_kwargs: + n = loss_kwargs["non_null_labels"] + d = data[n] + return torch.sigmoid(d), labels.int() if labels is not None else None + + def _process_for_loss( + self, + model_output: Dict[str, Tensor], + labels: Tensor, + loss_kwargs: Dict[str, Any], + ) -> Tuple[Tensor, Tensor, Dict[str, Any]]: + """ + Process the model output for calculating the loss. + + Args: + model_output (Dict[str, Tensor]): The output of the model. + labels (Tensor): The target labels. + loss_kwargs (Dict[str, Any]): Additional loss arguments. + + Returns: + tuple: A tuple containing the processed model output, labels, and loss arguments. + """ + kwargs_copy = dict(loss_kwargs) + if labels is not None: + labels = labels.float() + return model_output["logits"], labels, kwargs_copy + + def forward(self, data, **kwargs): + x = data["features"] + for layer in self.layers: + x = torch.relu(layer(x)) + return {"logits": x} + diff --git a/configs/data/deepGO/deepgo2_esm2.yml b/configs/data/deepGO/deepgo2_esm2.yml new file mode 100644 index 00000000..4b3ae3b1 --- /dev/null +++ b/configs/data/deepGO/deepgo2_esm2.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData +init_args: + go_branch: "MF" + max_sequence_length: 1000 + use_esm2_embeddings: True \ No newline at end of file diff --git a/configs/model/ffn.yml b/configs/model/ffn.yml new file mode 100644 index 00000000..193c6f64 --- /dev/null +++ b/configs/model/ffn.yml @@ -0,0 +1,7 @@ +class_path: chebai.models.ffn.FFN +init_args: + optimizer_kwargs: + lr: 1e-3 + hidden_size: 128 + num_hidden_layers: 3 + input_size: 2560 From 3b174875ecdcc981a3c0a245e535d83bdd5811e3 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 10 Jan 2025 14:11:51 +0100 Subject: [PATCH 33/33] reformat using Black --- chebai/models/ffn.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/chebai/models/ffn.py b/chebai/models/ffn.py index 77046ae6..ca1f6f22 100644 --- a/chebai/models/ffn.py +++ b/chebai/models/ffn.py @@ -4,11 +4,18 @@ import torch from torch import Tensor + class FFN(ChebaiBaseNet): NAME = "FFN" - def __init__(self, input_size: int = 1000, num_hidden_layers: int = 3, hidden_size: int = 128, **kwargs): + def __init__( + self, + input_size: int = 1000, + num_hidden_layers: int = 3, + hidden_size: int = 128, + **kwargs + ): super().__init__(**kwargs) self.layers = torch.nn.ModuleList() @@ -52,4 +59,3 @@ def forward(self, data, **kwargs): for layer in self.layers: x = torch.relu(layer(x)) return {"logits": x} -