Skip to content

Commit

Permalink
max_files.
Browse files Browse the repository at this point in the history
  • Loading branch information
ameya98 committed May 20, 2024
1 parent e6830dc commit ff2dab0
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions moleculib/protein/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(

# shuffle and sample
self.metadata = self.metadata.sample(frac=frac).reset_index(drop=True)
logging.info(f"Loaded metadata with {len(self.metadata)} samples")
print(f"Loaded metadata with {len(self.metadata)} samples")

# specific protein attributes
protein_attrs = [
Expand Down Expand Up @@ -178,11 +178,11 @@ def _maybe_fetch_and_extract(pdb_id, format, save_path):
except KeyboardInterrupt:
exit()
except (ValueError, IndexError) as error:
logging.info(traceback.format_exc())
logging.info(error)
print(traceback.format_exc())
print(error)
return None
except biotite.database.RequestError as request_error:
logging.info(request_error)
print(request_error)
return None
if len(datum.sequence) == 0:
return None
Expand All @@ -207,7 +207,7 @@ def build(
pdb_ids = pids_file_to_list(root + "/data/pids_all.txt")
if save_path is None:
save_path = mkdtemp()
logging.info(f"Fetching {len(pdb_ids)} PDB IDs with {max_workers} workers...")
print(f"Fetching {len(pdb_ids)} PDB IDs with {max_workers} workers...")

series = {c: Series(dtype=t) for (c, t) in PDB_METADATA_FIELDS}
metadata = DataFrame(series)
Expand Down Expand Up @@ -298,7 +298,7 @@ class TinyPDBDataset(PreProcessedDataset):
def __init__(self, base_path, transform: List[Callable] = None, shuffle=True):
base_path = os.path.join(base_path, "tinypdb.pyd")
with open(base_path, "rb") as fin:
logging.info("Loading data...")
print("Loading data...")
splits = pickle.load(fin)
super().__init__(splits, transform, shuffle, pre_transform=False)

Expand All @@ -308,7 +308,7 @@ class FrameDiffDataset(PreProcessedDataset):
def __init__(self, base_path, transform: List[Callable] = None, shuffle=True):
base_path = os.path.join(base_path, "framediff_train_data.pyd")
with open(base_path, "rb") as fin:
logging.info("Loading data...")
print("Loading data...")
splits = pickle.load(fin)
super().__init__(splits, transform, shuffle, pre_transform=False)

Expand All @@ -318,7 +318,7 @@ class TinyPDBDataset(PreProcessedDataset):
def __init__(self, base_path, transform: List[Callable] = None, shuffle=True):
base_path = os.path.join(base_path, "tinypdb.pyd")
with open(base_path, "rb") as fin:
logging.info("Loading data...")
print("Loading data...")
splits = pickle.load(fin)
super().__init__(splits, transform, shuffle, pre_transform=False)

Expand All @@ -328,7 +328,7 @@ class FoldingDiffDataset(PreProcessedDataset):
def __init__(self, base_path, transform: List[Callable] = None, shuffle=True):
base_path = os.path.join(base_path, "folddiff_train_data.pyd")
with open(base_path, "rb") as fin:
logging.info("Loading data...")
print("Loading data...")
splits = pickle.load(fin)
super().__init__(splits, transform, shuffle, pre_transform=False)

Expand All @@ -338,7 +338,7 @@ class FoldDataset(PreProcessedDataset):
def __init__(self, base_path, transform: List[Callable] = None, shuffle=True):
base_path = os.path.join(base_path, "fold.pyd")
with open(base_path, "rb") as fin:
logging.info("Loading data...")
print("Loading data...")
splits = pickle.load(fin)
super().__init__(splits, transform, shuffle)

Expand All @@ -348,7 +348,7 @@ class EnzymeCommissionDataset(PreProcessedDataset):
def __init__(self, base_path, transform: List[Callable] = None, shuffle=True):
path = os.path.join(base_path, "ec.pyd")
with open(path, "rb") as fin:
logging.info(f"Loading data from {path}")
print(f"Loading data from {path}")
splits = pickle.load(fin)
super().__init__(splits, transform, shuffle)

Expand All @@ -360,7 +360,7 @@ def __init__(
):
path = os.path.join(base_path, f"go_{level}.pyd")
with open(path, "rb") as fin:
logging.info(f"Loading data from {path}")
print(f"Loading data from {path}")
splits = pickle.load(fin)
super().__init__(splits, transform, shuffle)

Expand All @@ -370,7 +370,7 @@ class FuncDataset(PreProcessedDataset):
def __init__(self, base_path, transform: List[Callable] = None, shuffle=True):
path = os.path.join(base_path, "func.pyd")
with open(path, "rb") as fin:
logging.info(f"Loading data from {path}")
print(f"Loading data from {path}")
splits = pickle.load(fin)
super().__init__(splits, transform, shuffle)

Expand All @@ -381,10 +381,10 @@ def __init__(
self, base_path, transform: List[Callable] = None, shuffle=True, val_split=0.0
):
with open(os.path.join(base_path, "scaffolds.pyd"), "rb") as fin:
logging.info("Loading data...")
print("Loading data...")
dataset = pickle.load(fin)
if val_split > 0.0:
logging.info(f"Splitting data into train/val with val_split={val_split}")
print(f"Splitting data into train/val with val_split={val_split}")
dataset = np.random.permutation(dataset)
num_val = int(len(dataset) * val_split)
splits = dict(train=dataset[:-num_val], val=dataset[-num_val:])
Expand Down Expand Up @@ -423,7 +423,7 @@ def __init__(
self.atom_array = self.atom_array[self.aa_filter]
self.counter = 0
self._load_coords(self.files[0])
logging.info(f"{len(self)} total samples")
print(f"{len(self)} total samples")

def _list_files(self):
def extract_x_y(filename):
Expand Down Expand Up @@ -501,17 +501,17 @@ def __init__(
if len(self.files) == 0:
raise ValueError(f"No files found in {self.base_path}")

logging.info(f"Found {len(self.files)} files in {self.base_path}")
print(f"Found {len(self.files)} files in {self.base_path}")
if max_files is not None:
self.files = self.files[:max_files]
logging.info(f"Using {max_files} files")
print(f"Using {max_files} files")

logging.info(f"Loading first file: {self.files[0]}")
print(f"Loading first file: {self.files[0]}")
self._load_coords(self.files[0])

def _list_files(self):
files_with_extension = set()
for filename in os.listdir(self.base_path):
for filename in sorted(os.listdir(self.base_path)):
if filename.endswith(".npz") and not filename.startswith("."):
files_with_extension.add(os.path.join(self.base_path, filename))
return list(files_with_extension)
Expand Down

0 comments on commit ff2dab0

Please sign in to comment.