model¶
Defines Model objects for global epistasis modeling using JAX.
- class multidms.model.Model(data: Data, loss_type: Literal['functional_score_loss', 'count_loss'] = 'functional_score_loss', ge_type: Literal['Identity', 'Sigmoid'] = 'Sigmoid', l2reg: float = 0.0, fusionreg: float = 0.0, beta0_ridge: float = 0.0)¶
Bases:
objectModel for global epistasis analysis of DMS experiments.
This class wraps the jaxmodels backend to provide a user-friendly interface for fitting global epistasis models to deep mutational scanning data.
- Parameters:
data (multidms.Data) – Preprocessed DMS data object containing variants and functional scores.
loss_type ({'functional_score_loss', 'count_loss'}) – Type of loss function to use. ‘functional_score_loss’ for standard functional score fitting, ‘count_loss’ for count-based enrichment models.
ge_type ({'Identity', 'Sigmoid'}) – Global epistasis model type. ‘Identity’ for no global epistasis (linear), ‘Sigmoid’ for sigmoidal transformation.
l2reg (float) – L2 regularization strength for mutation effects (default: 0.0).
fusionreg (float) – Fusion regularization strength for shift parameters (default: 0.0).
beta0_ridge (float) – Ridge penalty for β0 offsets from reference condition (default: 0.0).
Example
>>> import pandas as pd >>> from multidms import Data, Model >>> df = pd.DataFrame({ ... 'condition': ['a', 'a', 'b', 'b'], ... 'aa_substitutions': ['', 'M1A', '', 'M1A'], ... 'func_score': [0.0, 1.2, 0.1, 1.5] ... }) >>> data = Data(df, reference='a') >>> model = Model(data, ge_type='Sigmoid', l2reg=0.01) >>> model Model(ge_type='Sigmoid', loss_type='functional_score_loss')
- property params¶
Model parameters (available after fit).
- property convergence_trajectory_df: DataFrame¶
Convergence trajectory showing loss over iterations.
- Returns:
DataFrame with columns ‘iteration’, ‘loss’, ‘error’
- Return type:
pd.DataFrame
- fit(warmstart: bool = True, maxiter: int = 10, tol: float = 1e-06, beta0_init: dict | None = None, beta_init: dict | None = None, alpha_init: dict | None = None, beta_clip_range: tuple | None = None, ge_kwargs: dict | None = None, cal_kwargs: dict | None = None, loss_kwargs: dict | None = None, verbose: bool = True)¶
Fit the model to data.
- Parameters:
warmstart (bool) – Whether to use Ridge regression for parameter initialization (default: True).
maxiter (int) – Maximum number of optimization iterations (default: 10).
tol (float) – Convergence tolerance on objective function (default: 1e-6).
beta0_init (dict, optional) – Initial β0 values per condition.
beta_init (dict, optional) – Initial β values per condition.
alpha_init (dict, optional) – Initial α scaling values per condition.
beta_clip_range (tuple, optional) – Tuple of (min, max) values for clipping β parameters during optimization. Example: (-10.0, 10.0). If None, no clipping is applied.
ge_kwargs (dict, optional) – Keyword arguments for global epistasis optimizer (e.g., tol, maxiter, maxls).
cal_kwargs (dict, optional) – Keyword arguments for calibration (α) optimizer (e.g., tol, maxiter, maxls).
loss_kwargs (dict, optional) – Keyword arguments for the loss function (e.g., δ for Huber loss).
verbose (bool) – Whether to print progress information during fitting (default: True).
- Returns:
Returns self for method chaining.
- Return type:
self
- get_mutations_df(phenotype_as_effect: bool = True) DataFrame¶
Extract mutation-level parameters in wide format.
- Parameters:
phenotype_as_effect (bool) – If True, report mutation effects. If False, report raw latent phenotypes.
- Returns:
DataFrame with mutations as rows (index) and columns: - beta_{condition} for each condition - shift_{condition} for each non-reference condition Shift parameters represent the difference in beta values between each condition and the reference condition.
- Return type:
pd.DataFrame
Example
For a model with conditions [‘a’, ‘b’] where ‘a’ is reference: - Columns: beta_a, beta_b, shift_b - One row per mutation
- get_variants_df(phenotype_as_effect: bool = True) DataFrame¶
Extract variant-level predictions.
- Parameters:
phenotype_as_effect (bool) – If True, report effects. If False, report raw latent phenotypes.
- Returns:
Variant-level predictions merged with original data.
- Return type:
pd.DataFrame
- add_phenotypes_to_df(df: DataFrame, substitutions_col: str = 'aa_substitutions', condition_col: str = 'condition', predicted_phenotype_col: str = 'predicted_func_score', overwrite_cols: bool = False) DataFrame¶
Add model predictions to a DataFrame of variants.
- Parameters:
df (pd.DataFrame) – DataFrame with columns specified by condition_col and substitutions_col. Additional columns will be preserved in output.
substitutions_col (str) – Column in df giving variants as substitution strings. Default is ‘aa_substitutions’.
condition_col (str) – Column in df giving the condition for each variant. Values must exist in the model’s conditions. Default is ‘condition’.
predicted_phenotype_col (str) – Name of column to add containing predicted functional scores. Default is ‘predicted_func_score’.
overwrite_cols (bool) – If the specified predicted phenotype column already exists in df, overwrite it? If False, raise an error.
- Returns:
A copy of df with predictions added.
- Return type:
pd.DataFrame
- Raises:
ValueError – If model is not fitted, required columns are missing, indices are not unique, conditions are invalid, or substitutions contain mutations not seen during training.
Example
>>> import pandas as pd >>> from multidms import Data, Model >>> df_train = pd.DataFrame({ ... 'condition': ['a', 'a', 'b', 'b'], ... 'aa_substitutions': ['', 'M1A', '', 'M1A'], ... 'func_score': [0.0, 1.2, 0.1, 1.5] ... }) >>> data = Data(df_train, reference='a') >>> model = Model(data, ge_type='Identity', l2reg=0.01) >>> _ = model.fit(maxiter=5, warmstart=False, verbose=False) >>> df_new = pd.DataFrame({ ... 'condition': ['a', 'b'], ... 'aa_substitutions': ['M1A', 'M1A'] ... }) >>> result = model.add_phenotypes_to_df(df_new) >>> 'predicted_func_score' in result.columns True >>> len(result) 2