from typing import Literal
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import seaborn as sns
from sklearn.metrics.pairwise import rbf_kernel
from tdm.utils import log2_1p, inv_log2_1p
from tdm.cell_types import CELL_TYPE_TO_COLOR, CELL_TYPE_TO_FULL_NAME, FIBROBLAST_COLOR
from tdm.analysis import Analysis
from tdm.dataset import NeighborsDataset
from tdm.model import Model
from tdm.style import STABILITIY_TYPE_TO_MARKER_STYLE_KWARGS, STABLE, SEMI_STABLE, UNSTABLE
from tdm.plot.utils import log_density_axis
from tdm.numerical.phase_portrait_analysis import (
get_nullclines,
compute_change_in_each_direction,
get_classified_fixed_points_inbounds,
)
from tdm.simulate.deterministic import compute_trajectory
def plot_bla():
pass
[docs]
def plot_phase_portrait(
ana: Analysis,
plot_nullclines: bool = True,
cell_a: str | None = None,
cell_b: str | None = None,
fixed_cell_counts: dict[str, float] | None = None,
phase_portrait_style: Literal["stream", "quiver"] = "stream",
mode: Literal["cells", "rates"] = "cells",
streamplot_density: float = 0.6,
nullcline_width: float = 1.5,
xlim: tuple[float, float] | None = None,
ylim: tuple[float, float] | None = None,
step: float = 0.35,
ax: plt.Axes | None = None,
nds_for_kdeplot: NeighborsDataset | None = None,
streamplot_color: str = "black",
add_fixed_points: bool = True,
):
"""Plot a 2D phase-portrait with nullclines and fixed-points.
Args:
ana (Analysis): _description_
plot_nullclines (bool, optional): _description_. Defaults to True.
fixed_cell_counts (dict[str, float] | None, optional): _description_. Defaults to None.
phase_portrait_style (Literal["stream", "quiver"], optional): _description_.
Defaults to "stream".
mode (Literal["cells", "rates"], optional): _description_. Defaults to "cells".
streamplot_density (float, optional): _description_. Defaults to 0.6.
nullcline_width (float, optional): _description_. Defaults to 1.5.
xlim (tuple[float, float] | None, optional): _description_. Defaults to None.
ylim (tuple[float, float] | None, optional): _description_. Defaults to None.
ax (plt.Axes | None, optional): _description_. Defaults to None.
nds_for_kdeplot (NeighborsDataset | None, optional): _description_. Defaults to None.
Examples:
>>> from tdm.analysis import Analysis
>>> from tdm.plot.two_cells.phase_portrait import plot_phase_portrait
>>> ana = Analysis.load("fm.pkl")
>>> plot_phase_portrait(ana)
"""
model = ana.model
cell_a = cell_a or ana.cell_a
cell_b = cell_b or ana.cell_b
xlim = xlim or ana.xlim
ylim = ylim or ana.ylim
if ax is None:
fig, ax = plt.subplots(figsize=(5, 5))
_plot_phase_portrait(
model=model,
cell_a=cell_a,
cell_b=cell_b,
xlim=xlim,
ylim=ylim,
ax=ax,
fixed_cell_counts=fixed_cell_counts,
style=phase_portrait_style,
mode=mode,
streamplot_density=streamplot_density,
streamplot_color=streamplot_color,
step=step,
)
if plot_nullclines:
_plot_nullclines(
model=model,
cell_a=cell_a,
cell_b=cell_b,
ax=ax,
xlim=xlim,
ylim=ylim,
fixed_cell_counts=fixed_cell_counts,
linewidth=nullcline_width,
)
if nds_for_kdeplot is not None:
plot_two_cell_density(nds=nds_for_kdeplot, cell_a=cell_a, cell_b=cell_b, ax=ax)
if add_fixed_points:
plot_fixed_points(
model=model, cell_a=cell_a, cell_b=cell_b, ax=ax, xlim=xlim, ylim=ylim, fixed_cell_counts=fixed_cell_counts
)
ax.set_xlim(*xlim)
ax.set_ylim(*ylim)
ax.set_title("Phase Portrait")
[docs]
def plot_growth_rate(
ana: Analysis,
kde_bw: float = 0.3,
cell_a_rbf_gamma: float = 1.2, # value for F in FM analysis
cell_b_rbf_gamma: float = 0.3, # value for M in FM analysis
xlim: tuple[float, float] | None = None,
ylim: tuple[float, float] | None = None,
fixed_cell_counts: dict[str, float] | None = None,
include_titles: bool = True,
plot_nullclines: bool = True,
streamplot_density: float = 0.6,
_plot_contours: bool = True,
) -> Figure:
"""Display the growth rate of both cell-types over the phase-plane.
Plots 3 axes: growth rate for each cell-type and a phase-portrait.
Hint:
**How to interpret these plots?**
The growth rate is computed by taking a smoothed mean of the binary cell-division events over the phase-plane.
We then subtract the mean division rate so that areas with high proliferation are positive (plotted red)
and areas with low proliferation are negative (plotted blue).
See `paper <https://www.biorxiv.org/content/10.1101/2024.04.22.590503v1>`_ for more on why the mean division
rate is a good approximation of the death rate.
For ``cell_a`` (x axis), regions with positive growth (red) correspond with flow to the right on the
phase-portrait, and regions with negative growth (blue) correspond with flow to the left.
For ``cell_b`` (y axis), positive and negative growth correspond with flow up or down respectively.
Thus, this plot connects the measured division events directly with the modeled dynamics.
Args:
ana (Analysis): _description_
kde_bw (float, optional): _description_. Defaults to 0.3.
cell_a_rbf_gamma (float, optional): _description_. Defaults to 1.2.
ylim (tuple[float, float] | None, optional): _description_. Defaults to None.
fixed_cell_counts (dict[str, float] | None, optional): _description_. Defaults to None.
include_titles (bool, optional): _description_. Defaults to True.
plot_nullclines (bool, optional): _description_. Defaults to True.
streamplot_density (float, optional): _description_. Defaults to 0.6.
_plot_contours (bool, optional): _description_. Defaults to True.
Returns:
Figure: _description_
"""
model = ana.model
nds = ana.rnds
cell_a = ana.cell_a
cell_b = ana.cell_b
xlim = xlim or ana.xlim
ylim = ylim or ana.ylim
fixed_cell_counts = fixed_cell_counts or {}
fig, axs = plt.subplots(ncols=3, figsize=(10, 5), sharex=False, sharey=False)
ax0, ax1, ax2 = axs
nbrs_a, obs_a = nds.fetch(cell_a)
nbrs_b, obs_b = nds.fetch(cell_b)
nbrs_a, nbrs_b = log2_1p(nbrs_a), log2_1p(nbrs_b)
for ax in axs:
ax.set_xlim(*xlim)
ax.set_ylim(*ylim)
# ax1 - cell a kde plot:
_plot_growth_rate_axis(
nbrs_a,
obs_a,
cell_a,
cell_b,
kde_bw=kde_bw,
rbf_gamma=cell_a_rbf_gamma, # 0.1 for restricted dataset, 1.2 for fibroblast macrophage analysis
ax=ax1,
_plot_contours=_plot_contours,
)
ax1.set_title(f"{CELL_TYPE_TO_FULL_NAME[cell_a]} growth Rate")
# ax2 - cell b kde plot:
_plot_growth_rate_axis(
nbrs_b,
obs_b,
cell_a,
cell_b,
kde_bw=kde_bw,
rbf_gamma=cell_b_rbf_gamma, # 0.8 for restricted dataset
ax=ax2,
_plot_contours=_plot_contours,
)
ax2.set_title(f"{CELL_TYPE_TO_FULL_NAME[cell_b]} growth Rate")
# ax0 - two-cell Phase Portrait
plot_phase_portrait(
ana,
streamplot_density=streamplot_density,
ax=ax0,
fixed_cell_counts=fixed_cell_counts,
plot_nullclines=plot_nullclines,
ylim=ylim,
)
# plot nullclines over all axes:
if plot_nullclines:
_plot_nullclines(
model,
cell_a,
cell_b,
ax0,
xlim,
ylim,
fixed_cell_counts,
)
_plot_nullclines(
model,
cell_a,
cell_b,
ax1,
xlim,
ylim,
fixed_cell_counts,
plot_b=False,
)
_plot_nullclines(
model,
cell_a,
cell_b,
ax2,
xlim,
ylim,
fixed_cell_counts,
plot_a=False,
)
plot_fixed_points(
model,
cell_a,
cell_b,
ax0,
xlim,
ylim,
fixed_cell_counts,
)
# plot_fixed_points(model, cell_a, cell_b, ax0)
_add_fixed_point_legend(ax0)
if not include_titles:
for ax in axs:
ax.set_title("")
fig.tight_layout()
return fig
def plot_trajectory(
ana: Analysis,
state0: tuple[float, float],
odeint_timepoints: np.ndarray,
arrow_timepoint_idxs: list[int],
color: str = "#3B76BC",
ax: plt.Axes | None = None,
zorder: int = 100,
):
if ax is None:
fig, ax = plt.subplots(figsize=(4, 4))
initial_cell_counts = inv_log2_1p(np.array(state0))
sol = compute_trajectory(ana, initial_cell_counts=initial_cell_counts, odeint_timepoints=odeint_timepoints)
# we log-transform for plotting:
sol = log2_1p(sol)
x_vals = sol[:, 0]
y_vals = sol[:, 1]
# plot trajectory line:
ax.plot(x_vals, y_vals, color=color, linewidth=4, zorder=zorder)
# add arrows to accentuate direction of trajectory:
for i in arrow_timepoint_idxs:
dx = x_vals[i + 1] - x_vals[i]
dy = y_vals[i + 1] - y_vals[i]
ax.arrow(
x_vals[i],
y_vals[i],
dx,
dy,
shape="full",
lw=1.0,
length_includes_head=False,
head_width=0.5,
facecolor="black",
zorder=zorder,
edgecolor=None,
clip_on=False,
)
def plot_tissues_over_phase_portrait(
ana: Analysis,
ax: plt.Axes | None = None,
):
if ax is None:
fig, ax = plt.subplots(figsize=(4, 4))
plot_phase_portrait(
ana,
ax=ax,
streamplot_color="black",
mode="rates",
plot_nullclines=False,
)
sns.scatterplot(
log2_1p(ana.tissue_states()),
x=ana.cell_a,
y=ana.cell_b,
c=FIBROBLAST_COLOR,
edgecolor="black",
zorder=-10,
ax=ax,
)
sns.despine(ax=ax)
plt.show()
"""
Internal plotting functions:
"""
def _plot_phase_portrait(
model: Model,
cell_a: str,
cell_b: str,
xlim: tuple[float, float],
ylim: tuple[float, float],
ax: plt.Axes,
fixed_cell_counts: dict[str, float] | None = None,
style: Literal["stream", "quiver"] = "stream",
mode: Literal["cells", "rates"] = "cells",
max_density: int | dict | None = None,
streamplot_density: float = 0.6,
streamplot_linewidth: float = 1.0,
streamplot_color: str = "black",
quiver_arrow_scale: float | None = None,
step: float = 0.35,
):
"""_summary_
Args:
model (Model): _description_
cell_a (str): _description_
cell_b (str): _description_
xlim (tuple[float, float]): _description_
ylim (tuple[float, float]): _description_
ax (plt): _description_
fixed_cell_counts (dict[str, float] | None, optional): _description_. Defaults to None.
style (Literal["stream", "quiver"], optional): _description_. Defaults to "stream".
mode (Literal["cells", "rates"], optional): _description_. Defaults to "cells".
max_density (int | dict | None, optional): maximal density of cells (untransformed count). Defaults to None.
streamplot_density (float, optional): _description_. Defaults to 0.6.
streamplot_linewidth (float, optional): _description_. Defaults to 1.0.
streamplot_color (str, optional): _description_. Defaults to "black".
quiver_arrow_scale (float | None, optional): _description_. Defaults to None.
step (float, optional): _description_. Defaults to 0.35.
Raises:
ValueError: _description_
"""
# x,y,u,v is typical streamplot naming for meshgrid locations x,y and
# rates of change in x and y directions u and v respectively
x, y, u, v = compute_change_in_each_direction(
model, cell_a, cell_b, xlim, ylim, fixed_cell_counts, step=step, mode=mode
)
if style == "stream":
ax.streamplot(
x,
y,
u,
v,
broken_streamlines=True,
color=streamplot_color,
density=streamplot_density,
linewidth=streamplot_linewidth,
)
elif style == "quiver":
# plot arrows beginning within max density:
alpha = None
if max_density is not None:
abs_x = inv_log2_1p(x).ravel()
abs_y = inv_log2_1p(y).ravel()
if np.isscalar(max_density):
alpha = abs_x + abs_y <= max_density
elif isinstance(max_density, dict):
alpha = (abs_x <= max_density[cell_a]) & (abs_y <= max_density[cell_b])
ax.quiver(
x,
y,
u,
v,
color=streamplot_color,
angles="uv",
alpha=alpha,
scale=quiver_arrow_scale,
)
else:
raise ValueError(f"{style} is not a valid phase portrait mode")
ax.set_title("Phase-Portrait")
ax.set_xlabel(log_density_axis(cell_a))
ax.set_ylabel(log_density_axis(cell_b))
ax.set_xlim(*xlim)
ax.set_ylim(*ylim)
sns.despine(ax=ax)
def _plot_nullclines(
model: Model,
cell_a: str,
cell_b: str,
ax: plt.Axes,
xlim: tuple[float, float],
ylim: tuple[float, float],
fixed_cell_counts: dict[str, float] | None = None,
plot_a: bool = True,
plot_b: bool = True,
linewidth: float = 1.5,
):
a_mesh, b_mesh, dadt, dbdt = compute_change_in_each_direction(
model, cell_a, cell_b, xlim, ylim, fixed_cell_counts, step=0.01
)
a_nullclines, b_nullclines = get_nullclines(a_mesh, b_mesh, dadt, dbdt)
if plot_a:
for g in a_nullclines.geoms:
x, y = g.xy
ax.plot(x, y, color=CELL_TYPE_TO_COLOR[cell_a], linewidth=linewidth)
if plot_b:
for g in b_nullclines.geoms:
x, y = g.xy
ax.plot(x, y, color=CELL_TYPE_TO_COLOR[cell_b], linewidth=linewidth)
def plot_two_cell_density(
nds: NeighborsDataset,
cell_a: str,
cell_b: str,
ax: plt.Axes | None = None,
kde_bw: float = 0.3,
fraction_outlined: float = 0.95,
):
np.random.seed(42)
nbrs_a, nbrs_b = nds.fetch(cell_a)[0], nds.fetch(cell_b)[0]
nbrs = log2_1p(pd.concat([nbrs_a, nbrs_b]).sample(1000))
if ax is None:
fig, ax = plt.subplots(figsize=(4, 4))
sns.kdeplot(
x=nbrs[cell_a],
y=nbrs[cell_b],
bw_method=kde_bw,
ax=ax,
levels=[1 - fraction_outlined, 1.0],
fill=True,
cbar=False,
color="#F0F0F0",
alpha=0.5,
zorder=-100,
)
def plot_fixed_points(
model: Model,
cell_a: str,
cell_b: str,
ax: plt.Axes,
xlim: tuple[float, float],
ylim: tuple[float, float],
fixed_cell_counts: dict[str, float] | None = None,
):
fixed_points_dict = get_classified_fixed_points_inbounds(model, cell_a, cell_b, xlim, ylim, fixed_cell_counts)
_plot_fixed_points(fixed_points_dict, ax)
def _plot_fixed_points(
fixed_points: dict,
ax: plt.Axes,
):
for stability_type in ["stable", "semi-stable", "unstable"]:
ps = fixed_points[stability_type]
if len(ps) > 0:
x, y = zip(*ps)
ax.plot(
x,
y,
**STABILITIY_TYPE_TO_MARKER_STYLE_KWARGS[stability_type], # type: ignore
linestyle="",
clip_on=False,
zorder=100,
)
def _plot_growth_rate_axis(
nbrs,
obs,
cell_a,
cell_b,
kde_bw,
rbf_gamma,
ax,
return_densest_point=False,
obs_type="division",
_plot_contours: bool = True,
):
ax.set_xlabel(log_density_axis(cell_a))
ax.set_ylabel(log_density_axis(cell_b))
"""
Plot empirical growth density
"""
# to keep runtimes reasonable - subsample to 30K cells:
max_cells = int(30e3)
if nbrs.shape[0] > max_cells:
nbrs = nbrs.sample(n=max_cells, random_state=42, replace=False)
obs = obs.iloc[nbrs.index]
nbrs.reset_index(drop=True, inplace=True)
obs.reset_index(drop=True, inplace=True)
weights = rbf_kernel(nbrs, gamma=rbf_gamma)
division_density = weights @ obs[obs_type] / np.sum(weights, axis=1)
division_density_minus_death = division_density - obs[obs_type].mean()
# set colormap - red "hot" = division > death, blue "cold" = death > division
vcenter = 0
vmin, vmax = np.round([division_density_minus_death.min(), division_density_minus_death.max()], 3)
eps = 1e-6
normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin - eps, vmax=vmax + eps)
colormap = plt.get_cmap("RdBu_r")
sns.scatterplot(
x=nbrs[cell_a],
y=nbrs[cell_b],
c=division_density_minus_death,
cmap=colormap,
norm=normalize,
ax=ax,
)
# add colorbar:
sm = cm.ScalarMappable(norm=normalize, cmap=colormap)
sm.set_array([])
divider = make_axes_locatable(ax)
cax = divider.append_axes("bottom", size="3%", pad=0.8)
cbar = plt.colorbar(sm, cax=cax, orientation="horizontal")
cbar.set_ticks(np.round([vmin, vcenter, vmax], 3)) # type: ignore
cbar.ax.set_title("fraction (Ki67 > threshold) - mean")
cbar.ax.tick_params(rotation=45)
"""
Plot empirical cell distribution
"""
if _plot_contours:
sns.kdeplot(
x=nbrs[cell_a],
y=nbrs[cell_b],
bw_method=kde_bw,
ax=ax,
levels=5,
fill=False,
cbar=False,
color="black",
linestyles="-",
linewidths=1,
)
# remove right and top limits
sns.despine(ax=ax)
def _add_fixed_point_legend(ax):
divider = make_axes_locatable(ax)
cax = divider.append_axes("bottom", size="5%", pad=0.8)
cax.set_xlim(-0.2, 1.5)
# stable point:
cax.plot(
[0.0],
[1],
**STABILITIY_TYPE_TO_MARKER_STYLE_KWARGS[STABLE],
linestyle="",
clip_on=False,
zorder=100,
)
# semi-stable
cax.plot(
[0.5],
[1],
**STABILITIY_TYPE_TO_MARKER_STYLE_KWARGS[SEMI_STABLE],
linestyle="",
clip_on=False,
zorder=100,
)
# unstable:
cax.plot(
[1.0],
[1],
**STABILITIY_TYPE_TO_MARKER_STYLE_KWARGS[UNSTABLE],
linestyle="",
clip_on=False,
zorder=100,
)
annotations = [
(STABLE, (0.0, 1), (25, -5)),
(SEMI_STABLE, (0.5, 1), (35, -5)),
(UNSTABLE, (1, 1), (30, -5)),
]
ax = cax
for label, point, offset in annotations:
ax.annotate(
label,
xy=point,
xytext=offset,
textcoords="offset points",
ha="center",
fontsize=10,
)
# Remove the ticks and tick labels for both axes
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
# Remove the axis labels
ax.set_xlabel("")
ax.set_ylabel("")
# Remove the grid lines
ax.grid(False)
# Remove the frame around the plot
ax.set_frame_on(False)