jaxmodels

A simple API for global epistasis modeling.

class multidms.jaxmodels.Data(x_wt: Int[Array, 'n_mutations'], X: Int[Array, 'n_variants n_mutations'], functional_scores: Float[Array, 'n_variants'], pre_count_wt: Int[Array, ''] | None = None, post_count_wt: Int[Array, ''] | None = None, pre_counts: Int[Array, ' n_variants'] | None = None, post_counts: Int[Array, ' n_variants'] | None = None)

Bases: Module

Data for a DMS experiment.

x_wt: Int[Array, 'n_mutations']

Binary encoding of the wildtype sequence.

X: Int[Array, 'n_variants n_mutations']

Variant encoding matrix (sparse format).

functional_scores: Float[Array, 'n_variants']

Functional scores for each variant.

pre_count_wt: Int[Array, ''] | None = None

Wildtype pre-selection count (optional).

post_count_wt: Int[Array, ''] | None = None

Wildtype post-selection count (optional).

pre_counts: Int[Array, ' n_variants'] | None = None

Pre-selection counts for each variant (optional).

post_counts: Int[Array, ' n_variants'] | None = None

Post-selection counts for each variant (optional).

static from_multidms(multidms_data: Data, condition: str) Self

Create data from a multidms data object.

Parameters:
  • multidms_data – The data to use. Note the WT must be the first variant in each condition.

  • condition – The condition to extract data for.

Returns:

Data object with count data if available in the source.

class multidms.jaxmodels.Latent(β0: Any, β: Any)

Bases: Module

Model a latent phenotype.

β0: Float[Array, '']

Intercept.

β: Float[Array, 'n_mutations']

Mutation effects.

static from_params(β0: Float, β: Float[Array, 'n_mutations']) Self

Create a latent model from explicit parameters.

Parameters:
  • β0 – Intercept value.

  • β – Mutation effects array.

Returns:

Latent model with specified parameters.

static zeros(n_mutations: int, β0: Float = 0.0) Self

Create a zero-initialized latent model with optional intercept.

Parameters:
  • n_mutations – Number of mutations.

  • β0 – Intercept value (default: 0.0).

Returns:

Latent model with β set to zeros and specified β0.

static warmstart(data: Data, l2reg: float = 0.0) Self

Warmstart the latent model.

Parameters:
  • data – Data to initialize the model for.

  • l2reg – L2 regularization strength for warmstart.

Returns:

Latent model initialized with warmstart parameters.

class multidms.jaxmodels.GlobalEpistasis

Bases: Module, ABC

Global epistasis model.

class multidms.jaxmodels.Identity

Bases: GlobalEpistasis

Identity function.

class multidms.jaxmodels.Sigmoid

Bases: GlobalEpistasis

Sigmoid function.

class multidms.jaxmodels.Model(φ: dict[str, Latent], α: dict[str, Float[Array, '']], logθ: dict[str, Float[Array, '']], reference_condition: str, global_epistasis: GlobalEpistasis = Identity())

Bases: Module

Model DMS data.

φ: dict[str, multidms.jaxmodels.Latent]

Latent models for each condition.

α: dict[str, jaxtyping.Float[Array, '']]

Fitness-functional score scaling factors for each condition.

logθ: dict[str, jaxtyping.Float[Array, '']]

Overdispersion parameter for each condition.

reference_condition: str

The condition to use as a reference.

global_epistasis: GlobalEpistasis = Identity()
predict_score(data_sets: dict[str, multidms.jaxmodels.Data]) dict[str, jaxtyping.Float[Array, 'n_variants']]

Predict functional scores, interpreted as \(\log_e\) enrichment wrt WT.

Parameters:

data_sets – Data sets for each condition.

predict_post_count(data_sets: dict[str, multidms.jaxmodels.Data]) dict[str, jaxtyping.Float[Array, 'n_variants']]

Predict post-selection counts.

Parameters:

data_sets – Data sets for each condition.

multidms.jaxmodels.count_loss(model: Model, data_sets: dict[str, multidms.jaxmodels.Data]) dict[str, jaxtyping.Float[Array, '']]

Count-based loss.

Parameters:
  • model – Model to evaluate.

  • data_sets – Data sets for each condition.

Returns:

Loss for each condition.

multidms.jaxmodels.functional_score_loss(model: Model, data_sets: dict[str, multidms.jaxmodels.Data], δ: Float = 1.0) dict[str, jaxtyping.Float[Array, '']]

Huber loss on functional scores.

Parameters:
  • model – Model to evaluate.

  • data_sets – Data sets for each condition.

  • δ – Huber loss parameter.

Returns:

Loss for each condition.

multidms.jaxmodels.fit(data_sets: dict[str, Data], reference_condition: str, l2reg: Float = 0.0, fusionreg: Float = 0.0, beta0_ridge: Float = 0.0, block_iters: int = 10, block_tol: Float = 1e-06, ge_kwargs: dict[str, Any] = {}, cal_kwargs: dict[str, Any] = {}, global_epistasis: GlobalEpistasis = Identity(), loss_fn: Callable[[Model, dict[str, Data]], dict[str, Float[Array, '']]] = <function functional_score_loss>, loss_kwargs: dict[str, Any] = {'δ': 1.0}, warmstart: bool = True, beta0_init: dict[str, Float] | None = None, beta_init: dict[str, Float[Array, ' n_mutations']] | None = None, alpha_init: dict[str, Float] | None = None, beta_clip_range: tuple[Float, Float] | None = None, verbose: bool = True) tuple[Model, list[float]]

Fit a model to data.

Parameters:
  • data_sets – Data to fit to. Each key is a condition.

  • reference_condition – The condition to use as a reference.

  • l2reg – L2 regularization strength for mutation effects.

  • fusionreg – Fusion (shift lasso) regularization strength.

  • beta0_ridge – Ridge penalty for β0 differences from reference condition.

  • block_iters – Number iterations for block coordinate descent.

  • block_tol – Tolerance on objective function for block coordinate descent.

  • ge_kwargs – Keyword arguments for the global epistasis model optimizer.

  • cal_kwargs – Keyword arguments for the experimental calibration parameter optimizer.

  • global_epistasis – Global epistasis model.

  • loss_fn – Loss function.

  • loss_kwargs – Keyword arguments for the loss function.

  • warmstart – Whether to use Ridge regression warmstart (default: True). If True, performs Ridge regression to initialize parameters. The warmstart values will be overridden by any explicit values provided in beta0_init or beta_init.

  • beta0_init – Initial β0 (intercept) values for each condition. If None, uses zeros (or warmstart values if warmstart=True). If dict provided, uses those values for specified conditions.

  • beta_init – Initial β (mutation effects) values for each condition. If None, uses zeros (or warmstart values if warmstart=True). If dict provided, uses those values for specified conditions.

  • alpha_init – Initial α (fitness-functional score scaling) values for each condition. If None, uses 1.0 for all conditions. If dict provided, uses those values for specified conditions.

  • beta_clip_range – Optional tuple of (min, max) values for clipping β parameters. If None, no clipping is applied. Example: (-10.0, 10.0). This constrains mutation effect parameters during optimization to prevent extreme values.

  • verbose – Whether to print progress information during fitting (default: True). If False, suppresses all print output.

Returns:

Tuple of (fitted model, loss trajectory).