model_collection

Contains the ModelCollection class, which takes a collection of models and merges the results for comparison and visualization.

exception multidms.model_collection.ModelCollectionFitError

Bases: Exception

Error fitting models.

multidms.model_collection.fit_one_model(dataset, ge_type='Sigmoid', l2reg=0.0, fusionreg=0.0, beta0_ridge=0.0, loss_type='functional_score_loss', maxiter=10, tol=1e-06, warmstart=True, beta0_init=None, beta_init=None, alpha_init=None, beta_clip_range=None, ge_kwargs=None, cal_kwargs=None, loss_kwargs=None, verbose=False, **kwargs)

Fit a single multidms model to a dataset.

This is a wrapper around Model construction and fitting that saves all hyperparameters for bookkeeping. Used by fit_models() for parallel fitting across parameter sweeps.

Parameters:
  • dataset (multidms.Data) – The dataset to fit to. dataset.name is saved for bookkeeping.

  • ge_type (str) – Global epistasis type: 'Identity' or 'Sigmoid'.

  • l2reg (float) – L2 regularization strength for mutation effects.

  • fusionreg (float) – Fusion (shift lasso) regularization strength.

  • beta0_ridge (float) – Ridge penalty for beta0 differences from reference condition.

  • loss_type (str) – Loss function: 'functional_score_loss' or 'count_loss'.

  • maxiter (int) – Maximum block coordinate descent iterations.

  • tol (float) – Convergence tolerance.

  • warmstart (bool) – Whether to use Ridge regression for initialization.

  • beta0_init (dict, optional) – Initial parameter values per condition.

  • beta_init (dict, optional) – Initial parameter values per condition.

  • alpha_init (dict, optional) – Initial parameter values per condition.

  • beta_clip_range (tuple, optional) – (min, max) clipping for beta parameters.

  • ge_kwargs (dict, optional) – Kwargs for sub-optimizers and loss function.

  • cal_kwargs (dict, optional) – Kwargs for sub-optimizers and loss function.

  • loss_kwargs (dict, optional) – Kwargs for sub-optimizers and loss function.

  • verbose (bool) – Print progress during fitting.

  • **kwargs (dict) – Additional keyword arguments saved for bookkeeping.

Returns:

fit_series – A series containing reference to the fit multidms.Model object and the associated parameters used for the fit, including 'dataset_name' and 'fit_time'.

Return type:

pandas.Series

multidms.model_collection.stack_fit_models(fit_models_list)

Given a list of pd.Series objects returned by fit_one_model, stack them into a single pd.DataFrame

multidms.model_collection.fit_models(params, gpu_ids=None, n_processes=1, n_threads=None, failures='error')

Fit collection of multidms.model.Model models.

Enables fitting of multiple models simultaneously. Most commonly, this function is used to fit a set of models across combinations of replicate training datasets and regularization coefficients for model selection and evaluation. The returned dataframe is meant to be passed into the ModelCollection class for comparison and visualization.

There are two parallelism modes, controlled by mutually exclusive parameters:

  • GPU mode (gpu_ids): Round-robin models across the specified GPUs using jax.default_device and a ThreadPoolExecutor, one model per GPU at a time.

  • CPU mode (n_processes): Spawn independent processes via multiprocessing.Pool with the spawn context.

Parameters:
  • params (dict) – Dictionary which defines the parameter space of all models you wish to run. Each value in the dictionary must be a list of values, even in the case of singletons. This function will compute all combinations of the parameter space and pass each combination to fit_one_model() to be run in parallel, thus only key-value pairs which match the kwargs are allowed. See the docstring of fit_one_model() for a description of the allowed parameters.

  • gpu_ids (list of int, optional) –

    GPU device IDs to use for fitting. Models are round-robin assigned across GPUs, one model per GPU at a time. Uses jax.default_device to pin each fit to a specific GPU. Mutually exclusive with n_processes.

    Note

    The IDs correspond to JAX device IDs from jax.devices("gpu"), which are determined by the CUDA_VISIBLE_DEVICES environment variable at the time JAX is first imported. To use multiple GPUs, ensure CUDA_VISIBLE_DEVICES includes all desired GPU IDs (e.g., export CUDA_VISIBLE_DEVICES=0,1,2,3) before starting Python or Jupyter.

  • n_processes (int) – Number of parallel CPU processes for fitting. Default is 1 (sequential, no multiprocessing overhead). Uses multiprocessing.Pool with the spawn context when > 1. Mutually exclusive with gpu_ids.

  • n_threads (int, optional) –

    Deprecated since version Use: gpu_ids for GPU fitting or n_processes for CPU fitting.

  • failures ({"error", "tolerate"}) – What if fitting fails for a model? If "error" then raise an error, if "tolerate" then just return None for models that failed optimization.

Returns:

Number of models that fit successfully, number of models that failed, and a dataframe which contains a row for each of the multidms.Model object references along with the parameters each was fit with for convenience. The dataframe is ultimately meant to be passed into the ModelCollection class for comparison and visualization.

Return type:

(n_fit, n_failed, fit_models)

class multidms.model_collection.ModelCollection(fit_models)

Bases: object

A class for the comparison and visualization of multiple multidms.Model fits. The respective collection of training datasets for each fit must share the same reference sequence and conditions. Additionally, the inferred site maps must agree upon condition wildtypes for all shared sites.

The utility function multidms.model_collection.fit_models is used to fit the collection of models, and the resulting dataframe is passed to the constructor of this class.

Parameters:

fit_models (pandas.DataFrame) – A dataframe containing the fit attributes and pickled model objects as returned by multidms.model_collection.fit_models.

property site_map_union: DataFrame

The union of all site maps of all datasets used for fitting.

property conditions: list

The conditions (shared by each fitting dataset) used for fitting.

property reference: str

The reference conditions (shared by each fitting dataset) used for fitting.

property shared_mutations: tuple

The mutations shared by each fitting dataset.

property all_mutations: tuple

The mutations shared by each fitting dataset.

split_apply_combine_muts(groupby=('dataset_name', 'fusionreg'), aggregate_func='mean', inner_merge_dataset_muts=True, query=None, **kwargs)

Wrapper to split-apply-combine the set of mutational dataframes harbored by each of the fits in the collection.

Here, we group the collection of fits using attributes (columns in ModelCollection.fit_models) specified using the groupby parameter. Each of the individual fits within a groups may then be filtered via **kwargs, and aggregated via aggregate_func, before the function stacks all the groups back together in a tall style dataframe. The resulting dataframe will have a multiindex with the mutation and the groupby attributes.

Parameters:
  • groupby (str or tuple of str or None, optional) – The attributes to group the fits by. If None, then group by all attributes except for the model, data, and step_loss attributes. The default is (“dataset_name”, “fusionreg”).

  • aggregate_func (str or callable, optional) – The function to aggregate the mutational dataframes within each group. The default is “mean”.

  • inner_merge_dataset_muts (bool, optional) – Whether to toss mutations which are _not_ shared across all datasets before aggregation of group mutation parameter values. The default is True.

  • query (str, optional) – The pandas query to apply to the ModelCollection.fit_models dataframe before splitting. The default is None.

  • **kwargs (dict) – Keyword arguments to pass to the multidms.Model.get_mutations_df() method (phenotype_as_effect and times_seen_threshold). See the method docstring for details.

Returns:

A dataframe containing the aggregated mutational parameter values

Return type:

pandas.DataFrame

add_eval_loss(test_data, overwrite=False)

Add evaluation (validation) loss to the fit collection dataframe.

Parameters:
  • test_data (pd.DataFrame or dict(str, pd.DataFrame)) – The testing dataframe to compute validation loss with respect to, must have columns “aa_substitutions”, “condition”, and “func_score”. If a dictionary is passed, there should be a key for each unique dataset_name factor in the self.fit_models dataframe - with the value being the respective testing dataframe.

  • overwrite (bool, optional) – Whether to overwrite the validation_loss column if it already exists. The default is False.

Return type:

None

loss_df(query=None)

Return a long form dataframe with columns “dataset_name”, “fusionreg”, “split” (“training” or “validation”), “loss” (actual value), and “condition”.

The condition column includes "total" for the summed loss.

Parameters:

query (str, optional) – The query to apply to the fit_models dataframe before formatting the loss dataframe. The default is None.

convergence_trajectory_df(query=None, id_vars=('dataset_name', 'fusionreg'))

Combine the converence trajectory dataframes of all fits in the queried collection.

mut_param_heatmap(query=None, mut_param='shift', aggregate_func='mean', inner_merge_dataset_muts=True, times_seen_threshold=0, phenotype_as_effect=True, **kwargs)

Create lineplot and heatmap altair chart across replicate datasets. This function optionally applies a given pandas.query on the fit_models dataframe that should result in a subset of fit’s which make sense to aggregate mutational data across, e.g. replicate datasets. It then computes the mean or median mutational parameter value (“beta”, “shift”, or “predicted_func_score”) between the remaining fits. and creates an interactive altair chart.

Note that this will throw an error if the queried fits have more than one unique hyper-parameter besides “dataset_name”.

Parameters:
  • query (str) – The query to apply to the fit_models dataframe. This should be used to subset the fits to only those which make sense to aggregate mutational data across, e.g. replicate datasets. For example, if you have a collection of fits with different epistatic models, you may want to query for only those fits with the same epistatic model. e.g. query=”ge_type == ‘Sigmoid’”. For more on the query syntax, see the pandas.query documentation.

  • mut_param (str, optional) – The mutational parameter to plot. The default is “shift”. Must be one of “shift”, “predicted_func_score”, or “beta”.

  • aggregate_func (str, optional) – The function to aggregate the mutational parameter values between dataset fits. The default is “mean”.

  • inner_merge_dataset_muts (bool, optional) – Whether to toss mutations which are _not_ shared across all datasets before aggregation of group mutation parameter values. The default is True.

  • times_seen_threshold (int, optional) – The minimum number of times a mutation must be seen across all conditions within a single fit to be included in the aggregation. The default is 0.

  • phenotype_as_effect (bool, optional) – Passed to Model.get_mutations_df(), Only applies if mut_param=”predicted_func_score”.

  • **kwargs (dict) – Keyword arguments to pass to multidms.plot._lineplot_and_heatmap().

Returns:

A chart object which can be displayed in a jupyter notebook or saved to a file.

Return type:

altair.Chart

mut_param_traceplot(mutations, mut_param='shift', x='fusionreg', width_scalar=100, height_scalar=100, **kwargs)

Visualize mutation parameter values across the lasso penalty weights (by default) of a given subset of the mutations in the form of an altair.FacetChart. This is useful when you would like to confirm that a reported mutational parameter value carries through across the individual fits.

Returns:

A chart object which can be displayed in a jupyter notebook or saved to a file.

Return type:

altair.Chart

shift_sparsity(x='fusionreg', width_scalar=100, height_scalar=100, return_data=False, **kwargs)

Visualize shift parameter set sparsity across the lasso penalty weights (by default) in the form of an altair.FacetChart. We will group the mutations according to their status as either a a “stop” (e.g. A15*), or “nonsynonymous” (e.g. A15G) mutation before calculating the sparsity. This is because in a way, mutations to stop codons act as a False positive rate, as we expect their mutational effect to be equally deleterious in all experiments, and thus have a shift parameter value of zero.

Returns:

A chart object which can be displayed in a jupyter notebook or saved to a file. If return_data=True, then a tuple containing the chart and the underlying data will be returned.

Return type:

altair.Chart or Tuple(pd.DataFrame, altair.Chart)

mut_param_dataset_correlation(x='fusionreg', width_scalar=200, height=200, return_data=False, r=2, **kwargs)

Visualize the correlation between replicate datasets across the lasso penalty weights (by default) in the form of an altair.FacetChart. We compute correlation of mutation parameters accross each pair of datasets in the collection.

Parameters:
  • x (str, optional) – The parameter to plot on the x-axis. The default is “fusionreg”.

  • width_scalar (int, optional) – The width of the chart. The default is 150.

  • height (int, optional) – The height of the chart. The default is 200.

  • return_data (bool, optional) – Whether to return the underlying data. The default is False.

  • r (int, optional) – The exponential of the correlation coefficient reported. May be either 1 for pearson, 2 for coefficient of determination (r-squared), The default is 2.

  • **kwargs (dict) – The keyword arguments to pass to the multidms.model_collection.ModelCollection.split_apply_combine_muts() method. See the method docstring for details.

Returns:

A chart object which can be displayed in a jupyter notebook or saved to a file. If return_data=True, then a tuple containing the chart and the underlying data will be returned.

Return type:

altair.Chart or Tuple(altair.Chart, pd.DataFrame)