Based on Distillation of MSA Embeddings to Protein Folded Structures (biorxiv preprint and full latest text).
This repository stands on shoulders of giant work by the scientific community:
- Evolutionary Scale Modeling
- SidechainNet
- Massively Parallel Natural Extension of Reference Frame
- Axial-Attention
- En-Transformer
- Graph-Transformer
- E3NN
For experimental and prototypical access to internal code, these repos are collected under building_blocks
(except sidechainnet). As development progresses they will be incorporated as original imports
.
debug
: whether to use wandb loggingsubmit
: submit training to a SLURM schedulername
: name of experimentnote
: experimental notereport_frequency
: how often to log metrics in log file
dataset_source
: path to sidechainnet-formatted datasetdownsample
: whether to uniformly at random downsample datasetseq_clamp
: clamp data at the sequence level, size of clampmax_seq_len
: throw out data with sequences larger than max_seq_lennum_workers
: number of CPU workers for data fetching and loadingbatch_size
: batch size
-
wipe_edge_information
: drops out all hij -
topography_giveaway
: instead of providing language-model-based hij, produces hij based on ground truth distance and orientation -
giveaway_distance_resolution
: number of bins of relative distance information to input -
giveaway_angle_resolution
: number of bins of relative orientation information to input -
wiring_checkpoint
: checkpoint model-inbetween Dense networks -
use_msa
: use ESM-MSA-1 embeddings -
use_seq
: use ESM-1b embeddings -
use_at
: process hij with an Axial Transformer after distillation -
use_gt
: project 3D coordinates with Graph Transformer after distillation -
use_en
: refine with E(n)-Transformer given coords
node_msa_distill_layers
: hidden layer enumeration of Dense for msa node information extraction [768, 256, 256, 128]edge_msa_distill_layers
: hidden layer enumeration of Dense for msa edge information extraction [96, 64, 64]
node_seq_distill_layers
: hidden layer enumeration of Dense for msa node information extraction [1280, 256, 128]edge_seq_distill_layers
: hidden layer enumeration of Dense for msa edge information extraction [160, 64, 64]
node_ens_distill_layers
: hidden layer enumeration of Dense for msa node information extraction [128, 128, 128]edge_ens_distill_layers
: hidden layer enumeration of Dense for msa edge information extraction [64, 64]
at_checkpoint
: if the axial transformer should be checkpointedat_dim
: axial transformer dimat_depth
: axial transformer depthat_heads
: axial transformer number of attention headsat_dim_head
: axial transformer dim headat_window_size
: axial transformer window size (for internal Long-Short optimization)
gt_checkpoint
: graph transformer checkpointgt_dim
: graph transformer dimgt_edim
: graph transformer edge dimgt_depth
: graph transformer depthgt_heads
: graph transformer number of headsgt_dim_head
: graph trnasformer dim head
gaussian_noise
: if graph transformer is not used, which gaussian noise to be added to backbone as starting pointet_checkpoint
: checkpoint en transformeret_dim
: dim of en transformeret_edim
: en transformer edge dimet_depth
: en transformer depthet_heads
: en transformer num headset_dim_head
: en transformer dim headet_coors_hidden_dim
: hidden dim of internal coordinate-head mixeren_num_neighbors
: num neighbors to consider in 3d spaceen_num_seq_neighbors
: num neighbors to consider in sequence space
unroll_steps
- during training, applies en transformer without gradients up to N, where N ~ U(0, unroll_steps) and each batch gets a different sampletrain_fold_steps
- during training, how many en transformer iterations to perform with gradientseval_fold_steps
- during testing, how many en trasnformer iterations to perform
angle_number_of_bins
- number of bins to use for predicting relative orientationsdistance_number_of_bins
- number of bins to use for predicting relative distancesdistance_max_radius
- maximum radius for predicitng relative distances
-
lr
: learning rate -
at_loss_coeff
: axial transformer loss coefficient -
gt_loss_coeff
: graph transformer loss coefficient -
et_loss_coeff
: en transformer loss coefficient -
et_drmsd
: use drmsd for en transformer -
max_epochs
: number of epochs -
validation_check_rate
: how often to perform validation checks -
validation_start
: when to start validating
coordinate_reset_prob
: legacy, will be removedmsa_wipe_out_prob
: probability of selecting MSA embeddingsmsa_wipe_out_dropout
: dropout of edge and node information for selected MSA embeddings
test_model
: path to model weights for testingretrain_model
: path to model weights for retraining