multidms package¶
multidms¶
multidms is a Python package for modeling deep mutational scanning data. In particular, it is designed to model data from more than one experiment, even if they don’t share the same wildtype amino acid sequence. It uses joint modeling to inform parameters across all experiments, while identifying experiment-specific mutation effects which differ.
Importing this package imports the following objects into the package namespace:
For a brief description about how the Model
class works to compose, compile, and optimize the model parameters
- as well as detailed code code documentation for each of the
equations described in the
biophysical docs -
see:
biophysical
plot mostly contains code for interactive plotting
at the moment.
It also imports the following alphabets:
AAS
AAS_WITHSTOP
AAS_WITHGAP
AAS_WITHSTOP_WITHGAP
- class multidms.Data(variants_df: DataFrame, reference: str, alphabet=('A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'), collapse_identical_variants=False, condition_colors=('#0072B2', '#CC79A7', '#009E73', '#17BECF', '#BCDB22'), letter_suffixed_sites=False, assert_site_integrity=False, verbose=False, name=None, include_counts=False)¶
Bases:
objectPrep and store one-hot encoding of variant substitutions data. Individual objects of this type can be shared by multiple
multidms.ModelObjects for efficiently fitting various models to the same data.Note
You can initialize a
Dataobject with apandas.DataFramewith a row for each variant sampled and annotations provided in the required columns:- condition - Experimental condition from
which a sample measurement was obtained.
- aa_substitutions - Defines each variant
\(v\) as a string of substitutions (e.g.,
'M3A K5G'). Note that while conditions may have differing wild types at a given site, the sites between conditions should reference the same site when alignment is performed between condition wild types.
- func_score - The functional score computed from experimental
measurements.
- Parameters:
variants_df (
pandas.DataFrameor None) – The variant level information from all experiments you wish to analyze. Should have columns named'condition','aa_substitutions', and'func_score'. See the class note for descriptions of each of the features.reference (str) – Name of the condition which annotates the reference. variants. Note that for model fitting this class will convert all amino acid substitutions for non-reference condition groups to relative to the reference condition. For example, if the wild type amino acid at site 30 is an A in the reference condition, and a G in a non-reference condition, then a Y30G mutation in the non-reference condition is recorded as an A30G mutation relative to the reference. This way, each condition informs the exact same parameters, even at sites that differ in wild type amino acid. These are encoded in a
binarymap.binarymap.BinaryMapobject for each condition, where all sites that are non-identical to the reference are 1’s. For motivation, see the Model overview section inmultidms.Modelclass notes.alphabet (array-like) – Allowed characters in mutation strings.
collapse_identical_variants ({'mean', 'median', False}) – If identical variants in
variants_df(same ‘aa_substitutions’), exist within individual condition groups, collapse them by taking mean or median of ‘func_score’, or (if False) do not collapse at all. Collapsing will make fitting faster, but not a good idea if you are doing bootstrapping.condition_colors (array-like or dict) – Maps each condition to the color used for plotting. Either a dict keyed by each condition, or an array of colors that are sequentially assigned to the conditions.
letter_suffixed_sites (bool) – True if sites are sequential and integer, False otherwise.
assert_site_integrity (bool) – If True, will assert that all sites in the data frame have the same wild type amino acid, grouped by condition.
verbose (bool) – If True, will print progress bars.
name (str or None) – Name of the data object. If None, will be assigned a unique name based upon the number of data objects instantiated.
include_counts (bool) – If True, expects ‘pre_count’ and ‘post_count’ columns in the input DataFrame and includes them in the data arrays. If False (default), these columns are not required and count data will not be available.
Example
Simple example with two conditions (
'a'and'b')>>> import pandas as pd >>> import multidms >>> func_score_data = { ... 'condition' : ["a","a","a","a", "b","b","b","b","b","b"], ... 'aa_substitutions' : [ ... 'M1E', 'G3R', 'G3P', 'M1W', 'M1E', ... 'P3R', 'P3G', 'M1E P3G', 'M1E P3R', 'P2T' ... ], ... 'func_score' : [2, -7, -0.5, 2.3, 1, -5, 0.4, 2.7, -2.7, 0.3], ... } >>> func_score_df = pd.DataFrame(func_score_data) >>> func_score_df condition aa_substitutions func_score 0 a M1E 2.0 1 a G3R -7.0 2 a G3P -0.5 3 a M1W 2.3 4 b M1E 1.0 5 b P3R -5.0 6 b P3G 0.4 7 b M1E P3G 2.7 8 b M1E P3R -2.7 9 b P2T 0.3
Instantiate a
DataObject allowing for stop codon variants and declaring condition “a” as the reference condition.>>> data = multidms.Data( ... func_score_df, ... alphabet = multidms.AAS_WITHSTOP, ... reference = "a", ... ) ...
Note this may take some time due to the string operations that must be performed when converting amino acid substitutions to be with respect to a reference wild type sequence.
After the object has finished being instantiated, we now have access to a few ‘static’ properties of our data. See individual property docstring for more information.
>>> data.reference 'a'
>>> data.conditions ('a', 'b')
>>> data.mutations ('M1E', 'M1W', 'G3P', 'G3R')
>>> data.site_map a b 1 M M 3 G P
>>> data.mutations_df mutation wts sites muts times_seen_a times_seen_b 0 M1E M 1 E 1 3 1 M1W M 1 W 1 0 2 G3P G 3 P 1 4 3 G3R G 3 R 1 2
>>> data.variants_df condition aa_substitutions func_score var_wrt_ref 0 a M1E 2.0 M1E 1 a G3R -7.0 G3R 2 a G3P -0.5 G3P 3 a M1W 2.3 M1W 4 b M1E 1.0 G3P M1E 5 b P3R -5.0 G3R 6 b P3G 0.4 7 b M1E P3G 2.7 M1E 8 b M1E P3R -2.7 G3R M1E
- property mutations: tuple¶
A tuple of all mutations in the order relative to their index into the binarymap.
- property site_map: DataFrame¶
A dataframe indexed by site, with columns for all conditions giving the wild type amino acid at each site.
- property non_identical_mutations: dict¶
A dictionary keyed by condition names with values being a string of all mutations that differ from the reference sequence.
- property non_identical_sites: dict¶
A dictionary keyed by condition names with values being a
pandas.DataFrameindexed by site, with columns for the reference and non-reference amino acid at each site that differs.
- property bundle_idxs: dict¶
A dictionary keyed by condition names with values being the indices into the binarymap representing bundle (non_identical) mutations
- property reference_sequence_conditions: list¶
A list of conditions that have the same wild type sequence as the reference condition.
- property binarymaps: dict¶
A dictionary keyed by condition names with values being a
BinaryMapobject for each condition.
- property mutparser: MutationParser¶
The mutation
polyclonal.utils.MutationParserused to parse mutations.
- property parse_mut: MutationParser¶
Returns a function that splits a single amino acid substitutions into wildtype, site, and mutation using the mutation parser.
- property parse_muts: partial¶
A function that splits amino acid substitutions (a string of more than one) into wildtype, site, and mutation using the mutation parser.
- property single_mut_encodings¶
A dictionary keyed by condition names with values being the one-hot encoding of all single mutations
- convert_subs_wrt_ref_seq(condition, aa_subs)¶
Covert amino acid substitutions to be with respect to the reference sequence.
- plot_times_seen_hist(saveas=None, show=True, **kwargs)¶
Plot a histogram of the number of times each mutation was seen.
- plot_func_score_boxplot(saveas=None, show=True, **kwargs)¶
Plot a boxplot of the functional scores for each condition.
- class multidms.Model(data: Data, loss_type: Literal['functional_score_loss', 'count_loss'] = 'functional_score_loss', ge_type: Literal['Identity', 'Sigmoid'] = 'Sigmoid', l2reg: float = 0.0, fusionreg: float = 0.0, beta0_ridge: float = 0.0)¶
Bases:
objectModel for global epistasis analysis of DMS experiments.
This class wraps the jaxmodels backend to provide a user-friendly interface for fitting global epistasis models to deep mutational scanning data.
- Parameters:
data (multidms.Data) – Preprocessed DMS data object containing variants and functional scores.
loss_type ({'functional_score_loss', 'count_loss'}) – Type of loss function to use. ‘functional_score_loss’ for standard functional score fitting, ‘count_loss’ for count-based enrichment models.
ge_type ({'Identity', 'Sigmoid'}) – Global epistasis model type. ‘Identity’ for no global epistasis (linear), ‘Sigmoid’ for sigmoidal transformation.
l2reg (float) – L2 regularization strength for mutation effects (default: 0.0).
fusionreg (float) – Fusion regularization strength for shift parameters (default: 0.0).
beta0_ridge (float) – Ridge penalty for β0 offsets from reference condition (default: 0.0).
Example
>>> import pandas as pd >>> from multidms import Data, Model >>> df = pd.DataFrame({ ... 'condition': ['a', 'a', 'b', 'b'], ... 'aa_substitutions': ['', 'M1A', '', 'M1A'], ... 'func_score': [0.0, 1.2, 0.1, 1.5] ... }) >>> data = Data(df, reference='a') >>> model = Model(data, ge_type='Sigmoid', l2reg=0.01) >>> model Model(ge_type='Sigmoid', loss_type='functional_score_loss')
- property params¶
Model parameters (available after fit).
- property convergence_trajectory_df: DataFrame¶
Convergence trajectory showing loss over iterations.
- Returns:
DataFrame with columns ‘iteration’, ‘loss’, ‘error’
- Return type:
pd.DataFrame
- fit(warmstart: bool = True, maxiter: int = 10, tol: float = 1e-06, beta0_init: dict | None = None, beta_init: dict | None = None, alpha_init: dict | None = None, beta_clip_range: tuple | None = None, ge_kwargs: dict | None = None, cal_kwargs: dict | None = None, loss_kwargs: dict | None = None, verbose: bool = True)¶
Fit the model to data.
- Parameters:
warmstart (bool) – Whether to use Ridge regression for parameter initialization (default: True).
maxiter (int) – Maximum number of optimization iterations (default: 10).
tol (float) – Convergence tolerance on objective function (default: 1e-6).
beta0_init (dict, optional) – Initial β0 values per condition.
beta_init (dict, optional) – Initial β values per condition.
alpha_init (dict, optional) – Initial α scaling values per condition.
beta_clip_range (tuple, optional) – Tuple of (min, max) values for clipping β parameters during optimization. Example: (-10.0, 10.0). If None, no clipping is applied.
ge_kwargs (dict, optional) – Keyword arguments for global epistasis optimizer (e.g., tol, maxiter, maxls).
cal_kwargs (dict, optional) – Keyword arguments for calibration (α) optimizer (e.g., tol, maxiter, maxls).
loss_kwargs (dict, optional) – Keyword arguments for the loss function (e.g., δ for Huber loss).
verbose (bool) – Whether to print progress information during fitting (default: True).
- Returns:
Returns self for method chaining.
- Return type:
self
- get_mutations_df(phenotype_as_effect: bool = True) DataFrame¶
Extract mutation-level parameters in wide format.
- Parameters:
phenotype_as_effect (bool) – If True, report mutation effects. If False, report raw latent phenotypes.
- Returns:
DataFrame with mutations as rows (index) and columns: - beta_{condition} for each condition - shift_{condition} for each non-reference condition Shift parameters represent the difference in beta values between each condition and the reference condition.
- Return type:
pd.DataFrame
Example
For a model with conditions [‘a’, ‘b’] where ‘a’ is reference: - Columns: beta_a, beta_b, shift_b - One row per mutation
- get_variants_df(phenotype_as_effect: bool = True) DataFrame¶
Extract variant-level predictions.
- Parameters:
phenotype_as_effect (bool) – If True, report effects. If False, report raw latent phenotypes.
- Returns:
Variant-level predictions merged with original data.
- Return type:
pd.DataFrame
- add_phenotypes_to_df(df: DataFrame, substitutions_col: str = 'aa_substitutions', condition_col: str = 'condition', predicted_phenotype_col: str = 'predicted_func_score', overwrite_cols: bool = False) DataFrame¶
Add model predictions to a DataFrame of variants.
- Parameters:
df (pd.DataFrame) – DataFrame with columns specified by condition_col and substitutions_col. Additional columns will be preserved in output.
substitutions_col (str) – Column in df giving variants as substitution strings. Default is ‘aa_substitutions’.
condition_col (str) – Column in df giving the condition for each variant. Values must exist in the model’s conditions. Default is ‘condition’.
predicted_phenotype_col (str) – Name of column to add containing predicted functional scores. Default is ‘predicted_func_score’.
overwrite_cols (bool) – If the specified predicted phenotype column already exists in df, overwrite it? If False, raise an error.
- Returns:
A copy of df with predictions added.
- Return type:
pd.DataFrame
- Raises:
ValueError – If model is not fitted, required columns are missing, indices are not unique, conditions are invalid, or substitutions contain mutations not seen during training.
Example
>>> import pandas as pd >>> from multidms import Data, Model >>> df_train = pd.DataFrame({ ... 'condition': ['a', 'a', 'b', 'b'], ... 'aa_substitutions': ['', 'M1A', '', 'M1A'], ... 'func_score': [0.0, 1.2, 0.1, 1.5] ... }) >>> data = Data(df_train, reference='a') >>> model = Model(data, ge_type='Identity', l2reg=0.01) >>> _ = model.fit(maxiter=5, warmstart=False, verbose=False) >>> df_new = pd.DataFrame({ ... 'condition': ['a', 'b'], ... 'aa_substitutions': ['M1A', 'M1A'] ... }) >>> result = model.add_phenotypes_to_df(df_new) >>> 'predicted_func_score' in result.columns True >>> len(result) 2
- class multidms.ModelCollection(fit_models)¶
Bases:
objectA class for the comparison and visualization of multiple multidms.Model fits. The respective collection of training datasets for each fit must share the same reference sequence and conditions. Additionally, the inferred site maps must agree upon condition wildtypes for all shared sites.
The utility function multidms.model_collection.fit_models is used to fit the collection of models, and the resulting dataframe is passed to the constructor of this class.
- Parameters:
fit_models (
pandas.DataFrame) – A dataframe containing the fit attributes and pickled model objects as returned by multidms.model_collection.fit_models.
- property reference: str¶
The reference conditions (shared by each fitting dataset) used for fitting.
The mutations shared by each fitting dataset.
- split_apply_combine_muts(groupby=('dataset_name', 'scale_coeff_lasso_shift'), aggregate_func='mean', inner_merge_dataset_muts=True, query=None, **kwargs)¶
Wrapper to split-apply-combine the set of mutational dataframes harbored by each of the fits in the collection.
Here, we group the collection of fits using attributes (columns in
ModelCollection.fit_models) specified using thegroupbyparameter. Each of the individual fits within a groups may then be filtered via**kwargs, and aggregated viaaggregate_func, before the function stacks all the groups back together in a tall style dataframe. The resulting dataframe will have a multiindex with the mutation and the groupby attributes.- Parameters:
groupby (str or tuple of str or None, optional) – The attributes to group the fits by. If None, then group by all attributes except for the model, data, and step_loss attributes. The default is (“dataset_name”, “scale_coeff_lasso_shift”).
aggregate_func (str or callable, optional) – The function to aggregate the mutational dataframes within each group. The default is “mean”.
inner_merge_dataset_muts (bool, optional) – Whether to toss mutations which are _not_ shared across all datasets before aggregation of group mutation parameter values. The default is True.
query (str, optional) – The pandas query to apply to the ModelCollection.fit_models dataframe before splitting. The default is None.
**kwargs (dict) – Keyword arguments to pass to the
multidms.Model.get_mutations_df()method (“phenotype_as_effect”, and “times_seen_threshold”) see the method docstring for details.
- Returns:
A dataframe containing the aggregated mutational parameter values
- Return type:
- add_validation_loss(test_data, overwrite=False)¶
Add validation loss to the fit collection dataframe.
- Parameters:
test_data (pd.DataFrame or dict(str, pd.DataFrame)) – The testing dataframe to compute validation loss with respect to, must have columns “aa_substitutitions”, “condition”, and “func_score”. If a dictionary is passed, there should be a key for each unique dataset_name factor in the self.fit_models dataframe - with the value being the respective testing dataframe.
overwrite (bool, optional) – Whether to overwrite the validation_loss column if it already exists. The default is False.
- Returns:
The self.fit_models dataframe with the validation loss added.
- Return type:
pd.DataFrame
- get_conditional_loss_df(query=None)¶
Return a long form dataframe with columns “dataset_name”, “scale_coeff_lasso_shift”, “split” (“training” or “validation”), “loss” (actual value), and “condition”.
- Parameters:
query (str, optional) – The query to apply to the fit_models dataframe before formatting the loss dataframe. The default is None.
- convergence_trajectory_df(query=None, id_vars=('dataset_name', 'scale_coeff_lasso_shift'))¶
Combine the converence trajectory dataframes of all fits in the queried collection.
- mut_param_heatmap(query=None, mut_param='shift', aggregate_func='mean', inner_merge_dataset_muts=True, times_seen_threshold=0, phenotype_as_effect=True, **kwargs)¶
Create lineplot and heatmap altair chart across replicate datasets. This function optionally applies a given pandas.query on the fit_models dataframe that should result in a subset of fit’s which make sense to aggregate mutational data across, e.g. replicate datasets. It then computes the mean or median mutational parameter value (“beta”, “shift”, or “predicted_func_score”) between the remaining fits. and creates an interactive altair chart.
Note that this will throw an error if the queried fits have more than one unique hyper-parameter besides “dataset_name”.
- Parameters:
query (str) – The query to apply to the fit_models dataframe. This should be used to subset the fits to only those which make sense to aggregate mutational data across, e.g. replicate datasets. For example, if you have a collection of fits with different epistatic models, you may want to query for only those fits with the same epistatic model. e.g. query=”epistatic_model == ‘Sigmoid’”. For more on the query syntax, see the pandas.query documentation.
mut_param (str, optional) – The mutational parameter to plot. The default is “shift”. Must be one of “shift”, “predicted_func_score”, or “beta”.
aggregate_func (str, optional) – The function to aggregate the mutational parameter values between dataset fits. The default is “mean”.
inner_merge_dataset_muts (bool, optional) – Whether to toss mutations which are _not_ shared across all datasets before aggregation of group mutation parameter values. The default is True.
times_seen_threshold (int, optional) – The minimum number of times a mutation must be seen across all conditions within a single fit to be included in the aggregation. The default is 0.
phenotype_as_effect (bool, optional) – Passed to Model.get_mutations_df(), Only applies if mut_param=”predicted_func_score”.
**kwargs (dict) – Keyword arguments to pass to
multidms.plot._lineplot_and_heatmap().
- Returns:
A chart object which can be displayed in a jupyter notebook or saved to a file.
- Return type:
altair.Chart
- mut_param_traceplot(mutations, mut_param='shift', x='scale_coeff_lasso_shift', width_scalar=100, height_scalar=100, **kwargs)¶
Visualize mutation parameter values across the lasso penalty weights (by default) of a given subset of the mutations in the form of an altair.FacetChart. This is useful when you would like to confirm that a reported mutational parameter value carries through across the individual fits.
- Returns:
A chart object which can be displayed in a jupyter notebook or saved to a file.
- Return type:
altair.Chart
- shift_sparsity(x='scale_coeff_lasso_shift', width_scalar=100, height_scalar=100, return_data=False, **kwargs)¶
Visualize shift parameter set sparsity across the lasso penalty weights (by default) in the form of an altair.FacetChart. We will group the mutations according to their status as either a a “stop” (e.g. A15*), or “nonsynonymous” (e.g. A15G) mutation before calculating the sparsity. This is because in a way, mutations to stop codons act as a False positive rate, as we expect their mutational effect to be equally deleterious in all experiments, and thus have a shift parameter value of zero.
- Returns:
A chart object which can be displayed in a jupyter notebook or saved to a file. If return_data=True, then a tuple containing the chart and the underlying data will be returned.
- Return type:
altair.Chart or Tuple(pd.DataFrame, altair.Chart)
- mut_param_dataset_correlation(x='scale_coeff_lasso_shift', width_scalar=200, height=200, return_data=False, r=2, **kwargs)¶
Visualize the correlation between replicate datasets across the lasso penalty weights (by default) in the form of an altair.FacetChart. We compute correlation of mutation parameters accross each pair of datasets in the collection.
- Parameters:
x (str, optional) – The parameter to plot on the x-axis. The default is “scale_coeff_lasso_shift”.
width_scalar (int, optional) – The width of the chart. The default is 150.
height (int, optional) – The height of the chart. The default is 200.
return_data (bool, optional) – Whether to return the underlying data. The default is False.
r (int, optional) – The exponential of the correlation coefficient reported. May be either 1 for pearson, 2 for coefficient of determination (r-squared), The default is 2.
**kwargs (dict) – The keyword arguments to pass to the
multidms.model_collection.ModelCollection.split_apply_combine_muts()method. See the method docstring for details.
- Returns:
A chart object which can be displayed in a jupyter notebook or saved to a file. If return_data=True, then a tuple containing the chart and the underlying data will be returned.
- Return type:
altair.Chart or Tuple(altair.Chart, pd.DataFrame)
- multidms.fit_models(params, n_threads=-1, failures='error')¶
Fit collection of
multidms.model.Modelmodels.Enables fitting of multiple models simultaneously using multiple threads. Most commonly, this function is used to fit a set of models across combinations of replicate training datasets, and lasso coefficients for model selection and evaluation. The returned dataframe is meant to be passed into the
multidms.model_collection.ModelCollectionclass for comparison and visualization.- Parameters:
params (dict) – Dictionary which defines the parameter space of all models you wish to run. Each value in the dictionary must be a list of values, even in the case of singletons. This function will compute all combinations of the parameter space and pass each combination to
multidms.utils.fit_one_model()to be run in parallel, thus only key-value pairs which match the kwargs are allowed. See the docstring ofmultidms.model_collection.fit_one_model()for a description of the allowed parameters.n_threads (int) – Number of threads (CPUs, cores) to use for fitting. Set to -1 to use all CPUs available.
failures ({"error", "tolerate"}) – What if fitting fails for a model? If “error” then raise an error, if “ignore” then just return None for models that failed optimization.
- Returns:
Number of models that fit successfully, number of models that failed, and a dataframe which contains a row for each of the multidms.Model object references along with the parameters each was fit with for convenience. The dataframe is ultimately meant to be passed into the ModelCollection class. for comparison and visualization.
- Return type:
(n_fit, n_failed, fit_models)
Submodules¶
- data
DataData.nameData.conditionsData.referenceData.reference_indexData.mutationsData.mutations_dfData.variants_dfData.site_mapData.non_identical_mutationsData.non_identical_sitesData.bundle_idxsData.reference_sequence_conditionsData.arraysData.training_dataData.scaled_arraysData.scaled_training_dataData.binarymapsData.targetsData.mutparserData.parse_mutData.parse_mutsData.single_mut_encodingsData.convert_subs_wrt_ref_seq()Data.plot_times_seen_hist()Data.plot_func_score_boxplot()
- jaxmodels
DataLatentGlobalEpistasisIdentitySigmoidModelcount_loss()functional_score_loss()fit()- model
Model- model_collection
ModelCollectionFitErrorfit_one_model()stack_fit_models()fit_models()ModelCollectionModelCollection.site_map_unionModelCollection.conditionsModelCollection.referenceModelCollection.shared_mutationsModelCollection.all_mutationsModelCollection.split_apply_combine_muts()ModelCollection.add_validation_loss()ModelCollection.get_conditional_loss_df()ModelCollection.convergence_trajectory_df()ModelCollection.mut_param_heatmap()ModelCollection.mut_param_traceplot()ModelCollection.shift_sparsity()ModelCollection.mut_param_dataset_correlation()
- plot
color_gradient_hex()- Utils
explode_params_dict()my_concat()split_sub()split_subs()difference_matrix()transform()rereference()