Source code for tdm.model.model

"""
Base class for all models.

A Model fits two models per cell_type, one for division and one for death.
"""

from typing import Literal, Callable, Any, Mapping
import pandas as pd
import numpy as np

from tdm.utils import dict_to_dataframe
from tdm.dataset.dataset import Dataset
from tdm.model.constant import ConstantProbabilityModel
from abc import ABC, abstractmethod

STABLE = "stable"
UNSTABLE = "unstable"


[docs] class Model(ABC): """ Base class for all models. A Model fits two models per cell_type, one for division and one for death. """
[docs] def __init__( self, dataset: Dataset, fit_cell_types: list[str] | None = None, death_estimation: Literal["mean"] = "mean", truncate_division_rate: bool = True, **kwargs, ) -> None: """Base class for models. Args: dataset (Dataset): :class:`~tdm.dataset.Dataset` used to fit the model. fit_cell_types (list[str] | None, optional): fit a model to this subset. Defaults to all cell types. death_estimation (Literal["mean", "fit"], optional): method for death rate estimation. Defaults to "mean". truncate_division_rate (bool, optional): truncate division rates at the maximal rate for cells in the data. Prevents extreme values in extrapolated regions. Defaults to False. """ self.ds = dataset self._cell_types = fit_cell_types or self.ds.cell_types() self.death_estimation = death_estimation # ensure non-positive flux at max density self.maximal_density_enforcer = None # _debug mode applies only enforcer component of dynamics self._debug_maximal_density_enforcer = False # fit models: self.models: dict[str, dict[str, Any]] = {} for cell_type in self.cell_types(): features, obs = dataset.fetch(cell_type) self.models[cell_type] = {} # division: self.models[cell_type]["division"] = self.fit(features, obs.division, cell_type=cell_type) # death: self.models[cell_type]["death"] = self._estimate_death(features, obs, cell_type=cell_type) # truncate division rate at the maximal value: self.truncate_division_rate = truncate_division_rate if self.truncate_division_rate: self._max_observed_fluxes: dict[str, float] = self._init_max_observed_fluxes(dataset)
[docs] @abstractmethod def fit(self, features: pd.DataFrame, obs: pd.Series, cell_type: str) -> object: """ Fits a single model to X=features, y=obs """ raise NotImplementedError
[docs] def cell_types(self): """ The cell types the model has been fitted to. """ return self._cell_types
def _estimate_death(self, features: pd.DataFrame, obs: pd.DataFrame, cell_type: str, method: str = "mean"): if method == "mean": return ConstantProbabilityModel(p=obs.division.mean()) else: raise ValueError(f"{method} is not a valid death estimation method")
[docs] def predict( self, cell_type: str, obs: Literal["death", "division", "division_minus_death"], features: pd.DataFrame, ) -> np.ndarray: """ Uses the model fit on cell_type data to predict death / division rates. predict() implements shared logic for all Model classes. Classes inheriting Model should implement the _predict function """ if self._debug_maximal_density_enforcer: return np.zeros(features.shape[0]) if obs == "death": return self._predict(cell_type=cell_type, obs=obs, features=features) elif obs == "division": p_div = self._predict(cell_type=cell_type, obs=obs, features=features) return self._truncate_if(p_div, cell_type) elif obs == "division_minus_death": p_death = self._predict(cell_type=cell_type, obs="death", features=features) p_div = self._predict(cell_type=cell_type, obs="division", features=features) p_div = self._truncate_if(division_rates=p_div, cell_type=cell_type) return p_div - p_death else: raise ValueError(f"Invalid argument: predict(obs={obs})")
@abstractmethod def _predict( self, cell_type: str, obs: Literal["death", "division"], features: pd.DataFrame, ) -> np.ndarray: pass
[docs] def delta_cells( self, cell_counts: Mapping[str, float | np.ndarray], return_order: list[str], mode: Literal["cells", "rates"] = "cells", ) -> list[np.ndarray] | np.ndarray: """ Uses division and death rate to compute the absolute number of cells gained / lost. Note: This method works with ACTUAL CELL NUMBERS! Do not perform any transformation to cell vals. The dataset object used to fit the model is responsible for performing the transformation. Parameters: cell_counts (dict): maps a cell type to an array or single integer value of non-transformed cell counts. Note: cell_counts must contain a value for every cell type the model was fit to. return_order (list[str]): determines the order of cells for returning results as a tuple """ # cell_counts should contain exactly the cell types used during model fit: # return order should be a subset of these types. if isinstance(cell_counts, dict): assert set(self.ds.cell_types()) == set(cell_counts.keys()) >= set(return_order) elif isinstance(cell_counts, pd.DataFrame): assert set(self.ds.cell_types()) == set(cell_counts.columns) >= set(return_order) # perform once here for performance cell_counts = dict_to_dataframe(cell_counts, columns=self.ds.cell_types()) delta_cells = [] for cell_type in return_order: delta_cells.append( self._delta_cells( cell_counts=cell_counts, target_cell=cell_type, mode=mode, ) ) # when computing a single point return the result as a single numpy array: return np.squeeze(delta_cells)
[docs] def get_delta_cells_func( self, predicted_cell_types: list[str], fixed_cell_counts: dict[str, float] | None = None, ) -> Callable: """Return the dynamics function f, such that dx/dt = f(x). Useful for numerical ode solvers. Args: predicted_cell_types (list[str]): return order of predictions. fixed_cell_counts (dict[str, float] | None, optional): provide these in the case of 2D dynamics with fixed cell counts for other cells. Defaults to None. Returns: Callable: the dynamics function """ if fixed_cell_counts is None: fixed_cell_counts = {} def f(x): # merge dictionaries of predicted and fixed counts: predicted_cell_counts = {c: x[i] for i, c in enumerate(predicted_cell_types)} cell_counts = {**predicted_cell_counts, **fixed_cell_counts} # to list, ordered like the dataset used for fit cell_counts = np.array([cell_counts[c] for c in self.ds.cell_types()]).reshape(1, -1) # to dataframe - faster than dict-to-dataframe cell_counts = pd.DataFrame(cell_counts, index=[0], columns=self.ds.cell_types()) return np.array(self.delta_cells(cell_counts=cell_counts, return_order=predicted_cell_types)) return f
def _delta_cells( self, cell_counts: pd.DataFrame, target_cell: str, mode: Literal["cells", "rates"], ) -> np.ndarray: """ Computes the number or rate of change of cells using features based on the absolute numbers of cells provided in cell_counts. """ features = self._construct_features_from_counts(cell_counts, target_cell=target_cell) division_minus_death_rates = self.predict(target_cell, "division_minus_death", features) # term for ensuring a non-positive flux at maximal density: maximal_density_correction = self._get_maximal_density_correction(cell_counts, target_cell) rate = division_minus_death_rates + maximal_density_correction if mode == "rates": return rate elif mode == "cells": # return number of cells lost or gained: n_cells = cell_counts[target_cell] return n_cells * rate else: ValueError(f"Invalid argument for _delta_cells( mode = {mode})") def _construct_features_from_counts(self, cell_counts: pd.DataFrame, target_cell: str): return self.ds.construct_features_from_counts(cell_counts, target_cell=target_cell)
[docs] def sample_observations( self, cell_type: str, cell_counts: dict | pd.DataFrame, obs: Literal["death", "division", "both"], ) -> np.ndarray: """ Sample death or division according to modeled probabilities. Useful for Monte Carlo methods. """ cell_counts = dict_to_dataframe(cell_counts) # compute division / death probabilities per cell: features = self.ds.construct_features_from_counts(cell_counts=cell_counts, target_cell=cell_type) p_div = self.predict(cell_type, "division", features) p_death = self.predict(cell_type, "death", features) # Note: only corrects if a maximal density enforcer was initialized abs_maximal_density_correction = np.abs( self._get_maximal_density_correction(cell_counts=cell_counts, target_cell=cell_type) ) # sample observations: n = p_div.shape[0] u = np.random.uniform(size=n, low=0, high=1) sampled_obs = np.zeros(n, dtype=np.int8) if obs == "division": sampled_obs[u < p_div] = 1 elif obs == "death": sampled_obs[u > 1 - p_death] = -1 elif obs == "both": sampled_obs[u < p_div] = 1 sampled_obs[u > 1 - (p_death + abs_maximal_density_correction)] = -1 return sampled_obs
[docs] def parameters(self, cell_type: str, obs: Literal["death"] | Literal["division"]) -> np.ndarray: """ Returns parameter values. """ raise NotImplementedError
[docs] def parameter_stds(self, cell_type: str, obs: Literal["death"] | Literal["division"]) -> np.ndarray: """ Returns standard deviation of parameters. """ raise NotImplementedError
[docs] def parameter_pvalues(self, cell_type: str, obs: Literal["death"] | Literal["division"]) -> np.ndarray: """ Returns parameter pvalues. """ raise NotImplementedError
[docs] def parameter_names(self, cell_type: str, obs: Literal["death"] | Literal["division"]) -> np.ndarray: """ Returns the parameter names associated with the death / division model for cells of type cell_type. """ raise NotImplementedError
[docs] def parameters_df(self, cell_type: str, obs: Literal["death"] | Literal["division"]) -> pd.DataFrame: return pd.DataFrame( { "name": self.parameter_names(cell_type, obs), "val": self.parameters(cell_type, obs), "std": self.parameter_stds(cell_type, obs), "pval": self.parameter_pvalues(cell_type, obs), } )
[docs] def death_prob(self, cell_type: str) -> float: try: return self.models[cell_type]["death"].p except AttributeError: raise AttributeError("Try fitting the model with `set_death_rate_to_mean_division_rate = True`")
[docs] def set_death_prob(self, cell_type: str, val: float): try: self.models[cell_type]["death"].p = val except AttributeError: raise AttributeError("Try fitting the model with `set_death_rate_to_mean_division_rate = True`")
[docs] def set_maximal_density_enforcement(self, enforcer): """ Initializes a maximal density enforcer model that corrects for positive fluxes at maximal density. """ self.maximal_density_enforcer = enforcer
[docs] def reset_maximal_density_enforcement(self): self.maximal_density_enforcer = None
def _set_debug_maximal_density_enforcer(self, debug=True): self._debug_maximal_density_enforcer = debug def _get_maximal_density_correction(self, cell_counts: dict | pd.DataFrame, target_cell: str) -> np.ndarray: """ Computes the signed (negative) rate (fraction / dt) of cells lost via the maximal density correction. Parameters: cell_counts (dict or dataframe): - dict: maps a cell type to an array or single integer value of non-transformed cell counts. - dataframe: one column per cell type Note: cell_counts must contain a value for every cell type the model was fit to. target_cell (str): the cell type for which the maximal density correction is computed. Returns: np.ndarray: the signed rate (fraction / dt) of cells lost via the maximal density correction. """ cell_counts = dict_to_dataframe(cell_counts) if self.maximal_density_enforcer is None: return np.zeros(cell_counts.shape[0]) return self.maximal_density_enforcer(cell_counts, target_cell) def _init_max_observed_fluxes(self, dataset: Dataset): d = {} for t in dataset.cell_types(): features = dataset.fetch(t)[0] rates = self._predict(t, obs="division", features=features) # intentionally using the "raw" _predict d[t] = np.max(rates) return d def _truncate_if(self, division_rates: np.ndarray, cell_type: str) -> np.ndarray: if self.truncate_division_rate: _max = self._max_observed_fluxes[cell_type] return np.where(division_rates > _max, _max, division_rates) else: return division_rates