Skip to content

Commit

Permalink
fast folding updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Allan dos Santos Costa committed Jul 22, 2024
1 parent 922ea45 commit f839d32
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions moleculib/protein/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,29 +426,32 @@ class FastFoldingDataset:

def __init__(
self,
base = "/mas/projects/molecularmachines/db/FastFoldingProteins/untar/",
base = "/mas/projects/molecularmachines/db/FastFoldingProteins/memmap/",
proteins=None,
tau=0,
shuffle=True,
stride=1,
preload=False,
epoch_size=10000,
padded=True,
num_folders=1,
):
proteins = list(FAST_FOLDING_PROTEINS.keys())
if proteins == None:
proteins = list(FAST_FOLDING_PROTEINS.keys())

self.base_path = base
self.proteins = proteins
self.tau = tau
self.time_sort = True
self.stride = stride
self.epoch_size = epoch_size * len(proteins)

self.num_folders = num_folders

self.describe()

self.atom_arrays = {
protein: pdb.PDBFile.read(
self.base_path + protein + "_trajectories/0/filtered.pdb"
self.base_path + protein + "/0/filtered.pdb"
).get_structure()[0] for protein in proteins
}
self.atom_arrays = {
Expand All @@ -468,8 +471,8 @@ def describe(self):
self.num_frames = defaultdict(int)

for protein in self.proteins:
protein_path = self.base_path + protein + "_trajectories/"
for trajectory in os.listdir(protein_path):
protein_path = self.base_path + protein + "/"
for idx, trajectory in enumerate(os.listdir(protein_path)):
if trajectory.startswith("."): continue
self.num_trajectories[protein] += 1
trajectory_path = protein_path + trajectory + "/"
Expand All @@ -481,6 +484,10 @@ def describe(self):
if self.files[protein].get(int(trajectory)) is None:
self.files[protein][int(trajectory)] = []
self.files[protein][int(trajectory)].append(trajectory_path + supframe)
if idx == self.num_folders - 1:
break


for protein in self.proteins:
print(f"{protein}: {self.num_trajectories[protein]} trajectories, {self.num_frames[protein]} total frames")
# print(f"Trajectory lengths: {self.num_frames_per_traj[protein]}")
Expand All @@ -491,14 +498,15 @@ def __len__(self):
def __getitem__(self, idx):
protein = self.proteins[idx % len(self.proteins)]

# sample traj and sample file
# need to check if len(coord) > self.tau + 1
while True:
traj_idx = np.random.randint(0, self.num_trajectories[protein])
subtraj_idx = np.random.randint(0, self.num_frames_per_traj[protein][traj_idx])

mmap_path = self.files[protein][traj_idx][subtraj_idx]
coord = open_memmap(mmap_path, mode='r', dtype=np.float32)
template = self.atom_arrays[protein]

if len(coord) > self.tau + 1: break

idx1 = np.random.randint(0, len(coord) - self.tau - 1)
Expand All @@ -509,7 +517,7 @@ def __getitem__(self, idx):
p1 = ProteinDatum.from_atom_array(
aa1,
header=dict(
idcode=None,
idcode=protein,
resolution=None,
),
)
Expand All @@ -524,10 +532,11 @@ def __getitem__(self, idx):
p2 = ProteinDatum.from_atom_array(
aa2,
header=dict(
idcode=None,
idcode=protein,
resolution=None,
),
)

p2 = self.pad.transform(p2)
return [p1, p2]

Expand All @@ -552,7 +561,6 @@ class ShardedFastFoldingDataset:
def __init__(
self,
base = "/mas/projects/molecularmachines/db/FastFoldingProteins/web/",
num_shards=240,
proteins=None,
tau=0,
padded=True,
Expand All @@ -566,6 +574,8 @@ def __init__(
for protein in proteins:
assert protein in FAST_FOLDING_PROTEINS.keys(), f'{protein} is not a valid option'

num_shards = len(list(filter(lambda x: 'shards-' in x, os.listdir(base))))

self.base_path = base
self.proteins = proteins
self.tau = tau
Expand Down

0 comments on commit f839d32

Please sign in to comment.