model¶
Defines Model objects for global epistasis modeling using JAX.
- class multidms.model.Model(data: Data, loss_type: Literal['functional_score_loss', 'count_loss'] = 'functional_score_loss', ge_type: Literal['Identity', 'Sigmoid'] = 'Sigmoid', l2reg: float = 0.0, fusionreg: float = 0.0, beta0_ridge: float = 0.0)¶
Bases:
objectModel for global epistasis analysis of DMS experiments.
This class wraps the jaxmodels backend to provide a user-friendly interface for fitting global epistasis models to deep mutational scanning data.
- Parameters:
data (multidms.Data) – Preprocessed DMS data object containing variants and functional scores.
loss_type ({'functional_score_loss', 'count_loss'}) – Type of loss function to use. ‘functional_score_loss’ for standard functional score fitting, ‘count_loss’ for count-based enrichment models.
ge_type ({'Identity', 'Sigmoid'}) – Global epistasis model type. ‘Identity’ for no global epistasis (linear), ‘Sigmoid’ for sigmoidal transformation.
l2reg (float) – L2 regularization strength for mutation effects (default: 0.0).
fusionreg (float) – Fusion regularization strength for shift parameters (default: 0.0).
beta0_ridge (float) – Ridge penalty for β0 offsets from reference condition (default: 0.0).
Example
>>> import pandas as pd >>> from multidms import Data, Model >>> df = pd.DataFrame({ ... 'condition': ['a', 'a', 'b', 'b'], ... 'aa_substitutions': ['', 'M1A', '', 'M1A'], ... 'func_score': [0.0, 1.2, 0.1, 1.5] ... }) >>> data = Data(df, reference='a') >>> model = Model(data, ge_type='Sigmoid', l2reg=0.01) >>> model Model(ge_type='Sigmoid', loss_type='functional_score_loss')
- property params¶
Model parameters (available after fit).
- property converged: bool¶
Whether the model fitting converged.
Convergence is determined by whether the objective error (relative change in the objective function) at the last iteration was below the tolerance used during fitting.
- property convergence_trajectory_df: DataFrame¶
Convergence trajectory showing objective and loss over iterations.
- Returns:
DataFrame with columns
iteration,objective_total_trajectory,objective_error_trajectory,loss_trajectory.- Return type:
pd.DataFrame
- fit(warmstart: bool = True, maxiter: int = 10, tol: float = 1e-06, beta0_init: dict | None = None, beta_init: dict | None = None, alpha_init: dict | None = None, beta_clip_range: tuple | None = None, ge_kwargs: dict | None = None, cal_kwargs: dict | None = None, loss_kwargs: dict | None = None, verbose: bool = True)¶
Fit the model to data.
- Parameters:
warmstart (bool) – Whether to use Ridge regression for parameter initialization (default: True).
maxiter (int) – Maximum number of optimization iterations (default: 10).
tol (float) – Convergence tolerance on objective function (default: 1e-6).
beta0_init (dict, optional) – Initial β0 values per condition.
beta_init (dict, optional) – Initial β values per condition.
alpha_init (dict, optional) – Initial α scaling values per condition.
beta_clip_range (tuple, optional) – Tuple of (min, max) values for clipping β parameters during optimization. Example: (-10.0, 10.0). If None, no clipping is applied.
ge_kwargs (dict, optional) – Keyword arguments for global epistasis optimizer (e.g., tol, maxiter, maxls).
cal_kwargs (dict, optional) – Keyword arguments for calibration (α) optimizer (e.g., tol, maxiter, maxls).
loss_kwargs (dict, optional) – Keyword arguments for the loss function (e.g., δ for Huber loss).
verbose (bool) – Whether to print progress information during fitting (default: True).
- Returns:
Returns self for method chaining.
- Return type:
self
- get_mutations_df(phenotype_as_effect: bool = True, times_seen_threshold: int = 0) DataFrame¶
Extract mutation-level parameters and predicted functional scores.
- Parameters:
- Returns:
DataFrame with mutations as rows (index) and columns: - beta_{condition} for each condition - shift_{condition} for each non-reference condition - predicted_func_score_{condition} for each condition Shift parameters represent the difference in beta values between each condition and the reference condition. Predicted functional scores are the model’s predictions for each single mutation on its condition-specific wild-type background.
- Return type:
pd.DataFrame
Example
For a model with conditions [‘a’, ‘b’] where ‘a’ is reference: - Columns: beta_a, beta_b, shift_b,
predicted_func_score_a, predicted_func_score_b
One row per mutation
- get_variants_df(phenotype_as_effect: bool = True) DataFrame¶
Extract variant-level predictions.
- Parameters:
phenotype_as_effect (bool) – If True, report effects. If False, report raw latent phenotypes.
- Returns:
Variant-level predictions merged with original data. Includes columns:
predicted_func_score: model-predicted functional score, i.e.α * (g(φ(X)) - g(φ(x_wt)))predicted_latent: latent phenotypeφ(X)predicted_fitness: predicted fitness ing(φ)space, i.e.predicted_func_score / α + g(φ(x_wt))measured_fitness: measured fitness ing(φ)space, i.e.func_score / α + g(φ(x_wt))
- Return type:
pd.DataFrame
- property training_loss: dict¶
Per-condition and total loss on training data.
- Returns:
Dictionary mapping condition names and
"total"to their training loss values.- Return type:
- Raises:
ValueError – If model has not been fitted.
- eval_loss(df)¶
Evaluate the model’s loss on an arbitrary DataFrame.
- Parameters:
df (pd.DataFrame) – DataFrame with columns ‘condition’, ‘aa_substitutions’, ‘func_score’.
- Returns:
Per-condition losses and
"total"loss.- Return type:
- Raises:
ValueError – If model is not fitted, required columns are missing, conditions are invalid, or substitutions contain unseen mutations.
- add_phenotypes_to_df(df: DataFrame, substitutions_col: str = 'aa_substitutions', condition_col: str = 'condition', predicted_phenotype_col: str = 'predicted_func_score', overwrite_cols: bool = False) DataFrame¶
Add model predictions to a DataFrame of variants.
- Parameters:
df (pd.DataFrame) – DataFrame with columns specified by condition_col and substitutions_col. Additional columns will be preserved in output.
substitutions_col (str) – Column in df giving variants as substitution strings. Default is ‘aa_substitutions’.
condition_col (str) – Column in df giving the condition for each variant. Values must exist in the model’s conditions. Default is ‘condition’.
predicted_phenotype_col (str) – Name of column to add containing predicted functional scores. Default is ‘predicted_func_score’.
overwrite_cols (bool) – If the specified predicted phenotype column already exists in df, overwrite it? If False, raise an error.
- Returns:
A copy of df with predictions added. Always includes:
predicted_func_score(or custom name): predicted functional scorepredicted_latent: latent phenotypeφ(X)predicted_fitness: predicted fitness ing(φ)space
If
func_scorecolumn is present in df, also includes:measured_fitness: measured fitness ing(φ)space
- Return type:
pd.DataFrame
- Raises:
ValueError – If model is not fitted, required columns are missing, indices are not unique, conditions are invalid, or substitutions contain mutations not seen during training.
Example
>>> import pandas as pd >>> from multidms import Data, Model >>> df_train = pd.DataFrame({ ... 'condition': ['a', 'a', 'b', 'b'], ... 'aa_substitutions': ['', 'M1A', '', 'M1A'], ... 'func_score': [0.0, 1.2, 0.1, 1.5] ... }) >>> data = Data(df_train, reference='a') >>> model = Model(data, ge_type='Identity', l2reg=0.01) >>> _ = model.fit(maxiter=5, warmstart=False, verbose=False) >>> df_new = pd.DataFrame({ ... 'condition': ['a', 'b'], ... 'aa_substitutions': ['M1A', 'M1A'] ... }) >>> result = model.add_phenotypes_to_df(df_new) >>> 'predicted_func_score' in result.columns True >>> 'predicted_latent' in result.columns True >>> 'predicted_fitness' in result.columns True >>> len(result) 2
- get_ge_landscape_df(n_curve_points: int = 200) tuple¶
Get data for plotting the global epistasis landscape.
Returns a tuple of (variants_df, ge_curve_df). The variants DataFrame contains per-variant latent phenotype and fitness columns from
get_variants_df(), plus awildtype_latentcolumn for reference-line plotting. The curve DataFrame contains the global epistasis function evaluated over the observed latent phenotype range.- Parameters:
n_curve_points (int) – Number of points for the
g(φ)curve grid.- Returns:
(variants_df, ge_curve_df)where:variants_dfhas all columns fromget_variants_df()pluswildtype_latent.ge_curve_dfhas columnspredicted_latentandge_curve_value.
- Return type:
tuple[pd.DataFrame, pd.DataFrame]