jaxmodels¶
A simple API for global epistasis modeling.
- class multidms.jaxmodels.Data(x_wt: Int[Array, 'n_mutations'], X: Int[Array, 'n_variants n_mutations'], functional_scores: Float[Array, 'n_variants'], pre_count_wt: Int[Array, ''] | None = None, post_count_wt: Int[Array, ''] | None = None, pre_counts: Int[Array, ' n_variants'] | None = None, post_counts: Int[Array, ' n_variants'] | None = None)¶
Bases:
ModuleData for a DMS experiment.
- x_wt: Int[Array, 'n_mutations']¶
Binary encoding of the wildtype sequence.
- X: Int[Array, 'n_variants n_mutations']¶
Variant encoding matrix (sparse format).
- functional_scores: Float[Array, 'n_variants']¶
Functional scores for each variant.
- pre_counts: Int[Array, ' n_variants'] | None = None¶
Pre-selection counts for each variant (optional).
- post_counts: Int[Array, ' n_variants'] | None = None¶
Post-selection counts for each variant (optional).
- static from_multidms(multidms_data: Data, condition: str) Self¶
Create data from a multidms data object.
- Parameters:
multidms_data – The data to use. Note the WT must be the first variant in each condition.
condition – The condition to extract data for.
- Returns:
Data object with count data if available in the source.
- class multidms.jaxmodels.Latent(β0: Any, β: Any)¶
Bases:
ModuleModel a latent phenotype.
- β0: Float[Array, '']¶
Intercept.
- β: Float[Array, 'n_mutations']¶
Mutation effects.
- static from_params(β0: Float, β: Float[Array, 'n_mutations']) Self¶
Create a latent model from explicit parameters.
- Parameters:
β0 – Intercept value.
β – Mutation effects array.
- Returns:
Latent model with specified parameters.
- class multidms.jaxmodels.Identity¶
Bases:
GlobalEpistasisIdentity function.
- class multidms.jaxmodels.Sigmoid¶
Bases:
GlobalEpistasisSigmoid function.
- class multidms.jaxmodels.Model(φ: dict[str, Latent], α: dict[str, Float[Array, '']], logθ: dict[str, Float[Array, '']], reference_condition: str, global_epistasis: GlobalEpistasis = Identity())¶
Bases:
ModuleModel DMS data.
- φ: dict[str, multidms.jaxmodels.Latent]¶
Latent models for each condition.
- α: dict[str, jaxtyping.Float[Array, '']]¶
Fitness-functional score scaling factors for each condition.
- global_epistasis: GlobalEpistasis = Identity()¶
- multidms.jaxmodels.count_loss(model: Model, data_sets: dict[str, multidms.jaxmodels.Data]) dict[str, jaxtyping.Float[Array, '']]¶
Count-based loss.
- Parameters:
model – Model to evaluate.
data_sets – Data sets for each condition.
- Returns:
Loss for each condition.
- multidms.jaxmodels.functional_score_loss(model: Model, data_sets: dict[str, multidms.jaxmodels.Data], δ: Float = 1.0) dict[str, jaxtyping.Float[Array, '']]¶
Huber loss on functional scores.
- Parameters:
model – Model to evaluate.
data_sets – Data sets for each condition.
δ – Huber loss parameter.
- Returns:
Loss for each condition.
- multidms.jaxmodels.fit(data_sets: dict[str, Data], reference_condition: str, l2reg: Float = 0.0, fusionreg: Float = 0.0, beta0_ridge: Float = 0.0, block_iters: int = 10, block_tol: Float = 1e-06, ge_kwargs: dict[str, Any] = {}, cal_kwargs: dict[str, Any] = {}, global_epistasis: GlobalEpistasis = Identity(), loss_fn: Callable[[Model, dict[str, Data]], dict[str, Float[Array, '']]] = <function functional_score_loss>, loss_kwargs: dict[str, Any] = {'δ': 1.0}, warmstart: bool = True, beta0_init: dict[str, Float] | None = None, beta_init: dict[str, Float[Array, ' n_mutations']] | None = None, alpha_init: dict[str, Float] | None = None, beta_clip_range: tuple[Float, Float] | None = None, verbose: bool = True) tuple[Model, list[float]]¶
Fit a model to data.
- Parameters:
data_sets – Data to fit to. Each key is a condition.
reference_condition – The condition to use as a reference.
l2reg – L2 regularization strength for mutation effects.
fusionreg – Fusion (shift lasso) regularization strength.
beta0_ridge – Ridge penalty for β0 differences from reference condition.
block_iters – Number iterations for block coordinate descent.
block_tol – Tolerance on objective function for block coordinate descent.
ge_kwargs – Keyword arguments for the global epistasis model optimizer.
cal_kwargs – Keyword arguments for the experimental calibration parameter optimizer.
global_epistasis – Global epistasis model.
loss_fn – Loss function.
loss_kwargs – Keyword arguments for the loss function.
warmstart – Whether to use Ridge regression warmstart (default: True). If True, performs Ridge regression to initialize parameters. The warmstart values will be overridden by any explicit values provided in beta0_init or beta_init.
beta0_init – Initial β0 (intercept) values for each condition. If None, uses zeros (or warmstart values if warmstart=True). If dict provided, uses those values for specified conditions.
beta_init – Initial β (mutation effects) values for each condition. If None, uses zeros (or warmstart values if warmstart=True). If dict provided, uses those values for specified conditions.
alpha_init – Initial α (fitness-functional score scaling) values for each condition. If None, uses 1.0 for all conditions. If dict provided, uses those values for specified conditions.
beta_clip_range – Optional tuple of (min, max) values for clipping β parameters. If None, no clipping is applied. Example: (-10.0, 10.0). This constrains mutation effect parameters during optimization to prevent extreme values.
verbose – Whether to print progress information during fitting (default: True). If False, suppresses all print output.
- Returns:
Tuple of (fitted model, loss trajectory).