Skip to content

Commit

Permalink
Changes for creating splits.csv,added lines for debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
vidvath7 committed Nov 28, 2024
1 parent 131ea90 commit 7056187
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,9 @@ def get_test_split(

df_train = df.iloc[train_indices]
df_test = df.iloc[test_indices]
print("Inside get_test_split")
print("Train Split : ", df_train.shape)
print("Test Split : ", df_test.shape)
return df_train, df_test

def get_train_val_splits_given_test(
Expand All @@ -534,14 +537,15 @@ def get_train_val_splits_given_test(
are the corresponding DataFrames.
"""
print(f"Split dataset into train / val with given test set")

test_ids = test_df["ident"].tolist()
# ---- list comprehension degrades performance, dataframe operations are faster
# mask = [trainval_id not in test_ids for trainval_id in df_trainval["ident"]]
# df_trainval = df_trainval[mask]
df_trainval = df[~df["ident"].isin(test_ids)]
df_trainval = df
if self.aug_data==False:
test_ids = test_df["ident"].tolist()
# ---- list comprehension degrades performance, dataframe operations are faster
# mask = [trainval_id not in test_ids for trainval_id in df_trainval["ident"]]
# df_trainval = df_trainval[mask]
df_trainval = df[~df["ident"].isin(test_ids)]
labels_list_trainval = df_trainval["labels"].tolist()

print("df_trainval.shape after removing overlapping points:",df_trainval.shape)
if self.use_inner_cross_validation:
folds = {}
kfold = MultilabelStratifiedKFold(
Expand Down Expand Up @@ -571,9 +575,13 @@ def get_train_val_splits_given_test(
train_indices, validation_indices = next(
msss.split(labels_list_trainval, labels_list_trainval)
)
print("train_indices.shape : ", train_indices.shape)
print("validation_indices.shape : ", validation_indices.shape)

df_validation = df_trainval.iloc[validation_indices]
df_train = df_trainval.iloc[train_indices]
print("df_train :",df_train.shape)
print("df_validation :",df_validation.shape)
return df_train, df_validation

@property
Expand Down Expand Up @@ -815,7 +823,9 @@ def _generate_dynamic_splits(self) -> None:

try:
filename = self.processed_file_names_dict["data"]
print("Directory:",os.path.join(data_dir, filename))
print("Directory of data.pt:",os.path.join(data_dir, filename))
#loading of data.pt
print("Loading : ", filename )
data_chebi_version = torch.load(os.path.join(data_dir, filename))
except FileNotFoundError:
raise FileNotFoundError(
Expand All @@ -824,10 +834,14 @@ def _generate_dynamic_splits(self) -> None:
)

df_chebi_version = pd.DataFrame(data_chebi_version)
print("Created dataframe for data.pt :",df_chebi_version)
print("Created dataframe size:",df_chebi_version.shape)
train_df_chebi_ver, df_test_chebi_ver = self.get_test_split(
df_chebi_version, seed=self.dynamic_data_split_seed
)

print("get_test_split done, splits size train: ", train_df_chebi_ver.shape)
print("get_test_split done, splits size test: ", df_test_chebi_ver.shape)
print("chebi_version_train : ",self.chebi_version_train)
if self.chebi_version_train is not None:
# Load encoded data derived from "chebi_version_train"
try:
Expand Down Expand Up @@ -872,9 +886,11 @@ def _generate_dynamic_splits(self) -> None:
pd.DataFrame({"id": df_test["ident"], "split": "test"}),
]
combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True)
#Saving csv
combined_split_assignment.to_csv(
os.path.join(self.processed_dir_main, "splits.csv")
)
print("Saving splits.csv")

# Store the splits in class variables
self.dynamic_df_train = df_train
Expand Down Expand Up @@ -934,6 +950,7 @@ def dynamic_split_dfs(self) -> Dict[str, pd.DataFrame]:
):
if self.splits_file_path is None:
# Generate splits based on given seed, create csv file to records the splits
print("no splits_file_path provided by the user")
self._generate_dynamic_splits()
else:
# If user has provided splits file path, use it to get the splits from the data
Expand Down

0 comments on commit 7056187

Please sign in to comment.