Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
danaru29 committed Jul 16, 2024
2 parents 66469df + a9e9ec7 commit be07f13
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 99 deletions.
3 changes: 1 addition & 2 deletions moleculib/assembly/datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
from moleculib.molecule.datum import MoleculeDatum
from moleculib.protein.datum import ProteinDatum
from typing import List
from simple_pytree import Pytree
from biotite.database import rcsb
import biotite.structure.io.pdb as pdb
from biotite.structure import filter_amino_acids


class AssemblyDatum(Pytree):
class AssemblyDatum:

def __init__(
self,
Expand Down
7 changes: 4 additions & 3 deletions moleculib/graphics/py3Dmol.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def plot_py3dmol(
colors = [c.get_hex_l() for c in colors]

for i, datum in enumerate(data):
datum.plot(v, color=colors[i], **kwargs)
v.zoomTo()
datum.plot(v, colors=[colors[i]] * len(datum), **kwargs)
# v.zoomTo()
v.setBackgroundColor("rgb(0,0,0)", 0)
return v

Expand Down Expand Up @@ -211,7 +211,8 @@ def plot_py3dmol_grid(
):
v = py3Dmol.view(
viewergrid=(len(grid), len(grid[0])),
linked=True,
# linked=True,
linked=False,
width=len(grid[0]) * window_size[0],
height=len(grid) * window_size[1],
)
Expand Down
5 changes: 5 additions & 0 deletions moleculib/nucleic/datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,11 @@ def to_pdb_str(self):
return lines


def to_dict(self):
self.idcode=None
self.sequence=None
return vars(self)

def plot(
self,
view,
Expand Down
296 changes: 210 additions & 86 deletions moleculib/protein/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import os
import pickle
import traceback
Expand Down Expand Up @@ -388,123 +389,246 @@ def __init__(
super().__init__(splits, transform, shuffle)


import mdtraj
from torch.utils.data import Dataset

FAST_FOLDING_PROTEINS = {
'chignolin': 10,
'trpcage': 20,
'bba': 28,
'wwdomain': 34,
'villin': 35,
'ntl9': 39,
'bbl': 47,
'proteinb': 47,
'homeodomain': 54,
'proteing': 56,
'a3D': 73,
'lambda': 80
}


from moleculib.protein.datum import ProteinDatum

from moleculib.protein.transform import ProteinPad
from biotite.structure import filter_amino_acids
import biotite.structure.io.pdb as pdb
import numpy as np
from tqdm import tqdm

import mdtraj
import os
from collections import defaultdict
from copy import deepcopy
from biotite.structure.io.xtc import XTCFile
from numpy.lib.format import open_memmap

from tqdm.contrib.concurrent import process_map

class FastFoldingDataset(Dataset):
class FastFoldingDataset:

def __init__(
self,
protein="chignolin",
num_files=-1,
tau=0,
stride=1,
time_sort=False,
buffer=100,
preload=True,
shuffle=False,
base = "/mas/projects/molecularmachines/db/FastFoldingProteins/untar/",
proteins=None,
tau=0,
shuffle=True,
stride=1,
preload=False,
epoch_size=10000,
padded=True,
):
base = "/mas/projects/molecularmachines/db/FastFoldingProteins/"
if protein == "chignolin":
self.base_path = base + "chignolin_trajectories/filtered/"
elif "trpcage" in protein: # trpcage0, trpcage1, trpcage2
self.base_path = base + f"rpcage_trajectories/batches/{protein[-1]}/filtered"
elif protein == "villin":
self.base_path = base + "villin_trajectories/filtered/"
elif "bba" in protein: # bba0, bba1, bba2
self.base_path = base + f"bba_trajectories/batches/{protein[-1]}/filtered"

self.num_files = num_files
proteins = list(FAST_FOLDING_PROTEINS.keys())

self.base_path = base
self.proteins = proteins
self.tau = tau
self.time_sort = True
self.stride = stride
self.time_sort = time_sort
self.files = self._list_files()[: self.num_files]
self.atom_array = pdb.PDBFile.read(
self.base_path + "filtered.pdb"
).get_structure()[0]
self.aa_filter = filter_amino_acids(self.atom_array)
self.atom_array = self.atom_array[self.aa_filter]
self.counter = 0
self.buffer = buffer

self.files = self.files[:1000]
self.epoch_size = epoch_size * len(proteins)

# self.coords = np.concatenate(
# [self._load_subtrajs(i) for i in tqdm(range(len(self.files)))]
# )
self.describe()

self.coords = np.concatenate(
process_map(self._load_subtrajs, range(len(self.files)), max_workers=24)
)

if shuffle:
self.shuffler = np.random.permutation(len(self))
else:
self.shuffler = np.arange(len(self))

self.splits = { 'train': self }
self.atom_arrays = {
protein: pdb.PDBFile.read(
self.base_path + protein + "_trajectories/0/filtered.pdb"
).get_structure()[0] for protein in proteins
}
self.atom_arrays = {
protein: atom_array[filter_amino_acids(atom_array)] for protein, atom_array in self.atom_arrays.items()
}

print(f"{len(self)} total samples")
if padded: self.pad = ProteinPad(pad_size=max([FAST_FOLDING_PROTEINS[protein] for protein in proteins]))
else: self.pad = lambda x: x

def _load_subtrajs(self, idx):
data = mdtraj.load(
self.files[idx],
top=self.base_path + "filtered.pdb",
stride=self.stride,
)
return data.xyz[:, self.aa_filter, :] * 10

def _list_files(self):
def extract_x_y(filename):
part = os.path.basename(filename).split("_")[0]
x, y = part.strip("e").split("s")
return int(x), int(y)

files_with_extension = set()
for filename in os.listdir(self.base_path):
if filename.endswith(".xtc") and not filename.startswith("."):
files_with_extension.add(self.base_path + filename)
self.splits = { 'train': self }

files = list(files_with_extension)
if self.time_sort:
return sorted(files, key=lambda x: extract_x_y(x))
return files
def describe(self):
# file is indexed by protein, trajectory, and frame
self.files = defaultdict(lambda: defaultdict(list))
self.num_trajectories = defaultdict(int)
self.num_frames_per_traj = defaultdict(lambda: defaultdict(int))
self.num_frames = defaultdict(int)

for protein in self.proteins:
protein_path = self.base_path + protein + "_trajectories/"
for trajectory in os.listdir(protein_path):
if trajectory.startswith("."): continue
self.num_trajectories[protein] += 1
trajectory_path = protein_path + trajectory + "/"
for supframe in os.listdir(trajectory_path):
if supframe.startswith("."): continue
if not supframe.endswith('.mmap'): continue
self.num_frames[protein] += 1
self.num_frames_per_traj[protein][int(trajectory)] += 1
if self.files[protein].get(int(trajectory)) is None:
self.files[protein][int(trajectory)] = []
self.files[protein][int(trajectory)].append(trajectory_path + supframe)
for protein in self.proteins:
print(f"{protein}: {self.num_trajectories[protein]} trajectories, {self.num_frames[protein]} total frames")
# print(f"Trajectory lengths: {self.num_frames_per_traj[protein]}")

def __len__(self):
return len(self.coords) - self.tau
return self.epoch_size

def __getitem__(self, idx):
# if self.counter > self.buffer:
# index = int(idx / self.buffer)
# index = min(index, len(self.files) - 1)
# self._load_coords(self.files[index])
# self.counter = 0
protein = self.proteins[idx % len(self.proteins)]

# self.counter += 1
# idxx = np.maximum(idx % (self.coords.shape[0] - self.tau),0)
idx1 = self.shuffler[idx]
self.atom_array._coord = self.coords[idx1]
# sample traj and sample file
while True:
traj_idx = np.random.randint(0, self.num_trajectories[protein])
subtraj_idx = np.random.randint(0, self.num_frames_per_traj[protein][traj_idx])

mmap_path = self.files[protein][traj_idx][subtraj_idx]
coord = open_memmap(mmap_path, mode='r', dtype=np.float32)
template = self.atom_arrays[protein]
if len(coord) > self.tau + 1: break

idx1 = np.random.randint(0, len(coord) - self.tau - 1)

aa1 = deepcopy(template)
aa1._coord = coord[idx1]

p1 = ProteinDatum.from_atom_array(
self.atom_array,
aa1,
header=dict(
idcode=None,
resolution=None,
),
)

p1 = self.pad.transform(p1)
if self.tau == 0:
return [p1]
return p1

idx2 = self.shuffler[idx + self.tau]
self.atom_array._coord = self.coords[idx2]
idx2 = idx1 + self.tau
aa2 = deepcopy(template)
aa2._coord = coord[idx2]
p2 = ProteinDatum.from_atom_array(
self.atom_array,
aa2,
header=dict(
idcode=None,
resolution=None,
),
)
return [p2, p1]
p2 = self.pad.transform(p2)
return [p1, p2]






from biotite.structure.io import pdb
from moleculib.protein.transform import ProteinPad


TAUS = [0, 1, 2, 4, 8, 16]
import webdataset as wds

from copy import deepcopy
from moleculib.protein.datum import ProteinDatum


class ShardedFastFoldingDataset:

def __init__(
self,
base = "/mas/projects/molecularmachines/db/FastFoldingProteins/web/",
num_shards=240,
proteins=None,
tau=0,
padded=True,
batch_size=1,
):
assert tau in TAUS, f"tau must be one of {TAUS}"

if proteins == None:
proteins = list(FAST_FOLDING_PROTEINS.keys())
else:
for protein in proteins:
assert protein in FAST_FOLDING_PROTEINS.keys(), f'{protein} is not a valid option'

self.base_path = base
self.proteins = proteins
self.tau = tau
self.time_sort = True
self.num_shards = num_shards
self.batch_size = batch_size

self.atom_arrays = {
protein: pdb.PDBFile.read(
os.path.join(base, protein + ".pdb")
).get_structure()[0] for protein in proteins
}

if padded: self.pad = ProteinPad(pad_size=max([FAST_FOLDING_PROTEINS[protein] for protein in proteins]))
else: self.pad = lambda x: x



def build_webdataset(sample):
if self.tau == 0:
key, coord1 = sample
coords = [ coord1 ]
else:
key, coord1, coord2 = sample
coords = [ coord1, coord2 ]
protein = key.split('_')[-1]
template = self.atom_arrays[protein]
data = []
for coord in coords:
new_aa = deepcopy(template)
new_aa.coord = coord
data.append(
self.pad.transform(
ProteinDatum.from_atom_array(
new_aa,
header={'idcode': protein, 'resolution': None}
)
)
)
return data

keys = ('__key__', 'coord.npy')
if self.tau > 0:
keys = keys + (f'coord_{self.tau}.npy', )

self.web_ds = iter(
wds.WebDataset(base + 'shards-' + '{00000..%05d}.tar' % (num_shards - 1))
.decode()
.to_tuple(*keys)
.map(build_webdataset)
.batched(batch_size, collation_fn=lambda x: x)
)

self.splits = { 'train': self }

def __len__(self):
return self.num_shards * 1000 // self.batch_size

def __iter__(self):
return self

def __next__(self):
return next(self.web_ds)

def __getitem__(self, index):
return next(self)
Loading

0 comments on commit be07f13

Please sign in to comment.