Model Fitting

Fit multidms models to simulated functional-score data across a grid of fusion-regularization values.

Outline

  1. Load simulated functional scores from disk

  2. Create multidms.Data objects (one per library x func_score_type)

  3. Fit models across the regularization grid via fit_models()

  4. Post-process and save the fit collection

[1]:
import warnings

warnings.filterwarnings("ignore")

import os
import pickle
import sys

sys.path.insert(0, "notebooks")

import pandas as pd
import multidms
from multidms.model_collection import fit_models

from _common import load_config, build_fit_params
[2]:
config_path = "config/config.yaml"
[3]:
# Parameters
config_path = "config/config.yaml"

[4]:
config = load_config(config_path)
sim = config["simulation"]
fit_config = sim["fitting"]
output_dir = sim["output_dir"]

os.makedirs(output_dir, exist_ok=True)
print(f"Output directory: {output_dir}")
Output directory: results

Load simulated functional scores

[5]:
func_scores = pd.read_csv(os.path.join(output_dir, "simulated_func_scores.csv"))
func_scores["func_score_type"] = pd.Categorical(
    func_scores["func_score_type"],
    categories=["observed_phenotype", "loose_bottle", "tight_bottle"],
    ordered=True,
)
print(f"Loaded {len(func_scores)} rows")
func_scores.head()
Loaded 349866 rows
[5]:
library homolog aa_substitutions func_score_type func_score pre_sample func_score_var pre_count post_count pre_count_wt post_count_wt pseudocount n_aa_substitutions variant_class latent_phenotype
0 lib_1 h1 S12F P43F observed_phenotype -0.338826 NaN NaN NaN NaN NaN NaN NaN NaN >1 nonsynonymous 2.696775
1 lib_1 h1 L33P observed_phenotype 0.008957 NaN NaN NaN NaN NaN NaN NaN NaN 1 nonsynonymous 5.253866
2 lib_1 h1 N3V F16S observed_phenotype -0.031679 NaN NaN NaN NaN NaN NaN NaN NaN >1 nonsynonymous 4.413085
3 lib_1 h1 N22W L49I observed_phenotype -5.121204 NaN NaN NaN NaN NaN NaN NaN NaN >1 nonsynonymous -1.817175
4 lib_1 h1 K21N observed_phenotype -0.032459 NaN NaN NaN NaN NaN NaN NaN NaN 1 nonsynonymous 4.402149

Create Data objects

One multidms.Data object per (library, func_score_type) combination.

[6]:
data_objects = []
for (lib, fst), group_df in func_scores.rename(
    columns={"homolog": "condition"}
).groupby(["library", "func_score_type"]):
    df = group_df.copy()
    df["aa_substitutions"] = df["aa_substitutions"].fillna("")
    data_objects.append(
        multidms.Data(
            df,
            reference="h1",
            alphabet=multidms.AAS_WITHSTOP_WITHGAP,
            verbose=False,
            name=f"{lib}_{fst}_func_score",
        )
    )

print(f"Created {len(data_objects)} Data objects:")
for d in data_objects:
    print(f"  {d.name}")
Created 6 Data objects:
  lib_1_observed_phenotype_func_score
  lib_1_loose_bottle_func_score
  lib_1_tight_bottle_func_score
  lib_2_observed_phenotype_func_score
  lib_2_loose_bottle_func_score
  lib_2_tight_bottle_func_score

Build fitting parameters and fit models

[7]:
fitting_params = build_fit_params(fit_config, data_objects)
print("Fitting parameters:")
for k, v in fitting_params.items():
    if k != "dataset":
        print(f"  {k}: {v}")
Fitting parameters:
  maxiter: [100]
  tol: [1e-06]
  fusionreg: [0.0, 0.4, 0.8, 1.6, 3.2, 6.4, 12.8]
  l2reg: [0.0001]
  beta0_ridge: [1e-05]
  ge_type: ['Sigmoid']
  ge_kwargs: [{'tol': 1e-05, 'maxiter': 1000, 'maxls': 40, 'jit': True, 'verbose': False}]
  cal_kwargs: [{'tol': 0.0001, 'maxiter': 1000, 'maxls': 40, 'jit': True, 'verbose': False}]
  loss_kwargs: [{'δ': 1.0}]
  warmstart: [False]
  beta0_init: [{'h1': 5.0, 'h2': 0.0}]
  alpha_init: [{'h1': 6.0, 'h2': 6.0}]
  beta_clip_range: [(-10, 10)]
[8]:
import os
from multidms.utils import explode_params_dict

# Determine n_processes: null in config = auto-detect
n_models = len(explode_params_dict(fitting_params))
cfg_n_processes = fit_config.get("n_processes")

if cfg_n_processes is None:
    n_processes = min(os.cpu_count() // 2, n_models)
else:
    n_processes = min(int(cfg_n_processes), n_models)

n_processes = max(n_processes, 1)  # ensure at least 1
print(f"Fitting {n_models} models with n_processes={n_processes} (cpus={os.cpu_count()})")

n_fit, n_failed, fit_collection_df = fit_models(
    fitting_params, n_processes=n_processes
)

# Convert dict-valued columns to strings for groupby compatibility
for col in fit_collection_df.columns:
    if fit_collection_df[col].apply(lambda x: isinstance(x, dict)).any():
        fit_collection_df[col] = fit_collection_df[col].apply(str)

print(f"Fit {n_fit} models successfully, {n_failed} failed")

Fitting 42 models with n_processes=16 (cpus=32)
Fit 42 models successfully, 0 failed

Post-process and save

Extract library and measurement_type from the dataset name.

[9]:
fit_collection_df = fit_collection_df.assign(
    library=(
        fit_collection_df["dataset_name"]
        .str.split("_").str[0:2].str.join("_")
    ),
    measurement_type=(
        fit_collection_df["dataset_name"]
        .str.split("_").str[2:4].str.join("_")
    ),
)

fit_collection_df["measurement_type"] = pd.Categorical(
    fit_collection_df["measurement_type"],
    categories=["observed_phenotype", "loose_bottle", "tight_bottle"],
    ordered=True,
)

print(f"Post-processed {len(fit_collection_df)} models")
Post-processed 42 models
[10]:
output_path = os.path.join(output_dir, "fit_collection.pkl")
with open(output_path, "wb") as f:
    pickle.dump(fit_collection_df, f)
print(f"Saved {output_path} ({len(fit_collection_df)} models)")
Saved results/fit_collection.pkl (42 models)

Summary

[11]:
summary_cols = ["dataset_name", "library", "measurement_type", "fusionreg", "fit_time"]
display_cols = [c for c in summary_cols if c in fit_collection_df.columns]
fit_collection_df[display_cols]
[11]:
dataset_name library measurement_type fusionreg fit_time
0 lib_1_observed_phenotype_func_score lib_1 observed_phenotype 0.0 1100
1 lib_1_observed_phenotype_func_score lib_1 observed_phenotype 0.4 753
2 lib_1_observed_phenotype_func_score lib_1 observed_phenotype 0.8 583
3 lib_1_observed_phenotype_func_score lib_1 observed_phenotype 1.6 828
4 lib_1_observed_phenotype_func_score lib_1 observed_phenotype 3.2 1579
5 lib_1_observed_phenotype_func_score lib_1 observed_phenotype 6.4 1560
6 lib_1_observed_phenotype_func_score lib_1 observed_phenotype 12.8 1491
7 lib_1_loose_bottle_func_score lib_1 loose_bottle 0.0 1183
8 lib_1_loose_bottle_func_score lib_1 loose_bottle 0.4 829
9 lib_1_loose_bottle_func_score lib_1 loose_bottle 0.8 653
10 lib_1_loose_bottle_func_score lib_1 loose_bottle 1.6 645
11 lib_1_loose_bottle_func_score lib_1 loose_bottle 3.2 1347
12 lib_1_loose_bottle_func_score lib_1 loose_bottle 6.4 1520
13 lib_1_loose_bottle_func_score lib_1 loose_bottle 12.8 1192
14 lib_1_tight_bottle_func_score lib_1 tight_bottle 0.0 1012
15 lib_1_tight_bottle_func_score lib_1 tight_bottle 0.4 761
16 lib_1_tight_bottle_func_score lib_1 tight_bottle 0.8 487
17 lib_1_tight_bottle_func_score lib_1 tight_bottle 1.6 433
18 lib_1_tight_bottle_func_score lib_1 tight_bottle 3.2 1419
19 lib_1_tight_bottle_func_score lib_1 tight_bottle 6.4 1390
20 lib_1_tight_bottle_func_score lib_1 tight_bottle 12.8 1394
21 lib_2_observed_phenotype_func_score lib_2 observed_phenotype 0.0 1189
22 lib_2_observed_phenotype_func_score lib_2 observed_phenotype 0.4 868
23 lib_2_observed_phenotype_func_score lib_2 observed_phenotype 0.8 690
24 lib_2_observed_phenotype_func_score lib_2 observed_phenotype 1.6 829
25 lib_2_observed_phenotype_func_score lib_2 observed_phenotype 3.2 845
26 lib_2_observed_phenotype_func_score lib_2 observed_phenotype 6.4 1208
27 lib_2_observed_phenotype_func_score lib_2 observed_phenotype 12.8 1136
28 lib_2_loose_bottle_func_score lib_2 loose_bottle 0.0 700
29 lib_2_loose_bottle_func_score lib_2 loose_bottle 0.4 676
30 lib_2_loose_bottle_func_score lib_2 loose_bottle 0.8 605
31 lib_2_loose_bottle_func_score lib_2 loose_bottle 1.6 818
32 lib_2_loose_bottle_func_score lib_2 loose_bottle 3.2 818
33 lib_2_loose_bottle_func_score lib_2 loose_bottle 6.4 825
34 lib_2_loose_bottle_func_score lib_2 loose_bottle 12.8 715
35 lib_2_tight_bottle_func_score lib_2 tight_bottle 0.0 597
36 lib_2_tight_bottle_func_score lib_2 tight_bottle 0.4 341
37 lib_2_tight_bottle_func_score lib_2 tight_bottle 0.8 297
38 lib_2_tight_bottle_func_score lib_2 tight_bottle 1.6 262
39 lib_2_tight_bottle_func_score lib_2 tight_bottle 3.2 465
40 lib_2_tight_bottle_func_score lib_2 tight_bottle 6.4 459
41 lib_2_tight_bottle_func_score lib_2 tight_bottle 12.8 439