Model Fitting¶
Fit multidms models to simulated functional-score data across a grid of fusion-regularization values.
Outline
Load simulated functional scores from disk
Create
multidms.Dataobjects (one per library x func_score_type)Fit models across the regularization grid via
fit_models()Post-process and save the fit collection
[1]:
import warnings
warnings.filterwarnings("ignore")
import os
import pickle
import sys
sys.path.insert(0, "notebooks")
import pandas as pd
import multidms
from multidms.model_collection import fit_models
from _common import load_config, build_fit_params
[2]:
config_path = "config/config.yaml"
[3]:
# Parameters
config_path = "config/config.yaml"
[4]:
config = load_config(config_path)
sim = config["simulation"]
fit_config = sim["fitting"]
output_dir = sim["output_dir"]
os.makedirs(output_dir, exist_ok=True)
print(f"Output directory: {output_dir}")
Output directory: results
Load simulated functional scores¶
[5]:
func_scores = pd.read_csv(os.path.join(output_dir, "simulated_func_scores.csv"))
func_scores["func_score_type"] = pd.Categorical(
func_scores["func_score_type"],
categories=["observed_phenotype", "loose_bottle", "tight_bottle"],
ordered=True,
)
print(f"Loaded {len(func_scores)} rows")
func_scores.head()
Loaded 349866 rows
[5]:
| library | homolog | aa_substitutions | func_score_type | func_score | pre_sample | func_score_var | pre_count | post_count | pre_count_wt | post_count_wt | pseudocount | n_aa_substitutions | variant_class | latent_phenotype | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | lib_1 | h1 | S12F P43F | observed_phenotype | -0.338826 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | >1 nonsynonymous | 2.696775 |
| 1 | lib_1 | h1 | L33P | observed_phenotype | 0.008957 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 1 nonsynonymous | 5.253866 |
| 2 | lib_1 | h1 | N3V F16S | observed_phenotype | -0.031679 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | >1 nonsynonymous | 4.413085 |
| 3 | lib_1 | h1 | N22W L49I | observed_phenotype | -5.121204 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | >1 nonsynonymous | -1.817175 |
| 4 | lib_1 | h1 | K21N | observed_phenotype | -0.032459 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 1 nonsynonymous | 4.402149 |
Create Data objects¶
One multidms.Data object per (library, func_score_type) combination.
[6]:
data_objects = []
for (lib, fst), group_df in func_scores.rename(
columns={"homolog": "condition"}
).groupby(["library", "func_score_type"]):
df = group_df.copy()
df["aa_substitutions"] = df["aa_substitutions"].fillna("")
data_objects.append(
multidms.Data(
df,
reference="h1",
alphabet=multidms.AAS_WITHSTOP_WITHGAP,
verbose=False,
name=f"{lib}_{fst}_func_score",
)
)
print(f"Created {len(data_objects)} Data objects:")
for d in data_objects:
print(f" {d.name}")
Created 6 Data objects:
lib_1_observed_phenotype_func_score
lib_1_loose_bottle_func_score
lib_1_tight_bottle_func_score
lib_2_observed_phenotype_func_score
lib_2_loose_bottle_func_score
lib_2_tight_bottle_func_score
Build fitting parameters and fit models¶
[7]:
fitting_params = build_fit_params(fit_config, data_objects)
print("Fitting parameters:")
for k, v in fitting_params.items():
if k != "dataset":
print(f" {k}: {v}")
Fitting parameters:
maxiter: [100]
tol: [1e-06]
fusionreg: [0.0, 0.4, 0.8, 1.6, 3.2, 6.4, 12.8]
l2reg: [0.0001]
beta0_ridge: [1e-05]
ge_type: ['Sigmoid']
ge_kwargs: [{'tol': 1e-05, 'maxiter': 1000, 'maxls': 40, 'jit': True, 'verbose': False}]
cal_kwargs: [{'tol': 0.0001, 'maxiter': 1000, 'maxls': 40, 'jit': True, 'verbose': False}]
loss_kwargs: [{'δ': 1.0}]
warmstart: [False]
beta0_init: [{'h1': 5.0, 'h2': 0.0}]
alpha_init: [{'h1': 6.0, 'h2': 6.0}]
beta_clip_range: [(-10, 10)]
[8]:
import os
from multidms.utils import explode_params_dict
# Determine n_processes: null in config = auto-detect
n_models = len(explode_params_dict(fitting_params))
cfg_n_processes = fit_config.get("n_processes")
if cfg_n_processes is None:
n_processes = min(os.cpu_count() // 2, n_models)
else:
n_processes = min(int(cfg_n_processes), n_models)
n_processes = max(n_processes, 1) # ensure at least 1
print(f"Fitting {n_models} models with n_processes={n_processes} (cpus={os.cpu_count()})")
n_fit, n_failed, fit_collection_df = fit_models(
fitting_params, n_processes=n_processes
)
# Convert dict-valued columns to strings for groupby compatibility
for col in fit_collection_df.columns:
if fit_collection_df[col].apply(lambda x: isinstance(x, dict)).any():
fit_collection_df[col] = fit_collection_df[col].apply(str)
print(f"Fit {n_fit} models successfully, {n_failed} failed")
Fitting 42 models with n_processes=16 (cpus=32)
Fit 42 models successfully, 0 failed
Post-process and save¶
Extract library and measurement_type from the dataset name.
[9]:
fit_collection_df = fit_collection_df.assign(
library=(
fit_collection_df["dataset_name"]
.str.split("_").str[0:2].str.join("_")
),
measurement_type=(
fit_collection_df["dataset_name"]
.str.split("_").str[2:4].str.join("_")
),
)
fit_collection_df["measurement_type"] = pd.Categorical(
fit_collection_df["measurement_type"],
categories=["observed_phenotype", "loose_bottle", "tight_bottle"],
ordered=True,
)
print(f"Post-processed {len(fit_collection_df)} models")
Post-processed 42 models
[10]:
output_path = os.path.join(output_dir, "fit_collection.pkl")
with open(output_path, "wb") as f:
pickle.dump(fit_collection_df, f)
print(f"Saved {output_path} ({len(fit_collection_df)} models)")
Saved results/fit_collection.pkl (42 models)
Summary¶
[11]:
summary_cols = ["dataset_name", "library", "measurement_type", "fusionreg", "fit_time"]
display_cols = [c for c in summary_cols if c in fit_collection_df.columns]
fit_collection_df[display_cols]
[11]:
| dataset_name | library | measurement_type | fusionreg | fit_time | |
|---|---|---|---|---|---|
| 0 | lib_1_observed_phenotype_func_score | lib_1 | observed_phenotype | 0.0 | 1100 |
| 1 | lib_1_observed_phenotype_func_score | lib_1 | observed_phenotype | 0.4 | 753 |
| 2 | lib_1_observed_phenotype_func_score | lib_1 | observed_phenotype | 0.8 | 583 |
| 3 | lib_1_observed_phenotype_func_score | lib_1 | observed_phenotype | 1.6 | 828 |
| 4 | lib_1_observed_phenotype_func_score | lib_1 | observed_phenotype | 3.2 | 1579 |
| 5 | lib_1_observed_phenotype_func_score | lib_1 | observed_phenotype | 6.4 | 1560 |
| 6 | lib_1_observed_phenotype_func_score | lib_1 | observed_phenotype | 12.8 | 1491 |
| 7 | lib_1_loose_bottle_func_score | lib_1 | loose_bottle | 0.0 | 1183 |
| 8 | lib_1_loose_bottle_func_score | lib_1 | loose_bottle | 0.4 | 829 |
| 9 | lib_1_loose_bottle_func_score | lib_1 | loose_bottle | 0.8 | 653 |
| 10 | lib_1_loose_bottle_func_score | lib_1 | loose_bottle | 1.6 | 645 |
| 11 | lib_1_loose_bottle_func_score | lib_1 | loose_bottle | 3.2 | 1347 |
| 12 | lib_1_loose_bottle_func_score | lib_1 | loose_bottle | 6.4 | 1520 |
| 13 | lib_1_loose_bottle_func_score | lib_1 | loose_bottle | 12.8 | 1192 |
| 14 | lib_1_tight_bottle_func_score | lib_1 | tight_bottle | 0.0 | 1012 |
| 15 | lib_1_tight_bottle_func_score | lib_1 | tight_bottle | 0.4 | 761 |
| 16 | lib_1_tight_bottle_func_score | lib_1 | tight_bottle | 0.8 | 487 |
| 17 | lib_1_tight_bottle_func_score | lib_1 | tight_bottle | 1.6 | 433 |
| 18 | lib_1_tight_bottle_func_score | lib_1 | tight_bottle | 3.2 | 1419 |
| 19 | lib_1_tight_bottle_func_score | lib_1 | tight_bottle | 6.4 | 1390 |
| 20 | lib_1_tight_bottle_func_score | lib_1 | tight_bottle | 12.8 | 1394 |
| 21 | lib_2_observed_phenotype_func_score | lib_2 | observed_phenotype | 0.0 | 1189 |
| 22 | lib_2_observed_phenotype_func_score | lib_2 | observed_phenotype | 0.4 | 868 |
| 23 | lib_2_observed_phenotype_func_score | lib_2 | observed_phenotype | 0.8 | 690 |
| 24 | lib_2_observed_phenotype_func_score | lib_2 | observed_phenotype | 1.6 | 829 |
| 25 | lib_2_observed_phenotype_func_score | lib_2 | observed_phenotype | 3.2 | 845 |
| 26 | lib_2_observed_phenotype_func_score | lib_2 | observed_phenotype | 6.4 | 1208 |
| 27 | lib_2_observed_phenotype_func_score | lib_2 | observed_phenotype | 12.8 | 1136 |
| 28 | lib_2_loose_bottle_func_score | lib_2 | loose_bottle | 0.0 | 700 |
| 29 | lib_2_loose_bottle_func_score | lib_2 | loose_bottle | 0.4 | 676 |
| 30 | lib_2_loose_bottle_func_score | lib_2 | loose_bottle | 0.8 | 605 |
| 31 | lib_2_loose_bottle_func_score | lib_2 | loose_bottle | 1.6 | 818 |
| 32 | lib_2_loose_bottle_func_score | lib_2 | loose_bottle | 3.2 | 818 |
| 33 | lib_2_loose_bottle_func_score | lib_2 | loose_bottle | 6.4 | 825 |
| 34 | lib_2_loose_bottle_func_score | lib_2 | loose_bottle | 12.8 | 715 |
| 35 | lib_2_tight_bottle_func_score | lib_2 | tight_bottle | 0.0 | 597 |
| 36 | lib_2_tight_bottle_func_score | lib_2 | tight_bottle | 0.4 | 341 |
| 37 | lib_2_tight_bottle_func_score | lib_2 | tight_bottle | 0.8 | 297 |
| 38 | lib_2_tight_bottle_func_score | lib_2 | tight_bottle | 1.6 | 262 |
| 39 | lib_2_tight_bottle_func_score | lib_2 | tight_bottle | 3.2 | 465 |
| 40 | lib_2_tight_bottle_func_score | lib_2 | tight_bottle | 6.4 | 459 |
| 41 | lib_2_tight_bottle_func_score | lib_2 | tight_bottle | 12.8 | 439 |