Model Fitting

Fit multidms models to spike functional-score data across a grid of fusion-regularization values, independently for each replicate.

Outline

  1. Load training functional scores

  2. Aggregate per (condition, aa_substitutions) within each replicate

  3. Create multidms.Data objects (one per replicate)

  4. Fit models across the regularization grid via fit_models()

  5. Save the fit collection

[1]:
import warnings

warnings.filterwarnings("ignore")

import os
import pickle
import sys

sys.path.insert(0, "notebooks")

import pandas as pd
import multidms
from multidms.model_collection import fit_models
from multidms.utils import explode_params_dict

from _common import load_config, build_fit_params
[2]:
config_path = "config/config.yaml"
[3]:
# Parameters
config_path = "config/config.yaml"

[4]:
config = load_config(config_path)
spike = config["spike"]
fit_config = spike["fitting"]
reference = spike["reference"]

output_dir = "results"
os.makedirs(output_dir, exist_ok=True)

Load training functional scores

[5]:
func_score_df = pd.read_csv(
    os.path.join(output_dir, "training_functional_scores.csv")
).fillna({"aa_substitutions": ""})
print(f"Loaded {len(func_score_df):,} variants")
Loaded 281,158 variants

Create Data objects

Aggregate functional scores per (condition, aa_substitutions) within each replicate, then create one multidms.Data object per replicate.

[6]:
data_objects = []
for rep_num, df_rep in func_score_df.groupby("replicate"):
    df_agg = (
        df_rep.groupby(["condition", "aa_substitutions"], dropna=False)
        .agg({"func_score": "mean"})
        .reset_index()
    )
    data = multidms.Data(
        df_agg,
        alphabet=multidms.AAS_WITHSTOP_WITHGAP,
        reference=reference,
        assert_site_integrity=False,
        name=f"rep_{rep_num}",
    )
    data_objects.append(data)
    print(f"rep_{rep_num}: {len(df_agg):,} variants, conditions={data.conditions}")
rep_1: 145,970 variants, conditions=('Delta', 'Omicron_BA1', 'Omicron_BA2')
rep_2: 135,188 variants, conditions=('Delta', 'Omicron_BA1', 'Omicron_BA2')

Fit models

[7]:
fitting_params = build_fit_params(fit_config, data_objects)
print("Fitting parameters:")
for k, v in fitting_params.items():
    if k != "dataset":
        print(f"  {k}: {v}")
Fitting parameters:
  maxiter: [75]
  tol: [1e-06]
  fusionreg: [0.0, 0.4, 0.8, 1.6, 3.2, 6.4, 12.8]
  l2reg: [0.01]
  beta0_ridge: [0]
  ge_type: ['Sigmoid']
  ge_kwargs: [{'tol': 1e-05, 'maxiter': 1000, 'maxls': 40, 'jit': True, 'verbose': False}]
  cal_kwargs: [{'tol': 0.0001, 'maxiter': 1000, 'maxls': 40, 'jit': True, 'verbose': False}]
  loss_kwargs: [{'δ': 1.0}]
  warmstart: [False]
  beta0_init: [{'Omicron_BA1': 0.0, 'Delta': 0.0, 'Omicron_BA2': 0.0}]
  alpha_init: [{'Omicron_BA1': 6.0, 'Delta': 6.0, 'Omicron_BA2': 6.0}]
  beta_clip_range: [(-10, 10)]
[8]:
n_models = len(explode_params_dict(fitting_params))
cfg_n_processes = fit_config.get("n_processes")

if cfg_n_processes is None:
    n_processes = min(os.cpu_count() // 2, n_models)
else:
    n_processes = min(int(cfg_n_processes), n_models)
n_processes = max(n_processes, 1)

print(f"Fitting {n_models} models with n_processes={n_processes} (cpus={os.cpu_count()})")

n_fit, n_failed, fit_collection_df = fit_models(
    fitting_params, n_processes=n_processes
)

# Convert dict-valued columns to strings for groupby compatibility
for col in fit_collection_df.columns:
    if fit_collection_df[col].apply(lambda x: isinstance(x, dict)).any():
        fit_collection_df[col] = fit_collection_df[col].apply(str)

print(f"Fit {n_fit} models successfully, {n_failed} failed")
Fitting 14 models with n_processes=14 (cpus=64)
Fit 14 models successfully, 0 failed

Save

[9]:
output_path = os.path.join(output_dir, "fit_collection.pkl")
with open(output_path, "wb") as f:
    pickle.dump(fit_collection_df, f)
print(f"Saved {output_path} ({len(fit_collection_df)} models)")
Saved results/fit_collection.pkl (14 models)
[10]:
display_cols = ["dataset_name", "fusionreg", "converged", "fit_time"]
display_cols = [c for c in display_cols if c in fit_collection_df.columns]
fit_collection_df[display_cols]
[10]:
dataset_name fusionreg fit_time
0 rep_1 0.0 979
1 rep_1 0.4 616
2 rep_1 0.8 1566
3 rep_1 1.6 627
4 rep_1 3.2 1075
5 rep_1 6.4 890
6 rep_1 12.8 1064
7 rep_2 0.0 1268
8 rep_2 0.4 652
9 rep_2 0.8 599
10 rep_2 1.6 1956
11 rep_2 3.2 1468
12 rep_2 6.4 808
13 rep_2 12.8 2079