From 329035b836e374762ad07ac5fcb2bea62a113b64 Mon Sep 17 00:00:00 2001 From: Allan Costa Date: Thu, 13 Jun 2024 17:37:55 +0000 Subject: [PATCH 1/7] updates --- moleculib/graphics/py3Dmol.py | 6 +++--- moleculib/nucleic/datum.py | 5 +++++ moleculib/protein/dataset.py | 9 ++------- moleculib/protein/datum.py | 11 ++++++----- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/moleculib/graphics/py3Dmol.py b/moleculib/graphics/py3Dmol.py index cac9688..2a9b30a 100644 --- a/moleculib/graphics/py3Dmol.py +++ b/moleculib/graphics/py3Dmol.py @@ -36,8 +36,8 @@ def plot_py3dmol( colors = [c.get_hex_l() for c in colors] for i, datum in enumerate(data): - datum.plot(v, color=colors[i], **kwargs) - v.zoomTo() + datum.plot(v, colors=[colors[i]] * len(datum), **kwargs) + # v.zoomTo() v.setBackgroundColor("rgb(0,0,0)", 0) return v @@ -221,7 +221,7 @@ def plot_py3dmol_grid( plot_py3dmol_traj(datum, v, viewer=(i, j), **kwargs) else: datum.plot(v, viewer=(i, j), **kwargs) - v.zoomTo() + # v.zoomTo() v.setBackgroundColor("rgb(0,0,0)", 0) return v diff --git a/moleculib/nucleic/datum.py b/moleculib/nucleic/datum.py index 20e38d5..5b663c1 100644 --- a/moleculib/nucleic/datum.py +++ b/moleculib/nucleic/datum.py @@ -353,6 +353,11 @@ def to_pdb_str(self): return lines + def to_dict(self): + self.idcode=None + self.sequence=None + return vars(self) + def plot( self, view, diff --git a/moleculib/protein/dataset.py b/moleculib/protein/dataset.py index 075489e..2552e91 100644 --- a/moleculib/protein/dataset.py +++ b/moleculib/protein/dataset.py @@ -431,12 +431,7 @@ def __init__( self.counter = 0 self.buffer = buffer - self.files = self.files[:1000] - - # self.coords = np.concatenate( - # [self._load_subtrajs(i) for i in tqdm(range(len(self.files)))] - # ) - + # self.files = self.files[:1000] self.coords = np.concatenate( process_map(self._load_subtrajs, range(len(self.files)), max_workers=24) ) @@ -496,7 +491,7 @@ def __getitem__(self, idx): ), ) if self.tau == 0: - return [p1] + return p1 idx2 = self.shuffler[idx + self.tau] self.atom_array._coord = self.coords[idx2] diff --git a/moleculib/protein/datum.py b/moleculib/protein/datum.py index 65f29f6..9c2779a 100644 --- a/moleculib/protein/datum.py +++ b/moleculib/protein/datum.py @@ -488,8 +488,9 @@ def plot( view.addStyle({'model': -1}, {'stick': {'radius': 0.2}}, viewer=viewer) if colors is not None: + print('here') colors = {i+1: c for i, c in enumerate(colors)} - view.setStyle({'model': -1}, {'stick':{'colorscheme':{'prop':'resi','map':colors}}}) + view.setStyle({'stick':{'colorscheme':{'prop':'resi','map':colors}}}) return view @@ -555,12 +556,12 @@ def to_all_atom_array(prot): cif.set_structure(file, atom_array) file.write(filepath) - def to_dict(self): + def to_pytree(self): return vars(self) - # def to_pytree(self): - # return Pytree(self.to_dict()) - + def from_pytree(self, tree): + return ProteinDatum(**tree) + @classmethod def from_dict(cls, dict_): return cls(**dict_) \ No newline at end of file From 4a09cb8c138f8eb062f419744a202300b2ca7aaa Mon Sep 17 00:00:00 2001 From: Ilan Mitnikov Date: Sun, 23 Jun 2024 17:11:46 +0000 Subject: [PATCH 2/7] fast folding more paths --- moleculib/protein/dataset.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/moleculib/protein/dataset.py b/moleculib/protein/dataset.py index 075489e..71224b1 100644 --- a/moleculib/protein/dataset.py +++ b/moleculib/protein/dataset.py @@ -410,13 +410,19 @@ def __init__( ): base = "/mas/projects/molecularmachines/db/FastFoldingProteins/" if protein == "chignolin": - self.base_path = base + "chignolin_trajectories/filtered/" + self.base_path = base + "chignolin_trajectories/batches/0/filtered/" elif "trpcage" in protein: # trpcage0, trpcage1, trpcage2 self.base_path = base + f"rpcage_trajectories/batches/{protein[-1]}/filtered" elif protein == "villin": self.base_path = base + "villin_trajectories/filtered/" elif "bba" in protein: # bba0, bba1, bba2 self.base_path = base + f"bba_trajectories/batches/{protein[-1]}/filtered" + elif "homeodomain" == protein: + self.base_path = base + "homeodomain_trajectories/filtered" + elif "proteinb" == protein: + self.base_path = base + "proteinb_trajectories/filtered" + elif "proteing1" == protein: + self.base_path = base + "proteing_1_trajectories/testfilter" self.num_files = num_files self.tau = tau From cadd9dd7154bac19289955c4e40c1e5437f3af20 Mon Sep 17 00:00:00 2001 From: Ilan Mitnikov Date: Sat, 29 Jun 2024 00:03:41 +0000 Subject: [PATCH 3/7] fastfolding paths --- moleculib/graphics/py3Dmol.py | 2 +- moleculib/protein/dataset.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/moleculib/graphics/py3Dmol.py b/moleculib/graphics/py3Dmol.py index 5c3078b..0919f56 100644 --- a/moleculib/graphics/py3Dmol.py +++ b/moleculib/graphics/py3Dmol.py @@ -221,7 +221,7 @@ def plot_py3dmol_grid( plot_py3dmol_traj(datum, v, viewer=(i, j), **kwargs) else: datum.plot(v, viewer=(i, j), **kwargs) - # v.zoomTo() + v.zoomTo() v.setBackgroundColor("rgb(0,0,0)", 0) return v diff --git a/moleculib/protein/dataset.py b/moleculib/protein/dataset.py index b7c94c9..5d66d83 100644 --- a/moleculib/protein/dataset.py +++ b/moleculib/protein/dataset.py @@ -423,7 +423,19 @@ def __init__( self.base_path = base + "proteinb_trajectories/filtered" elif "proteing1" == protein: self.base_path = base + "proteing_1_trajectories/testfilter" - + elif "proteing2" == protein: + self.base_path = base + "proteing_2_trajectories/workspace3/Folding/folding-proteing/crystal/ss_contacts/50ns/proteinG/batches/0/filtered" + elif "wwdomain1" == protein: + self.base_path = base + "wwdomain_1_trajectories/WWdomain/batches/0/filtered" + elif "wwdomain2" == protein: + self.base_path = base + "wwdomain_2_trajectories/WWdomain/batches/0/filtered" + elif "bbl1" == protein: + self.base_path = base + "bbl_1_trajectories/folding-bbl/crystal/ss_contacts/50ns/bbl/batches/0/filtered" + elif "bbl2" == protein: + self.base_path = base + "bbl_2_trajectories/folding-bbl-2/crystal/ss_contacts/50ns/bbl/batches/0/filtered" + elif "ntl9" == protein: + self.base_path = base + "ntl9_trajectories/workspace3/Folding/folding/crystal/ss_contacts/20ns/ntl9/batches/0/filtered" + self.num_files = num_files self.tau = tau self.stride = stride From a4952d029a74e905816fae6a2b6b3e4d28a12197 Mon Sep 17 00:00:00 2001 From: Allan Costa Date: Mon, 1 Jul 2024 23:27:44 +0000 Subject: [PATCH 4/7] update fast folding proteins ds --- moleculib/protein/dataset.py | 188 ++++++++++++++++++++--------------- moleculib/protein/datum.py | 3 +- 2 files changed, 109 insertions(+), 82 deletions(-) diff --git a/moleculib/protein/dataset.py b/moleculib/protein/dataset.py index 2552e91..4c0ec03 100644 --- a/moleculib/protein/dataset.py +++ b/moleculib/protein/dataset.py @@ -1,3 +1,4 @@ +from collections import defaultdict import os import pickle import traceback @@ -388,118 +389,145 @@ def __init__( super().__init__(splits, transform, shuffle) -import mdtraj -from torch.utils.data import Dataset + +FAST_FOLDING_PROTEINS = { + 'chignolin': 10, + 'trpcage': 20, + 'bba': 28, + 'wwdomain': 34, + 'villin': 35, + 'ntl9': 39, + 'bbl': 47, + 'proteinb': 47, + 'homeodomain': 54, + 'proteing': 56, + 'a3D': 73, + # 'lambda': 80 +} + + +from moleculib.protein.datum import ProteinDatum + +from moleculib.protein.transform import ProteinPad from biotite.structure import filter_amino_acids import biotite.structure.io.pdb as pdb +import numpy as np +from tqdm import tqdm +import mdtraj +import os +from collections import defaultdict +from copy import deepcopy +from biotite.structure.io.xtc import XTCFile +from numpy.lib.format import open_memmap -from tqdm.contrib.concurrent import process_map -class FastFoldingDataset(Dataset): +class FastFoldingDataset: + def __init__( self, - protein="chignolin", - num_files=-1, - tau=0, - stride=1, - time_sort=False, - buffer=100, - preload=True, - shuffle=False, + base = "/mas/projects/molecularmachines/db/FastFoldingProteins/untar/", + proteins=None, + tau=0, + shuffle=True, + stride=1, + preload=False, + epoch_size=10000, + padded=True, ): - base = "/mas/projects/molecularmachines/db/FastFoldingProteins/" - if protein == "chignolin": - self.base_path = base + "chignolin_trajectories/filtered/" - elif "trpcage" in protein: # trpcage0, trpcage1, trpcage2 - self.base_path = base + f"rpcage_trajectories/batches/{protein[-1]}/filtered" - elif protein == "villin": - self.base_path = base + "villin_trajectories/filtered/" - elif "bba" in protein: # bba0, bba1, bba2 - self.base_path = base + f"bba_trajectories/batches/{protein[-1]}/filtered" - - self.num_files = num_files + if proteins is None: + proteins = list(FAST_FOLDING_PROTEINS.keys()) + + self.base_path = base + self.proteins = proteins self.tau = tau + self.time_sort = True self.stride = stride - self.time_sort = time_sort - self.files = self._list_files()[: self.num_files] - self.atom_array = pdb.PDBFile.read( - self.base_path + "filtered.pdb" - ).get_structure()[0] - self.aa_filter = filter_amino_acids(self.atom_array) - self.atom_array = self.atom_array[self.aa_filter] - self.counter = 0 - self.buffer = buffer - - # self.files = self.files[:1000] - self.coords = np.concatenate( - process_map(self._load_subtrajs, range(len(self.files)), max_workers=24) - ) - - if shuffle: - self.shuffler = np.random.permutation(len(self)) - else: - self.shuffler = np.arange(len(self)) - - self.splits = { 'train': self } + self.epoch_size = epoch_size * len(proteins) - print(f"{len(self)} total samples") + self.describe() - def _load_subtrajs(self, idx): - data = mdtraj.load( - self.files[idx], - top=self.base_path + "filtered.pdb", - stride=self.stride, - ) - return data.xyz[:, self.aa_filter, :] * 10 + self.atom_arrays = { + protein: pdb.PDBFile.read( + self.base_path + protein + "_trajectories/0/filtered.pdb" + ).get_structure()[0] for protein in proteins + } + self.atom_arrays = { + protein: atom_array[filter_amino_acids(atom_array)] for protein, atom_array in self.atom_arrays.items() + } - def _list_files(self): - def extract_x_y(filename): - part = os.path.basename(filename).split("_")[0] - x, y = part.strip("e").split("s") - return int(x), int(y) + if padded: self.pad = ProteinPad(pad_size=max([FAST_FOLDING_PROTEINS[protein] for protein in proteins])) + else: self.pad = lambda x: x - files_with_extension = set() - for filename in os.listdir(self.base_path): - if filename.endswith(".xtc") and not filename.startswith("."): - files_with_extension.add(self.base_path + filename) + self.splits = { 'train': self } - files = list(files_with_extension) - if self.time_sort: - return sorted(files, key=lambda x: extract_x_y(x)) - return files + def describe(self): + # file is indexed by protein, trajectory, and frame + self.files = defaultdict(lambda: defaultdict(list)) + self.num_trajectories = defaultdict(int) + self.num_frames_per_traj = defaultdict(lambda: defaultdict(int)) + self.num_frames = defaultdict(int) + + for protein in self.proteins: + protein_path = self.base_path + protein + "_trajectories/" + for trajectory in os.listdir(protein_path): + if trajectory.startswith("."): continue + self.num_trajectories[protein] += 1 + trajectory_path = protein_path + trajectory + "/" + for supframe in os.listdir(trajectory_path): + if supframe.startswith("."): continue + if not supframe.endswith('.mmap'): continue + self.num_frames[protein] += 1 + self.num_frames_per_traj[protein][int(trajectory)] += 1 + if self.files[protein].get(int(trajectory)) is None: + self.files[protein][int(trajectory)] = [] + self.files[protein][int(trajectory)].append(trajectory_path + supframe) + for protein in self.proteins: + print(f"{protein}: {self.num_trajectories[protein]} trajectories, {self.num_frames[protein]} total frames") + # print(f"Trajectory lengths: {self.num_frames_per_traj[protein]}") def __len__(self): - return len(self.coords) - self.tau + return self.epoch_size def __getitem__(self, idx): - # if self.counter > self.buffer: - # index = int(idx / self.buffer) - # index = min(index, len(self.files) - 1) - # self._load_coords(self.files[index]) - # self.counter = 0 + protein = self.proteins[idx % len(self.proteins)] - # self.counter += 1 - # idxx = np.maximum(idx % (self.coords.shape[0] - self.tau),0) - idx1 = self.shuffler[idx] - self.atom_array._coord = self.coords[idx1] + # sample traj and sample file + traj_idx = np.random.randint(0, self.num_trajectories[protein]) + subtraj_idx = np.random.randint(0, self.num_frames_per_traj[protein][traj_idx]) + + mmap_path = self.files[protein][traj_idx][subtraj_idx] + coord = open_memmap(mmap_path, mode='r', dtype=np.float32) + template = self.atom_arrays[protein] + + idx1 = np.random.randint(0, len(coord) - self.tau - 1) + + aa1 = deepcopy(template) + aa1._coord = coord[idx1] + p1 = ProteinDatum.from_atom_array( - self.atom_array, + aa1, header=dict( idcode=None, resolution=None, ), ) + + p1 = self.pad.transform(p1) if self.tau == 0: return p1 - idx2 = self.shuffler[idx + self.tau] - self.atom_array._coord = self.coords[idx2] + idx2 = idx1 + self.tau + aa2 = deepcopy(template) + aa2._coord = coord[idx2] p2 = ProteinDatum.from_atom_array( - self.atom_array, + aa2, header=dict( idcode=None, resolution=None, ), ) - return [p2, p1] \ No newline at end of file + p2 = self.pad.transform(p2) + return [p1, p2] + + diff --git a/moleculib/protein/datum.py b/moleculib/protein/datum.py index ca343c1..c26a0ee 100644 --- a/moleculib/protein/datum.py +++ b/moleculib/protein/datum.py @@ -488,9 +488,8 @@ def plot( view.addStyle({'model': -1}, {'stick': {'radius': 0.2}}, viewer=viewer) if colors is not None: - print('here') colors = {i+1: c for i, c in enumerate(colors)} - view.setStyle({'stick':{'colorscheme':{'prop':'resi','map':colors}}}) + view.addStyle({'model': -1}, {'stick':{'colorscheme':{'prop':'resi','map':colors}}}) return view From 8324d25242aab50e22dd90d2264a9a2aaba63652 Mon Sep 17 00:00:00 2001 From: Allan Costa Date: Mon, 1 Jul 2024 23:28:16 +0000 Subject: [PATCH 5/7] add py3dmol updates --- moleculib/graphics/py3Dmol.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/moleculib/graphics/py3Dmol.py b/moleculib/graphics/py3Dmol.py index 5c3078b..f24da5d 100644 --- a/moleculib/graphics/py3Dmol.py +++ b/moleculib/graphics/py3Dmol.py @@ -211,7 +211,8 @@ def plot_py3dmol_grid( ): v = py3Dmol.view( viewergrid=(len(grid), len(grid[0])), - linked=True, + # linked=True, + linked=False, width=len(grid[0]) * window_size[0], height=len(grid) * window_size[1], ) From 672b849313b62506176ec5fff27d778de7741071 Mon Sep 17 00:00:00 2001 From: Allan Costa Date: Tue, 9 Jul 2024 12:20:44 +0000 Subject: [PATCH 6/7] update traj ds --- moleculib/protein/dataset.py | 96 +++++++++++++++++++++++++++++++++--- 1 file changed, 88 insertions(+), 8 deletions(-) diff --git a/moleculib/protein/dataset.py b/moleculib/protein/dataset.py index 8d0a00f..8e7daae 100644 --- a/moleculib/protein/dataset.py +++ b/moleculib/protein/dataset.py @@ -402,7 +402,7 @@ def __init__( 'homeodomain': 54, 'proteing': 56, 'a3D': 73, - # 'lambda': 80 + 'lambda': 80 } @@ -492,13 +492,15 @@ def __getitem__(self, idx): protein = self.proteins[idx % len(self.proteins)] # sample traj and sample file - traj_idx = np.random.randint(0, self.num_trajectories[protein]) - subtraj_idx = np.random.randint(0, self.num_frames_per_traj[protein][traj_idx]) - - mmap_path = self.files[protein][traj_idx][subtraj_idx] - coord = open_memmap(mmap_path, mode='r', dtype=np.float32) - template = self.atom_arrays[protein] - + while True: + traj_idx = np.random.randint(0, self.num_trajectories[protein]) + subtraj_idx = np.random.randint(0, self.num_frames_per_traj[protein][traj_idx]) + + mmap_path = self.files[protein][traj_idx][subtraj_idx] + coord = open_memmap(mmap_path, mode='r', dtype=np.float32) + template = self.atom_arrays[protein] + if len(coord) > self.tau + 1: break + idx1 = np.random.randint(0, len(coord) - self.tau - 1) aa1 = deepcopy(template) @@ -530,3 +532,81 @@ def __getitem__(self, idx): return [p1, p2] + + + + +from biotite.structure.io import pdb +from moleculib.protein.transform import ProteinPad + + +TAUS = [0, 1, 2, 4, 8, 16] +import webdataset as wds + +from copy import deepcopy +from moleculib.protein.datum import ProteinDatum + + +class ShardedFastFoldingDataset: + + def __init__( + self, + base = "/mas/projects/molecularmachines/db/FastFoldingProteins/web/", + num_shards=240, + proteins=None, + tau=0, + padded=True, + batch_size=1, + ): + assert tau in TAUS, f"tau must be one of {TAUS}" + + proteins = list(FAST_FOLDING_PROTEINS.keys()) + + self.base_path = base + self.proteins = proteins + self.tau = tau + self.time_sort = True + self.num_shards = num_shards + self.batch_size = batch_size + + self.atom_arrays = { + protein: pdb.PDBFile.read( + self.base_path + protein + ".pdb" + ).get_structure()[0] for protein in proteins + } + + if padded: self.pad = ProteinPad(pad_size=max([FAST_FOLDING_PROTEINS[protein] for protein in proteins])) + else: self.pad = lambda x: x + + def build_datum(sample): + key, coords = sample + protein = key.split('_')[-1] + template = self.atom_arrays[protein] + new_aa = deepcopy(template) + new_aa.coord = coords + return self.pad.transform(ProteinDatum.from_atom_array( + new_aa, + header={'idcode': None, 'resolution': None} + )) + + self.web_ds = iter( + wds.WebDataset(base + 'shards-' + '{00000..%05d}.tar' % (num_shards - 1)) + .decode() + .to_tuple('__key__', "coord.npy") + .map(build_datum) + .batched(batch_size, collation_fn=lambda x: x) + ) + + self.splits = { 'train': self } + + def __len__(self): + return self.num_shards * 1000 // self.batch_size + + def __iter__(self): + return self + + def __next__(self): + return next(self.web_ds) + + def __getitem__(self, index): + return next(self) \ No newline at end of file From a9e9ec79fdc2fefea97899fc03fd3a15d4a63a49 Mon Sep 17 00:00:00 2001 From: Allan dos Santos Costa Date: Fri, 12 Jul 2024 16:32:16 -0400 Subject: [PATCH 7/7] clean up + fixes to dataset --- moleculib/assembly/datum.py | 3 +-- moleculib/protein/dataset.py | 48 ++++++++++++++++++++++++++---------- moleculib/protein/datum.py | 44 ++++++++++++++++++++++++++++++--- setup.cfg | 1 + 4 files changed, 78 insertions(+), 18 deletions(-) diff --git a/moleculib/assembly/datum.py b/moleculib/assembly/datum.py index 7c7c194..df6492c 100644 --- a/moleculib/assembly/datum.py +++ b/moleculib/assembly/datum.py @@ -2,13 +2,12 @@ from moleculib.molecule.datum import MoleculeDatum from moleculib.protein.datum import ProteinDatum from typing import List -from simple_pytree import Pytree from biotite.database import rcsb import biotite.structure.io.pdb as pdb from biotite.structure import filter_amino_acids -class AssemblyDatum(Pytree): +class AssemblyDatum: def __init__( self, diff --git a/moleculib/protein/dataset.py b/moleculib/protein/dataset.py index 8e7daae..386f992 100644 --- a/moleculib/protein/dataset.py +++ b/moleculib/protein/dataset.py @@ -559,8 +559,12 @@ def __init__( batch_size=1, ): assert tau in TAUS, f"tau must be one of {TAUS}" - - proteins = list(FAST_FOLDING_PROTEINS.keys()) + + if proteins == None: + proteins = list(FAST_FOLDING_PROTEINS.keys()) + else: + for protein in proteins: + assert protein in FAST_FOLDING_PROTEINS.keys(), f'{protein} is not a valid option' self.base_path = base self.proteins = proteins @@ -571,29 +575,47 @@ def __init__( self.atom_arrays = { protein: pdb.PDBFile.read( - self.base_path + protein + ".pdb" + os.path.join(base, protein + ".pdb") ).get_structure()[0] for protein in proteins } if padded: self.pad = ProteinPad(pad_size=max([FAST_FOLDING_PROTEINS[protein] for protein in proteins])) else: self.pad = lambda x: x - def build_datum(sample): - key, coords = sample + + + def build_webdataset(sample): + if self.tau == 0: + key, coord1 = sample + coords = [ coord1 ] + else: + key, coord1, coord2 = sample + coords = [ coord1, coord2 ] protein = key.split('_')[-1] template = self.atom_arrays[protein] - new_aa = deepcopy(template) - new_aa.coord = coords - return self.pad.transform(ProteinDatum.from_atom_array( - new_aa, - header={'idcode': None, 'resolution': None} - )) + data = [] + for coord in coords: + new_aa = deepcopy(template) + new_aa.coord = coord + data.append( + self.pad.transform( + ProteinDatum.from_atom_array( + new_aa, + header={'idcode': protein, 'resolution': None} + ) + ) + ) + return data + + keys = ('__key__', 'coord.npy') + if self.tau > 0: + keys = keys + (f'coord_{self.tau}.npy', ) self.web_ds = iter( wds.WebDataset(base + 'shards-' + '{00000..%05d}.tar' % (num_shards - 1)) .decode() - .to_tuple('__key__', "coord.npy") - .map(build_datum) + .to_tuple(*keys) + .map(build_webdataset) .batched(batch_size, collation_fn=lambda x: x) ) diff --git a/moleculib/protein/datum.py b/moleculib/protein/datum.py index c26a0ee..7ed1cd9 100644 --- a/moleculib/protein/datum.py +++ b/moleculib/protein/datum.py @@ -31,7 +31,7 @@ ) from einops import rearrange, repeat -from simple_pytree import Pytree + class ProteinSequence: @@ -493,6 +493,44 @@ def plot( return view + def to_atom_array(self): + atom_mask = self.atom_mask.astype(np.bool_) + all_atom_coords = self.atom_coord[atom_mask] + all_atom_tokens = self.atom_token[atom_mask] + all_atom_res_tokens = repeat(self.residue_token, "r -> r a", a=14)[atom_mask] + all_atom_res_indices = repeat(self.residue_index, "r -> r a", a=14)[atom_mask] + + # just in case, move to cpu + atom_mask = np.array(atom_mask) + all_atom_coords = np.array(all_atom_coords) + all_atom_tokens = np.array(all_atom_tokens) + all_atom_res_tokens = np.array(all_atom_res_tokens) + all_atom_res_indices = np.array(all_atom_res_indices) + + atoms = [] + for idx, (coord, token, res_token, res_index) in enumerate( + zip( + all_atom_coords, + all_atom_tokens, + all_atom_res_tokens, + all_atom_res_indices, + ) + ): + name = all_atoms[int(token)] + res_name = all_residues[int(res_token)] + atoms.append( + Atom( + atom_name=name, + element=name[0], + coord=coord, + res_id=res_index, + res_name=res_name, + chain_id='A', + ) + ) + + return AtomArrayConstructor(atoms) + def align_to( self, @@ -502,7 +540,7 @@ def align_to( """ Aligns the current protein datum to another protein datum based on CA atoms. """ - def to_atom_array(prot, mask): + def to_ca_atom_array(prot, mask): cas = prot.atom_coord[..., 1, :] atoms = [ Atom( @@ -520,7 +558,7 @@ def to_atom_array(prot, mask): if window is not None: common_mask = common_mask & (np.arange(len(common_mask)) < window[1]) & (np.arange(len(common_mask)) >= window[0]) - self_array, other_array = to_atom_array(self, common_mask), to_atom_array(other, common_mask) + self_array, other_array = to_ca_atom_array(self, common_mask), to_ca_atom_array(other, common_mask) _, transform = superimpose(other_array, self_array) new_atom_coord = self.atom_coord + transform.center_translation new_atom_coord = np.einsum("rca,ab->rcb", new_atom_coord, transform.rotation.squeeze(0)) diff --git a/setup.cfg b/setup.cfg index d6e61c6..9e82c56 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,3 +32,4 @@ install_requires = py3Dmol rdkit plotly + mdtraj