diff --git a/moleculib/nucleic/datum.py b/moleculib/nucleic/datum.py index 0f3b325..992a537 100644 --- a/moleculib/nucleic/datum.py +++ b/moleculib/nucleic/datum.py @@ -1,5 +1,7 @@ import numpy as np from Bio.PDB import parse_pdb_header +from Bio.PDB import MMCIFParser, PDBIO, Select + from biotite.database import rcsb from biotite.sequence import ProteinSequence, NucleotideSequence, GeneralSequence, Alphabet from biotite.structure import ( @@ -14,17 +16,22 @@ Atom, superimpose, AffineTransformation, - rmsd + rmsd, + AtomArray ) +import RNA #ViennaRNA from biotite.structure import filter_nucleotides import os -import biotite.structure.io.mmtf as mmtf +# import biotite.structure.io.mmtf as mmtf from einops import rearrange, repeat # from biotite.structure import Atom from biotite.structure import array as AtomArrayConstructor # from biotite.structure import superimpose +from biotite.structure.io.pdb import PDBFile +from biotite.structure.io import pdbx + import plotly.graph_objects as go import plotly.offline as pyo @@ -38,7 +45,11 @@ from pathlib import Path from biotite.structure import filter_nucleotides -from biotite.structure.io.pdb import PDBFile +# from biotite.structure.io import PDBFile, MMCIFFile +from biotite.structure.io.pdb import PDBFile # +# from biotite.structure.io. import MMCIFFile +from biotite.structure.io import pdbx + import numpy as np @@ -47,7 +58,7 @@ config = {"cache_dir": os.path.join(home_dir, ".cache", "moleculib")} #not sure either -def pdb_to_atom_array(pdb_path, RNA=False): +def pdb_to_atom_array(pdb_path, cif, model=None, chain=None, RNA=False, id = None): """_summary_ Args: @@ -59,9 +70,45 @@ def pdb_to_atom_array(pdb_path, RNA=False): Returns: _type_: _description_ """ - pdb_file = PDBFile.read(pdb_path) - atom_array = pdb_file.get_structure( - model=1, extra_fields=["atom_id", "b_factor", "occupancy", "charge"]) + if model == None: + model = 1 + + if cif: + cif_file = pdbx.CIFFile.read(pdb_path) + #try block since model may be inaccurate and supposed to be a chain + try: + atom_array = pdbx.get_structure( + cif_file, model=int(model),extra_fields=["atom_id", "b_factor", "occupancy", "charge"]) + + except ValueError as e: + # Check if the error is specifically about the model not existing + if "the given model" in str(e): + print(f"Model {model} does not exist. Treating input as a chain.") + atom_array = pdbx.get_structure( + cif_file, model=1,extra_fields=["atom_id", "b_factor", "occupancy", "charge"]) + chain = model + + else: + pdb_file = PDBFile.read(pdb_path) + try: + atom_array = pdb_file.get_structure( + model=int(model), extra_fields=["atom_id", "b_factor", "occupancy", "charge"]) + except ValueError as e: + # Check if the error is specifically about the model not existing + if "the given model" in str(e): + print(f"Model {model} does not exist. Treating input as a chain.") + atom_array = pdb_file.get_structure( + model=int(model), extra_fields=["atom_id", "b_factor", "occupancy", "charge"]) + chain = model + + #get only the specific chain: + if chain is not None: + atom_array = atom_array[atom_array.chain_id == chain] + if len(atom_array) ==0 : + print(f"Chain {desired_chain} is not present in the atom array of id {id}.") + else: + print(f'Extracted chain {chain} from the atom array') + nuc_filter = filter_nucleotides(atom_array) if RNA==True: DNA = ["DA", "DC", "DG", "DI", "DT", "DU"] @@ -95,6 +142,7 @@ def __init__( atom_token: np.ndarray, atom_coord: np.ndarray, atom_mask: np.ndarray, + contact_map: np.ndarray = None, #binary map of base pairs [N, N] where 1 indicates 2 nucs are paired **kwargs, ): self.idcode = idcode @@ -107,6 +155,7 @@ def __init__( self.atom_token = atom_token self.atom_coord = atom_coord self.atom_mask = atom_mask + self.contact_map = contact_map for key, value in kwargs.items(): setattr(self, key, value) @@ -227,20 +276,28 @@ def empty_nuc(cls): chain_token=np.array([]), atom_token=np.array([]), atom_coord=np.array([]), - atom_mask=np.array([]) + atom_mask=np.array([]), ) @classmethod - def from_filepath(cls, filepath): - atom_array = pdb_to_atom_array(filepath, RNA=False) #NOTE: CHANGE RNA TO TRUE IF WANT ONLY RNA. filters pdb to only nucleotides - header = parse_pdb_header(filepath) - return cls.from_atom_array(atom_array, header=header) + def from_filepath(cls, filepath, cif, from_filepath=None,model: int = None, chain: str = None, id = None): + atom_array = pdb_to_atom_array(filepath, cif, model = model, chain = chain, RNA=True) #NOTE: CHANGE RNA TO TRUE IF WANT ONLY RNA. filters pdb to only nucleotides + header = parse_pdb_header(filepath) + # print(f'header is {header}') + return cls.from_atom_array(atom_array, header=header, id=id) @classmethod - def fetch_pdb_id(cls, id, save_path=None): - filepath = rcsb.fetch(id, "pdb", save_path) - return cls.from_filepath(filepath) - + def fetch_pdb_id(cls ,id , save_path=None, model: int = None, chain: str = None): ## + cif = False + try: + filepath = rcsb.fetch(id, "pdb", save_path) + exception_raised = False + except: + print(f"PDB format not available for {id}, trying CIF format") + filepath = rcsb.fetch(id, "cif", save_path) + cif = True + return cls.from_filepath(filepath, cif, model = model, chain = chain, id=id) + def set( self, **kwargs, @@ -254,6 +311,8 @@ def from_atom_array( cls, atom_array, header, + id = None, + chain: str = None, ): """ Reshapes atom array to residue-indexed representation to @@ -262,12 +321,17 @@ def from_atom_array( # print("length of atom array: " , len(atom_array)) if atom_array.array_length() == 0: return cls.empty_nuc() + + if chain != None: + atom_array = atom_array[atom_array.chain_name == chain] _, res_names = get_residues(atom_array) res_names = [ ("UNK" if (name not in all_nucs) else name) for name in res_names ] + + sequence = GeneralSequence(Alphabet(all_nucs), list(res_names)) # breakpoint() # index residues globally @@ -339,8 +403,65 @@ def _reshape_residue_attr(attr): residue_mask = residue_mask & (atom_extract["atom_coord"].sum((-1, -2)) != 0) chain_token = _reshape_residue_attr(chain_token) + + def secondary_dot_bracket_to_contact_map(dot_bracket): + length = len(dot_bracket) + contact_map = np.zeros((length, length), dtype=int) + stack = [] + + for i, char in enumerate(dot_bracket): + if char == '(': + stack.append(i) + elif char == ')': + j = stack.pop() + contact_map[i, j] = 1 + contact_map[j, i] = 1 + + return contact_map + + # Create a fold compound for the sequence + seq = str(sequence) + if len(seq) != len(residue_token): + # print(f'len(seq) != len(residue_token), seq is {seq} residue_token is {residue_token}') + #get seq from residue_token: + # rna_res_names = ['A', 'U', 'RT', 'G', 'C', 'I', 'UNK'] + # dna_res_names = ['DA', 'DU', 'DT', 'DG', 'DC', 'DI', 'UNK', 'PAD'] + # dna_res_tokens = list(map(lambda res: get_nucleotide_index(res), dna_res_names)) + # rna_res_tokens = list(map(lambda res: get_nucleotide_index(res), rna_res_names)) + + # rna_res_tokens_dict = {res: get_nucleotide_index(res) for res in rna_res_names} + # dna_res_tokens_dict = {res: get_nucleotide_index(res) for res in dna_res_names} + token_to_rna_letter = {'0': 'A', + '1': 'U', + '2': 'T', + '3': 'G', + '4': 'C', + '5': 'I', + '13': 'N', + '6': 'A', #DNA + '11': 'U',#DNA + '10': 'T',#DNA + '8': 'G',#DNA + '7': 'C',#DNA + '9': 'I',#DNA + '12': 'N'} #PAD + seq ='' + for r in residue_token: + seq += token_to_rna_letter[str(r)] + + + fc = RNA.fold_compound(seq) + mfe_structure, mfe = fc.mfe() #Example of mfe structure "....(...((.())))" + contact_pairs = secondary_dot_bracket_to_contact_map(mfe_structure) + if contact_pairs is None: + print("contact pairs is None, seq is ", seq) + if contact_pairs.shape[0] != len(residue_token): + print("contact_pairs.shape[0] != len(residue_token)") + # print("Done") + + return cls( - idcode=header["idcode"], + idcode=id, sequence=sequence, resolution=header["resolution"], nuc_token=residue_token, @@ -349,6 +470,9 @@ def _reshape_residue_attr(attr): chain_token=chain_token, **atom_extract, atom_mask=atom_mask, + # id=id_, + contact_map = contact_pairs, + )