-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
deepgo se mirgration : add class to migrate
- Loading branch information
1 parent
c6d60cd
commit ca5461f
Showing
1 changed file
with
297 additions
and
45 deletions.
There are no files selected for viewing
342 changes: 297 additions & 45 deletions
342
chebai/preprocessing/migration/deep_go_data_mirgration.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,54 +1,306 @@ | ||
from typing import List | ||
import os | ||
from collections import OrderedDict | ||
from random import randint | ||
from typing import List, Literal | ||
|
||
import pandas as pd | ||
from jsonargparse import CLI | ||
|
||
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L18-L22 | ||
NAMESPACES = { | ||
"cc": "cellular_component", | ||
"mf": "molecular_function", | ||
"bp": "biological_process", | ||
} | ||
|
||
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 | ||
MAXLEN = 1000 | ||
|
||
|
||
def load_data(data_dir): | ||
test_df = pd.DataFrame(pd.read_pickle("test_data.pkl")) | ||
train_df = pd.DataFrame(pd.read_pickle("train_data.pkl")) | ||
validation_df = pd.DataFrame(pd.read_pickle("valid_data.pkl")) | ||
|
||
required_columns = [ | ||
"proteins", | ||
"accessions", | ||
"sequences", | ||
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L45-L58 | ||
"exp_annotations", # Directly associated GO ids | ||
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69 | ||
"prop_annotations", # Transitively associated GO ids | ||
] | ||
|
||
new_df = pd.concat( | ||
[ | ||
train_df[required_columns], | ||
validation_df[required_columns], | ||
test_df[required_columns], | ||
], | ||
ignore_index=True, | ||
) | ||
# Generate splits.csv file to store ids of each corresponding split | ||
split_assignment_list: List[pd.DataFrame] = [ | ||
pd.DataFrame({"id": train_df["proteins"], "split": "train"}), | ||
pd.DataFrame({"id": validation_df["proteins"], "split": "validation"}), | ||
pd.DataFrame({"id": test_df["proteins"], "split": "test"}), | ||
] | ||
from chebai.preprocessing.datasets.go_uniprot import ( | ||
GOUniProtOver50, | ||
GOUniProtOver250, | ||
_GOUniProtDataExtractor, | ||
) | ||
|
||
|
||
class DeepGoDataMigration: | ||
""" | ||
A class to handle data migration and processing for the DeepGO project. It migrates the data from the DeepGO-SE | ||
data structure to our data structure followed for GO-UniProt data. | ||
Attributes: | ||
_CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes. | ||
_MAXLEN (int): Maximum sequence length for sequences. | ||
_LABELS_START_IDX (int): Starting index for labels in the dataset. | ||
Methods: | ||
__init__(data_dir, go_branch): Initializes the data directory and GO branch. | ||
_load_data(): Loads train, validation, test, and terms data from the specified directory. | ||
_record_splits(): Creates a DataFrame with IDs and their corresponding split. | ||
migrate(): Executes the migration process including data loading, processing, and saving. | ||
_extract_required_data_from_splits(): Extracts required columns from the splits data. | ||
_generate_labels(data_df): Generates label columns for the data based on GO terms. | ||
extract_go_id(go_list): Extracts GO IDs from a list. | ||
save_migrated_data(data_df, splits_df): Saves the processed data and splits. | ||
""" | ||
|
||
# Link for the namespaces convention used for GO branch | ||
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L18-L22 | ||
_CORRESPONDING_GO_CLASSES = { | ||
"cc": GOUniProtOver50, | ||
"mf": GOUniProtOver50, | ||
"bp": GOUniProtOver250, | ||
} | ||
|
||
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 | ||
_MAXLEN = 1000 | ||
_LABELS_START_IDX = _GOUniProtDataExtractor._LABELS_START_IDX | ||
|
||
def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): | ||
""" | ||
Initializes the data migration object with a data directory and GO branch. | ||
Args: | ||
data_dir (str): Directory containing the data files. | ||
go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process). | ||
""" | ||
valid_go_branches = list(self._CORRESPONDING_GO_CLASSES.keys()) | ||
if go_branch not in valid_go_branches: | ||
raise ValueError(f"go_branch must be one of {valid_go_branches}") | ||
self._go_branch = go_branch | ||
|
||
self._data_dir = os.path.join(data_dir, go_branch) | ||
self._train_df: pd.DataFrame = None | ||
self._test_df: pd.DataFrame = None | ||
self._validation_df: pd.DataFrame = None | ||
self._terms_df: pd.DataFrame = None | ||
self._classes: List[str] = None | ||
|
||
def _load_data(self) -> None: | ||
""" | ||
Loads the test, train, validation, and terms data from the pickled files | ||
in the data directory. | ||
""" | ||
try: | ||
print(f"Loading data from {self._data_dir}......") | ||
self._test_df = pd.DataFrame( | ||
pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl")) | ||
) | ||
self._train_df = pd.DataFrame( | ||
pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl")) | ||
) | ||
self._validation_df = pd.DataFrame( | ||
pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl")) | ||
) | ||
self._terms_df = pd.DataFrame( | ||
pd.read_pickle(os.path.join(self._data_dir, "terms.pkl")) | ||
) | ||
except FileNotFoundError as e: | ||
print(f"Error loading data: {e}") | ||
|
||
def _record_splits(self) -> pd.DataFrame: | ||
""" | ||
Creates a DataFrame that stores the IDs and their corresponding data splits. | ||
Returns: | ||
pd.DataFrame: A combined DataFrame containing split assignments. | ||
""" | ||
print("Recording splits...") | ||
split_assignment_list: List[pd.DataFrame] = [ | ||
pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), | ||
pd.DataFrame( | ||
{"id": self._validation_df["proteins"], "split": "validation"} | ||
), | ||
pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}), | ||
] | ||
|
||
combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) | ||
return combined_split_assignment | ||
|
||
def migrate(self) -> None: | ||
""" | ||
Executes the data migration by loading, processing, and saving the data. | ||
""" | ||
print("Migration started......") | ||
self._load_data() | ||
if not all( | ||
[self._train_df, self._validation_df, self._test_df, self._terms_df] | ||
): | ||
raise Exception( | ||
"Data splits or terms data is not available in instance variables." | ||
) | ||
splits_df = self._record_splits() | ||
|
||
data_df = self._extract_required_data_from_splits() | ||
data_with_labels_df = self._generate_labels(data_df) | ||
|
||
if not all([data_with_labels_df, splits_df, self._classes]): | ||
raise Exception( | ||
"Data splits or terms data is not available in instance variables." | ||
) | ||
|
||
self.save_migrated_data(data_df, splits_df) | ||
|
||
def _extract_required_data_from_splits(self) -> pd.DataFrame: | ||
""" | ||
Extracts required columns from the combined data splits. | ||
Returns: | ||
pd.DataFrame: A DataFrame containing the essential columns for processing. | ||
""" | ||
print("Combining the data splits with required data..... ") | ||
required_columns = [ | ||
"proteins", | ||
"accessions", | ||
"sequences", | ||
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L45-L58 | ||
"exp_annotations", # Directly associated GO ids | ||
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69 | ||
"prop_annotations", # Transitively associated GO ids | ||
] | ||
|
||
new_df = pd.concat( | ||
[ | ||
self._train_df[required_columns], | ||
self._validation_df[required_columns], | ||
self._test_df[required_columns], | ||
], | ||
ignore_index=True, | ||
) | ||
new_df["go_ids"] = new_df.apply( | ||
lambda row: self.extract_go_id(row["exp_annotations"]) | ||
+ self.extract_go_id(row["prop_annotations"]), | ||
axis=1, | ||
) | ||
|
||
combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) | ||
data_df = pd.DataFrame( | ||
OrderedDict( | ||
swiss_id=new_df["proteins"], | ||
accession=new_df["accessions"], | ||
go_ids=new_df["go_ids"], | ||
sequence=new_df["sequences"], | ||
) | ||
) | ||
return data_df | ||
|
||
def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: | ||
""" | ||
Generates label columns for each GO term in the dataset. | ||
def save_data(data_dir, data_df): | ||
pass | ||
Args: | ||
data_df (pd.DataFrame): DataFrame containing data with GO IDs. | ||
Returns: | ||
pd.DataFrame: DataFrame with new label columns. | ||
""" | ||
print("Generating labels based on terms.pkl file.......") | ||
parsed_go_ids: pd.Series = self._terms_df.apply( | ||
lambda row: self.extract_go_id(row["gos"]) | ||
) | ||
all_go_ids_list = parsed_go_ids.values.tolist() | ||
self._classes = all_go_ids_list | ||
new_label_columns = pd.DataFrame( | ||
False, index=data_df.index, columns=all_go_ids_list | ||
) | ||
data_df = pd.concat([data_df, new_label_columns], axis=1) | ||
|
||
for index, row in data_df.iterrows(): | ||
for go_id in row["go_ids"]: | ||
if go_id in data_df.columns: | ||
data_df.at[index, go_id] = True | ||
|
||
data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] | ||
return data_df | ||
|
||
@staticmethod | ||
def extract_go_id(go_list: List[str]) -> List[str]: | ||
""" | ||
Extracts and parses GO IDs from a list of GO annotations. | ||
Args: | ||
go_list (List[str]): List of GO annotation strings. | ||
Returns: | ||
List[str]: List of parsed GO IDs. | ||
""" | ||
return [ | ||
_GOUniProtDataExtractor._parse_go_id(go_id_str) for go_id_str in go_list | ||
] | ||
|
||
def save_migrated_data( | ||
self, data_df: pd.DataFrame, splits_df: pd.DataFrame | ||
) -> None: | ||
""" | ||
Saves the processed data and split information. | ||
Args: | ||
data_df (pd.DataFrame): Data with GO labels. | ||
splits_df (pd.DataFrame): Split assignment DataFrame. | ||
""" | ||
print("Saving transformed data......") | ||
go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[ | ||
self._go_branch | ||
](go_branch=self._go_branch, max_sequence_length=self._MAXLEN) | ||
|
||
go_class_instance.save_processed( | ||
data_df, go_class_instance.processed_file_names_dict["data"] | ||
) | ||
print( | ||
f"{go_class_instance.processed_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}" | ||
) | ||
|
||
splits_df.to_csv( | ||
os.path.join(go_class_instance.processed_dir_main, "splits.csv"), | ||
index=False, | ||
) | ||
print(f"splits.csv saved to {go_class_instance.processed_dir_main}") | ||
|
||
classes = sorted(self._classes) | ||
with open( | ||
os.path.join(go_class_instance.processed_dir_main, "classes.txt"), "wt" | ||
) as fout: | ||
fout.writelines(str(node) + "\n" for node in classes) | ||
print(f"classes.txt saved to {go_class_instance.processed_dir_main}") | ||
print("Migration completed!") | ||
|
||
|
||
class Main: | ||
""" | ||
Main class to handle the migration process for DeepGoDataMigration. | ||
Methods: | ||
migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]): | ||
Initiates the migration process for the specified data directory and GO branch. | ||
""" | ||
|
||
def migrate(self, data_dir: str, go_branch: str) -> None: | ||
""" | ||
Initiates the migration process by creating a DeepGoDataMigration instance | ||
and invoking its migrate method. | ||
Args: | ||
data_dir (str): Directory containing the data files. | ||
go_branch (Literal["cc", "mf", "bp"]): GO branch to use | ||
("cc" for cellular_component, | ||
"mf" for molecular_function, | ||
or "bp" for biological_process). | ||
""" | ||
DeepGoDataMigration(data_dir, go_branch).migrate() | ||
|
||
|
||
class Main1: | ||
def __init__(self, max_prize: int = 100): | ||
""" | ||
Args: | ||
max_prize: Maximum prize that can be awarded. | ||
""" | ||
self.max_prize = max_prize | ||
|
||
def person(self, name: str, additional_prize: int = 0): | ||
""" | ||
Args: | ||
name: Name of the winner. | ||
additional_prize: Additional prize that can be added to the prize amount. | ||
""" | ||
prize = randint(0, self.max_prize) + additional_prize | ||
return f"{name} won {prize}€!" | ||
|
||
|
||
if __name__ == "__main__": | ||
pass | ||
# Example: python script_name.py migrate data_dir="data/deep_go_se_training_data" go_branch="bp" | ||
# --data_dir specifies the directory containing the data files. | ||
# --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration. | ||
CLI( | ||
Main1, | ||
description="DeepGoDataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).", | ||
) |