Building a Variant Effect Predictor for Antibody Affinity

Introduction

Antibody affinity maturation—the process of improving the binding strength between an antibody and its target antigen—is a critical step in therapeutic antibody development. Natural affinity maturation occurs in vivo during B cell development, where somatic hypermutation generates antibody variants with improved binding. In vitro affinity maturation mimics this process, using techniques like error-prone PCR, DNA shuffling, or computational design to generate variant libraries that are then screened for improved binding. However, experimental affinity maturation remains resource-intensive, requiring the expression and screening of hundreds to thousands of variants. Machine learning offers a transformative approach: predicting the effect of mutations on binding affinity before experimental testing, enabling intelligent design of focused variant libraries and dramatically accelerating the maturation process.

Understanding Antibody-Antigen Binding

Antibody binding affinity is determined by the complementarity between the antibody’s paratope (the binding site, primarily in the complementarity-determining regions or CDRs) and the antigen’s epitope. The binding energy arises from multiple non-covalent interactions:

  • Hydrogen bonds: Polar interactions between donor and acceptor groups
  • Van der Waals forces: Weak attractive forces from electron correlation
  • Electrostatic interactions: Coulombic attractions between charged residues
  • Hydrophobic effects: Entropic driving force from water exclusion

The dissociation constant $K_D$ quantifies binding affinity:

$$K_D = \frac{k_{off}}{k_{on}}$$

where $k_{off}$ is the dissociation rate and $k_{on}$ is the association rate. Lower $K_D$ indicates higher affinity (stronger binding). Affinity is often reported in molar concentrations, with typical therapeutic antibodies achieving low nanomolar to picomolar affinity.

The change in binding free energy upon mutation, $\Delta \Delta G$, is the key quantity to predict:

$$\Delta \Delta G = \Delta G_{mut} - \Delta G_{wt}$$

where $\Delta G_{wt}$ is the wild-type binding free energy and $\Delta G_{mut}$ is the mutant binding free energy. Positive $\Delta \Delta G$ indicates reduced affinity; negative values indicate improved affinity.

Approaches to Variant Effect Prediction

1. Physics-Based Methods

Physics-based approaches use molecular mechanics force fields or quantum chemical calculations to compute binding energies:

$$E_{binding} = E_{complex} - E_{antibody} - E_{antigen}$$

Popular methods include:

  • Rosetta ddg
  • FoldX
  • Molecular mechanics/generalized Born model
  • Free energy perturbation (FEP)

While physically rigorous, these methods are computationally expensive and require structural models.

2. Knowledge-Based Potentials

Statistical potentials derive interaction preferences from known protein structures:

$$W_{ij} = -\ln \frac{p_{ij}}{p_i p_j}$$

where $p_{ij}$ is the observed frequency of residue pair $(i,j)$ and $p_i$ is the marginal frequency. These potentials can be used to score variants.

3. Machine Learning Approaches

ML approaches learn sequence-function relationships from experimental data:

  • Sequence-based models: CNNs, RNNs, or transformers on antibody sequences
  • Structure-based models: GNNs or 3D CNNs on antibody-antigen complexes
  • Hybrid models: Combining sequence and structural features

Building a Variant Effect Predictor

Data Preparation

The foundation of any ML model is high-quality training data. For antibody affinity prediction, relevant datasets include:

  1. Therapeutic antibody binding data: Published $K_D$ measurements for antibody variants
  2. Deep mutational scanning: High-throughput assays measuring the effect of all mutations at a specific position
  3. ** phage display selections**: Ranking data from enrichment experiments

Each data point consists of:

  • Wild-type antibody sequence (or structure)
  • Mutation(s) introduced
  • Measured affinity change ($\Delta \Delta G$ or fold-change)

Feature Engineering

Features capture the relevant properties of antibodies and mutations:

import numpy as np

# Amino acid physicochemical properties
AA_PROPERTIES = {
    'A': {'hydrophobicity': 1.8, 'charge': 0, 'volume': 88.6, 'polarity': 0, 'mass': 89},
    'R': {'hydrophobicity': -4.5, 'charge': 1, 'volume': 173.4, 'polarity': 52, 'mass': 174},
    # ... complete dictionary
}

def extract_mutation_features(wt_seq, mut_seq, position):
    wt_aa = wt_seq[position]
    mut_aa = mut_seq[position]
    
    features = [
        AA_PROPERTIES[mut_aa]['hydrophobicity'] - AA_PROPERTIES[wt_aa]['hydrophobicity'],
        AA_PROPERTIES[mut_aa]['charge'] - AA_PROPERTIES[wt_aa]['charge'],
        AA_PROPERTIES[mut_aa]['volume'] - AA_PROPERTIES[wt_aa]['volume'],
        AA_PROPERTIES[mut_aa]['polarity'] - AA_PROPERTIES[wt_aa]['polarity'],
        AA_PROPERTIES[mut_aa]['mass'] - AA_PROPERTIES[wt_aa]['mass'],
        # Position-specific features
        position / len(wt_seq),  # Relative position
        1 if position < 3 else 0,  # CDR1
        1 if 50 < position < 65 else 0,  # CDR2
        1 if 95 < position < 110 else 0,  # CDR3
    ]
    return np.array(features)

Sequence Representation

Beyond simple features, modern approaches use learned representations:

  1. One-hot encoding: Simple but high-dimensional
  2. k-mer features: Capturing local sequence context
  3. Pre-trained embeddings: Using ESM-2 or ProtGPT2 for rich contextual representations
from transformers import EsmModel, EsmTokenizer
import torch

class AntibodyEncoder:
    def __init__(self, model_name="facebook/esm2_t33_650M_UR50D"):
        self.model = EsmModel.from_pretrained(model_name)
        self.tokenizer = EsmTokenizer.from_pretrained(model_name)
        
    def encode(self, sequence):
        inputs = self.tokenizer(sequence, return_tensors="pt", 
                               padding=True, truncation=True)
        with torch.no_grad():
            outputs = self.model(**inputs)
        # Use mean pooling of last hidden state
        embeddings = outputs.last_hidden_state.mean(dim=1)
        return embeddings.numpy()

Model Architectures

1. CNN for Sequence Classification

import torch
import torch.nn as nn

class CNNVariantPredictor(nn.Module):
    def __init__(self, vocab_size=21, embed_dim=64, kernel_sizes=[3, 5, 7]):
        super(CNNVariantPredictor, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        self.convs = nn.ModuleList([
            nn.Conv1d(embed_dim, 128, k) for k in kernel_sizes
        ])
        
        self.fc = nn.Sequential(
            nn.Linear(128 * len(kernel_sizes), 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )
        
    def forward(self, x):
        x = self.embedding(x).transpose(1, 2)
        conv_outputs = [torch.relu(conv(x)) for conv in self.convs]
        pooled = [torch.max(conv_out, dim=2)[0] for conv_out in conv_outputs]
        concat = torch.cat(pooled, dim=1)
        return self.fc(concat)

2. Transformer Architecture

class TransformerVariantPredictor(nn.Module):
    def __init__(self, d_model=256, nhead=8, num_layers=4):
        super(TransformerVariantPredictor, self).__init__()
        self.embedding = nn.Embedding(21, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=512)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        
        self.regressor = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(d_model // 2, 1)
        )
        
    def forward(self, x):
        x = self.embedding(x)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        x = x.mean(dim=1)  # Global average pooling
        return self.regressor(x)

3. Graph Neural Network for Structure-Based Prediction

When antibody-antigen complex structures are available, GNNs can model the binding interface:

import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops

class InterfaceGNN(MessagePassing):
    def __init__(self, node_dim, edge_dim, hidden_dim):
        super(InterfaceGNN, self).__init__(aggr='add')
        self.node_encoder = nn.Linear(node_dim, hidden_dim)
        self.edge_encoder = nn.Linear(edge_dim, hidden_dim)
        self.message_mlp = nn.Sequential(
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.update_mlp = nn.Sequential(
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
    def forward(self, x, edge_index, edge_attr):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        edge_attr = self.edge_encoder(edge_attr)
        
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)
    
    def message(self, x_i, x_j, edge_attr):
        return self.message_mlp(torch.cat([x_i, x_j], dim=-1))
    
    def update(self, aggr_out, x):
        return self.update_mlp(torch.cat([x, aggr_out], dim=-1))

Training Strategy

1. Data Splitting

Critical considerations:

  • Avoid data leakage: variants with similar parent sequences should be in same split
  • Use sequence clustering to ensure generalization
  • Validate on held-out targets
from sklearn.model_selection import GroupKFold

# Group variants by parent sequence
groups = variant_df['parent_sequence'].values
gkf = GroupKFold(n_splits=5)

for train_idx, val_idx in gkf.split(X, y, groups):
    # Train/validation split preserving parent groupings

2. Loss Functions

For regression of $\Delta \Delta G$:

  • MSE Loss: Standard mean squared error
  • MAE Loss: More robust to outliers
  • Ranking Loss: Optimize for correct ordering of variants
class RankingLoss(nn.Module):
    def forward(self, pred, true):
        # Ensure predictions correlate with true values
        n = pred.size(0)
        pred_expanded = pred.unsqueeze(1).expand(n, n)
        true_expanded = true.unsqueeze(0).expand(n, n)
        
        # Positive when pred_i > pred_j and true_i > true_j
        ranking = (pred_expanded - pred_expanded.T) * (true_expanded - true_expanded.T)
        return F.relu(1 - ranking).mean()

3. Uncertainty Quantification

Essential for practical use—models should express confidence:

class EnsembleUncertainty:
    def __init__(self, models):
        self.models = models
        
    def predict_with_uncertainty(self, x):
        predictions = torch.stack([model(x) for model in self.models])
        mean = predictions.mean(dim=0)
        std = predictions.std(dim=0)
        return mean, std

Evaluation Metrics

Key metrics for variant effect prediction:

  1. Pearson Correlation: Measures linear relationship between predicted and observed
  2. Spearman Correlation: Measures rank correlation (often more relevant)
  3. RMSE/RMAE: Prediction accuracy in energy units
  4. Top-K Recovery: Fraction of true top-K variants recovered in predicted top-K

Practical Considerations

1. Data Augmentation

  • Reverse mutations: if A→R has known effect, R→A is related
  • Sequence masking: predict masked positions

2. Transfer Learning

Pre-train on general protein mutation data, fine-tune on antibody-specific data:

# Pre-trained on general protein variants
base_model = load_pretrained_model("protein_variant_predictor")
# Fine-tune on antibody data
for param in base_model.parameters():
    param.requires_grad = False
    
# Add antibody-specific head
head = nn.Linear(768, 1)
# Fine-tune head on antibody data

3. Active Learning

Iteratively select variants to test experimentally:

def select_variants_for_testing(model, candidates, n_select=10):
    predictions, uncertainties = model.predict_with_uncertainty(candidates)
    # Select high uncertainty + high predicted improvement
    scores = predictions + 0.5 * uncertainties
    return candidates[torch.topk(scores, n_select).indices]

Applications in Affinity Maturation

A complete ML-driven affinity maturation workflow:

  1. Starting point: Antibody with moderate affinity (Kd ~ 10 nM)
  2. Generate candidates: Use ML to predict effects of all possible CDR mutations
  3. Filter: Remove variants with predicted developability issues
  4. Priority ranking: Select top candidates for experimental testing
  5. Iterate: Use new data to retrain models, repeat

This approach can reduce the number of variants that need experimental testing from thousands to hundreds.

Challenges and Future Directions

  1. Data scarcity: High-quality affinity measurement data is limited
  2. Generalization: Models may not transfer across targets
  3. Epistasis: Interactions between multiple mutations are hard to capture
  4. Structural context: Modeling both bound and unbound states

Future directions include:

  • Foundation models for antibody variants
  • Integration with structure prediction (AlphaFold)
  • Diffusion models for variant generation
  • In silico screening with physics-based refinement

Conclusion

Building an effective variant effect predictor for antibody affinity requires careful attention to data quality, feature engineering, model architecture, and training strategy. While challenging, ML-based approaches offer transformative potential for antibody engineering, enabling rational design of improved therapeutics and dramatically accelerating the affinity maturation process. As training data accumulates and models improve, ML-guided affinity maturation will become standard practice in therapeutic antibody development.

Brook Tilahun
Brook Tilahun
Computational Biology Scientist

Applying machine learning and AI to accelerate therapeutic antibody discovery and protein engineering.