From f839d32758c229483177dad36d2e9e79fcd6b7b0 Mon Sep 17 00:00:00 2001 From: Allan dos Santos Costa Date: Mon, 22 Jul 2024 13:16:52 -0400 Subject: [PATCH] fast folding updates --- moleculib/protein/dataset.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/moleculib/protein/dataset.py b/moleculib/protein/dataset.py index 340bca6..b238ed9 100644 --- a/moleculib/protein/dataset.py +++ b/moleculib/protein/dataset.py @@ -426,7 +426,7 @@ class FastFoldingDataset: def __init__( self, - base = "/mas/projects/molecularmachines/db/FastFoldingProteins/untar/", + base = "/mas/projects/molecularmachines/db/FastFoldingProteins/memmap/", proteins=None, tau=0, shuffle=True, @@ -434,8 +434,10 @@ def __init__( preload=False, epoch_size=10000, padded=True, + num_folders=1, ): - proteins = list(FAST_FOLDING_PROTEINS.keys()) + if proteins == None: + proteins = list(FAST_FOLDING_PROTEINS.keys()) self.base_path = base self.proteins = proteins @@ -443,12 +445,13 @@ def __init__( self.time_sort = True self.stride = stride self.epoch_size = epoch_size * len(proteins) - + self.num_folders = num_folders + self.describe() self.atom_arrays = { protein: pdb.PDBFile.read( - self.base_path + protein + "_trajectories/0/filtered.pdb" + self.base_path + protein + "/0/filtered.pdb" ).get_structure()[0] for protein in proteins } self.atom_arrays = { @@ -468,8 +471,8 @@ def describe(self): self.num_frames = defaultdict(int) for protein in self.proteins: - protein_path = self.base_path + protein + "_trajectories/" - for trajectory in os.listdir(protein_path): + protein_path = self.base_path + protein + "/" + for idx, trajectory in enumerate(os.listdir(protein_path)): if trajectory.startswith("."): continue self.num_trajectories[protein] += 1 trajectory_path = protein_path + trajectory + "/" @@ -481,6 +484,10 @@ def describe(self): if self.files[protein].get(int(trajectory)) is None: self.files[protein][int(trajectory)] = [] self.files[protein][int(trajectory)].append(trajectory_path + supframe) + if idx == self.num_folders - 1: + break + + for protein in self.proteins: print(f"{protein}: {self.num_trajectories[protein]} trajectories, {self.num_frames[protein]} total frames") # print(f"Trajectory lengths: {self.num_frames_per_traj[protein]}") @@ -491,7 +498,7 @@ def __len__(self): def __getitem__(self, idx): protein = self.proteins[idx % len(self.proteins)] - # sample traj and sample file + # need to check if len(coord) > self.tau + 1 while True: traj_idx = np.random.randint(0, self.num_trajectories[protein]) subtraj_idx = np.random.randint(0, self.num_frames_per_traj[protein][traj_idx]) @@ -499,6 +506,7 @@ def __getitem__(self, idx): mmap_path = self.files[protein][traj_idx][subtraj_idx] coord = open_memmap(mmap_path, mode='r', dtype=np.float32) template = self.atom_arrays[protein] + if len(coord) > self.tau + 1: break idx1 = np.random.randint(0, len(coord) - self.tau - 1) @@ -509,7 +517,7 @@ def __getitem__(self, idx): p1 = ProteinDatum.from_atom_array( aa1, header=dict( - idcode=None, + idcode=protein, resolution=None, ), ) @@ -524,10 +532,11 @@ def __getitem__(self, idx): p2 = ProteinDatum.from_atom_array( aa2, header=dict( - idcode=None, + idcode=protein, resolution=None, ), ) + p2 = self.pad.transform(p2) return [p1, p2] @@ -552,7 +561,6 @@ class ShardedFastFoldingDataset: def __init__( self, base = "/mas/projects/molecularmachines/db/FastFoldingProteins/web/", - num_shards=240, proteins=None, tau=0, padded=True, @@ -566,6 +574,8 @@ def __init__( for protein in proteins: assert protein in FAST_FOLDING_PROTEINS.keys(), f'{protein} is not a valid option' + num_shards = len(list(filter(lambda x: 'shards-' in x, os.listdir(base)))) + self.base_path = base self.proteins = proteins self.tau = tau