Skip to content

Commit

Permalink
clean up + fixes to dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Allan dos Santos Costa committed Jul 12, 2024
1 parent 672b849 commit a9e9ec7
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 18 deletions.
3 changes: 1 addition & 2 deletions moleculib/assembly/datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
from moleculib.molecule.datum import MoleculeDatum
from moleculib.protein.datum import ProteinDatum
from typing import List
from simple_pytree import Pytree
from biotite.database import rcsb
import biotite.structure.io.pdb as pdb
from biotite.structure import filter_amino_acids


class AssemblyDatum(Pytree):
class AssemblyDatum:

def __init__(
self,
Expand Down
48 changes: 35 additions & 13 deletions moleculib/protein/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,8 +559,12 @@ def __init__(
batch_size=1,
):
assert tau in TAUS, f"tau must be one of {TAUS}"

proteins = list(FAST_FOLDING_PROTEINS.keys())

if proteins == None:
proteins = list(FAST_FOLDING_PROTEINS.keys())
else:
for protein in proteins:
assert protein in FAST_FOLDING_PROTEINS.keys(), f'{protein} is not a valid option'

self.base_path = base
self.proteins = proteins
Expand All @@ -571,29 +575,47 @@ def __init__(

self.atom_arrays = {
protein: pdb.PDBFile.read(
self.base_path + protein + ".pdb"
os.path.join(base, protein + ".pdb")
).get_structure()[0] for protein in proteins
}

if padded: self.pad = ProteinPad(pad_size=max([FAST_FOLDING_PROTEINS[protein] for protein in proteins]))
else: self.pad = lambda x: x

def build_datum(sample):
key, coords = sample


def build_webdataset(sample):
if self.tau == 0:
key, coord1 = sample
coords = [ coord1 ]
else:
key, coord1, coord2 = sample
coords = [ coord1, coord2 ]
protein = key.split('_')[-1]
template = self.atom_arrays[protein]
new_aa = deepcopy(template)
new_aa.coord = coords
return self.pad.transform(ProteinDatum.from_atom_array(
new_aa,
header={'idcode': None, 'resolution': None}
))
data = []
for coord in coords:
new_aa = deepcopy(template)
new_aa.coord = coord
data.append(
self.pad.transform(
ProteinDatum.from_atom_array(
new_aa,
header={'idcode': protein, 'resolution': None}
)
)
)
return data

keys = ('__key__', 'coord.npy')
if self.tau > 0:
keys = keys + (f'coord_{self.tau}.npy', )

self.web_ds = iter(
wds.WebDataset(base + 'shards-' + '{00000..%05d}.tar' % (num_shards - 1))
.decode()
.to_tuple('__key__', "coord.npy")
.map(build_datum)
.to_tuple(*keys)
.map(build_webdataset)
.batched(batch_size, collation_fn=lambda x: x)
)

Expand Down
44 changes: 41 additions & 3 deletions moleculib/protein/datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)

from einops import rearrange, repeat
from simple_pytree import Pytree


class ProteinSequence:

Expand Down Expand Up @@ -493,6 +493,44 @@ def plot(

return view

def to_atom_array(self):
atom_mask = self.atom_mask.astype(np.bool_)
all_atom_coords = self.atom_coord[atom_mask]
all_atom_tokens = self.atom_token[atom_mask]
all_atom_res_tokens = repeat(self.residue_token, "r -> r a", a=14)[atom_mask]
all_atom_res_indices = repeat(self.residue_index, "r -> r a", a=14)[atom_mask]

# just in case, move to cpu
atom_mask = np.array(atom_mask)
all_atom_coords = np.array(all_atom_coords)
all_atom_tokens = np.array(all_atom_tokens)
all_atom_res_tokens = np.array(all_atom_res_tokens)
all_atom_res_indices = np.array(all_atom_res_indices)

atoms = []
for idx, (coord, token, res_token, res_index) in enumerate(
zip(
all_atom_coords,
all_atom_tokens,
all_atom_res_tokens,
all_atom_res_indices,
)
):
name = all_atoms[int(token)]
res_name = all_residues[int(res_token)]
atoms.append(
Atom(
atom_name=name,
element=name[0],
coord=coord,
res_id=res_index,
res_name=res_name,
chain_id='A',
)
)

return AtomArrayConstructor(atoms)


def align_to(
self,
Expand All @@ -502,7 +540,7 @@ def align_to(
"""
Aligns the current protein datum to another protein datum based on CA atoms.
"""
def to_atom_array(prot, mask):
def to_ca_atom_array(prot, mask):
cas = prot.atom_coord[..., 1, :]
atoms = [
Atom(
Expand All @@ -520,7 +558,7 @@ def to_atom_array(prot, mask):
if window is not None:
common_mask = common_mask & (np.arange(len(common_mask)) < window[1]) & (np.arange(len(common_mask)) >= window[0])

self_array, other_array = to_atom_array(self, common_mask), to_atom_array(other, common_mask)
self_array, other_array = to_ca_atom_array(self, common_mask), to_ca_atom_array(other, common_mask)
_, transform = superimpose(other_array, self_array)
new_atom_coord = self.atom_coord + transform.center_translation
new_atom_coord = np.einsum("rca,ab->rcb", new_atom_coord, transform.rotation.squeeze(0))
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ install_requires =
py3Dmol
rdkit
plotly
mdtraj

0 comments on commit a9e9ec7

Please sign in to comment.