Simulation validation

Overview

The bloom lab had developed a pipeline to simulate DMS data to test dms_variants global epistasis fitting. Taking inspriation from that, we build upon the pipeline by simulating data for two homologs - introducing shifts in mutational effects at a subset of sites with the goal validating the multidms joint-fitting approach.

This notebook has a few major steps involved:

  1. Simulate sequences, and respective mutational effects for two homologs – introducing shifts at a subset of sites for the non-reference homolog.

  2. Define replicate libraries of variants for each homolog.

  3. Simulating pre, and post-selection library variant counts derived from a simulated bottleneck and used to introduce realistic noise in the fitting targets (functional scores).

  4. multidms model fitting.

  5. Model fit comparison and selection criteria

  6. Model evaluation.

Import Python modules

For this analysis, we’ll be using the dms_variants package for simulation. All dependencies needed can be installed by cloning the multidms repo and executing the command:

$ pip install '.[dev]'
[176]:
import warnings
import pprint
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import Bio.Seq
import itertools
import random
import dms_variants.codonvarianttable
import dms_variants.globalepistasis
import dms_variants.plotnine_themes
import dms_variants.simulate
from dms_variants.constants import CBPALETTE, CODONS_NOSTOP, AAS_NOSTOP, AA_TO_CODONS, AAS_WITHSTOP
from plotnine import *
import multidms
from collections import defaultdict
from multidms.utils import split_sub

Define notebook parameters

Set parameters that define this notebook’s behavior. These can all be modifed (using papermill) to reasonable values and the notebook should execute as expected.

[177]:
seed = 2  # random number seed

# define the reference protein sequence we're going to simulate
genelength = 50  # gene length in codons
wt_latent = 5.0 # wildtype latent phenotype
stop_effect = -10 # -15 is the default for `SigmoidPhenotypeSimulator` object
norm_weights=((0.4, -.33, .8), (0.6, -2.666, 2)) # See `SigmoidPhenotypeSimulator` for more details

# define the non reference sequence attributes.
n_non_identical_sites = 10 # number of amino-acid mutations separating homologs
min_muteffect_in_bundle = -1.0 # minimum effect per mutation
max_muteffect_in_bundle = 1.0 # maximum effect per mutation
n_shifted_non_identical_sites = 6 # number of non identical sites (in the bundle) for which mutations at that site are expected to have shifted effects
n_shifted_identical_sites = 4 # number of amino-acid mutations separating shifted
shift_gauss_variance = 0.666 # variance of the gaussian distribution from which the shifted effects are drawn

# define the sigmoid phenotype parameters. See `multidms.biophysical.sigmoidal_global_epistasis` for more details.
sigmoid_phenotype_scale = 6

# define the libraries
variants_per_lib_genelength_scalar = 1000  # variants per library
avgmuts = 2.0  # average codon mutations per variant
bclen = 16  # length of nucleotide barcode for each variant
variant_error_rate = 0.0 #0.005  # rate at which variant sequence mis-called
avgdepth_per_variant = 200  # average per-variant sequencing depth
lib_uniformity = 5  # uniformity of library pre-selection
noise = 0.0 # 0.05 # random gaussian noise added to the post-selection counts of the phenotype

# bottlenecks from pre- to post-selection
bottleneck_variants_per_lib_scalar = {
    "tight_bottle": 5,
    "loose_bottle": 100,
}

# define the fitting parameters
maxiter = 100000 # default 20000
tol=1e-3
coef_lasso_shift = [0.0, 5.00e-6, 1.00e-05, 2.00e-05, 4.00e-05, 8.00e-05, 1.60e-04, 3.20e-04, 6.40e-04] # the sweep of lasso coefficient params
init_beta0 = 5.0 # We've found that we need to start with a higher beta0 to get the model to converge correctly,
coef_ridge_beta = 1e-7 # the sweep of ridge coefficient params
train_frac = 0.8 # fraction of data to use for cross validation training.
lasso_choice = 8e-5 # the lasso coefficient to use for the final model
csv_output_dir = 'scaled_down_parameter_output' # the directory to save the output csv files

Set some configurations and a few other more experimental parameters which could break things if not set correctly:

[178]:
%matplotlib inline

random.seed(seed)
np.random.seed(seed)
warnings.simplefilter('ignore')

if not os.path.exists(csv_output_dir):
    os.makedirs(csv_output_dir)

Simulate a reference sequence

Start by simulating a sequence of nucleotides, then translate it to amino acids. We’ll use this sequence as the reference sequence for the simulated datasets.

[179]:
geneseq_h1 = "".join(random.choices(CODONS_NOSTOP, k=genelength))
aaseq_h1 = str(Bio.Seq.Seq(geneseq_h1).translate())
print(f"Wildtype gene of {genelength} codons:\n{geneseq_h1}")
print(f"Wildtype protein:\n{aaseq_h1}")
Wildtype gene of 50 codons:
TTCTTAAATACCTCAGTAGGACAGGCAGCCGATAGCCGGCGAGTATTTTTAGACCGTCAAAAGAACCTACATCCTTGCGAAGAGATGAACCATAGACTTTTTGGCAGTTGCTACGTATGGGTGTACCCCTTGTTCAGCGTCGGTCTAGAA
Wildtype protein:
FLNTSVGQAADSRRVFLDRQKNLHPCEEMNHRLFGSCYVWVYPLFSVGLE

Mutational effects

Simulate latent mutational effects, \(\beta_m\), storing data in a SigmoidPhenotypeSimulator object. Also create an identical object for the second homolog. Later, we’ll update this object to include shifted mutational effects.

[180]:
mut_pheno_args = {
    "geneseq" : geneseq_h1,
    "wt_latent" : wt_latent,
    "seed" : seed,
    "stop_effect" : stop_effect,
    "norm_weights" : norm_weights
}

SigmoidPhenotype_h1 = dms_variants.simulate.SigmoidPhenotypeSimulator(**mut_pheno_args)
SigmoidPhenotype_h2 = dms_variants.simulate.SigmoidPhenotypeSimulator(**mut_pheno_args)

Plot the distribution of mutation effects for the first homolog.

[181]:
p = (
    ggplot(pd.DataFrame({'muteffects': SigmoidPhenotype_h1.muteffects}),
           aes(x='muteffects'))
    + geom_histogram(binwidth=0.1, color='black', fill='pink', alpha=0.7)
    + xlab("mutation effect, $β$")
    + ylab('number of mutations')
    + theme_classic()
    + theme(figure_size=(5, 3))
)
_ = p.draw(show=True)
_images/simulation_validation_18_0.png

As expected, our mutational effects are weighted in the negative direction, as we expect the majority of mutations to be deleterious to a protein.

Non-reference homolog sequence

Next, we’ll simulate the DNA/protein sequence of second homolog by making a defined number of random amino-acid mutations to the first homolog. When choosing the “bundle” of mutation which separate the two homologs, we avoid mutations that decrease or increase the latent phenotype by more than min_muteffect_in_bundle and max_muteffect_in_bundle respectively. This is to ensure that the latent phenotype of the second homolog is not too different from the first homolog, an important assumption for the multidms model.

[182]:
# Input params
non_identical_sites = sorted(random.sample(range(1, len(aaseq_h1)+1), n_non_identical_sites))
aaseq_h2 = ''
geneseq_h2 = ''


# Iterate over each amino acid in the first homolog, and make a mutation if indicated.
for (aa_n, aa) in enumerate(aaseq_h1, 1):
    codon = geneseq_h1[(aa_n-1)*3:3*aa_n]
    if aa_n in non_identical_sites:
        # define the valid mutations for this site
        valid_bundle_mutations = [
            mut_aa for mut_aa in AAS_NOSTOP
            if (
                (mut_aa != aa) \
                and (SigmoidPhenotype_h1.muteffects[f'{aa}{aa_n}{mut_aa}'] > min_muteffect_in_bundle) \
                and (SigmoidPhenotype_h1.muteffects[f'{aa}{aa_n}{mut_aa}'] < max_muteffect_in_bundle)
            )
        ]
        assert len(valid_bundle_mutations) > 0, aa_n
        mut_aa = random.choice(valid_bundle_mutations)
        aaseq_h2 += mut_aa
        mut_codon = random.choice(AA_TO_CODONS[mut_aa])
        geneseq_h2 += mut_codon
    else:
        aaseq_h2 += aa
        geneseq_h2 += codon

# Store and summarize results
homolog_seqs_df = pd.DataFrame(
    {"wt_aa_h1" : list(aaseq_h1), "wt_aa_h2" : list(aaseq_h2)},
    index = range(1, len(aaseq_h1)+1),
)
homolog_seqs_df.index.name = 'site'

# n_diffs = sum([aa1 != aa2 for (aa1, aa2) in zip(*homolog_seqs.values())])
n_diffs = len(homolog_seqs_df.query('wt_aa_h1 != wt_aa_h2'))
print('Sequence alignment of homologs h1 and h2:')
print('h1', aaseq_h1)
print('h2', aaseq_h2)
print('Number of aa differences:', n_diffs)
assert len(aaseq_h1) == len(aaseq_h2)
assert aaseq_h2 == str(Bio.Seq.Seq(geneseq_h2).translate())
assert n_diffs == len(non_identical_sites)
Sequence alignment of homologs h1 and h2:
h1 FLNTSVGQAADSRRVFLDRQKNLHPCEEMNHRLFGSCYVWVYPLFSVGLE
h2 FLNTMVGWDADTRRVFLDFAKNLHPVEEMNCRLFGSCYVWVVPLFSYGLE
Number of aa differences: 10

Shifted mutational effects

Next, randomly choose a subset of (identical and non-identical) sites that will have shifted mutational effects. Do this independently for sites that are identical and non-identical between homologs, so that we are sure to have shifted sites in each category.

[183]:
# Non-identical sites
shifted_non_identical_sites = sorted(random.sample(
    non_identical_sites,
    n_shifted_non_identical_sites
))

# choose a subset of the identical sites to have shifted effects
shifted_identical_sites = sorted(random.sample(
    list(set(range(1, len(aaseq_h1)+1)) - set(non_identical_sites)),
    n_shifted_identical_sites
))

# Make a list of all shifted sites
shifted_sites = sorted(shifted_identical_sites + shifted_non_identical_sites)
assert len(shifted_sites) == len(set(shifted_sites))
print('Sites with shifts that are...')
print(f'identical (n={len(shifted_identical_sites)}):', ', '.join(map(str, shifted_identical_sites)))
print(f'non-identical (n={len(shifted_non_identical_sites)}):', ', '.join(map(str, shifted_non_identical_sites)))
Sites with shifts that are...
identical (n=4): 6, 14, 25, 39
non-identical (n=6): 8, 9, 12, 19, 31, 47

Next, we’ll create a dataframe for all mutational effects and shifts, simulating shifts at each of the above sites by randomly simulate a shift in the effect of each mutation by drawing shifts from a Gaussian distribution:

[184]:
def sim_mut_shift(shifted_site, mutation):
    if (not shifted_site) or ('*' in mutation):
        return 0
    else:
        return np.random.normal(loc=0, scale=shift_gauss_variance, size=1)[0]

# Make a dataframe of mutation effects in the first homolog
mut_effects_df = (
    pd.DataFrame.from_dict(
        SigmoidPhenotype_h1.muteffects,
        orient='index',
        columns=['beta_h1']
    )
    .reset_index()
    .rename(columns={'index':'mutation'})
    .assign(
        wt_aa=lambda x: x['mutation'].apply(lambda y: split_sub(y)[0]),
        site=lambda x: x['mutation'].apply(lambda y: int(split_sub(y)[1])),
        mut_aa=lambda x: x['mutation'].apply(lambda y: split_sub(y)[2]),
        shifted_site=lambda x: x['site'].isin(shifted_sites),
        shift=lambda x: x.apply(
            lambda row: sim_mut_shift(
                row['shifted_site'],
                row['mutation']
            ),
            axis=1
        ),
        beta_h2=lambda x: x['beta_h1'] + x['shift']
    )
    .merge(homolog_seqs_df, left_on='site', right_index=True, how='left')
    .assign(
        bundle_mut = lambda x: x['mut_aa'] == x['wt_aa_h2']
    )
)

# Show data for a subset of sites with shifts
mut_effects_df[mut_effects_df['shifted_site'] == True][[
    'site', 'wt_aa_h1', 'wt_aa_h2', 'mutation',
    'beta_h1', 'shift', 'beta_h2',
    'shifted_site'
]]
[184]:
site wt_aa_h1 wt_aa_h2 mutation beta_h1 shift beta_h2 shifted_site
100 6 V V V6A -0.406196 -0.277561 -0.683757 True
101 6 V V V6C -3.915089 -0.037474 -3.952563 True
102 6 V V V6D -0.409342 -1.422707 -1.832049 True
103 6 V V V6E -4.280136 1.092420 -3.187716 True
104 6 V V V6F -0.862349 -1.194428 -2.056777 True
... ... ... ... ... ... ... ... ...
935 47 V Y V47S -0.027698 -0.967941 -0.995638 True
936 47 V Y V47T -2.240557 -0.225120 -2.465677 True
937 47 V Y V47W -0.178289 0.239183 0.060894 True
938 47 V Y V47Y -0.491120 0.414399 -0.076721 True
939 47 V Y V47* -10.000000 0.000000 -10.000000 True

200 rows × 8 columns

Plot the distribution of all simulated shifts, excluding mutations to stop codons and sites that have no shifted effects.

[185]:
non_stop_shifted_muts = mut_effects_df[
    (mut_effects_df['site'].isin(shifted_sites)) &
    ~(mut_effects_df['mutation'].str.contains('\*'))
]

p = (
    ggplot(non_stop_shifted_muts, aes(x='shift'))
    + geom_histogram(binwidth=0.1, color='black', fill='pink', alpha=0.7)
    + xlab("shift")
    + ylab('number of mutations')
    + theme_classic()
    + theme(figure_size=(5, 3))
)

_ = p.draw(show=True)
_images/simulation_validation_29_0.png

We can also plot the shifts as a heatmap, and mark the wildtype amino acid at each site (x for reference, o for non-reference):

[186]:
max_abs_value = max(abs(mut_effects_df['shift'].min()), mut_effects_df['shift'].max())
x_scale = [s for s in range(1, len(homolog_seqs_df)+1)]
y_scale = sorted([
    "R","K","H","D","E","Q","N","S",
    "T","Y","W","F","A","I","L","M",
    "V","G","P","C","*",
])

p = (
    ggplot(mut_effects_df)
    + geom_tile(
        aes(x='factor(site)', y='mut_aa', fill='shift', width=.9, height=.9)
    )
    + geom_point(
        homolog_seqs_df.reset_index(),
        aes(
            x='factor(site)', y='wt_aa_h1' #, label='aaseq_h1'
        ),
        color='black',
        size=2,
        shape='x'
    )
    + geom_point(
        homolog_seqs_df.reset_index().query('wt_aa_h1 != wt_aa_h2'),
        aes(
            x='factor(site)', y='wt_aa_h2'
        ),
        color='black',
        size=2,
        shape='o'
    )
    + scale_x_discrete(
        limits = x_scale,
        labels = x_scale
    )
    + scale_y_discrete(
        limits = y_scale,
        labels = y_scale
    )
    + scale_fill_gradientn(
        colors=['red', 'white', 'blue'],
        limits=[-max_abs_value, max_abs_value],
    )
    + theme_classic()
    + theme(
        axis_text_x=element_text(angle=90),
        figure_size=(7, 3.5)
    )
    + xlab('site')
    + ylab('mutation')
)
_ = p.draw(show=True)
_images/simulation_validation_31_0.png

Recal that we have already created a SigmoidPhenotypeSimulator object for the second homolog by making a copy of the one for the first homolog. This object stores the homolog’s wildtype latent phenotype and the latent effects of individual mutations. Below, we update both of these traits based on the simulated shifts from above. To update the wildtype latent phenotype of the second homolog, we add the effects of all mutations that separate the homologs. In computing this sum, we will use \(\beta\) parameters for the second homolog, which already include shifted effects.

[187]:
# Update individual mutational effects
assert sum(mut_effects_df['mutation'].duplicated()) == 0
for mutation in SigmoidPhenotype_h2.muteffects.keys():
    SigmoidPhenotype_h2.muteffects[mutation] = mut_effects_df.loc[
        mut_effects_df['mutation']==mutation, 'beta_h2'
    ].values[0]

wt_latent_phenotype_shift = mut_effects_df.query("bundle_mut")["beta_h2"].sum()
SigmoidPhenotype_h2.wt_latent = SigmoidPhenotype_h1.wt_latent + wt_latent_phenotype_shift

print('Characteristics of mutations separating homologs:\n')
for metric in ['beta_h1', 'shift', 'beta_h2']:
    print(f'  Sum of {metric}:', round(sum(mut_effects_df.query("bundle_mut")[metric]), 2))
print('  Final WT latent phenotype of h2:', round(SigmoidPhenotype_h2.wt_latent, 2))
print('  Table of mutations that separate homologs:')
mut_effects_df.query("mut_aa == wt_aa_h2")
Characteristics of mutations separating homologs:

  Sum of beta_h1: -0.76
  Sum of shift: -0.83
  Sum of beta_h2: -1.59
  Final WT latent phenotype of h2: 3.41
  Table of mutations that separate homologs:
[187]:
mutation beta_h1 wt_aa site mut_aa shifted_site shift beta_h2 wt_aa_h1 wt_aa_h2 bundle_mut
90 S5M -0.209250 S 5 M False 0.000000 -0.209250 S M True
157 Q8W 0.216508 Q 8 W True 0.334237 0.550745 Q W True
161 A9D 0.226606 A 9 D True 0.361207 0.587813 A D True
235 S12T 0.586065 S 12 T True -0.254090 0.331975 S T True
364 R19F -0.550723 R 19 F True -1.560840 -2.111563 R F True
380 Q20A -0.188525 Q 20 A False 0.000000 -0.188525 Q A True
516 C26V -0.115867 C 26 V False 0.000000 -0.115867 C V True
601 H31C -0.628807 H 31 C True -0.122049 -0.750856 H C True
837 Y42V 0.393497 Y 42 V False 0.000000 0.393497 Y V True
938 V47Y -0.491120 V 47 Y True 0.414399 -0.076721 V Y True

At this point, all mutations in the second homolog’s SigmoidPhenotypeSimulator object are defined relative to the wildtype sequence of the first homolog, which we will call the “reference” homolog. However, to compute phenotypes for the second homolog, we still need to assign the effects of mutations at non-identical sites in the “non-reference” homolog. For instance, if the wildtype amino acid at site 30 is an A in the reference homolog, but is a Y in a non-reference homolog, then the effect of a Y30G is absent from the second homolog’s simulator object.

To add these missing entries, we will use a strategy that assumes additivity between mutations at the same site. For instance, in the above example, the effect of a Y30G mutation is defined as the sum of the A30Y (negated) and A30G effects in the reference homolog. This approach assumes that mutational effects can be negated, such that A30Y has the opposite effect as Y30A. It also assumes that mutational effects are additive, such that the effect of Y30G is the sum of the effects of Y30A and A30G.

The below expression defines this approach more explicitly for an arbitrary site. It uses the notation \(\beta_{x,n,z}\) where \(x\) and \(z\) are amino acids. For a site \(n\), if \(aa_{\text{wt}}\) is the site’s wildtype amino acid in a non-reference homolog, \(aa_{\text{mut}}\) is a mutant amino acid in a variant of that homolog, and \(aa_{\text{ref}}\) is the site’s wildtype amino acid in the reference homolog, then the mutation’s effect is:

\[\begin{split}\beta_{aa_{\text{wt}},n,aa_{\text{mut}}} = \begin{cases} \beta_{aa_{\text{ref}},n,aa_{\text{mut}}} & \text{if } aa_{\text{wt}} = aa_{\text{ref}}\\ -\beta_{aa_{\text{ref}},n,aa_{\text{wt}}} & \text{if } aa_{\text{mut}} = aa_{\text{ref}}\\ - \beta_{aa_{\text{ref}},n,aa_{\text{wt}}} +\beta_{aa_{\text{ref}},n,aa_{\text{mut}}} & \text{otherwise}\\ \end{cases}\end{split}\]

The above \(\beta\) parameters are effects in the background of the non-reference homolog, such that they account for shifts in mutational effects between homologs.

The below cell adds mutational effects for missing entries using the above strategy:

[188]:
# iterate over "bundle" of mutations which disguish the homologs
for idx, row in mut_effects_df.query("mut_aa == wt_aa_h2").iterrows():

    # compute the effect of each possible mutation at a bundle site
    for aa_mut in AAS_WITHSTOP:

        # skip the non-ref wt-to-wt
        if aa_mut == row.wt_aa_h2:
            continue

        non_ref_mutation = f'{row.wt_aa_h2}{row.site}{aa_mut}'

        # stop effect is the same for all homologs
        if aa_mut == '*':
            SigmoidPhenotype_h2.muteffects[non_ref_mutation] = stop_effect

        # back-mutation to reference
        elif aa_mut == row.wt_aa_h1:
            SigmoidPhenotype_h2.muteffects[non_ref_mutation] = -row.beta_h2

        # all other mutations
        else:
            ref_mut = f'{row.wt_aa_h1}{row.site}{aa_mut}'
            ref_mut_effect = mut_effects_df.loc[
                mut_effects_df['mutation'] == ref_mut, 'beta_h2'
            ].values[0]
            SigmoidPhenotype_h2.muteffects[non_ref_mutation] = -row.beta_h2 + ref_mut_effect
[189]:
mut_effects_df.to_csv(f"{csv_output_dir}/simulated_muteffects.csv", index=False)
mut_effects_df.head()
[189]:
mutation beta_h1 wt_aa site mut_aa shifted_site shift beta_h2 wt_aa_h1 wt_aa_h2 bundle_mut
0 F1A -2.019912 F 1 A False 0.0 -2.019912 F F False
1 F1C -0.417888 F 1 C False 0.0 -0.417888 F F False
2 F1D -2.928096 F 1 D False 0.0 -2.928096 F F False
3 F1E -1.516192 F 1 E False 0.0 -1.516192 F F False
4 F1G -4.732575 F 1 G False 0.0 -4.732575 F F False

Variant libraries

Simulate a set of variant libraries that one might use in an actual experiment for each homolog, each with two replicate libraries.

[190]:
libs = ["lib_1", "lib_2"]  # distinct libraries of gene
variants_per_lib = variants_per_lib_genelength_scalar * genelength

CodonVariantTable_h1 = dms_variants.simulate.simulate_CodonVariantTable(
    geneseq=geneseq_h1,
    bclen=bclen,
    library_specs={
        lib: {"avgmuts": avgmuts, "nvariants": variants_per_lib} for lib in libs
    },
    seed=seed,
)
CodonVariantTable_h2 = dms_variants.simulate.simulate_CodonVariantTable(
    geneseq=geneseq_h2,
    bclen=bclen,
    library_specs={
        lib: {"avgmuts": avgmuts, "nvariants": variants_per_lib} for lib in libs
    },
    seed=seed,
)
[191]:
CodonVariantTable_h1.n_variants_df(samples=None)
[191]:
library sample count
0 lib_1 barcoded variants 50000
1 lib_2 barcoded variants 50000
2 all libraries barcoded variants 100000

plot the number of variant support sequences as a histogram

[192]:
p = CodonVariantTable_h1.plotVariantSupportHistogram(max_support=10)
p += theme_classic()
p += theme(
    panel_grid_major_x=element_blank(),  # no vertical grid lines
    figure_size=(5, 3)
)
_ = p.draw(show=True)
_images/simulation_validation_43_0.png
[193]:
CodonVariantTable_h2.n_variants_df(samples=None)
[193]:
library sample count
0 lib_1 barcoded variants 50000
1 lib_2 barcoded variants 50000
2 all libraries barcoded variants 100000
[194]:
p = CodonVariantTable_h2.plotVariantSupportHistogram(max_support=10)
p += theme_classic()
p += theme(
    panel_grid_major_x=element_blank(),  # no vertical grid lines
    figure_size=(5, 3)
)
_ = p.draw(show=True)

_images/simulation_validation_45_0.png

Variant phenotypes and enrichments

Now that we have simulated libraries of variants \(v_i \in V\), we’ll assign their respective ground truth phenotypes and enrichments.

We’ll start by computing latent phenotype, \(\phi(v_i)\). Recall that each of the two homologs has a SigmoidPhenotypeSimulator object that stores the latent effects of individual mutations. We will use these objects to compute the latent phenotypes of all variants in the libraries.

[195]:
# we'll store a collection of functions for each homolog
phenotype_fxn_dict_h1 = {"latentPhenotype" : SigmoidPhenotype_h1.latentPhenotype}
phenotype_fxn_dict_h2 = {"latentPhenotype" : SigmoidPhenotype_h2.latentPhenotype}

# add the latent phenotype to the barcode variant table
CodonVariantTable_h1.barcode_variant_df['latent_phenotype'] = CodonVariantTable_h1.barcode_variant_df['aa_substitutions'].apply(
    phenotype_fxn_dict_h1["latentPhenotype"]
)
CodonVariantTable_h2.barcode_variant_df['latent_phenotype'] = CodonVariantTable_h2.barcode_variant_df['aa_substitutions'].apply(
    phenotype_fxn_dict_h2["latentPhenotype"]
)

# print a sample of the variant table
CodonVariantTable_h1.barcode_variant_df[["library", "aa_substitutions", "latent_phenotype"]].head()
[195]:
library aa_substitutions latent_phenotype
0 lib_1 S12F P43F 2.696775
1 lib_1 L33P 5.253866
2 lib_1 N3V F16S 4.413085
3 lib_1 N22W L49I -1.817175
4 lib_1 K21N 4.402149
[196]:
CodonVariantTable_h2.barcode_variant_df[["library", "aa_substitutions", "latent_phenotype"]].head()
[196]:
library aa_substitutions latent_phenotype
0 lib_1 T12F P43F 0.687919
1 lib_1 L33P 3.665115
2 lib_1 N3V F16S 2.824335
3 lib_1 N22W L49I -3.405926
4 lib_1 K21N 2.813398

Next, we can compute an observed phenotype, \(p(v, d)\) = \(g_{\theta}(\phi(v, d))\), where \(g_{\theta}\) is the global epistasis function implimented as a flexible sigmoid with two free parameters, \(\theta_{s}\) and \(\theta_{b}\), that serve as a scale and bias, respectively. Concretely, \(g_{\theta}(z) = \theta_{b} + \frac{\theta_{s}}{1 + e^{-z}}\). For the purposes of this analysis, we compute a fixed value for the \(\theta_{b}\) parameter, based upon the specified parameters for \(\theta_{s}\) (sigmoid_phenotype_scale) and \(\phi(v_{WT})\) (wt_latent) such that the reference wildtype phenotype is exactly 0.

Note that while the SigmoidPhenotypeSimulator object (as the name suggests) provides its own global epistasis object for computing phenotype and enrichments, we will only be using it to get mutational effects and the associated latent phenotypes. For computing the post-latent phenotypes and enrichments, we will use the multidms biophysical model such that we’re simulating under the same model we’ll be fitting as described above.

[197]:
# define numpy wrapper for multidms native jax sigmoid fxn
def g(z:float, ge_scale=sigmoid_phenotype_scale, wt_latent=wt_latent):

    # ensure the reference phenotype is 0
    ge_bias = -sigmoid_phenotype_scale / (1 + np.exp(-wt_latent))
    return ge_bias + (ge_scale/(1+np.exp(-z)))
[198]:
wt_latent
[198]:
5.0
[199]:
g(SigmoidPhenotype_h2.wt_latent)
[199]:
np.float64(-0.15151731588729156)
[200]:
g(5, wt_latent=SigmoidPhenotype_h2.wt_latent)
[200]:
np.float64(0.15151731588729156)

Visualize the sigmoidal function that we’ll be using to map latent phenotypes to observed phenotypes. Place vertical line at the reference (“h1”) wildtype latent phenotype, as well as a dashed line for the non-reference (“h2”) phenotype:

[201]:
resolution = (-10, 10, 100)
p = (
    ggplot(
        pd.DataFrame(
            {
                "x":np.linspace(*resolution),
                "y":np.array(list(map(g, np.linspace(*resolution))))
            }
        ),
    )
    + geom_line(aes(x="x", y="y"))
    + geom_vline(xintercept=SigmoidPhenotype_h1.wt_latent)
    + geom_vline(xintercept=SigmoidPhenotype_h2.wt_latent, linetype="dotted")
    + labs(x="latent phenotype", y="observed phenotype")
    + theme_classic()
    + theme(figure_size=(4, 4))
)
_ = p.draw(show=True)

_images/simulation_validation_56_0.png

Next, assign each barcoded variant an observed phenotype, \(p(v, d)\), as well as an observed enrichment , \(2^{p(v, d)}\).

[202]:
subs = CodonVariantTable_h1.barcode_variant_df['aa_substitutions']
phenotype_fxn_dict_h1["observedPhenotype"] = lambda x: g(float(phenotype_fxn_dict_h1["latentPhenotype"](x)))# - g(float(wt_latent))
CodonVariantTable_h1.barcode_variant_df['observed_phenotype'] = subs.apply(
    phenotype_fxn_dict_h1["observedPhenotype"]
)
phenotype_fxn_dict_h1["observedEnrichment"] = lambda x: 2 ** (phenotype_fxn_dict_h1["observedPhenotype"](x))
CodonVariantTable_h1.barcode_variant_df['observed_enrichment'] = subs.apply(
    phenotype_fxn_dict_h1["observedEnrichment"]
)

subs = CodonVariantTable_h2.barcode_variant_df['aa_substitutions']
phenotype_fxn_dict_h2["observedPhenotype"] = lambda x: g(float(phenotype_fxn_dict_h2["latentPhenotype"](x))) - g(float(SigmoidPhenotype_h2.wt_latent))
CodonVariantTable_h2.barcode_variant_df['observed_phenotype'] = subs.apply(
    phenotype_fxn_dict_h2["observedPhenotype"]
)
# subtract the wt observed phenotype from all observed phenotypes
# CodonVariantTable_h2.barcode_variant_df['observed_phenotype'] -= CodonVariantTable_h2.barcode_variant_df.query("aa_substitutions == ''")['observed_phenotype'].iloc[0]

phenotype_fxn_dict_h2["observedEnrichment"] = lambda x: 2 ** (phenotype_fxn_dict_h2["observedPhenotype"](x))
CodonVariantTable_h2.barcode_variant_df['observed_enrichment'] = subs.apply(
    phenotype_fxn_dict_h2["observedEnrichment"]
)
variants_df = pd.concat(
    [
        CodonVariantTable_h1.barcode_variant_df.assign(homolog="h1"),
        CodonVariantTable_h2.barcode_variant_df.assign(homolog="h2")
    ]
)

variants_df[["aa_substitutions", "latent_phenotype", "observed_phenotype", "observed_enrichment", "homolog"]].head()
[202]:
aa_substitutions latent_phenotype observed_phenotype observed_enrichment homolog
0 S12F P43F 2.696775 -0.338826 0.790684 h1
1 L33P 5.253866 0.008957 1.006228 h1
2 N3V F16S 4.413085 -0.031679 0.978281 h1
3 N22W L49I -1.817175 -5.121204 0.028732 h1
4 K21N 4.402149 -0.032459 0.977752 h1
[203]:
variants_df.query("aa_substitutions == '' & homolog == 'h2'")
[203]:
library barcode variant_call_support codon_substitutions aa_substitutions n_codon_substitutions n_aa_substitutions latent_phenotype observed_phenotype observed_enrichment homolog
5 lib_1 AAAAAACGATCGAGCT 3 0 0 3.411249 0.0 1.0 h2
12 lib_1 AAAAAATCCGCGCCCG 3 0 0 3.411249 0.0 1.0 h2
13 lib_1 AAAAAATCGAATTCAG 1 0 0 3.411249 0.0 1.0 h2
14 lib_1 AAAAAATCTCTGATCT 1 0 0 3.411249 0.0 1.0 h2
29 lib_1 AAAAACTAATAGAACC 4 0 0 3.411249 0.0 1.0 h2
... ... ... ... ... ... ... ... ... ... ... ...
99965 lib_2 TTTTTCAACAAACTCA 3 0 0 3.411249 0.0 1.0 h2
99968 lib_2 TTTTTCGCCCGCCTTC 1 TTC45TTT 1 0 3.411249 0.0 1.0 h2
99973 lib_2 TTTTTGAATCTAGTGA 1 0 0 3.411249 0.0 1.0 h2
99980 lib_2 TTTTTGCTGCTCACCG 1 0 0 3.411249 0.0 1.0 h2
99984 lib_2 TTTTTGGCGTGGACAC 1 0 0 3.411249 0.0 1.0 h2

14676 rows × 11 columns

[ ]:

Note that The enrichment score transforms phenotypes to be between 0 and 1, and can thus be used to compute the number of counts for each variant in the post-selection library using a multinomial distribution where the probability of observing a variant is proportional to its enrichment. Below, we plot each variant’s latent phenotype against its observed enrichment:

[204]:

p = ( ggplot( variants_df, aes( x = "latent_phenotype", y = "observed_enrichment", color = "homolog" ) ) + geom_point() + geom_vline(xintercept=SigmoidPhenotype_h1.wt_latent) + geom_vline(xintercept=SigmoidPhenotype_h2.wt_latent, linetype="dotted") + theme_classic() + theme( figure_size=(5, 3), axis_text_x=element_text(angle=90), panel_grid_major_x=element_blank(), # no vertical grid lines ) + labs( x="$\phi$", y="observed enrichment" ) # zoom in on the x axis limits + scale_x_continuous(limits=(-10, 10)) ) _ = p.draw(show=True)
_images/simulation_validation_62_0.png

Pre and post-selection variant read counts

Next, we will simulate both pre and post-selection counts of each variant subjected to an experimental selection step. Post-selection counts will be simulated by applying the specified selection bottleneck(s) to the pre-selection counts. To do this, we will use the functionality provided by dms_variants.simulate.simulateSampleCounts and feed it our custom phenotype function.

[205]:
bottlenecks = {
    name : variants_per_lib * scalar
    for name, scalar in bottleneck_variants_per_lib_scalar.items()
}

# do this independently for each of the homologs.
counts_h1, counts_h2 = [
    dms_variants.simulate.simulateSampleCounts(
        variants=variants,
        phenotype_func=pheno_fxn_dict["observedEnrichment"],
        variant_error_rate=variant_error_rate,
        pre_sample={
            "total_count": variants_per_lib * np.random.poisson(avgdepth_per_variant),
            "uniformity": lib_uniformity,
        },
        pre_sample_name="pre-selection",
        post_samples={
            name: {
                "noise": noise,
                "total_count": variants_per_lib * np.random.poisson(avgdepth_per_variant),
                "bottleneck": bottleneck,
            }
            for name, bottleneck in bottlenecks.items()
        },
        seed=seed,
    )
    for variants, pheno_fxn_dict in zip(
        [CodonVariantTable_h1, CodonVariantTable_h2], [phenotype_fxn_dict_h1, phenotype_fxn_dict_h2]
    )
]
CodonVariantTable_h1.add_sample_counts_df(counts_h1)
CodonVariantTable_h2.add_sample_counts_df(counts_h2)

Plot the number of counts for each variant in each sample. The horizontal dashed line shows the total number of variants. The plot shows that all variants are well-sampled in the pre-selection libraries, but that post- selection some variants are sampled more or less. This is expected since selection will decrease and increase the frequency of variants:

[206]:
for variants, title in zip([CodonVariantTable_h1, CodonVariantTable_h2], ['homolog 1', 'homolog 2']):
    p = variants.plotCumulVariantCounts()
    p += theme_classic()
    p += theme(
        panel_grid_major_x=element_blank(),  # no vertical grid lines
        figure_size=(5, 5)
    )
    p += labs(title=title)
    _ = p.draw(show=True)
_images/simulation_validation_67_0.png
_images/simulation_validation_67_1.png

Distribution of the number of amino-acid mutations per variant in each sample. As expected, mutations go down after selection:

[207]:
for variants, title in zip([CodonVariantTable_h1, CodonVariantTable_h2], ['homolog 1', 'homolog 2']):
    p = variants.plotNumMutsHistogram(mut_type="aa")
    p += theme_classic()
    p += theme(figure_size=(5, 5))
    p += labs(title=title)
    _ = p.draw(show=True)
_images/simulation_validation_69_0.png
_images/simulation_validation_69_1.png

Plot how thoroughly amino-acid mutations are sampled. The plots below show that the stop mutations are sampled very poorly post-selection because they are eliminated during selection:

[208]:
for variants, title in zip([CodonVariantTable_h1, CodonVariantTable_h2], ['homolog 1', 'homolog 2']):
    p = variants.plotCumulMutCoverage(variant_type="all", mut_type="aa")
    p += theme_classic()
    p += theme(figure_size=(6, 5))
    p += theme(axis_text_x=element_text(angle=90))
    p += labs(title=title)
    _ = p.draw(show=True)
_images/simulation_validation_71_0.png
_images/simulation_validation_71_1.png

Prep training data

Prepare the training dataframes for fitting our multidms models.

As this is a joint-fitting approach, we combine homolog variants from each of the corresponding library replicates into a single training dataset. Additionally, for each replicate dataset, we’ll train models on each of the following targets representing different levels of noise in the data:

  1. Ground truth observed phenotype - this target acts as a control and benchmarks the ability of the model to recover the true latent effects and shifts controlling for other sources of noise.

  2. Loose bottleneck counts derived functional scores - these scores are derived from the observed enrichments, and used to asses model performance in the context of realistic experimental noise.

  3. Tight bottleneck counts derived functional scores - the same as above, but with a tighter bottleneck for more extreme cases of noise.

For further details on how to prepare data and fit models, see the multidms quick-start tutorial.

Start by creating training data with ground truth phenotype target. Because the barcode replicates share ground truth phenotypes we can can collapse the counts accross replicates by simple dropping duplicates.

[209]:
(variants.variant_count_df.barcode.value_counts().values == 3).all()
[209]:
np.False_
[ ]:

[210]:
variants.variant_count_df["sample"].value_counts().sort_index()
[210]:
sample
pre-selection    100000
loose_bottle     100000
tight_bottle     100000
Name: count, dtype: int64
[211]:
variants.barcode_variant_df.library.value_counts()
[211]:
library
lib_1    50000
lib_2    50000
Name: count, dtype: int64
[212]:
req_cols = ["library", "homolog", "aa_substitutions", "func_score_type", "func_score"]
ground_truth_training_set = (
    pd.concat(
        [
            variants.barcode_variant_df[["library", "aa_substitutions", "observed_phenotype", "latent_phenotype"]]
            .drop_duplicates()
            .assign(homolog=homolog)
            for variants, homolog in zip([CodonVariantTable_h1, CodonVariantTable_h2], ['h1', 'h2'])
        ]
    )
    .melt(
        id_vars=["library", "aa_substitutions", "homolog", "latent_phenotype"],
        value_vars=["observed_phenotype"],
        var_name="func_score_type",
        value_name="func_score",
    )
    [req_cols]
)
ground_truth_training_set.round(2).head()
[212]:
library homolog aa_substitutions func_score_type func_score
0 lib_1 h1 S12F P43F observed_phenotype -0.34
1 lib_1 h1 L33P observed_phenotype 0.01
2 lib_1 h1 N3V F16S observed_phenotype -0.03
3 lib_1 h1 N22W L49I observed_phenotype -5.12
4 lib_1 h1 K21N observed_phenotype -0.03
[213]:
ground_truth_training_set.aa_substitutions.value_counts()
[213]:
aa_substitutions
N3L P25A               4
R14D                   4
D11I                   4
L33W                   4
W40Y                   4
                      ..
V39L V47D              1
S5G H24R E28Y L33Y     1
Q8L S12T L23M P43L     1
K21R H31L G35I Y42I    1
W8Y P43V               1
Name: count, Length: 83677, dtype: int64

Next, compute functional scores from pre-post counts in each bottleneck after aggregating the barcode replicate counts for unique variants. The dms_variants.CodonVariantTable.func_scores method provides the ability to compute functional scores from pre and post-selection counts. We’ll use this to compute functional scores for each of the three targets:

[214]:
bottle_cbf = pd.concat(
    [
        (
            variants
            .func_scores("pre-selection", by="aa_substitutions", libraries=libs, syn_as_wt=True)
            .assign(homolog=homolog)
            .rename({"post_sample":"func_score_type"}, axis=1)
            # [req_cols]
            .astype({c:str for c in req_cols[:-1]})
        )
        for variants, homolog in zip([CodonVariantTable_h1, CodonVariantTable_h2], ['h1', 'h2'])
    ]
)
bottle_cbf.head()
[214]:
library pre_sample func_score_type aa_substitutions func_score func_score_var pre_count post_count pre_count_wt post_count_wt pseudocount n_aa_substitutions homolog
0 lib_1 pre-selection loose_bottle S5I G7R F34A -0.167084 0.003633 907 1556 1508174 2904366 0.5 3 h1
1 lib_1 pre-selection loose_bottle V15G -2.168313 0.001447 4802 2057 1508174 2904366 0.5 1 h1
2 lib_1 pre-selection loose_bottle S12F V41G 0.094183 0.003961 781 1606 1508174 2904366 0.5 2 h1
3 lib_1 pre-selection loose_bottle M29S -0.006238 0.000505 6292 12065 1508174 2904366 0.5 1 h1
4 lib_1 pre-selection loose_bottle 0.000000 0.000004 1508174 2904366 1508174 2904366 0.5 0 h1
[215]:
bottle_cbf.to_csv(f"{csv_output_dir}/simulated_bottleneck_cbf.csv", index=False)

Finally, combine the two dataframes computed above and classify the variants based on the number of amino acid substitutions.

[216]:
def classify_variant(aa_subs):
    if "*" in aa_subs:
        return "stop"
    elif aa_subs == "":
        return "wildtype"
    elif len(aa_subs.split()) == 1:
        return "1 nonsynonymous"
    elif len(aa_subs.split()) > 1:
        return ">1 nonsynonymous"
    else:
        raise ValueError(f"unexpected aa_subs: {aa_subs}")

func_scores = (
    pd.concat([ground_truth_training_set, bottle_cbf])
    .assign(variant_class = lambda x: x['aa_substitutions'].apply(classify_variant))
    .merge(
        variants_df[["aa_substitutions", "homolog", "latent_phenotype", "library"]].drop_duplicates(),
        on=["aa_substitutions", "homolog", "library"],
        how="inner"
    )
)

# categorical for plotting
func_scores["func_score_type"] = pd.Categorical(
    func_scores["func_score_type"],
    categories=["observed_phenotype", "loose_bottle", "tight_bottle"],
    ordered=True
)

func_scores
[216]:
library homolog aa_substitutions func_score_type func_score pre_sample func_score_var pre_count post_count pre_count_wt post_count_wt pseudocount n_aa_substitutions variant_class latent_phenotype
0 lib_1 h1 S12F P43F observed_phenotype -0.338826 NaN NaN NaN NaN NaN NaN NaN NaN >1 nonsynonymous 2.696775
1 lib_1 h1 L33P observed_phenotype 0.008957 NaN NaN NaN NaN NaN NaN NaN NaN 1 nonsynonymous 5.253866
2 lib_1 h1 N3V F16S observed_phenotype -0.031679 NaN NaN NaN NaN NaN NaN NaN NaN >1 nonsynonymous 4.413085
3 lib_1 h1 N22W L49I observed_phenotype -5.121204 NaN NaN NaN NaN NaN NaN NaN NaN >1 nonsynonymous -1.817175
4 lib_1 h1 K21N observed_phenotype -0.032459 NaN NaN NaN NaN NaN NaN NaN NaN 1 nonsynonymous 4.402149
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
349861 lib_2 h2 T4C N22A E28V tight_bottle -0.872689 pre-selection 0.225674 17.0 19.0 1631741.0 3329303.0 0.5 3.0 >1 nonsynonymous 0.016912
349862 lib_2 h2 L2S A10S E28L tight_bottle -0.428415 pre-selection 0.222853 15.0 23.0 1631741.0 3329303.0 0.5 3.0 >1 nonsynonymous 0.329871
349863 lib_2 h2 T4H D18S tight_bottle -1.590687 pre-selection 0.332509 15.0 10.0 1631741.0 3329303.0 0.5 2.0 >1 nonsynonymous 0.215730
349864 lib_2 h2 T12P A20I V41E tight_bottle -3.564861 pre-selection 0.976092 14.0 2.0 1631741.0 3329303.0 0.5 3.0 >1 nonsynonymous -1.403989
349865 lib_2 h2 T12A L23T E27A V42G tight_bottle -1.559323 pre-selection 0.782739 6.0 4.0 1631741.0 3329303.0 0.5 4.0 >1 nonsynonymous -6.215249

349866 rows × 15 columns

[217]:
func_scores.to_csv(f"{csv_output_dir}/simulated_functional_scores.csv", index=False)

Let’s plot the functional scores as a function of the latent phenotype.

[218]:
p = (
    ggplot(func_scores, aes(x="latent_phenotype", y="func_score"))
    + geom_point(alpha=0.05)
    + facet_wrap("~func_score_type")
    + theme_classic()
    + theme(figure_size=(6, 3))
)
_ = p.draw(show=True)
_images/simulation_validation_88_0.png

Plot a pairplot to see how targets compare.

[171]:
fig = sns.pairplot(
    (
        func_scores
        .pivot(
            index=["library", "homolog", "aa_substitutions", "variant_class"],
            columns="func_score_type",
            values="func_score"
        )
        .reset_index()
        .sample(frac=0.1)
    ),
    hue='homolog',
    plot_kws = {"alpha":0.25},
    corner=True
)
fig.fig.set_size_inches(6, 5)
plt.tight_layout()
plt.show()
_images/simulation_validation_90_0.png

We can see here that the tight bottleneck indeed introduces more noise in the data, as the functional scores are less correlated with the ground truth phenotype.

Plot the functional scores distributions of the bottleneck counts-computed functional scores. Recall that we collapsed barcode replicates so there is only a single wildtype phenotype in each of the datasets, thus we will not plot them below.

[172]:
p = (
    ggplot(
        func_scores.query("(func_score_type != 'observed_phenotype') & (variant_class != 'wildtype')"),
        aes("variant_class", "func_score")
    )
    + geom_violin(fill="lightblue")
    + ylab("functional score")
    + xlab("")
    + facet_grid("func_score_type ~ library + homolog")
    + theme_classic()
    + theme(
        figure_size=(6, 5),
        axis_text_x=element_text(angle=90),
        panel_grid_major_x=element_blank(),  # no vertical grid lines
    )
    # + scale_fill_manual(values=CBPALETTE[1:])
)
_ = p.draw(show=True)
_images/simulation_validation_93_0.png
[173]:
func_scores.to_csv(f"{csv_output_dir}/simulated_func_scores.csv", index=False)
func_scores.round(2).head()
[173]:
library homolog aa_substitutions func_score_type func_score pre_sample func_score_var pre_count post_count pre_count_wt post_count_wt pseudocount n_aa_substitutions variant_class latent_phenotype
0 lib_1 h1 L33P observed_phenotype 0.01 NaN NaN NaN NaN NaN NaN NaN NaN 1 nonsynonymous 5.25
1 lib_1 h1 N3V F16S observed_phenotype -0.03 NaN NaN NaN NaN NaN NaN NaN NaN >1 nonsynonymous 4.41
2 lib_1 h1 observed_phenotype 0.00 NaN NaN NaN NaN NaN NaN NaN NaN wildtype 5.00
3 lib_1 h1 L2Q H31* observed_phenotype -5.93 NaN NaN NaN NaN NaN NaN NaN NaN stop -5.28
4 lib_1 h1 N3Y L17N D18I E27R observed_phenotype -5.04 NaN NaN NaN NaN NaN NaN NaN NaN >1 nonsynonymous -1.71

Create Data objects for each library replicate, and bottleneck

multidms model fitting is generally a two-step process: (1) Create Data objects from functional score dataframes which, among other things, encode the variant data into one-hot encoded matrices, and (2) fit a collection of models across a grid of Data objects and hyperparameters. Here, we’ll create Data objects for each library replicate / fitting target combination:

[ ]:
data_objects = [
    multidms.Data(
        fs_df,
        reference="h1",
        alphabet = multidms.AAS_WITHSTOP_WITHGAP,
        verbose=False,
        name = f"{lib}_{target}_func_score"
    )
    for (lib, target), fs_df in func_scores.rename(columns={"homolog":"condition"}).groupby(['library', 'func_score_type'])
]

data_objects

Fit models to training data (lasso sweep)

Next, we’ll fit a set of models to each of the datasets defined above using the multidms.fit_models function. For each dataset, we independently fit models with a number of different lasso penalty coefficients (\(\lambda_{L1}\)) as defined at the top of this notebook.

[ ]:
fitting_params = {
    # "num_training_steps": [num_training_steps], # default 100
    "maxiter": [10000], # default 20000
    "tol":[0.0001],
    "coef_lasso_shift": coef_lasso_shift, # the sweep of lasso coefficient params
    "init_beta0" : [init_beta0], # We've found that we need to start with a higher beta0 to get the model to converge correctly,
    "coef_ridge_beta" : [coef_ridge_beta], # the sweep of ridge coefficient params
    "acceleration" : [True], # default True
    "maxls" : [50]
}

fitting_params["dataset"] = data_objects
pprint.pprint(fitting_params)
[ ]:
pprint.pprint(multidms.model_collection._explode_params_dict(fitting_params)[0])
[ ]:
multidms.model_collection.fit_one_model(**multidms.model_collection._explode_params_dict(fitting_params)[0])

Fit the models:

[ ]:
_, _, fit_collection_df = multidms.model_collection.fit_models(fitting_params)

Note that the return type of multidms.fit_models is a tuple(int, int, pd.DataFrame) where the first two values are the number of fits that were successful and the number that failed, respectively, and the third value is a dataframe where each row contains a fit Model object and the hyperparameters used to fit it.

[ ]:
fit_collection_df[["model", "dataset_name", "coef_lasso_shift", "fit_time"]].info()

Add a few helpful features to this dataframe for plotting by splitting the “dataset_name” (name of the Data Object that was used for fitting) into more understandable columns:

[ ]:
fit_collection_df = fit_collection_df.assign(
    measurement_type = fit_collection_df["dataset_name"].str.split("_").str[2:4].str.join("_"),
    library = fit_collection_df["dataset_name"].str.split("_").str[0:2].str.join("_")
)

# convert measurement type to categorical ordered by 'observed', 'loose', 'tight' for plotting
fit_collection_df["measurement_type"] = pd.Categorical(
    fit_collection_df["measurement_type"],
    categories=["observed_phenotype", "loose_bottle", "tight_bottle"],
    ordered=True
)

fit_collection_df[["model", "measurement_type", "library", "coef_lasso_shift"]].head()

Model convergence

[ ]:
# %load_ext autoreload
# %autoreload 2

# import multidms
[ ]:
# model_collection_100K_50_w_wo_acceleration = model_collection
model_collection = model_collection_100K_50_w_wo_acceleration
[ ]:
model_collection = multidms.model_collection.ModelCollection(fit_collection_df)
model_collection.fit_models.converged.value_counts()
[ ]:
data = model_collection.convergence_trajectory_df(
    id_vars = ("coef_lasso_shift", "measurement_type", "library", "acceleration", "maxls")
)
data
[ ]:
data.index.name = "step"
data.reset_index(inplace=True)
data.head()
[ ]:
p = (
    ggplot(
        data.query("~acceleration & step > 90000 & library == 'lib_1' & coef_lasso_shift == 4e-05"),
        aes(
            x="step",
            y="error",
            linetype="library"
        )
    )
    + geom_line()
    + facet_wrap("~coef_lasso_shift+measurement_type", scales="free_y", ncol=3)
    # copnstrain y axis to 0-1
    # + scale_y_continuous(limits=(0, .1))
    # make figure size taller
    + theme(figure_size=(10, 4))
    # add title
    + labs(
        # title="acceleration=False"
    )
)
_ = p.draw(show=True)
[ ]:
p = (
    ggplot(
        data.query("acceleration & step > 3000 & library == 'lib_1'"),
        aes(
            x="step",
            y="error",
            linetype="library"
        )
    )
    + geom_line()
    + facet_wrap("~coef_lasso_shift+measurement_type", scales="free_y", ncol=3)
    # copnstrain y axis to 0-1
    # + scale_y_continuous(limits=(0, .1))
    # make figure size taller
    + theme(figure_size=(10, 10))
)
_ = p.draw(show=True)
[ ]:
p = (
    ggplot(
        data.query("acceleration & step > 3000 & library == 'lib_1'"),
        aes(
            x="step",
            y="loss",
            linetype="library"
        )
    )
    + geom_line()
    + facet_wrap("~coef_lasso_shift+measurement_type", scales="free_y", ncol=3)
    # copnstrain y axis to 0-1
    # + scale_y_continuous(limits=(0, .1))
    # make figure size taller
    + theme(figure_size=(10, 10))
)
_ = p.draw(show=True)

Model vs. truth mutational effects

The multidms.ModelCollection takes a a dataframe such as the one created above and provides a few helpful methods for model selection and aggregation. We’ll use these methods to evaluate the fits and select the best model for each dataset.

[ ]:
model_collection = multidms.model_collection.ModelCollection(fit_collection_df)
[ ]:
model_collection.fit_models.converged.value_counts()
[ ]:
chart_no_bottle = model_collection.mut_param_dataset_correlation(query="measurement_type == 'observed_phenotype'")
chart_loose_bottle = model_collection.mut_param_dataset_correlation(query="measurement_type == 'loose_bottle'")
chart_tight_bottle = model_collection.mut_param_dataset_correlation(query="measurement_type == 'tight_bottle'")
chart_no_bottle | chart_loose_bottle | chart_tight_bottle
[ ]:
(
    model_collection
    .fit_models
    .query(
        f"coef_lasso_shift == 8.0e-5 and measurement_type == 'loose_bottle' and library == 'lib_1'"
    )
    .iloc[0]
    .model
    .plot_epistasis()
)

Get the mutational parameters for each of the models using model_collection.split_apply_combine_muts() method, and merge them with the simulated ground truth parameters:

[ ]:
# the columns that distinguish fits are what we'll groupby such that none of the model collection parameters are aggregated
groupby=("library", "measurement_type", "coef_lasso_shift")
collection_muts_df = (
    model_collection.split_apply_combine_muts(
        groupby=groupby
    )
    .reset_index()
    .rename(
        {
            'beta' : 'predicted_beta',
            'shift_h2' : 'predicted_shift_h2',
        },
        axis=1
    )
    .merge(
        mut_effects_df.rename(
            {
                'beta_h1' : 'true_beta',
                'beta_h2' : 'true_beta_h2',
                'shift' : 'true_shift',
            },
            axis=1
        ),
        on='mutation'
    )
)
assert collection_muts_df.shape[0] == len(mut_effects_df) * len(fit_collection_df)
collection_muts_df[["mutation", "library", "measurement_type", "coef_lasso_shift", "predicted_shift_h2", "true_shift"]].head()

To compare model fits parameter values to the simulated ground truth values, add mean squared error and pearsonr metrics to the ModelCollection.fit_models attribute:

[ ]:
def series_corr(y_true, y_pred):
    return np.corrcoef(y_true, y_pred)[0, 1]

def series_mae(y_true, y_pred):
    return np.mean(y_true - y_pred)

# compute the new metric columns
new_fit_models_cols = defaultdict(list)
for group, model_mutations_df in collection_muts_df.groupby(list(groupby)):
    # add cols for merging
    for i, attribute in enumerate(group):
        new_fit_models_cols[groupby[i]].append(group[i])

    for parameter in ["beta", "shift"]:
        for metric_fxn, name in zip([series_corr, series_mae], ["corr", "mae"]):

            # add the new metric columns
            postfix="_h2" if parameter == "shift" else ""
            y_pred = model_mutations_df[f"predicted_{parameter}{postfix}"]
            y_true = model_mutations_df[f"true_{parameter}"]
            new_fit_models_cols[f"{parameter}_{name}"].append(metric_fxn(y_true, y_pred))

# merge the new columns into the model collection
model_collection.fit_models = model_collection.fit_models.merge(
    pd.DataFrame(new_fit_models_cols),
    on=list(groupby)
)

# print the first few rows of the model collection
model_collection.fit_models[list(groupby) + [c for c in new_fit_models_cols.keys() if c not in groupby]].head()

Next, we’ll make some summary plots with the predicted vs. simulated ground MSE truth parameters across model fits

[ ]:
metric = "corr"

data = (
    model_collection.fit_models
    .assign(
        measurement_library = lambda x: x["measurement_type"].astype(str) + " " + x["library"]
    )
    .melt(
        id_vars=list(groupby) + ["measurement_library"],
        value_vars=[f"beta_{metric}", f"shift_{metric}"],
        var_name="parameter",
        value_name=metric,
    )
)
data["measurement_type"] = pd.Categorical(
    data["measurement_type"],
    categories=["observed_phenotype", "loose_bottle", "tight_bottle"],
    ordered=True
)
data["parameter"] = data["parameter"].str.replace(f"_{metric}", "")
data["parameter"] = pd.Categorical(
    data["parameter"],
    categories=["shift", "beta"],
    ordered=True
)
for parameter, parameter_df in data.groupby("parameter", observed=True):
    p = (
        ggplot(parameter_df)
        + geom_line(
            aes(
                x="coef_lasso_shift",
                y=metric,
                group="measurement_library",
            ),
        )
        + geom_point(
            aes(
                x="coef_lasso_shift",
                y=metric,
                shape="library"
            ),
            size=4
        )
        + facet_wrap("measurement_type", scales="free_y")
        + theme_classic()
        + theme(
            figure_size=(6, 3.3),
            axis_text_x=element_text(angle=90),
            panel_grid_major_x=element_blank(),  # no vertical grid lines
        )
        + labs(
            title=f"Prediction vs. Simulated Ground Truth: {parameter}",
            x="lasso penalty coefficient (λ)",
            y="Mean Absolute Error" if metric == "mae" else "Correlation"
        )
    )
    _ = p.draw(show=True)
[ ]:
data.to_csv(f"{csv_output_dir}/model_vs_truth_beta_shift.csv", index=False)
data.round(2).head()

Shift sparsity

Another way we might evaluate the fits is by looking at the sparsity of the models by computing the percentage of shift parameters that are equal to zero among the models. We look at this metric separately for mutations to stop codons and mutations to non-stop codons because we expect the stop codons to be equally deleterious in both homologs, and thus we expect all the shift parameters associated with stop codon mutations to be zero in a “good” fit.

[ ]:
chart, data = model_collection.shift_sparsity(return_data=True)
data.head()
[ ]:
data =data.assign(
    library=data.dataset_name.str.split("_").str[:2].str.join("_"),
    library_type = data.dataset_name.str.split("_").str[:2].str.join("_") + "-" + data.mut_type,
    measurement_type = data.dataset_name.str.split("_").str[2:4].str.join("_")
)
# convert measurement type to categorical ordered by 'observed', 'loose', 'tight'
data["measurement_type"] = pd.Categorical(
    data["measurement_type"],
    categories=["observed_phenotype", "loose_bottle", "tight_bottle"],
    ordered=True
)
data['True Sparsity'] = '' #dummy col for plotting legend
# data.sort_values("coef_lasso_shift", inplace=True)
data["coef_lasso_shift"] = data.coef_lasso_shift.astype(object)

true_sparsity = 1 - mut_effects_df.shifted_site.mean()

p = (
    ggplot(
        data,
        aes("coef_lasso_shift", "sparsity")
    )
    + geom_hline(aes(yintercept=true_sparsity, color="True Sparsity"), linetype="dashed")
    + geom_line(
        aes(
            group="library_type",
        ),
    )
    + geom_point(
        aes(
            fill="mut_type",
            shape="library"
        ),
        size=4
    )
    + scale_fill_manual(values=CBPALETTE)
    + facet_wrap("measurement_type")
    + theme_classic()
    + theme(
        figure_size=(6, 3.3),
        axis_text_x=element_text(angle=90),
        panel_grid_major_x=element_blank(),  # no vertical grid lines
    )
    + labs(
        title=f"",
        x="lasso penalty coefficient (λ)",
        y="sparsity"
    )
)
_ = p.draw(show=True)
[ ]:
data.to_csv(f"{csv_output_dir}/fit_sparsity.csv", index=False)
data.round(2).head()

Replicate correlations

Another common metric to look at when evaluating models is to look at the correlation of the inferred mutational effects between the two libraries. We can do this by computing the pearson correlation coefficient between the inferred mutational effects for each mutation in the two libraries using the multidms.ModelCollection.mut_param_dataset_correlation method.

[ ]:
chart, data = model_collection.mut_param_dataset_correlation(return_data=True)
data.head()

By default, the method returns a dataframe giving parameter between each unique pair of datasets, meaning it includes correlations between models fit on different measurement types. We’ll only be comparing models trained the same fitting target combination.

[ ]:
data = (
    data
    .assign(
        lib1 = data.datasets.str.split(",").str[0],
        lib2 = data.datasets.str.split(",").str[1],
        measurement_type_1 = lambda x: x["lib1"].str.split("_").str[2:4].str.join("_"),
        measurement_type_2 = lambda x: x["lib2"].str.split("_").str[2:4].str.join("_")
    )
    .query("(measurement_type_1 == measurement_type_2) & (~mut_param.str.contains('predicted_func_score'))")
    .rename(
        {
            "measurement_type_1" : "measurement_type",
        },
        axis=1
    )
    .replace({"shift_h2": "shift"})
    .drop(["lib1", "lib2", "datasets", "measurement_type_2"], axis=1)
)

data.head(10)
[ ]:


data["measurement_type"] = pd.Categorical( data["measurement_type"], categories=["observed_phenotype", "loose_bottle", "tight_bottle"], ordered=True ) data["mut_param"] = pd.Categorical( data["mut_param"], categories=["shift", "beta"], ordered=True ) data["coef_lasso_shift"] = data.coef_lasso_shift.astype(object) for parameter, parameter_df in data.groupby("mut_param", observed=True): p = ( ggplot( parameter_df ) + geom_line( aes( x="coef_lasso_shift", y="correlation", group="measurement_type" ), ) + geom_point( aes( x="coef_lasso_shift", y="correlation", ), size=4 ) + facet_wrap("measurement_type", scales="free_y") + theme_classic() + theme( figure_size=(6, 3.3), axis_text_x=element_text(angle=90), panel_grid_major_x=element_blank(), # no vertical grid lines ) + labs( title=f"Library Replicate Correlations: {parameter}", x="lasso penalty coefficient (λ)", y="pearsonr" ) ) _ = p.draw(show=True)
[ ]:
data.to_csv(f"{csv_output_dir}/library_replicate_correlation.csv", index=False)
data.round(2).head()

Model vs. truth variant phenotypes

Let’s take a look at how the lasso constraint affects the models’ ability to predict the true latent and observed phenotypes.

First, we grab the model predictions of latent and observed phenotypes for each of the datasets using multidms.Model.get_variants_df.

[ ]:
variants_df = pd.concat(
    [
        row.model.get_variants_df(phenotype_as_effect=False)
        .assign(
            library=row.library,
            measurement_type=row.measurement_type,
            coef_lasso_shift=row.coef_lasso_shift,
        )
        .rename(
            {
                "predicted_func_score": "predicted_phenotype",
                "predicted_latent": "predicted_latent_phenotype",
                "func_score" : "measured_phenotype",
            },
            axis=1,
        )
        # add enrichments
        .assign(
            predicted_enrichment = lambda x: 2**x['predicted_phenotype'],
            measured_enrichment = lambda x: 2**x['measured_phenotype'],
            fit_idx = idx
        )
        for idx, row in fit_collection_df.iterrows()
    ]
)
variants_df.head()

Next, add the simulated ground truth phenotypes:

[ ]:
variants_df  = pd.concat(
    [
        variants_df.query("condition == @homolog")
        .assign(
            true_latent_phenotype = lambda x: x['aa_substitutions'].apply(phenotype_fxn_dict["latentPhenotype"]),
            true_observed_phenotype = lambda x: x['aa_substitutions'].apply(phenotype_fxn_dict["observedPhenotype"]),
            true_enrichment = lambda x: x['aa_substitutions'].apply(phenotype_fxn_dict["observedEnrichment"]),
        )
        for homolog, phenotype_fxn_dict in zip(["h1", "h2"], [phenotype_fxn_dict_h1, phenotype_fxn_dict_h2])
    ]
)

variants_df["measurement_type"] = pd.Categorical(
    variants_df["measurement_type"],
    categories=["observed_phenotype", "loose_bottle", "tight_bottle"],
    ordered=True
)

# As a sanity check let's make sure the true phenotypes match the
# "measured phenotypes" (i.e. fitting targets) for the variants that were trained on ground truth targets
assert np.allclose(
    variants_df.query("measurement_type == 'observed_phenotype'")['measured_phenotype'],
    variants_df.query("measurement_type == 'observed_phenotype'")['true_observed_phenotype'],
)

For clarity, let’s define and arrange the columns we now have:

[ ]:
cols = [
    # unique combinations of these make up the model collection that we've fit
    "library", # replicate library,
    "measurement_type", # type of functional score
    "coef_lasso_shift", # lasso coefficient of model

    # variant defining columns
    "aa_substitutions", # variant substitutions
    "var_wrt_ref", # variant substitutions relative to reference wildtype
    "condition", # homolog

    "measured_phenotype", # the actual target functional score for training, in multidms, this is "func_score"
    "measured_enrichment", # 2 ** measured_func_score

    "predicted_latent_phenotype", # predicted latent phenotype
    "predicted_phenotype", # predicted observed phenotype - or in jesse's case, the "observed phenotype"
    "predicted_enrichment", # predicted enrichment

    "true_latent_phenotype", # true latent phenotype
    "true_observed_phenotype", # true observed phenotype
    "true_enrichment", # true enrichment
]
# variants_df[cols].round(2).head()

Add correlation of ground truth targets / predicted phenotypes to the fit collection dataframe and plot them:

[ ]:
for idx, model_variants_df in variants_df.groupby("fit_idx"):
    for metric_fxn, metric_name in zip([series_corr, series_mae], ["corr", "mae"]):
        fit_collection_df.loc[idx, f"variant_phenotype_{metric_name}"] = metric_fxn(
            model_variants_df["measured_phenotype"],
            model_variants_df["predicted_phenotype"]
        )

fit_collection_df[["coef_lasso_shift", "dataset_name", "variant_phenotype_corr", "variant_phenotype_mae"]].head()

[ ]:
# for metric in ["corr", "mae"]:
data = fit_collection_df[["library", "measurement_type", "coef_lasso_shift", "dataset_name", "variant_phenotype_corr", "variant_phenotype_mae"]]
metric = "corr"
p = (
    ggplot(
        data
        .assign(
            measurement_library = lambda x: x["library"] + "_" + x["measurement_type"].astype(str)
        )
    )
    + geom_line(
        aes(
            x="coef_lasso_shift",
            y=f"variant_phenotype_{metric}",
            group="measurement_library",
        ),
    )
    + geom_point(
        aes(
            x="coef_lasso_shift",
            y=f"variant_phenotype_{metric}",
            shape="library"
        ),
        size=4
    )
    + facet_wrap("measurement_type", scales="free_y")
    + theme_classic()
    + theme(
        figure_size=(6.5, 3.3),
        axis_text_x=element_text(angle=90),
        panel_grid_major_x=element_blank(),  # no vertical grid lines
    )
    + labs(
        title=f"Predicted Vs. True Variant Phenotype",
        x="lasso penalty coefficient ($\lambda_{L1}$)",
        y="pearsonr" if metric == "corr" else "mean absolute error"
    )
)
_ = p.draw(show=True)

As expected, the lasso constraint has a negative effect on the models’ ability to predict the true latent and observed phenotypes. This is because a model with more freedom to choose parameters is likely to overfit to the training data.

[ ]:
data.to_csv(f"{csv_output_dir}/model_vs_truth_variant_phenotype.csv", index=False)
data.round(2).head()

Cross validation

Above, we saw evidence that the lasso negatively impacts the model performance on the training data. To test for overfitting, we can perform cross validation to test that the model is actually more accurate on unseen data.

[ ]:
train, test = [], {}
for (library, measurement), fs_df in func_scores.rename(columns={"homolog":"condition"}).groupby(["library", "func_score_type"]):

    if "enrichment" in measurement:
        continue

    dataset = fs_df.sample(frac=1, random_state=seed)
    train_split, test_split = dataset.iloc[:int(len(dataset) * train_frac)], dataset.iloc[int(len(dataset) * train_frac):]
    name = f"{library}_{measurement}"

    train.append(
        multidms.Data(
            train_split,
            reference="h1",
            alphabet = multidms.AAS_WITHSTOP_WITHGAP,
            verbose=False,
            name = name
        )
    )

    test[name] = test_split
[ ]:
fitting_params["dataset"] = train
_, _, fit_collection_cv = multidms.model_collection.fit_models(fitting_params)

The multidms.ModelCollection.add_validation_loss method takes in unseen data, computes model loss on that data, and appends it to the ModelCollection.fit_models dataframe.

[ ]:
mc = multidms.model_collection.ModelCollection(fit_collection_cv)
mc.add_validation_loss(test, overwrite=True)
[ ]:
mc.fit_models[[c for c in mc.fit_models.columns if "loss" in c]].head()
[ ]:
data = mc.fit_models

data["library"] = data["dataset_name"].str.split("_").str[0:2].str.join("_")
data["measurement_type"] = data["dataset_name"].str.split("_").str[2:4].str.join("_")
data["measurement_type"] = pd.Categorical(
    data["measurement_type"],
    categories=["observed_phenotype", "loose_bottle", "tight_bottle"],
    ordered=True
)

data = (
    data
    .melt(
        id_vars=["coef_lasso_shift", "library", "measurement_type"],
        value_vars=["total_loss_training", "total_loss_validation"],
        var_name="dataset",
        value_name="loss",
    )
    .assign(
        dataset = lambda x: x["dataset"].str.replace("_loss", "")
    )
    .assign(
        lib_dataset = lambda x: x["library"] + " " + x["dataset"]
    )
)
p = (
    ggplot(data)
    + geom_line(
        aes(
            x="coef_lasso_shift",
            y=f"loss",
            group="lib_dataset",
        )
    )
    + geom_point(
        aes(
            x="coef_lasso_shift",
            y=f"loss",
            fill="dataset",
            shape="library"
        ),
        size=4
    )
    + facet_wrap("measurement_type", scales="free_y")
    + theme_classic()
    + theme(
        figure_size=(6.5, 3.3),
        axis_text_x=element_text(angle=90),
        panel_grid_major_x=element_blank(),  # no vertical grid lines
    )
    + labs(
        x="lasso penalty coefficient (λ)",
        y="Huber loss w/o penalty"
    )
)
_ = p.draw(show=True)

Above we can see that indeed for both the loose and tight bottleneck datasets, the lasso constraint provides a more robust model to unseen data. We don’t however see this effect for the models trained on ground truth observed phenotypes, as there is no noise for the models to overfit to.

[ ]:
data.to_csv(f"{csv_output_dir}/cross_validation_loss.csv", index=False)
data.round(2).head()

Final Selection

This analysis above provides a number of clues to inform our choice of model. With empricial data we obviously won’t have comparisons to the ground truth values, and thus the choice of a lasso penalty will laregely depend on the shift parameter sparsity, correlation of inferred mutational effects between replicate libraries, and cross validation performance. Focusing on the loose bottleneck training dataset (a fairly realistic level of noise that we often observe in real experiments), it seems that a lasso penalty of \(1e-4\) provides a false positive rate of nearly zero (i.e. the stop codon sparsity \(\approx 100\%\)), a good correlation of inferred mutational effects between replicate libraries (>0.95 pearsonr), and a relatively low loss on validation when compared to the other models.

To validate the ability to correctly infer the shape of epistasis, we’ll plot the inferred global epistasis function for each of the models, and compare it to the true global epistasis function.

[ ]:
data = variants_df.query(f"coef_lasso_shift == {lasso_choice}")
for x, y in itertools.combinations(
    ["predicted_latent_phenotype", "predicted_phenotype", "measured_phenotype"], 2
):
    p = (
        ggplot(data.sample(frac=1.0))
        + geom_point(
            aes(
                x=x,
                y=y
            ),
            alpha=0.05,
            color="grey"
        )
        + ylab(y)
        + xlab(x)
        + facet_grid("~measurement_type")
        + theme_classic()
        + theme(
            figure_size=(6.5, 3.5),
            axis_text_x=element_text(angle=90),
        )
    )

    if x == "predicted_latent_phenotype":
        p+= geom_point(
            aes(
                x="true_latent_phenotype",
                y="true_observed_phenotype"
            ),
            size=0.5,
        )
    else:
        p += geom_abline(
            slope=1,
            intercept=0,
        )

    _ = p.draw(show=True)

Plot model vs. measured (functional score) vs. ground truth enrichments.

[ ]:
p = (
    ggplot(
        data.melt(
            # id_vars=["measurement_type", "library", "condition", "true_enrichment"],
            id_vars=["measurement_type", "library", "true_enrichment"],
            value_vars=["predicted_enrichment", "measured_enrichment"],
            var_name="enrichment_type",
            value_name="enrichment",
        ),
        # variants_df,
        aes("true_enrichment", "enrichment"),
    )
    + geom_abline(
        slope=1,
        intercept=0,
        # alpha=0.5,
        # size=0.5,
    )
    + geom_point(alpha=0.05, size=0.5)
    + facet_grid("enrichment_type~measurement_type", scales="free_y")
    + theme_classic()
    + theme(
        figure_size=(6.5, 5),
    )
)

_ = p.draw(show=True)

Amazingly, the model does better at predicting true enrichments than even the counts based functional scores!

Finally, let’s plot the relationship between the model’s inferred shifts, and the ground truth shifts.

[ ]:
# plot correlations of true and predicted shifts
collection_muts_df["measurement_type"] = pd.Categorical(
    collection_muts_df["measurement_type"],
    categories=["observed_phenotype", "loose_bottle", "tight_bottle"],
    ordered=True
)

for (x, y) in itertools.combinations(
    ["true_shift", "predicted_shift_h2"], 2
):
    p = (
        ggplot(collection_muts_df.query(f"coef_lasso_shift == {lasso_choice}"))
        + geom_point(
            aes(
                x=x,
                y=y
            ),
            alpha=0.5,
        )
        + geom_abline(
            slope=1,
            intercept=0,
        )
        + ylab(y)
        + xlab(x)
        + facet_grid("library~measurement_type")
        + theme_classic()
        + theme(
            figure_size=(6.5, 5),
            axis_text_x=element_text(angle=90),
        )
    )
    _ = p.draw(show=True)