This repository contains the reference implementation for a sequence-to-score model based on Bayesian Neural Networks (BNNs). Score is relative to the binding affinity. We train BNNs inside a loop that explicitly simulates the phage display process and its measurement noise, in order to reason jointly about experimental noise and model uncertainty.
Abstract context: Phage display is a powerful technique to study interactions between proteins and other molecules. Its combination with deep learning for protein design is underused due to high noise levels and complex pre-processing and interpretation. We propose a BNN trained within a loop that simulates the experiment and its noise, aiming to enable reliable application by understanding experimental noise and uncertainty. We validate on real binding affinity measurements instead of only proxy values from phage display experiments.
- Bayesian training loop tailored to phage display noise
- CNN/MLP/Transformer fitness predictors with configurable architectures
- Hydra-based configuration for clean overrides and reproducibility
- Optional Neptune logging for experiment tracking
- Optional explainability with Captum (Grad-CAM, LRP, Integrated Gradients, DeepLIFT/SHAP). CAUTION: Explainability uses CDRs to compute certain localization metrics (AUC-ROC, MA). CDRs are only supported for antibody data. For other types of proteins, the explanation features may not work.
scripts/: entrypoints for training and prediction (train.py,predict.py)config/: Hydra config (default.yaml, model and preprocessing profiles)src/: data processing, model definitions, explainers, utilitiestests/: unit tests & more on small datasets for CI. Contain a test that check if your CNN configuration is valid.notebook/: Contains notebooks to visualize the results of the experiment.data/: Folder containing example data. Place your own data here.
- Python >= 3.10
- Optional: NVIDIA GPU + CUDA for training/inference speed + multiple CPUs to train large experiments
# Install uv (see https://docs.astral.sh/uv/ for details)
curl -Ls https://astral.sh/uv/install.sh | sh
# From the repo root
uv sync
# To find the CDRs areas during explanation (not mandatory for training)
uv run ANARCI --build_modelsThis project uses Hydra. The main config is config/default.yaml, which composes:
preprocessing/default.yamlmodel/baseline.yaml
Below is a concise reference of the most relevant options and their allowed values. All keys can be overridden from the CLI.
to_cuda:cudaorcpu. Ifcudabut no GPU is available, it will fall back to CPU.raise_exception_if_gpus_not_available:trueto error out when a GPU is requested but not available; otherwise the code will continue on CPU.use_neptune: Enable/disable Neptune logging (default:false).neptune_project: Neptune project name when logging is enabled.paths.data_dir: Base directory for input data.paths.results_dir: Base directory where experiment outputs are written.
esm_model_name: Hugging Face ESM or ESM++ checkpoint name. Examples:- ESM2:
facebook/esm2_t6_8M_UR50D,facebook/esm2_t12_35M_UR50D,esm2_t30_150M_UR50D,facebook/esm2_t33_650M_UR50D - ESM++:
Synthyra/ESMplusplus_small,Synthyra/ESMplusplus_large
- ESM2:
esm_batch_size: Batch size for embedding extraction.representation_layer: Integer layer index to extract representations from.storage_device:cudaorcpufor where to hold the embedding model.dataloader_type:normal(default; reserved for future options).
type: Profile type (e.g.,baseline).embedding_size:[E, L]— embedding dimension and sequence length. Usually inferred from the chosen ESM model; can be left as reference/default.cnn_sequence_length: Sequence length used when CNN/Transformer consume per-position embeddings.tags: Free-form list of labels to attach to experiments.
splits.training_pair,splits.validation_pair,splits.test_pair: Lists of experiment identifiers to define train/val/test pairs.splits.relative_csv_file_location: Optional mapping by assay (e.g.,DNA,PVP) to CSV paths relative topaths.data_dir.
phage_initial_population_size: Initial clone population size used by the simulator.sequencing_depth: Reads per selection used for measurement noise simulation.
fitness_predictor_architecture: One ofMLP: Classic MLP taking as input the mean value of the ESM-2 embeddings.CNN: Classic 2D-CNN taking as input the ESM-2 embeddings.Transformer: Classic Transformer taking as input the ESM-2 embeddings.HybridCNN: Classic 2D-CNN taking as input the ESM-2 embeddings. Here, only the MLP is BayesianDoubleCNNWithNegative: Create 1 (with 2 outputs) or 2 (with 1 output) CNNs scoring both the positive and negative selection.CNNWithBoltzmann: Create 4, 2 or 1 CNNs (with related outputs) scoring the Boltzmann Constants of the 4 different modes which are: Bind to Base, Do not Bind to Base, Bind to Target, Do not Bind to Target.
Each architecture has a dedicated sub-config:
-
fitness_predictor_mlp:layers_size: Hidden sizes ending with1(output). The first element should match the flattened embedding dimension.
-
fitness_predictor_cnn:list_of_channels: Convolutional channels (per conv block; first equals embedding channels).list_of_kernel_size: Kernel sizes per conv block.list_of_stride: Stride per conv block.list_of_padding: Padding per conv block. If residual connections are used, ensurepadding = (kernel - 1) // 2to preserve length.layers_size: Fully-connected head sizes, ending with1.pooling_name:MaxPool1d,AvgPool1d, ornone.pool_size,pool_stride: Pooling parameters per conv block.batch_norm_name:BatchNorm1dornone.residual_connections:true/falseto enable skip connections.residual_skip_layers: Integer skip distance for residual connections.
-
fitness_predictor_transformer:d_model: Model hidden size (should match embedding channels).nhead: Number of attention heads.num_encoder_layers: Number of encoder layers.dim_feedforward: Feed-forward network size.ffn_activation_name:tanh,sigmoid,ReLU,LeakyReLU.pooling_type:meanorcls.
- Paths and naming:
relative_model_checkpoint_location: Resume from a checkpoint path relative to the experiment directory (ornull).relative_csv_file_location: Optional CSV path relative topaths.data_dirto directly point to training data. Ifnull, the path is inferred fromsplits.experiment_name: Folder name underresults/.
- Optimization and regularization:
learning_rate,weight_decay,clip_norm,dropout_value.activation_name:tanh,sigmoid,ReLU,LeakyReLU.optimizer:AdamW,Adam,RMSprop.use_scheduler:true/falsewithgamma_schedulerandend_of_warmup.
- Bayesian/SVI controls:
prior_scale: Prior std scale for weights.number_of_particles: Number of MC particles for ELBO.kl_annealing: Steps per annealing cycle;1disables annealing.kl_annealing_type:linear,exponential,cyclic.latents_to_anneal: Optional list to restrict annealing to specific latents.guide_type: One ofMultivariateNormal,DiagonalNormal,Normal,LowRankMultivariateNormal.number_of_models:distinct,one,two,sequential_latent,sequential_output.estimate_method:batch,global,N-1,moving_average.EMA_beta: Exponential moving average coefficient for some estimates.
- Loop/runtime:
epochs,batch_size,shuffle,num_workers.best_model_decision_set:trainorvalid.every_n_epochs: Evaluation/artefact saving cadence.
- Likelihood and data tweaks:
distribution: Fully-qualified likelihood class. Provided options include- Poisson:
src.models.baseline_model.PhageDisplayPoissonRegressor - Multinomial:
src.models.baseline_model.PhageDisplayMultinomialRegressor
- Poisson:
delete_initial_null_counts:add_pseudo_counts,delete, ornull.is_null_model:trueto train/evaluate a null baseline for sanity checks.test_data_dir: Optional alternative directory for test-time data.
-
model.prediction:model_dir_path: Path to a trained experiment directory to load the model.experiment_name: Name used for prediction outputs.background_csv_file_location: Optional CSV (relative topaths.data_dir) providing background sequences for certain explainers.to_predict_csv_file_location: CSV to score (relative topaths.data_dir).experiment_pair_background,experiment_pair_predict: Experiment IDs to contextualize background and prediction.explain:true/falseto enable explanation export. Note: explanations are currently computed only whenfitness_predictor_architectureisCNN.use_training_cfg: Iftrue, reuse architecture/training config saved alongside the trained model.number_of_particles: MC samples during predictive inference.save_deterministic_ckpt,deterministic_checkpoint_path: Optional deterministic checkpointing during prediction.
-
model.explainer:type: One oflrp_captum,deeplift,integrated_gradients,deeplift_shap,gradcam_captum,saliency,random.- Generic knobs:
use_padding_for_baseline: Whether to pad sequences for baseline handling.num_model_to_explain: If0, explanations use the mean model; otherwise, number of sampled models.use_neptune,save_images.aggregation_mode:meanorpercentilewithpercentile_unionandpercentile_inter.epsilon_uai: Threshold for UAI+ calculation.
- Method-specific sub-configs (present when applicable):
lrp_captum.rule:epsilonoralpha_betawithepsilon,alpha,beta, pluscalculate_metrics.gradcam_captum.target_layers: Conv layer index or list of indices.deeplift_shap.num_baselines, number of random examples to define the shap baselineintegrated_gradients.steps, number of steps in the integral approximation
You can override any key inline. Examples:
# Choose architecture
uv run train_phage_display model.fitness_predictor_architecture=CNN
# Point to a specific training CSV and turn on Neptune
uv run train_phage_display \
model.training.relative_csv_file_location="data/my_dataset.csv" \
use_neptune=true neptune_project=MyOrg/MyProject
- Neptune can be toggled with
use_neptuneinconfig/default.yaml(default: false). - To enable, set
use_neptune: trueand exportNEPTUNE_API_TOKEN.
You can use the installed console script or run via uv. Don't forget to configurate the train/validation dataset before.
uv run train_phage_displayUpon successful completion of training, a new directory will be created in results/experiments. This folder will include:
- The best model checkpoint saved as a
.pthfile. - A
cfg.yamlfile capturing all relevant model architecture and training parameters.
You should reuse this configuration when running the prediction script. If you enable use_training_cfg, the prediction script will automatically load cfg.yaml from your experiment folder to ensure the architecture and other settings precisely match your trained model. This avoids any issues due to configuration drift between training and inference.
If Neptune logging is enabled (use_neptune: true), training progress and results—including metrics, curves, and checkpoints—will be tracked in your Neptune project for easy monitoring and comparison. Otherwise, training and validation performance curves will be saved locally in the graphs/ directory at the end of training for your review.
The prediction script loads the model config (if present in the experiment directory) to ensure architecture compatibility, scores sequences, and optionally saves explainability outputs. If you don't use this option, ensure that the configuration betzeen the model you are loading and the actual one match. Don't forget to configuate the model checkpoint path created in experiments during Training and prediction dataset before.
uv run predict_phage_displayAfter running prediction and (optionally) explainability analyses, a new results directory will be created under results/experiments with a name starting with predict. This directory will contain comprehensive prediction and explainability outputs, including:
mean_prediction.npy: Array of the mean predicted values for each input sequence.predictions.csv: CSV containing per-sequence prediction results, including batch means and relevant metadata.predictions_raw.npy: All raw predictions for each sequence (no mean-aggregation), enabling accurate standard deviation and uncertainty estimation.sequence.npy: Array of processed sequence representations used in prediction.std_prediction.npy: Standard deviations corresponding to each prediction, reflecting model uncertainty.explainer_images/: Folder containing all visual outputs from the selected explainability method(s), such as saliency maps or attention scores.
These files provide both quantitative results and visual explanations to support further downstream analysis or reporting.
Training and prediction operate on CSV files with sequence columns consistent with the selected experimental pairs. When model.training.relative_csv_file_location is null, the path is inferred from model.splits.relative_csv_file_location using the first training pair. Otherwise, set model.training.relative_csv_file_location to point to your dataset relative to paths.data_dir.
To run the different test, please use the following command line:
uv run pytestThe public datasets included in this repository originate from the following publication:
Boyer et al. (2016) — PubMed
Please credit the original authors when using these data in derivative works or publications.
If you use this codebase in your research, please cite the associated paper once available. A BibTeX entry will be added here upon publication.