Model Fitting¶
Fit multidms models to spike functional-score data across a grid of fusion-regularization values, independently for each replicate.
Outline
Load training functional scores
Aggregate per (condition, aa_substitutions) within each replicate
Create
multidms.Dataobjects (one per replicate)Fit models across the regularization grid via
fit_models()Save the fit collection
[1]:
import warnings
warnings.filterwarnings("ignore")
import os
import pickle
import sys
sys.path.insert(0, "notebooks")
import pandas as pd
import multidms
from multidms.model_collection import fit_models
from multidms.utils import explode_params_dict
from _common import load_config, build_fit_params
[2]:
config_path = "config/config.yaml"
[3]:
# Parameters
config_path = "config/config.yaml"
[4]:
config = load_config(config_path)
spike = config["spike"]
fit_config = spike["fitting"]
reference = spike["reference"]
output_dir = "results"
os.makedirs(output_dir, exist_ok=True)
Load training functional scores¶
[5]:
func_score_df = pd.read_csv(
os.path.join(output_dir, "training_functional_scores.csv")
).fillna({"aa_substitutions": ""})
print(f"Loaded {len(func_score_df):,} variants")
Loaded 281,158 variants
Create Data objects¶
Aggregate functional scores per (condition, aa_substitutions) within each replicate, then create one multidms.Data object per replicate.
[6]:
data_objects = []
for rep_num, df_rep in func_score_df.groupby("replicate"):
df_agg = (
df_rep.groupby(["condition", "aa_substitutions"], dropna=False)
.agg({"func_score": "mean"})
.reset_index()
)
data = multidms.Data(
df_agg,
alphabet=multidms.AAS_WITHSTOP_WITHGAP,
reference=reference,
assert_site_integrity=False,
name=f"rep_{rep_num}",
)
data_objects.append(data)
print(f"rep_{rep_num}: {len(df_agg):,} variants, conditions={data.conditions}")
rep_1: 145,970 variants, conditions=('Delta', 'Omicron_BA1', 'Omicron_BA2')
rep_2: 135,188 variants, conditions=('Delta', 'Omicron_BA1', 'Omicron_BA2')
Fit models¶
[7]:
fitting_params = build_fit_params(fit_config, data_objects)
print("Fitting parameters:")
for k, v in fitting_params.items():
if k != "dataset":
print(f" {k}: {v}")
Fitting parameters:
maxiter: [75]
tol: [1e-06]
fusionreg: [0.0, 0.4, 0.8, 1.6, 3.2, 6.4, 12.8]
l2reg: [0.01]
beta0_ridge: [0]
ge_type: ['Sigmoid']
ge_kwargs: [{'tol': 1e-05, 'maxiter': 1000, 'maxls': 40, 'jit': True, 'verbose': False}]
cal_kwargs: [{'tol': 0.0001, 'maxiter': 1000, 'maxls': 40, 'jit': True, 'verbose': False}]
loss_kwargs: [{'δ': 1.0}]
warmstart: [False]
beta0_init: [{'Omicron_BA1': 0.0, 'Delta': 0.0, 'Omicron_BA2': 0.0}]
alpha_init: [{'Omicron_BA1': 6.0, 'Delta': 6.0, 'Omicron_BA2': 6.0}]
beta_clip_range: [(-10, 10)]
[8]:
n_models = len(explode_params_dict(fitting_params))
cfg_n_processes = fit_config.get("n_processes")
if cfg_n_processes is None:
n_processes = min(os.cpu_count() // 2, n_models)
else:
n_processes = min(int(cfg_n_processes), n_models)
n_processes = max(n_processes, 1)
print(f"Fitting {n_models} models with n_processes={n_processes} (cpus={os.cpu_count()})")
n_fit, n_failed, fit_collection_df = fit_models(
fitting_params, n_processes=n_processes
)
# Convert dict-valued columns to strings for groupby compatibility
for col in fit_collection_df.columns:
if fit_collection_df[col].apply(lambda x: isinstance(x, dict)).any():
fit_collection_df[col] = fit_collection_df[col].apply(str)
print(f"Fit {n_fit} models successfully, {n_failed} failed")
Fitting 14 models with n_processes=14 (cpus=64)
Fit 14 models successfully, 0 failed
Save¶
[9]:
output_path = os.path.join(output_dir, "fit_collection.pkl")
with open(output_path, "wb") as f:
pickle.dump(fit_collection_df, f)
print(f"Saved {output_path} ({len(fit_collection_df)} models)")
Saved results/fit_collection.pkl (14 models)
[10]:
display_cols = ["dataset_name", "fusionreg", "converged", "fit_time"]
display_cols = [c for c in display_cols if c in fit_collection_df.columns]
fit_collection_df[display_cols]
[10]:
| dataset_name | fusionreg | fit_time | |
|---|---|---|---|
| 0 | rep_1 | 0.0 | 979 |
| 1 | rep_1 | 0.4 | 616 |
| 2 | rep_1 | 0.8 | 1566 |
| 3 | rep_1 | 1.6 | 627 |
| 4 | rep_1 | 3.2 | 1075 |
| 5 | rep_1 | 6.4 | 890 |
| 6 | rep_1 | 12.8 | 1064 |
| 7 | rep_2 | 0.0 | 1268 |
| 8 | rep_2 | 0.4 | 652 |
| 9 | rep_2 | 0.8 | 599 |
| 10 | rep_2 | 1.6 | 1956 |
| 11 | rep_2 | 3.2 | 1468 |
| 12 | rep_2 | 6.4 | 808 |
| 13 | rep_2 | 12.8 | 2079 |