Source code for tdm.plot.model.calibration

"""
Model calibration plot.
"""

import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from sklearn.calibration import calibration_curve
from copy import deepcopy

from tdm.analysis import Analysis
from tdm.model.model import Model
from tdm.dataset.dataset import Dataset
from tdm.cell_types import cell_type_to_full_name


[docs] def plot_calibration(ana: Analysis, n_cells_per_bin: int = 2000, max_p: dict | float | None = None): """Plots model calibration for all cell types. Args: ana (Analysis): an Analysis object. n_cells_per_bin (int, optional): number of cells used in each point. A small number introduces greater error in estimating the true probability. Defaults to 2000. max_p (float | None, optional): sets the y-lim of the true probability axis. Defaults to None. """ n_types = len(ana.cell_types) fig, axs = plt.subplots(ncols=n_types, figsize=(2 * n_types, 2)) for i, cell_type in enumerate(ana.cell_types): ax = axs[i] if isinstance(max_p, dict): _p = max_p[cell_type] else: _p = max_p _plot_model_calibration( cell_type=cell_type, model=ana.model, dataset=ana.pds, n_cells_per_bin=n_cells_per_bin, ax=ax, max_p=_p, ) fig.tight_layout()
[docs] def plot_calibration_against_permuted_divisions(ana: Analysis, n_cells_per_bin: int = 2000): """Plots model calibration for all cell types. Plots calibration on real data and a version with permuted divisions side by side. Args: ana (Analysis): an Analysis object. n_cells_per_bin (int, optional): number of cells used in each point. A small number introduces greater error in estimating the true probability. Defaults to 2000. max_p (float | None, optional): sets the y-lim of the true probability axis. Defaults to None. """ # original model and dataset: pds = ana.pds m = ana.model # permuted model and dataset: pds_permuted = deepcopy(pds) for cell_type in pds.cell_types(): counts, div = pds_permuted.fetch(cell_type) div_permuted = div.iloc[np.random.permutation(len(div))].reset_index(drop=True) pds_permuted.dataset_dict[cell_type] = counts, div_permuted m_permuted = ana._fit_model(pds_permuted) # plot calibration plots for all types: for cell_type in m.cell_types(): fig, axs = plt.subplots(figsize=(6, 3), ncols=2) ax = axs[0] min_p, max_p = _plot_model_calibration( cell_type, m, pds, n_cells_per_bin=n_cells_per_bin, ax=ax, max_p=None, plot_min_max_lines=True, ) ax = axs[1] _plot_model_calibration( cell_type, m_permuted, pds_permuted, n_cells_per_bin=n_cells_per_bin, ax=ax, min_p=min_p, max_p=max_p, plot_min_max_lines=True, ) ax.set_title("Calibration over permuted divisions") fig.tight_layout()
def _plot_model_calibration( cell_type: str, model: Model, dataset: Dataset, n_cells_per_bin: int = 2000, ax=None, min_p: float | None = None, max_p: float | None = None, obs_type="division", plot_min_max_lines: bool = False, ) -> tuple[float, float]: """ """ features, obs = dataset.fetch(cell_type) probs = model.predict(cell_type, obs=obs_type, features=features) n_bins = int(len(probs) / n_cells_per_bin) prob_true, prob_pred = calibration_curve( obs[obs_type], probs, n_bins=n_bins, strategy="quantile", ) if ax is None: fig, ax = plt.subplots(figsize=(5, 5)) sns.scatterplot(y=prob_true, x=prob_pred, color="#53AC69", edgecolor="#252526", ax=ax) min_p = min_p or prob_true.min() max_p = max_p or prob_true.max() sns.lineplot( x=np.linspace(0, max_p, 10), y=np.linspace(0, max_p, 10), color="black", linestyle="--", ax=ax, ) if plot_min_max_lines: ax.axhline(y=prob_true.min(), color="red", linestyle="--") ax.axhline(y=prob_true.max(), color="red", linestyle="--") ax.set_xlabel("predicted probability") ax.set_ylabel("true probability") ax.set_title(f"{cell_type_to_full_name(cell_type)} model") eps = 0.01 ax.set_xlim(min_p - eps, max_p + eps) ax.set_ylim(min_p - eps, max_p + eps) sns.despine(ax=ax) return min_p, max_p