model

Defines Model objects for global epistasis modeling using JAX.

class multidms.model.Model(data: Data, loss_type: Literal['functional_score_loss', 'count_loss'] = 'functional_score_loss', ge_type: Literal['Identity', 'Sigmoid'] = 'Sigmoid', l2reg: float = 0.0, fusionreg: float = 0.0, beta0_ridge: float = 0.0)

Bases: object

Model for global epistasis analysis of DMS experiments.

This class wraps the jaxmodels backend to provide a user-friendly interface for fitting global epistasis models to deep mutational scanning data.

Parameters:
  • data (multidms.Data) – Preprocessed DMS data object containing variants and functional scores.

  • loss_type ({'functional_score_loss', 'count_loss'}) – Type of loss function to use. ‘functional_score_loss’ for standard functional score fitting, ‘count_loss’ for count-based enrichment models.

  • ge_type ({'Identity', 'Sigmoid'}) – Global epistasis model type. ‘Identity’ for no global epistasis (linear), ‘Sigmoid’ for sigmoidal transformation.

  • l2reg (float) – L2 regularization strength for mutation effects (default: 0.0).

  • fusionreg (float) – Fusion regularization strength for shift parameters (default: 0.0).

  • beta0_ridge (float) – Ridge penalty for β0 offsets from reference condition (default: 0.0).

Example

>>> import pandas as pd
>>> from multidms import Data, Model
>>> df = pd.DataFrame({
...     'condition': ['a', 'a', 'b', 'b'],
...     'aa_substitutions': ['', 'M1A', '', 'M1A'],
...     'func_score': [0.0, 1.2, 0.1, 1.5]
... })
>>> data = Data(df, reference='a')  
>>> model = Model(data, ge_type='Sigmoid', l2reg=0.01)
>>> model  
Model(ge_type='Sigmoid', loss_type='functional_score_loss')
property data: Data

The Data object used for model fitting.

property params

Model parameters (available after fit).

property converged: bool

Whether the model fitting converged.

Convergence is determined by whether the objective error (relative change in the objective function) at the last iteration was below the tolerance used during fitting.

property convergence_trajectory_df: DataFrame

Convergence trajectory showing objective and loss over iterations.

Returns:

DataFrame with columns iteration, objective_total_trajectory, objective_error_trajectory, loss_trajectory.

Return type:

pd.DataFrame

fit(warmstart: bool = True, maxiter: int = 10, tol: float = 1e-06, beta0_init: dict | None = None, beta_init: dict | None = None, alpha_init: dict | None = None, beta_clip_range: tuple | None = None, ge_kwargs: dict | None = None, cal_kwargs: dict | None = None, loss_kwargs: dict | None = None, verbose: bool = True)

Fit the model to data.

Parameters:
  • warmstart (bool) – Whether to use Ridge regression for parameter initialization (default: True).

  • maxiter (int) – Maximum number of optimization iterations (default: 10).

  • tol (float) – Convergence tolerance on objective function (default: 1e-6).

  • beta0_init (dict, optional) – Initial β0 values per condition.

  • beta_init (dict, optional) – Initial β values per condition.

  • alpha_init (dict, optional) – Initial α scaling values per condition.

  • beta_clip_range (tuple, optional) – Tuple of (min, max) values for clipping β parameters during optimization. Example: (-10.0, 10.0). If None, no clipping is applied.

  • ge_kwargs (dict, optional) – Keyword arguments for global epistasis optimizer (e.g., tol, maxiter, maxls).

  • cal_kwargs (dict, optional) – Keyword arguments for calibration (α) optimizer (e.g., tol, maxiter, maxls).

  • loss_kwargs (dict, optional) – Keyword arguments for the loss function (e.g., δ for Huber loss).

  • verbose (bool) – Whether to print progress information during fitting (default: True).

Returns:

Returns self for method chaining.

Return type:

self

get_mutations_df(phenotype_as_effect: bool = True, times_seen_threshold: int = 0) DataFrame

Extract mutation-level parameters and predicted functional scores.

Parameters:
  • phenotype_as_effect (bool) – If True, report mutation effects. If False, report raw latent phenotypes.

  • times_seen_threshold (int) – Minimum number of times a mutation must be seen in ALL conditions to be included. Default is 0 (no filtering).

Returns:

DataFrame with mutations as rows (index) and columns: - beta_{condition} for each condition - shift_{condition} for each non-reference condition - predicted_func_score_{condition} for each condition Shift parameters represent the difference in beta values between each condition and the reference condition. Predicted functional scores are the model’s predictions for each single mutation on its condition-specific wild-type background.

Return type:

pd.DataFrame

Example

For a model with conditions [‘a’, ‘b’] where ‘a’ is reference: - Columns: beta_a, beta_b, shift_b,

predicted_func_score_a, predicted_func_score_b

  • One row per mutation

get_variants_df(phenotype_as_effect: bool = True) DataFrame

Extract variant-level predictions.

Parameters:

phenotype_as_effect (bool) – If True, report effects. If False, report raw latent phenotypes.

Returns:

Variant-level predictions merged with original data. Includes columns:

  • predicted_func_score: model-predicted functional score, i.e. α * (g(φ(X)) - g(φ(x_wt)))

  • predicted_latent: latent phenotype φ(X)

  • predicted_fitness: predicted fitness in g(φ) space, i.e. predicted_func_score / α + g(φ(x_wt))

  • measured_fitness: measured fitness in g(φ) space, i.e. func_score / α + g(φ(x_wt))

Return type:

pd.DataFrame

property training_loss: dict

Per-condition and total loss on training data.

Returns:

Dictionary mapping condition names and "total" to their training loss values.

Return type:

dict[str, float]

Raises:

ValueError – If model has not been fitted.

eval_loss(df)

Evaluate the model’s loss on an arbitrary DataFrame.

Parameters:

df (pd.DataFrame) – DataFrame with columns ‘condition’, ‘aa_substitutions’, ‘func_score’.

Returns:

Per-condition losses and "total" loss.

Return type:

dict[str, float]

Raises:

ValueError – If model is not fitted, required columns are missing, conditions are invalid, or substitutions contain unseen mutations.

add_phenotypes_to_df(df: DataFrame, substitutions_col: str = 'aa_substitutions', condition_col: str = 'condition', predicted_phenotype_col: str = 'predicted_func_score', overwrite_cols: bool = False) DataFrame

Add model predictions to a DataFrame of variants.

Parameters:
  • df (pd.DataFrame) – DataFrame with columns specified by condition_col and substitutions_col. Additional columns will be preserved in output.

  • substitutions_col (str) – Column in df giving variants as substitution strings. Default is ‘aa_substitutions’.

  • condition_col (str) – Column in df giving the condition for each variant. Values must exist in the model’s conditions. Default is ‘condition’.

  • predicted_phenotype_col (str) – Name of column to add containing predicted functional scores. Default is ‘predicted_func_score’.

  • overwrite_cols (bool) – If the specified predicted phenotype column already exists in df, overwrite it? If False, raise an error.

Returns:

A copy of df with predictions added. Always includes:

  • predicted_func_score (or custom name): predicted functional score

  • predicted_latent: latent phenotype φ(X)

  • predicted_fitness: predicted fitness in g(φ) space

If func_score column is present in df, also includes:

  • measured_fitness: measured fitness in g(φ) space

Return type:

pd.DataFrame

Raises:

ValueError – If model is not fitted, required columns are missing, indices are not unique, conditions are invalid, or substitutions contain mutations not seen during training.

Example

>>> import pandas as pd
>>> from multidms import Data, Model
>>> df_train = pd.DataFrame({
...     'condition': ['a', 'a', 'b', 'b'],
...     'aa_substitutions': ['', 'M1A', '', 'M1A'],
...     'func_score': [0.0, 1.2, 0.1, 1.5]
... })
>>> data = Data(df_train, reference='a')  
>>> model = Model(data, ge_type='Identity', l2reg=0.01)
>>> _ = model.fit(maxiter=5, warmstart=False, verbose=False)
>>> df_new = pd.DataFrame({
...     'condition': ['a', 'b'],
...     'aa_substitutions': ['M1A', 'M1A']
... })
>>> result = model.add_phenotypes_to_df(df_new)
>>> 'predicted_func_score' in result.columns
True
>>> 'predicted_latent' in result.columns
True
>>> 'predicted_fitness' in result.columns
True
>>> len(result)
2
get_ge_landscape_df(n_curve_points: int = 200) tuple

Get data for plotting the global epistasis landscape.

Returns a tuple of (variants_df, ge_curve_df). The variants DataFrame contains per-variant latent phenotype and fitness columns from get_variants_df(), plus a wildtype_latent column for reference-line plotting. The curve DataFrame contains the global epistasis function evaluated over the observed latent phenotype range.

Parameters:

n_curve_points (int) – Number of points for the g(φ) curve grid.

Returns:

(variants_df, ge_curve_df) where:

  • variants_df has all columns from get_variants_df() plus wildtype_latent.

  • ge_curve_df has columns predicted_latent and ge_curve_value.

Return type:

tuple[pd.DataFrame, pd.DataFrame]

get_ge_curve(grid_min: float = -5.0, grid_max: float = 5.0, n_points: int = 200) DataFrame

Evaluate the global epistasis function over a latent phenotype grid.

Parameters:
  • grid_min (float) – Minimum latent phenotype value for the grid.

  • grid_max (float) – Maximum latent phenotype value for the grid.

  • n_points (int) – Number of points in the grid.

Returns:

DataFrame with columns ‘latent’ and ‘observed’.

Return type:

pd.DataFrame

property wildtype_latent: dict

Wildtype latent phenotype for each condition.

Returns:

Dictionary mapping condition names to the wildtype’s latent phenotype value in that condition.

Return type:

dict[str, float]