Source code for tdm.dataset.restrict

import pandas as pd
import numpy as np
from tdm.dataset import Dataset, NeighborsDataset


[docs] class RestrictedNeighborsDataset(NeighborsDataset): """ Filters out cells according to lists of allowed cell types and cell types that must be neighboring each cell. """
[docs] def __init__( self, nds: Dataset, allowed_neighbor_types: list[str] | None = None, required_neighbor_types: list[str] | None = None, keep_types: list[str] | None = None, ) -> None: """ Parameters: nds (NeighborsDataset): instance of NeighborsDataset allowed_neighbor_types (list[str], optional): A list of cell types. Cells that are or have neighbors outside this list are excluded from the restricted dataset. Default behavior: allow all types. required_neighbor_types (list[str], optional): A list of cell types. Cells that do not have at least one neighbors of each type in this list are excluded from the restricted dataset. Note that a cell is by definition a neighbor of itself. Default behavior: don't require any type. keep_types (list[str], optional): A list of cell types. Drops all columns for cell types outside this list. If not provided, uses allowed_types because all columns outside allowed_neighbor_types will have only zeros. """ self.nds = nds self.allowed_types = allowed_neighbor_types self.must_types = required_neighbor_types self.keep_types = keep_types or allowed_neighbor_types self.dataset_dict: dict[str, tuple[pd.DataFrame, pd.DataFrame]] = {} if self.keep_types is not None: for cell_type in self.keep_types: features, obs = nds.fetch(cell_type) include_mask = self.get_include_mask(features) self.dataset_dict[cell_type] = features.loc[include_mask, self.keep_types].reset_index( drop=True ), obs.loc[include_mask].reset_index(drop=True) else: raise UserWarning("keep_types was None, no cell types included!")
[docs] def get_include_mask(self, features) -> np.ndarray: """ Returns boolean a mask that includes cells that have: 1. zero neighbors of unallowed types. 2. at least one neighbor of each must type. """ # start by including all cells: mask = np.repeat(True, repeats=features.shape[0]) # include only cells that have 0 neighbors of unallowed type: if self.allowed_types is not None: all_types = self.nds.cell_types() unallowed_types = [t for t in all_types if t not in self.allowed_types] for t in unallowed_types: mask = mask & (features[t] == 0) # include only cells that have at least 1 neighbor of each must-type: if self.must_types is not None: for t in self.must_types: mask = mask & (features[t] > 0) return mask