Skip to content

analysis tools including classification

Functions that create or require the classifier clf object, and friends

build_gene_knn

build_gene_knn(adata, mask_var=None, mean_cluster=True, groupby='leiden', knn=5, use_knn=True, metric='euclidean', key='gene')

Compute a gene–gene kNN graph (hard or Gaussian‑weighted) and store sparse connectivities & distances in adata.uns.

Parameters

adata AnnData object (cells × genes). Internally transposed to (genes × cells). mask_var If not None, must be a column name in adata.var of boolean values. Only genes where adata.var[mask_var] == True are included. If None, use all genes. mean_cluster If True, aggregate cells by cluster defined in adata.obs[groupby]. The kNN graph is computed on the mean‑expression profiles of each cluster (genes × n_clusters) rather than genes × n_cells. groupby Column in adata.obs holding cluster labels. Only used if mean_cluster=True. knn Integer: how many neighbors per gene to consider. Passed as n_neighbors=knn to sc.pp.neighbors. use_knn Boolean: passed to sc.pp.neighbors as knn=use_knn. - If True, builds a hard kNN graph (only k nearest neighbors).
- If False, uses a Gaussian kernel to weight up to the k-th neighbor. metric Distance metric for kNN computation (e.g. "euclidean", "manhattan", "correlation", etc.). If metric=="correlation" and the gene‑expression matrix is sparse, it will be converted to dense. key Prefix under which to store results in adata.uns. The function sets: - adata.uns[f"{key}_gene_index"] - adata.uns[f"{key}_connectivities"] - adata.uns[f"{key}_distances"]

Source code in src/pySingleCellNet/tools/gene.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def build_gene_knn(
    adata,
    mask_var: str = None,
    mean_cluster: bool = True,
    groupby: str = 'leiden',
    knn: int = 5,
    use_knn: bool = True,
    metric: str = "euclidean",
    key: str = "gene"
):
    """
    Compute a gene–gene kNN graph (hard or Gaussian‑weighted) and store sparse connectivities & distances in adata.uns.

    Parameters
    ----------
    adata
        AnnData object (cells × genes). Internally transposed to (genes × cells).
    mask_var
        If not None, must be a column name in adata.var of boolean values.
        Only genes where adata.var[mask_var] == True are included. If None, use all genes.
    mean_cluster
        If True, aggregate cells by cluster defined in adata.obs[groupby].
        The kNN graph is computed on the mean‑expression profiles of each cluster
        (genes × n_clusters) rather than genes × n_cells.
    groupby
        Column in adata.obs holding cluster labels. Only used if mean_cluster=True.
    knn
        Integer: how many neighbors per gene to consider.
        Passed as n_neighbors=knn to sc.pp.neighbors.
    use_knn
        Boolean: passed to sc.pp.neighbors as knn=use_knn. 
        - If True, builds a hard kNN graph (only k nearest neighbors).  
        - If False, uses a Gaussian kernel to weight up to the k-th neighbor.
    metric
        Distance metric for kNN computation (e.g. "euclidean", "manhattan", "correlation", etc.).
        If metric=="correlation" and the gene‑expression matrix is sparse, it will be converted to dense.
    key
        Prefix under which to store results in adata.uns. The function sets:
          - adata.uns[f"{key}_gene_index"]
          - adata.uns[f"{key}_connectivities"]
          - adata.uns[f"{key}_distances"]
    """
    # 1) Work on a shallow copy so we don’t overwrite adata.X prematurely
    adata_work = adata.copy()

    # 2) If mask_var is provided, subset to only those genes first
    if mask_var is not None:
        if mask_var not in adata_work.var.columns:
            raise ValueError(f"Column '{mask_var}' not found in adata.var.")
        gene_mask = adata_work.var[mask_var].astype(bool)
        selected_genes = adata_work.var.index[gene_mask].tolist()
        if len(selected_genes) == 0:
            raise ValueError(f"No genes found where var['{mask_var}'] is True.")
        adata_work = adata_work[:, selected_genes].copy()

    # 3) If mean_cluster=True, aggregate by cluster label in `groupby`
    if mean_cluster:
        if groupby not in adata_work.obs.columns:
            raise ValueError(f"Column '{groupby}' not found in adata.obs.")
        # Aggregate each cluster to its mean expression; stored in .layers['mean']
        adata_work = sc.get.aggregate(adata_work, by=groupby, func='mean')
        # Overwrite .X with the mean‑expression matrix
        adata_work.X = adata_work.layers['mean']

    # 4) Transpose so that each gene (or cluster‑mean) is one “observation”
    adata_genes = adata_work.T.copy()

    # 5) If metric=="correlation" and X is sparse, convert to dense
    if metric == "correlation" and sparse.issparse(adata_genes.X):
        adata_genes.X = adata_genes.X.toarray()

    # 6) Compute neighbors on the (genes × [cells or clusters]) matrix.
    #    Pass n_neighbors=knn and knn=use_knn. Default method selection in Scanpy will
    #    use 'umap' if use_knn=True, and 'gauss' if use_knn=False.
    sc.pp.neighbors(
        adata_genes,
        n_neighbors=knn,
        knn=use_knn,
        metric=metric,
        use_rep="X"
    )

    # 7) Extract the two sparse matrices from adata_genes.obsp:
    conn = adata_genes.obsp["connectivities"].copy()  # CSR: gene–gene adjacency weights
    dist = adata_genes.obsp["distances"].copy()       # CSR: gene–gene distances

    # 8) Record the gene‑order (after masking + optional aggregation)
    gene_index = np.array(adata_genes.obs_names)

    adata.uns[f"{key}_gene_index"]      = gene_index
    adata.uns[f"{key}_connectivities"] = conn
    adata.uns[f"{key}_distances"]      = dist

categorize_classification

categorize_classification(adata_c, thresholds, graph=None, k=3, columns_to_ignore=['rand'], inplace=True, class_obs_name='SCN_class_argmax')

Classify cells based on SCN scores and thresholds, then categorize multi-class cells as either 'Intermediate' or 'Hybrid'.

Classification rules
  • If exactly one cell type exceeds threshold: "Singular"
  • If zero cell types exceed threshold: "None"
  • If more than one cell type exceeds threshold:
    • If all pairs of high-scoring cell types are within k edges in the provided graph: "Intermediate"
    • Otherwise: "Hybrid"
  • If predicted cell type is 'rand': Set classification to "Rand"

Parameters:

  • adata_c (AnnData) –

    Annotated data matrix containing: - .obsm["SCN_score"]: DataFrame of SCN scores for each cell type. - .obs[class_obs_name]: Predicted cell type (argmax classification).

  • thresholds (DataFrame) –

    Thresholds for each cell type. Expected to match the columns in SCN_score.

  • graph (Graph, default: None ) –

    An iGraph describing relationships between cell types. Must have vertex names matching the cell-type columns in SCN_score.

  • k (int, default: 3 ) –

    Maximum graph distance to consider cell types "Intermediate". Defaults to 3.

  • columns_to_ignore (list, default: ['rand'] ) –

    List of SCN score columns to ignore. Defaults to ["rand"].

  • inplace (bool, default: True ) –

    If True, modify adata_c in place. Otherwise, return a new AnnData object. Defaults to True.

  • class_obs_name (str, default: 'SCN_class_argmax' ) –

    The name of the .obs column with argmax classification. Defaults to 'SCN_class_argmax'.

Raises:

  • ValueError

    If graph is None.

  • ValueError

    If "SCN_score" is missing in adata_c.obsm.

  • ValueError

    If class_obs_name is not found in adata_c.obs.

  • ValueError

    If the provided graph does not have vertex "name" attributes.

Returns:

  • AnnData or None: Returns modified AnnData if inplace is False, otherwise None.

Source code in src/pySingleCellNet/tools/categorize.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def categorize_classification(
    adata_c: AnnData,
    thresholds: pd.DataFrame,
    graph: ig.Graph = None,
    k: int = 3,
    columns_to_ignore: list = ["rand"],
    inplace: bool = True,
    class_obs_name: str = 'SCN_class_argmax'
):
    """Classify cells based on SCN scores and thresholds, then categorize 
    multi-class cells as either 'Intermediate' or 'Hybrid'.

    Classification rules:
      - If exactly one cell type exceeds threshold: "Singular"
      - If zero cell types exceed threshold: "None"
      - If more than one cell type exceeds threshold:
          * If all pairs of high-scoring cell types are within `k` edges 
            in the provided graph: "Intermediate"
          * Otherwise: "Hybrid"
      - If predicted cell type is 'rand': Set classification to "Rand"

    Args:
        adata_c (AnnData): Annotated data matrix containing:
            - `.obsm["SCN_score"]`: DataFrame of SCN scores for each cell type.
            - `.obs[class_obs_name]`: Predicted cell type (argmax classification).
        thresholds (pd.DataFrame): Thresholds for each cell type. Expected to 
            match the columns in `SCN_score`.
        graph (ig.Graph): An iGraph describing relationships between cell types. 
            Must have vertex names matching the cell-type columns in SCN_score.
        k (int, optional): Maximum graph distance to consider cell types "Intermediate". Defaults to 3.
        columns_to_ignore (list, optional): List of SCN score columns to ignore. Defaults to ["rand"].
        inplace (bool, optional): If True, modify `adata_c` in place. Otherwise, return a new AnnData object. Defaults to True.
        class_obs_name (str, optional): The name of the `.obs` column with argmax classification. Defaults to 'SCN_class_argmax'.

    Raises:
        ValueError: If `graph` is None.
        ValueError: If "SCN_score" is missing in `adata_c.obsm`.
        ValueError: If `class_obs_name` is not found in `adata_c.obs`.
        ValueError: If the provided graph does not have vertex "name" attributes.

    Returns:
        AnnData or None: Returns modified AnnData if `inplace` is False, otherwise None.
    """
    if graph is None:
        raise ValueError("A valid iGraph 'graph' must be provided. None was given.")

    if "SCN_score" not in adata_c.obsm:
        raise ValueError("No 'SCN_score' in adata_c.obsm. Please provide SCN scores.")

    SCN_scores = adata_c.obsm["SCN_score"].copy()
    SCN_scores.drop(columns=columns_to_ignore, inplace=True, errors='ignore')

    exceeded = SCN_scores.sub(thresholds.squeeze(), axis=1) > 0
    true_counts = exceeded.sum(axis=1)

    result_list = [
        [col for col in exceeded.columns[exceeded.iloc[row].values]]
        for row in range(exceeded.shape[0])
    ]

    class_type = pd.Series(["None"] * len(true_counts), index=true_counts.index, name="SCN_class_type")

    singular_mask = (true_counts == 1)
    class_type.loc[singular_mask] = "Singular"

    if "name" in graph.vs.attributes():
        type2index = {graph.vs[i]["name"]: i for i in range(graph.vcount())}
    else:
        raise ValueError("graph does not have a 'name' attribute for vertices.")

    def is_all_within_k_edges(cell_types):
        """Check if all pairs of cell types are within k edges in the graph.

        Args:
            cell_types (list): List of cell type names.

        Returns:
            bool: True if all pairs are within k edges, False otherwise.
        """
        if len(cell_types) <= 1:
            return True
        for i in range(len(cell_types)):
            for j in range(i + 1, len(cell_types)):
                ct1, ct2 = cell_types[i], cell_types[j]
                if ct1 not in type2index or ct2 not in type2index:
                    return False
                idx1 = type2index[ct1]
                idx2 = type2index[ct2]
                dist = graph.shortest_paths(idx1, idx2)[0][0]
                if dist >= k:
                    return False
        return True

    multi_mask = (true_counts > 1)
    multi_indices = np.where(multi_mask)[0]

    for i in multi_indices:
        c_types = result_list[i]
        if is_all_within_k_edges(c_types):
            class_type.iloc[i] = "Intermediate"
        else:
            class_type.iloc[i] = "Hybrid"

    ans = ['_'.join(lst) if lst else 'None' for lst in result_list]

    adata_c.obs['SCN_class_emp'] = ans
    adata_c.obs['SCN_class_type'] = class_type

    if class_obs_name not in adata_c.obs:
        raise ValueError(f"{class_obs_name} not found in adata_c.obs.")

    adata_c.obs['SCN_class_emp'] = adata_c.obs.apply(
        lambda row: 'Rand' if row[class_obs_name] == 'rand' else row['SCN_class_emp'],
        axis=1
    )
    adata_c.obs['SCN_class_type'] = adata_c.obs.apply(
        lambda row: 'Rand' if row[class_obs_name] == 'rand' else row['SCN_class_type'],
        axis=1
    )

    _add_scn_class_cat(adata_c)

    if inplace:
        return None
    else:
        return adata_c

classify_anndata

classify_anndata(adata, rf_tsp, nrand=0)

Classifies cells in the adata object based on the given gene expression and cross-pair information using a random forest classifier in rf_tsp trained with the provided xpairs genes.

Parameters:

adata: AnnData An annotated data matrix containing the gene expression information for cells. rf_tsp: List[float] A list of random forest classifier parameters used for classification. nrand: int Number of random permutations for the null distribution. Default is 0.

Returns:

Updates adata with classification results

Source code in src/pySingleCellNet/tools/classifier.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
def classify_anndata(adata: AnnData, rf_tsp, nrand: int = 0):
    """
    Classifies cells in the `adata` object based on the given gene expression and cross-pair information using a
    random forest classifier in rf_tsp trained with the provided xpairs genes.

    Parameters:
    -----------
    adata: `AnnData`
        An annotated data matrix containing the gene expression information for cells.
    rf_tsp: List[float]
        A list of random forest classifier parameters used for classification.
    nrand: int
        Number of random permutations for the null distribution. Default is 0.

    Returns:
    --------
    Updates adata with classification results 
    """

    # Classify cells using the `_scn_predict` function
    classRes = _scn_predict(rf_tsp, adata, nrand=nrand)

    # add the classification result as to `obsm`
    # adNew = AnnData(classRes, obs=adata.obs, var=pd.DataFrame(index=categories))
    adata.obsm['SCN_score'] = classRes

    # Get the categories (i.e., predicted cell types) from the classification result
    # categories = classRes.columns.values
    # possible_classes = rf_tsp['classifier'].classes_
    possible_classes = pd.Categorical(classRes.columns)
    # Add a new column to `obs` for the predicted cell types
    predicted_classes = classRes.idxmax(axis=1)
    adata.obs['SCN_class_argmax'] = pd.Categorical(predicted_classes, categories=possible_classes, ordered=True)

    # store this for consistent coloring
    # adata.uns['SCN_class_colors'] = rf_tsp['ctColors']        

    # import matplotlib.colors as mcolors
    # celltype_colors = rf_tsp['ctColors']
    # mycolors = [celltype_colors[ct] for ct in adata.obs['SCN_class_argmax'].cat.categories]
    # cmap = mcolors.ListedColormap(mycolors)
    adata.uns['SCN_class_argmax_colors'] = rf_tsp['ctColors']

cluster_alot

cluster_alot(adata, leiden_resolutions, prefix='autoc', pca_params=None, knn_params=None, random_state=None, overwrite=True, verbose=True)

Grid-search Leiden clusterings over (n_pcs, n_neighbors, resolution).

Runs a parameter sweep that combines different numbers of principal components, k-nearest-neighbor sizes, and Leiden resolutions. Optionally performs random PC subsampling (within the first N PCs) when constructing the KNN graph, repeating each configuration multiple times for robustness. Cluster labels are written to adata.obs under keys derived from prefix and the parameter settings.

Assumptions
  • adata.X is already log-transformed.
  • PCA has been computed and adata.obsm['X_pca'] is present; this is used as the base embedding for PC selection/subsampling.

Parameters:

  • adata

    AnnData object containing the log-transformed expression matrix. Must include obsm['X_pca'] (shape (n_cells, n_pcs_total)).

  • leiden_resolutions (Sequence[float]) –

    Leiden resolution values to evaluate (passed to sc.tl.leiden). Each resolution is combined with every KNN/PC configuration in the sweep.

  • prefix (str, default: 'autoc' ) –

    String prefix used to construct output keys for cluster labels in adata.obs (e.g., "{prefix}_pc{N}_k{K}_res{R}"). Defaults to "autoc".

  • pca_params (Optional[Dict[str, Any]], default: None ) –

    Configuration for PC selection and optional subsampling. Supported keys: * "top_n_pcs" (List[int], default [40]): Candidate values for the maximum PC index N (i.e., use the first N PCs). * "percent_of_pcs" (Optional[float], default None): If set with 0 < value <= 1, randomly select round(value * N) PCs from the first N for KNN construction. If None or 1, use the first N PCs without subsampling. * "n_random_samples" (Optional[int], default None): Number of random PC subsets to draw per (N, K) when percent_of_pcs is set in (0, 1). If None or less than 1, no repeated subsampling is performed.

  • knn_params (Optional[Dict[str, Any]], default: None ) –

    KNN graph parameters. Supported keys: * "n_neighbors" (List[int], default [10]): Candidate values for K used in sc.pp.neighbors.

  • random_state (Optional[int], default: None ) –

    Random seed for PC subset sampling (when percent_of_pcs is used). Pass None for non-deterministic sampling. Defaults to None.

  • overwrite (bool, default: True ) –

    If True (default), overwrite existing adata.obs keys produced by previous runs that match the constructed names. If False, skip runs whose target keys already exist.

  • verbose (bool, default: True ) –

    If True (default), print progress messages for each run.

Returns:

  • DataFrame

    pd.DataFrame:

  • DataFrame
    • runs (pd.DataFrame): One row per clustering run with metadata columns such as:
    • obs_key: Name of the column in adata.obs that stores cluster labels.
    • neighbors_key: Name of the neighbors graph key used/created.
    • resolution: Leiden resolution value used for the run.
    • top_n_pcs: Number of leading PCs considered.
    • pct_pcs: Fraction of PCs used when subsampling (percent_of_pcs), or 1.0 if all were used.
    • sample_idx: Index of the PC subsampling repeat (0..n-1) or 0 if no subsampling.
    • n_neighbors: Number of neighbors (K) used in KNN construction.
    • n_clusters: Number of clusters returned by Leiden for that run.
    • pcs_used_count: Actual number of PCs used to build the KNN graph (round(pct_pcs * top_n_pcs) or top_n_pcs if no subsampling).

Raises:

  • KeyError

    If 'X_pca' is missing from adata.obsm.

  • ValueError

    If any provided parameter is out of range (e.g., percent_of_pcs not in (0, 1]; empty lists; non-positive n_neighbors).

  • RuntimeError

    If neighbor graph construction or Leiden clustering fails.

Notes
  • This function modifies adata in place by adding cluster label columns to adata.obs (and potentially adding or reusing neighbor graphs in adata.obsp / adata.uns with a constructed neighbors_key).
  • To ensure reproducibility when using PC subsampling, set random_state and keep other sources of randomness (e.g., parallel BLAS) controlled in your environment.

Examples:

>>> runs = cluster_alot(
...     adata,
...     leiden_resolutions=[0.1, 0.25, 0.5],
...     pca_params={"top_n_pcs": [20, 40],
...                 "percent_of_pcs": 0.5,
...                 "n_random_samples": 3},
...     knn_params={"n_neighbors": [10, 20]},
...     random_state=42,
... )
>>> runs[["obs_key", "n_clusters"]].head()
Source code in src/pySingleCellNet/tools/cluster.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def cluster_alot(
    adata,
    leiden_resolutions: Sequence[float],
    prefix: str = "autoc",
    pca_params: Optional[Dict[str, Any]] = None,
    knn_params: Optional[Dict[str, Any]] = None,
    random_state: Optional[int] = None,
    overwrite: bool = True,
    verbose: bool = True,
) -> pd.DataFrame:
    """Grid-search Leiden clusterings over (n_pcs, n_neighbors, resolution).

    Runs a parameter sweep that combines different numbers of principal components,
    k-nearest-neighbor sizes, and Leiden resolutions. Optionally performs random
    *PC subsampling* (within the first ``N`` PCs) when constructing the KNN graph,
    repeating each configuration multiple times for robustness. Cluster labels
    are written to ``adata.obs`` under keys derived from ``prefix`` and the
    parameter settings.

    Assumptions:
        * ``adata.X`` is **already log-transformed**.
        * PCA has been computed and ``adata.obsm['X_pca']`` is present; this is
          used as the base embedding for PC selection/subsampling.

    Args:
        adata: AnnData object containing the log-transformed expression matrix.
            Must include ``obsm['X_pca']`` (shape ``(n_cells, n_pcs_total)``).
        leiden_resolutions: Leiden resolution values to evaluate (passed to
            ``sc.tl.leiden``). Each resolution is combined with every KNN/PC
            configuration in the sweep.
        prefix: String prefix used to construct output keys for cluster labels in
            ``adata.obs`` (e.g., ``"{prefix}_pc{N}_k{K}_res{R}"``). Defaults to
            ``"autoc"``.
        pca_params: Configuration for PC selection and optional subsampling.
            Supported keys:
            * ``"top_n_pcs"`` (List[int], default ``[40]``): Candidate values
              for the maximum PC index ``N`` (i.e., use the first ``N`` PCs).
            * ``"percent_of_pcs"`` (Optional[float], default ``None``): If set
              with ``0 < value <= 1``, randomly select
              ``round(value * N)`` PCs **from the first ``N``** for KNN
              construction. If ``None`` or ``1``, use the first ``N`` PCs
              without subsampling.
            * ``"n_random_samples"`` (Optional[int], default ``None``): Number
              of random PC subsets to draw **per (N, K)** when
              ``percent_of_pcs`` is set in ``(0, 1)``. If ``None`` or less than
              1, no repeated subsampling is performed.
        knn_params: KNN graph parameters. Supported keys:
            * ``"n_neighbors"`` (List[int], default ``[10]``): Candidate values
              for ``K`` used in ``sc.pp.neighbors``.
        random_state: Random seed for PC subset sampling (when
            ``percent_of_pcs`` is used). Pass ``None`` for non-deterministic
            sampling. Defaults to ``None``.
        overwrite: If ``True`` (default), overwrite existing ``adata.obs`` keys
            produced by previous runs that match the constructed names. If
            ``False``, skip runs whose target keys already exist.
        verbose: If ``True`` (default), print progress messages for each run.

    Returns:
        pd.DataFrame:  

        * **runs** (``pd.DataFrame``): One row per clustering run with metadata columns such as:
          - ``obs_key``: Name of the column in ``adata.obs`` that stores cluster labels.
          - ``neighbors_key``: Name of the neighbors graph key used/created.
          - ``resolution``: Leiden resolution value used for the run.
          - ``top_n_pcs``: Number of leading PCs considered.
          - ``pct_pcs``: Fraction of PCs used when subsampling (``percent_of_pcs``), or ``1.0`` if all were used.
          - ``sample_idx``: Index of the PC subsampling repeat (``0..n-1``) or ``0`` if no subsampling.
          - ``n_neighbors``: Number of neighbors (``K``) used in KNN construction.
          - ``n_clusters``: Number of clusters returned by Leiden for that run.
          - ``pcs_used_count``: Actual number of PCs used to build the KNN graph
            (``round(pct_pcs * top_n_pcs)`` or ``top_n_pcs`` if no subsampling).

    Raises:
        KeyError: If ``'X_pca'`` is missing from ``adata.obsm``.
        ValueError: If any provided parameter is out of range (e.g.,
            ``percent_of_pcs`` not in ``(0, 1]``; empty lists; non-positive
            ``n_neighbors``).
        RuntimeError: If neighbor graph construction or Leiden clustering fails.

    Notes:
        * This function **modifies** ``adata`` in place by adding cluster label
          columns to ``adata.obs`` (and potentially adding or reusing neighbor
          graphs in ``adata.obsp`` / ``adata.uns`` with a constructed
          ``neighbors_key``).
        * To ensure reproducibility when using PC subsampling, set
          ``random_state`` and keep other sources of randomness (e.g., parallel
          BLAS) controlled in your environment.

    Examples:
        >>> runs = cluster_alot(
        ...     adata,
        ...     leiden_resolutions=[0.1, 0.25, 0.5],
        ...     pca_params={"top_n_pcs": [20, 40],
        ...                 "percent_of_pcs": 0.5,
        ...                 "n_random_samples": 3},
        ...     knn_params={"n_neighbors": [10, 20]},
        ...     random_state=42,
        ... )
        >>> runs[["obs_key", "n_clusters"]].head()
    """

    # ---- Validate prerequisites ----
    if "X_pca" not in adata.obsm:
        raise ValueError("`adata.obsm['X_pca']` not found. Please run PCA first.")
    Xpca = adata.obsm["X_pca"]
    n_pcs_available = Xpca.shape[1]
    if n_pcs_available < 2:
        raise ValueError(f"Not enough PCs ({n_pcs_available}) in `X_pca`.")

    # ---- Normalize params ----
    pca_params = dict(pca_params or {})
    knn_params = dict(knn_params or {})
    top_n_pcs: List[int] = pca_params.get("top_n_pcs", [40])
    percent_of_pcs: Optional[float] = pca_params.get("percent_of_pcs", None)
    n_random_samples: Optional[int] = pca_params.get("n_random_samples", None)
    n_neighbors_list: List[int] = knn_params.get("n_neighbors", [10])

    # sanitize lists
    if isinstance(top_n_pcs, (int, np.integer)): top_n_pcs = [int(top_n_pcs)]
    if isinstance(n_neighbors_list, (int, np.integer)): n_neighbors_list = [int(n_neighbors_list)]
    top_n_pcs = [int(x) for x in top_n_pcs]
    n_neighbors_list = [int(x) for x in n_neighbors_list]

    # sanity checks
    if percent_of_pcs is not None:
        if not (0 < float(percent_of_pcs) <= 1.0):
            raise ValueError("`percent_of_pcs` must be in (0, 1] when provided.")
        if (n_random_samples is None) or (int(n_random_samples) < 1):
            raise ValueError("When using `percent_of_pcs`, set `n_random_samples` >= 1.")
        n_random_samples = int(n_random_samples)

    rng = np.random.default_rng(random_state)

    # ---- Helper: build neighbors from a given PC subspace ----
    def _neighbors_from_pc_indices(pc_idx: np.ndarray, n_neighbors: int, neighbors_key: str):
        """Create a neighbors graph using the given PC column indices."""
        # Create a temporary representation name
        temp_rep_key = f"X_pca_sub_{neighbors_key}"
        adata.obsm[temp_rep_key] = Xpca[:, pc_idx]

        # Build neighbors; store under unique keys (in uns & obsp)
        sc.pp.neighbors(
            adata,
            n_neighbors=n_neighbors,
            use_rep=temp_rep_key,
            key_added=neighbors_key,
        )

        # Record which PCs were used (for provenance)
        if neighbors_key in adata.uns:
            adata.uns[neighbors_key]["pcs_indices"] = pc_idx.astype(int)

        # Clean up the temporary representation to save memory
        del adata.obsm[temp_rep_key]

    # ---- Iterate over parameter combinations ----
    rows = []

    for N, kn, res in itertools.product(top_n_pcs, n_neighbors_list, leiden_resolutions):
        if N > n_pcs_available:
            if verbose:
                print(f"[skip] top_n_pcs={N} > available={n_pcs_available}")
            continue

        # Decide how many runs per (N, kn): either 1 (no subsample) or n_random_samples
        do_subsample = (percent_of_pcs is not None) and (percent_of_pcs < 1.0)
        repeats = n_random_samples if do_subsample else 1
        pcs_target_count = int(round((percent_of_pcs or 1.0) * N))

        # guards
        pcs_target_count = max(1, min(pcs_target_count, N))

        for rep_idx in range(repeats):
            # Choose PC indices
            if do_subsample:
                chosen = rng.choice(N, size=pcs_target_count, replace=False)
                chosen.sort()
                pct_str = f"{percent_of_pcs:.2f}"
                sample_tag = f"s{rep_idx+1:02d}"
            else:
                chosen = np.arange(N, dtype=int)
                pct_str = "1.00"
                sample_tag = "s01"

            # Construct unique keys
            neighbors_key = f"{prefix}_nbrs_pc{N}_pct{pct_str}_{sample_tag}_k{kn}"
            obs_key      = f"{prefix}_pc{N}_pct{pct_str}_{sample_tag}_k{kn}_res{res:g}"

            # Skip or overwrite?
            if (not overwrite) and (obs_key in adata.obs.columns):
                if verbose:
                    print(f"[skip-existing] {obs_key}")
                # we still record an entry (marked as skipped) to keep accounting stable
                rows.append({
                    "obs_key": obs_key,
                    "neighbors_key": neighbors_key,
                    "resolution": res,
                    "top_n_pcs": N,
                    "pct_pcs": float(pct_str),
                    "sample_idx": rep_idx + 1,
                    "n_neighbors": kn,
                    "n_clusters": np.nan,
                    "pcs_used_count": int(chosen.size),
                    "status": "skipped_exists",
                })
                continue

            # Build neighbors
            _neighbors_from_pc_indices(chosen, n_neighbors=kn, neighbors_key=neighbors_key)

            # Cluster using THIS neighbors graph (very important: neighbors_key=...)
            if verbose:
                print(f"[leiden] res={res} | N={N} | pct={pct_str} | {sample_tag} | k={kn} -> {obs_key}")

            sc.tl.leiden(
                adata,
                resolution=float(res),
                flavor="igraph",
                n_iterations=2,
                directed=False,
                key_added=obs_key,
                neighbors_key=neighbors_key,
            )

            # Summaries
            n_clusters = int(pd.Series(adata.obs[obs_key]).nunique())
            rows.append({
                "obs_key": obs_key,
                "neighbors_key": neighbors_key,
                "resolution": float(res),
                "top_n_pcs": int(N),
                "pct_pcs": float(pct_str),
                "sample_idx": int(rep_idx + 1),
                "n_neighbors": int(kn),
                "n_clusters": n_clusters,
                "pcs_used_count": int(chosen.size),
                "status": "ok",
            })

    summary_df = pd.DataFrame(rows)
    # nice ordering
    cols = ["obs_key","neighbors_key","resolution","top_n_pcs","pct_pcs","sample_idx",
            "n_neighbors","pcs_used_count","n_clusters","status"]
    summary_df = summary_df[cols]

    return summary_df

clustering_quality_vs_nn_summary

clustering_quality_vs_nn_summary(adata, label_cols, n_genes=5, naive={'p_val': 0.01, 'fold_change': 0.5}, strict={'minpercentin': 0.2, 'maxpercentout': 0.1, 'p_val': 0.01}, n_pcs_for_nn=30, has_log1p=True, gene_mask_col=None, layer=None, p_adjust_method='fdr_bh', deduplicate_partitions=True, return_pairs=False)

Summarize clustering quality across multiple label columns.

Computes clustering-quality metrics for each .obs label column in label_cols and returns a single summary table (one row per labeling). Optionally returns per–cluster-pair differential-expression tables for each labeling. A single PCA/neighbor graph (using n_pcs_for_nn PCs) is reused across runs, and identical partitions (up to relabeling) can be deduplicated for speed.

The method evaluates per-cluster marker genes under two regimes:

  • Naive: rank by test statistic and select the top n_genes that meet the naive thresholds (e.g., unadjusted p_val and minimum fold_change).
  • Strict: apply stricter filters on expression prevalence inside vs. outside the cluster (minpercentin / maxpercentout) and an adjusted p-value cutoff (p_val after p_adjust_method), then count genes.

Parameters:

  • adata

    AnnData object containing count/expression data. Uses adata.X or the specified layer; cluster labels must be in adata.obs.

  • label_cols (Sequence[str]) –

    Names of adata.obs columns whose clusterings will be evaluated (e.g., ["leiden_0.2", "leiden_0.5"]).

  • n_genes (int, default: 5 ) –

    Number of top genes to consider per cluster in the naive regime (after applying naive thresholds). Defaults to 5.

  • naive (dict, default: {'p_val': 0.01, 'fold_change': 0.5} ) –

    Thresholds for the naive regime. Expected keys: - "p_val" (float): Maximum unadjusted p-value. - "fold_change" (float): Minimum log2 fold-change. Defaults to {"p_val": 1e-2, "fold_change": 0.5}.

  • strict (dict, default: {'minpercentin': 0.2, 'maxpercentout': 0.1, 'p_val': 0.01} ) –

    Thresholds for the strict regime. Expected keys: - "minpercentin" (float): Minimum fraction of cells within the cluster expressing the gene. - "maxpercentout" (float): Maximum fraction of cells outside the cluster expressing the gene. - "p_val" (float): Maximum adjusted p-value (per p_adjust_method). Defaults to {"minpercentin": 0.20, "maxpercentout": 0.10, "p_val": 0.01}.

  • n_pcs_for_nn (int, default: 30 ) –

    Number of principal components to use when building the neighbor graph used for nearest-neighbor detection. Defaults to 30.

  • has_log1p (bool, default: True ) –

    Whether the data are already log1p-transformed. If False, the implementation may log1p-transform counts before testing. Defaults to True.

  • gene_mask_col (Optional[str], default: None ) –

    Optional name of a boolean column in adata.var used to mask genes prior to testing (e.g., to restrict to HVGs or exclude mitochondrial genes). If None, no mask is applied. Defaults to None.

  • layer (Optional[str], default: None ) –

    Name of an adata.layers matrix to use instead of adata.X. For example, "log1p" or "counts". Defaults to None.

  • p_adjust_method (str, default: 'fdr_bh' ) –

    Method for multiple testing correction (e.g., "fdr_bh"). Passed to the underlying p-value adjustment routine. Defaults to "fdr_bh".

  • deduplicate_partitions (bool, default: True ) –

    If True, detect and skip evaluations for labelings that produce the same partition (up to label renaming), reusing the computed result. Defaults to True.

  • return_pairs (bool, default: False ) –

    If True, also return a dict of per–cluster-pair result tables keyed by the label column. Each value is a pd.DataFrame with pairwise statistics for that labeling. Defaults to False.

Returns:

  • Union[DataFrame, Tuple[DataFrame, Dict[str, DataFrame]]]

    Union[pd.DataFrame, Tuple[pd.DataFrame, Dict[str, pd.DataFrame]]]:

  • Union[DataFrame, Tuple[DataFrame, Dict[str, DataFrame]]]
    • summary (pd.DataFrame): One row per labeling with columns such as:
    • label_col: The label column name.
    • n_clusters: Number of clusters in the labeling.
    • n_pairs: Number of cluster pairs evaluated.
    • tested_genes: Number of genes tested after masking.
    • unique_naive_genes / unique_strict_genes: Count of genes uniquely satisfying naive/strict criteria.
    • frac_pairs_with_at_least_n_strict: Fraction of cluster pairs with ≥ n strict marker genes (exact column name may reflect n).
    • Additional min/max/median summaries for naive/strict exclusivity per pair.
  • Union[DataFrame, Tuple[DataFrame, Dict[str, DataFrame]]]
    • pairs_by_label (Dict[str, pd.DataFrame], optional): Returned only when return_pairs=True. For each labeling, a DataFrame of per–cluster-pair statistics and gene sets.

Raises:

  • KeyError

    If any entry in label_cols is not found in adata.obs, or if gene_mask_col is provided but not found in adata.var.

  • ValueError

    If required keys are missing from naive or strict, if n_genes < 1, or if p_adjust_method is unsupported.

  • RuntimeError

    If neighbor graph construction or differential testing fails.

Notes
  • The function does not modify adata in place (beyond any cached neighbor graph/PCs if your implementation chooses to store them).
  • For reproducibility, set any random seeds used by the nearest-neighbor or clustering components upstream.

Examples:

>>> summary = clustering_quality_vs_nn_summary(
...     adata,
...     label_cols=["leiden_0.2", "leiden_0.5"],
...     n_genes=10,
...     strict={"minpercentin": 0.25, "maxpercentout": 0.05, "p_val": 0.01},
... )
>>> summary[["label_col", "n_clusters", "unique_strict_genes"]].head()
>>> summary, pairs = clustering_quality_vs_nn_summary(
...     adata,
...     label_cols=["leiden_0.5"],
...     return_pairs=True,
... )
>>> pairs["leiden_0.5"].head()
Source code in src/pySingleCellNet/tools/cluster_eval.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
def clustering_quality_vs_nn_summary(
    adata,
    label_cols: Sequence[str],
    n_genes: int = 5,
    naive: dict = {"p_val": 1e-2, "fold_change": 0.5},
    strict: dict = {"minpercentin": 0.20, "maxpercentout": 0.10, "p_val": 0.01},
    n_pcs_for_nn: int = 30,
    has_log1p: bool = True,
    gene_mask_col: Optional[str] = None,
    layer: Optional[str] = None,
    p_adjust_method: str = "fdr_bh",
    deduplicate_partitions: bool = True,
    return_pairs: bool = False,
) -> Union[pd.DataFrame, Tuple[pd.DataFrame, Dict[str, pd.DataFrame]]]:
    """Summarize clustering quality across multiple label columns.

    Computes clustering-quality metrics for each `.obs` label column in
    ``label_cols`` and returns a single summary table (one row per labeling).
    Optionally returns per–cluster-pair differential-expression tables for each
    labeling. A single PCA/neighbor graph (using ``n_pcs_for_nn`` PCs) is
    reused across runs, and identical partitions (up to relabeling) can be
    deduplicated for speed.

    The method evaluates per-cluster marker genes under two regimes:

    * **Naive:** rank by test statistic and select the top ``n_genes`` that meet
      the naive thresholds (e.g., unadjusted ``p_val`` and minimum ``fold_change``).
    * **Strict:** apply stricter filters on expression prevalence inside vs.
      outside the cluster (``minpercentin`` / ``maxpercentout``) and an adjusted
      p-value cutoff (``p_val`` after ``p_adjust_method``), then count genes.

    Args:
        adata: AnnData object containing count/expression data. Uses
            ``adata.X`` or the specified ``layer``; cluster labels must be in
            ``adata.obs``.
        label_cols: Names of ``adata.obs`` columns whose clusterings will be
            evaluated (e.g., ``["leiden_0.2", "leiden_0.5"]``).
        n_genes: Number of top genes to consider per cluster in the naive regime
            (after applying naive thresholds). Defaults to ``5``.
        naive: Thresholds for the naive regime. Expected keys:
            - ``"p_val"`` (float): Maximum unadjusted p-value.
            - ``"fold_change"`` (float): Minimum log2 fold-change.
            Defaults to ``{"p_val": 1e-2, "fold_change": 0.5}``.
        strict: Thresholds for the strict regime. Expected keys:
            - ``"minpercentin"`` (float): Minimum fraction of cells within the
              cluster expressing the gene.
            - ``"maxpercentout"`` (float): Maximum fraction of cells outside the
              cluster expressing the gene.
            - ``"p_val"`` (float): Maximum **adjusted** p-value (per
              ``p_adjust_method``).
            Defaults to
            ``{"minpercentin": 0.20, "maxpercentout": 0.10, "p_val": 0.01}``.
        n_pcs_for_nn: Number of principal components to use when building the
            neighbor graph used for nearest-neighbor detection. Defaults to ``30``.
        has_log1p: Whether the data are already log1p-transformed. If ``False``,
            the implementation may log1p-transform counts before testing.
            Defaults to ``True``.
        gene_mask_col: Optional name of a boolean column in ``adata.var`` used to
            mask genes prior to testing (e.g., to restrict to HVGs or exclude
            mitochondrial genes). If ``None``, no mask is applied. Defaults to
            ``None``.
        layer: Name of an ``adata.layers`` matrix to use instead of ``adata.X``.
            For example, ``"log1p"`` or ``"counts"``. Defaults to ``None``.
        p_adjust_method: Method for multiple testing correction (e.g., ``"fdr_bh"``).
            Passed to the underlying p-value adjustment routine. Defaults to ``"fdr_bh"``.
        deduplicate_partitions: If ``True``, detect and skip evaluations for
            labelings that produce the same partition (up to label renaming),
            reusing the computed result. Defaults to ``True``.
        return_pairs: If ``True``, also return a dict of per–cluster-pair result
            tables keyed by the label column. Each value is a ``pd.DataFrame``
            with pairwise statistics for that labeling. Defaults to ``False``.

    Returns:
        Union[pd.DataFrame, Tuple[pd.DataFrame, Dict[str, pd.DataFrame]]]:  

        * **summary** (``pd.DataFrame``): One row per labeling with columns such as:
          - ``label_col``: The label column name.
          - ``n_clusters``: Number of clusters in the labeling.
          - ``n_pairs``: Number of cluster pairs evaluated.
          - ``tested_genes``: Number of genes tested after masking.
          - ``unique_naive_genes`` / ``unique_strict_genes``: Count of genes
            uniquely satisfying naive/strict criteria.
          - ``frac_pairs_with_at_least_n_strict``: Fraction of cluster pairs with
            ≥ *n* strict marker genes (exact column name may reflect *n*).
          - Additional min/max/median summaries for naive/strict exclusivity per pair.
        * **pairs_by_label** (``Dict[str, pd.DataFrame]``, optional): Returned
          only when ``return_pairs=True``. For each labeling, a DataFrame of
          per–cluster-pair statistics and gene sets.

    Raises:
        KeyError: If any entry in ``label_cols`` is not found in ``adata.obs``,
            or if ``gene_mask_col`` is provided but not found in ``adata.var``.
        ValueError: If required keys are missing from ``naive`` or ``strict``,
            if ``n_genes`` < 1, or if ``p_adjust_method`` is unsupported.
        RuntimeError: If neighbor graph construction or differential testing fails.

    Notes:
        * The function does not modify ``adata`` in place (beyond any cached
          neighbor graph/PCs if your implementation chooses to store them).
        * For reproducibility, set any random seeds used by the nearest-neighbor
          or clustering components upstream.

    Examples:
        >>> summary = clustering_quality_vs_nn_summary(
        ...     adata,
        ...     label_cols=["leiden_0.2", "leiden_0.5"],
        ...     n_genes=10,
        ...     strict={"minpercentin": 0.25, "maxpercentout": 0.05, "p_val": 0.01},
        ... )
        >>> summary[["label_col", "n_clusters", "unique_strict_genes"]].head()

        >>> summary, pairs = clustering_quality_vs_nn_summary(
        ...     adata,
        ...     label_cols=["leiden_0.5"],
        ...     return_pairs=True,
        ... )
        >>> pairs["leiden_0.5"].head()
    """

    if not label_cols:
        raise ValueError("Provide at least one column in `label_cols`.")

    # expression matrix & gene mask
    expr = adata.layers[layer] if layer is not None else adata.X
    if expr.shape[0] != adata.n_obs:
        raise ValueError("Selected expression matrix has wrong shape.")
    genes = adata.var_names.to_numpy()

    if gene_mask_col is None:
        gene_mask = np.ones(adata.n_vars, dtype=bool)
    else:
        if gene_mask_col not in adata.var.columns:
            raise ValueError(f"'{gene_mask_col}' not found in adata.var")
        gene_mask = adata.var[gene_mask_col].to_numpy().astype(bool)
        if gene_mask.sum() == 0:
            raise ValueError(f"Gene mask '{gene_mask_col}' selects 0 genes.")

    # shared embedding for NN detection
    rep = _compute_representation(adata, n_pcs_for_nn, layer=layer, has_log1p=has_log1p)

    # dedupe bookkeeping
    sig_to_result: Dict[str, Tuple[Dict[str, Any], pd.DataFrame]] = {}
    sig_to_example_col: Dict[str, str] = {}
    per_run_pairs: Dict[str, pd.DataFrame] = {}
    rows = []

    for col in label_cols:
        if col not in adata.obs.columns:
            raise ValueError(f"'{col}' not found in adata.obs")
        series = adata.obs[col]

        # build canonical signature to detect identical partitions (up to relabeling)
        codes = _canonical_codes(series)
        sig = _hash_codes(codes) if deduplicate_partitions else None

        if deduplicate_partitions and (sig in sig_to_result):
            # reuse
            summary, pair_df = sig_to_result[sig]
            row = {"label_col": col, **summary}
            rows.append(row)
            if return_pairs:
                per_run_pairs[col] = pair_df.copy()
            continue

        # evaluate once
        summary, pair_df = _evaluate_one_partition(
            adata=adata,
            codes=codes,
            rep=rep,
            expr=expr,
            genes=genes,
            gene_mask=gene_mask,
            n_genes=n_genes,
            naive=naive,
            strict=strict,
            has_log1p=has_log1p,
            p_adjust_method=p_adjust_method,
        )
        row = {"label_col": col, **summary}
        rows.append(row)

        if deduplicate_partitions:
            sig_to_result[sig] = (summary, pair_df)
            sig_to_example_col[sig] = col
        if return_pairs:
            per_run_pairs[col] = pair_df

    out_df = pd.DataFrame(rows)
    # nice column order
    ordered_cols = ["label_col","n_clusters","n_pairs","tested_genes",
                    "unique_naive_genes","unique_strict_genes",
                    "min_naive_per_pair","min_strict_per_pair",
                    "max_naive_per_pair","max_strict_per_pair",
                    "mean_naive_per_pair","mean_strict_per_pair",
                    "median_naive_per_pair","median_strict_per_pair",
                    "gini_naive_per_pair","gini_strict_per_pair",
                    "frac_pairs_with_at_least_n_naive","frac_pairs_with_at_least_n_strict",
                    "min_naive_exclusive_per_pair","min_strict_exclusive_per_pair",
                    "max_naive_exclusive_per_pair","max_strict_exclusive_per_pair"]
    out_df = out_df[ordered_cols]

    return (out_df, per_run_pairs) if return_pairs else out_df

collect_gsea_results_from_dict

collect_gsea_results_from_dict(gsea_dict2, fdr_thr=0.25, top_n=3)

Collect and filter GSEA results from a dictionary of GSEA objects.

For each cell type
  1. Sets NES=0 for any gene set with FDR > fdr_thr.
  2. Selects up to top_n sets with the largest positive NES and top_n with the most negative NES.

The final output is limited to the union of all such selected sets across all cell types, with zeroes preserved for cell types in which the pathway is not among the top_n or fails the FDR threshold.

Parameters:

  • gsea_dict2 (dict) –

    Dictionary mapping cell types to GSEA result objects. Each object has a .res2d DataFrame with columns ["Term", "NES", "FDR q-val"].

  • fdr_thr (float, default: 0.25 ) –

    FDR threshold above which NES values are set to 0. Defaults to 0.25.

  • top_n (int, default: 3 ) –

    Maximum number of positive and negative results (by NES) to keep per cell type. Defaults to 10.

Returns:

  • pd.DataFrame: A DataFrame whose rows are the union of selected gene sets across all cell types, and whose columns are cell types. Entries are filtered NES values (0 where FDR fails, or if not in the top_n).

Source code in src/pySingleCellNet/tools/comparison.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
def collect_gsea_results_from_dict(
    gsea_dict2: dict,
    fdr_thr: float = 0.25,
    top_n: int = 3
):
    """
    Collect and filter GSEA results from a dictionary of GSEA objects.

    For each cell type:
      1. Sets NES=0 for any gene set with FDR > fdr_thr.
      2. Selects up to top_n sets with the largest positive NES and 
         top_n with the most negative NES.

    The final output is limited to the union of all such selected sets
    across all cell types, with zeroes preserved for cell types in which
    the pathway is not among the top_n or fails the FDR threshold.

    Args:
        gsea_dict2 (dict): Dictionary mapping cell types to GSEA result objects.
            Each object has a .res2d DataFrame with columns ["Term", "NES", "FDR q-val"].
        fdr_thr (float, optional): FDR threshold above which NES values are set to 0. 
            Defaults to 0.25.
        top_n (int, optional): Maximum number of positive and negative results 
            (by NES) to keep per cell type. Defaults to 10.

    Returns:
        pd.DataFrame: A DataFrame whose rows are the union of selected gene sets 
            across all cell types, and whose columns are cell types. Entries 
            are filtered NES values (0 where FDR fails, or if not in the top_n).
    """
    import copy

    # Make a copy of the input to avoid in-place modifications
    gsea_dict = copy.deepcopy(gsea_dict2)

    # Collect all possible gene set names and cell types
    pathways = pd.Index([])
    cell_types = list(gsea_dict.keys())

    for cell_type in cell_types:
        tmpRes = gsea_dict[cell_type].res2d
        gene_set_names = list(tmpRes['Term'])
        pathways = pathways.union(gene_set_names)

    # Initialize NES DataFrame
    nes_df = pd.DataFrame(0, columns=cell_types, index=pathways)

    # Apply FDR threshold and fill NES
    for cell_type in cell_types:
        ct_df = gsea_dict[cell_type].res2d.copy()
        ct_df.index = ct_df['Term']
        # Zero out NES where FDR is too high
        ct_df.loc[ct_df['FDR q-val'] > fdr_thr, "NES"] = 0
        nes_df[cell_type] = ct_df["NES"]

    # Convert NES to numeric just in case
    nes_df = nes_df.apply(pd.to_numeric, errors='coerce')

    # Determine top_n positive and top_n negative for each cell type
    selected_sets = set()
    for cell_type in cell_types:
        ct_values = nes_df[cell_type]
        # Filter non-zero for positives and negatives
        pos_mask = ct_values > 0
        neg_mask = ct_values < 0

        # Select top_n largest positive NES
        top_pos_index = ct_values[pos_mask].sort_values(ascending=False).head(top_n).index
        # Select top_n most negative NES (smallest ascending)
        top_neg_index = ct_values[neg_mask].sort_values(ascending=True).head(top_n).index

        selected_sets.update(top_pos_index)
        selected_sets.update(top_neg_index)

    # Restrict DataFrame to the union of selected sets, converting the set to a list
    selected_sets_list = list(selected_sets)
    nes_df = nes_df.loc[selected_sets_list]

    return nes_df

comp_ct_thresh

comp_ct_thresh(adata_c, qTile=0.05, obs_name='SCN_class_argmax')

Compute quantile thresholds for each cell type based on SCN scores.

For each cell type (excluding "rand"), this function calculates the qTile quantile of the SCN scores for cells predicted to belong to that type.

Parameters:

  • adata_c (AnnData) –

    Annotated data matrix with: - .obsm["SCN_score"]: DataFrame of SCN scores. - .obs: Observation metadata containing predictions.

  • qTile (int, default: 0.05 ) –

    The quantile to compute (e.g., 0.05 for 5th percentile). Defaults to 0.05.

  • obs_name (str, default: 'SCN_class_argmax' ) –

    The column in .obs containing cell type predictions. Defaults to 'SCN_class_argmax'.

Returns:

  • DataFrame

    pd.DataFrame: A DataFrame where each row corresponds to a cell type

  • DataFrame

    (excluding 'rand') and contains the computed quantile threshold.

  • DataFrame

    Returns None if 'SCN_score' is not present in adata_c.obsm.

Source code in src/pySingleCellNet/tools/categorize.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def comp_ct_thresh(adata_c: AnnData, qTile: int = 0.05, obs_name='SCN_class_argmax') -> pd.DataFrame:
    """Compute quantile thresholds for each cell type based on SCN scores.

    For each cell type (excluding "rand"), this function calculates the qTile 
    quantile of the SCN scores for cells predicted to belong to that type.

    Args:
        adata_c (AnnData): Annotated data matrix with:
            - `.obsm["SCN_score"]`: DataFrame of SCN scores.
            - `.obs`: Observation metadata containing predictions.
        qTile (int, optional): The quantile to compute (e.g., 0.05 for 5th percentile). Defaults to 0.05.
        obs_name (str, optional): The column in `.obs` containing cell type predictions. Defaults to 'SCN_class_argmax'.

    Returns:
        pd.DataFrame: A DataFrame where each row corresponds to a cell type 
        (excluding 'rand') and contains the computed quantile threshold.
        Returns None if 'SCN_score' is not present in `adata_c.obsm`.
    """
    if "SCN_score" not in adata_c.obsm_keys():
        print("No .obsm['SCN_score'] was found in the AnnData provided. You may need to run PySingleCellNet.scn_classify()")
        return
    else:
        sampTab = adata_c.obs.copy()
        scnScores = adata_c.obsm["SCN_score"].copy()

        cts = scnScores.columns.drop('rand')
        thrs = pd.DataFrame(np.repeat(0, len(cts)), index=cts)

        for ct in cts:
            # print(ct)
            templocs = sampTab[sampTab[obs_name] == ct].index
            tempscores = scnScores.loc[templocs, ct]
            thrs.loc[ct, 0] = np.quantile(tempscores, q=qTile)

        return thrs

convert_diffExp_to_dict

convert_diffExp_to_dict(adata, uns_name='rank_genes_groups')

Convert differential expression results from AnnData into a dictionary of DataFrames.

This function extracts differential expression results stored in adata.uns[uns_name] using Scanpy's get.rank_genes_groups_df, cleans the data, and organizes it into a dictionary where each key corresponds to a group and each value is a DataFrame of differential expression results for that group.

Parameters:

  • adata (AnnData) –

    Annotated data matrix containing differential expression results in adata.uns.

  • uns_name (str, default: 'rank_genes_groups' ) –

    Key in adata.uns where rank_genes_groups results are stored. Defaults to 'rank_genes_groups'.

Returns:

  • dict

    Dictionary mapping each group to a DataFrame of its differential

  • expression results, with rows corresponding to genes and relevant statistics

  • for each gene.

Source code in src/pySingleCellNet/tools/comparison.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def convert_diffExp_to_dict(
    adata,
    uns_name: str = 'rank_genes_groups'
):
    """Convert differential expression results from AnnData into a dictionary of DataFrames.

    This function extracts differential expression results stored in `adata.uns[uns_name]` 
    using Scanpy's `get.rank_genes_groups_df`, cleans the data, and organizes it into 
    a dictionary where each key corresponds to a group and each value is a DataFrame 
    of differential expression results for that group.

    Args:
        adata (AnnData): Annotated data matrix containing differential expression results 
            in `adata.uns`.
        uns_name (str, optional): Key in `adata.uns` where rank_genes_groups results 
            are stored. Defaults to 'rank_genes_groups'.

    Returns:
        dict: Dictionary mapping each group to a DataFrame of its differential 
        expression results, with rows corresponding to genes and relevant statistics 
        for each gene.
    """
    import scanpy as sc  # Ensure Scanpy is imported
    tempTab = sc.get.rank_genes_groups_df(adata, group=None, key=uns_name)
    tempTab = tempTab.dropna()
    groups = tempTab['group'].cat.categories.to_list()

    ans = {}
    for g in groups:
        ans[g] = tempTab[tempTab['group'] == g].copy()
    return ans

create_classifier_report

create_classifier_report(adata, ground_truth, prediction)

Generate a classification report as a pandas DataFrame from an AnnData object.

This function computes a classification report using ground truth and prediction columns in adata.obs. It supports both string and dictionary outputs from sklearn.metrics.classification_report and transforms them into a standardized DataFrame format.

Parameters:

  • adata (AnnData) –

    An annotated data matrix containing observations with categorical truth and prediction labels.

  • ground_truth (str) –

    The column name in adata.obs containing the true class labels.

  • prediction (str) –

    The column name in adata.obs containing the predicted class labels.

Returns:

  • DataFrame

    pd.DataFrame: A DataFrame with columns ["Label", "Precision", "Recall",

  • DataFrame

    "F1-Score", "Support"] summarizing classification metrics for each class.

Raises:

  • ValueError

    If the classification report is neither a string nor a dictionary.

Source code in src/pySingleCellNet/tools/classifier.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def create_classifier_report(adata: AnnData,
    ground_truth: str,
    prediction: str) -> pd.DataFrame:
    """
    Generate a classification report as a pandas DataFrame from an AnnData object.

    This function computes a classification report using ground truth and prediction
    columns in `adata.obs`. It supports both string and dictionary outputs from
    `sklearn.metrics.classification_report` and transforms them into a standardized
    DataFrame format.

    Args:
        adata (AnnData): An annotated data matrix containing observations with
            categorical truth and prediction labels.
        ground_truth (str): The column name in `adata.obs` containing the true
            class labels.
        prediction (str): The column name in `adata.obs` containing the predicted
            class labels.

    Returns:
        pd.DataFrame: A DataFrame with columns ["Label", "Precision", "Recall",
        "F1-Score", "Support"] summarizing classification metrics for each class.

    Raises:
        ValueError: If the classification report is neither a string nor a dictionary.
    """

    report = classification_report(adata.obs[ground_truth], adata.obs[prediction],labels=adata.obs[ground_truth].cat.categories, output_dict = True)
    # Parse the sklearn classification report into a DataFrame
    if isinstance(report, str):
        lines = report.split('\n')
        rows = []
        for line in lines[2:]:
            if line.strip() == '':
                continue
            row = line.split()
            if row[0] == 'micro' or row[0] == 'macro' or row[0] == 'weighted':
                row[0] = ' '.join(row[:2])
                row = [row[0]] + row[2:]
            elif len(row) > 5:
                row[0] = ' '.join(row[:2])
                row = [row[0]] + row[2:]
            rows.append(row)

        df = pd.DataFrame(rows, columns=["Label", "Precision", "Recall", "F1-Score", "Support"])
        df["Precision"] = pd.to_numeric(df["Precision"], errors='coerce')
        df["Recall"] = pd.to_numeric(df["Recall"], errors='coerce')
        df["F1-Score"] = pd.to_numeric(df["F1-Score"], errors='coerce')
        df["Support"] = pd.to_numeric(df["Support"], errors='coerce')
    elif isinstance(report, dict):
        df = pd.DataFrame(report).T.reset_index()
        df.columns = ["Label", "Precision", "Recall", "F1-Score", "Support"]
    else:
        raise ValueError("Report must be a string or a dictionary.")
    return df

deg

deg(adata, sample_obsvals=[], limitto_obsvals=[], cellgrp_obsname='comb_cellgrp', groupby_obsname='comb_sampname', ncells_per_sample=30, test_name='t-test', mask_var='highly_variable')

Perform differential expression analysis on an AnnData object across specified cell groups and samples.

This function iterates over specified or all cell groups within the adata object and performs differential expression analysis using the specified statistical test (e.g., t-test). It filters groups based on the minimum number of cells per sample and returns the results in a structured dictionary.

Parameters:

  • adata (AnnData) –

    The annotated data matrix containing observations and variables.

  • sample_obsvals (list, default: [] ) –

    List of sample observation values to include. Defaults to an empty list. Impacts the sign of the test statistic.

  • limitto_obsvals (list, default: [] ) –

    List of cell group observation values to limit the analysis to. If empty, all cell groups in adata are tested. Defaults to an empty list.

  • cellgrp_obsname (str, default: 'comb_cellgrp' ) –

    The .obs column name in adata that holds the cell sub-groups. Defaults to 'comb_cellgrp'.

  • groupby_obsname (str, default: 'comb_sampname' ) –

    The .obs column name in adata used to group observations for differential expression. Defaults to 'comb_sampname'.

  • ncells_per_sample (int, default: 30 ) –

    The minimum number of cells per sample required to perform the test. Groups with fewer cells are skipped. Defaults to 30.

  • test_name (str, default: 't-test' ) –

    The name of the statistical test to use for differential expression. Defaults to 't-test'.

  • mask_var (str, default: 'highly_variable' ) –

    The name of the .var column indicating highly variable genes Defaults to 'highly_variable'.

Returns:

  • dict ( dict ) –

    A dictionary containing: - 'sample_names': List of sample names used in the analysis. - 'geneTab_dict': A dictionary where each key is a cell group name and each value is a DataFrame of differential expression results for that group.

Source code in src/pySingleCellNet/tools/comparison.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def deg(
    adata: AnnData,
    sample_obsvals: list = [],  # Impacts the sign of the test statistic
    limitto_obsvals: list = [],  # Specifies which cell groups to test; if empty, tests all
    cellgrp_obsname: str = 'comb_cellgrp',  # .obs column name holding the cell sub-groups to iterate over
    groupby_obsname: str = 'comb_sampname',  # .obs column name to group by for differential expression
    ncells_per_sample: int = 30,  # Minimum number of cells per sample required to perform the test
    test_name: str = 't-test',  # Name of the statistical test to use
    mask_var: str = 'highly_variable'
) -> dict:
    """
    Perform differential expression analysis on an AnnData object across specified cell groups and samples.

    This function iterates over specified or all cell groups within the `adata` object and performs
    differential expression analysis using the specified statistical test (e.g., t-test). It filters
    groups based on the minimum number of cells per sample and returns the results in a structured dictionary.

    Args:
        adata (AnnData): The annotated data matrix containing observations and variables.
        sample_obsvals (list, optional): List of sample observation values to include. Defaults to an empty list.
            Impacts the sign of the test statistic.
        limitto_obsvals (list, optional): List of cell group observation values to limit the analysis to.
            If empty, all cell groups in `adata` are tested. Defaults to an empty list.
        cellgrp_obsname (str, optional): The `.obs` column name in `adata` that holds the cell sub-groups.
            Defaults to 'comb_cellgrp'.
        groupby_obsname (str, optional): The `.obs` column name in `adata` used to group observations for differential expression.
            Defaults to 'comb_sampname'.
        ncells_per_sample (int, optional): The minimum number of cells per sample required to perform the test.
            Groups with fewer cells are skipped. Defaults to 30.
        test_name (str, optional): The name of the statistical test to use for differential expression.
            Defaults to 't-test'.
        mask_var (str, optional): The name of the .var column indicating highly variable genes
            Defaults to 'highly_variable'.

    Returns:
        dict: A dictionary containing:
            - 'sample_names': List of sample names used in the analysis.
            - 'geneTab_dict': A dictionary where each key is a cell group name and each value is a DataFrame
              of differential expression results for that group.
    """
    ans = dict()

    # Keys for the rank_genes_groups object
    subset_keys = ['names', 'scores', 'pvals', 'pvals_adj', 'logfoldchanges']

    # If no specific sample observation values are provided, use all unique values from adata
    if len(sample_obsvals) == 0:
        sample_obsvals = adata.obs[groupby_obsname].unique().tolist()

    # Store the sample names in the result dictionary for later ordering of differential expression DataFrame
    ans['sample_names'] = sample_obsvals

    # Retrieve unique cell group names from the AnnData object
    cellgroup_names_in_anndata = adata.obs[cellgrp_obsname].unique()

    # If limitto_obsvals is provided, validate and set the cell groups to test
    if len(limitto_obsvals) > 0:
        # Identify any provided cell groups that are not present in adata
        unique_to_input = [x for x in limitto_obsvals if x not in cellgroup_names_in_anndata]
        if len(unique_to_input) > 0:
            print(f"The argument cellgrp_obsname has values that are not present in adata: {unique_to_input}")
        else:
            cellgroup_names = limitto_obsvals
    else:
        # If no limit is set, use all available cell groups
        cellgroup_names = cellgroup_names_in_anndata

    # Initialize a temporary dictionary to store differential expression results
    tmp_dict = dict()

    # Create a mask to filter adata for the specified sample observation values
    mask = adata.obs[groupby_obsname].isin(sample_obsvals)
    adata = adata[mask].copy()

    def convert_rankGeneGroup_to_df(rgg: dict, list_of_keys: list) -> pd.DataFrame:
        """
        Convert the rank_genes_groups result from AnnData to a pandas DataFrame.

        Args:
            rgg (dict): The rank_genes_groups result from AnnData.
            list_of_keys (list): List of keys to extract from the rank_genes_groups result.

        Returns:
            pd.DataFrame: A DataFrame containing the extracted rank genes information.
        """
        # Initialize a dictionary to hold arrays for each key
        arrays_dict = {}
        for key in list_of_keys:
            recarray = rgg[key]
            field_name = recarray.dtype.names[0]  # Get the first field name from the structured array
            arrays_dict[key] = recarray[field_name]

        # Convert the dictionary of arrays to a DataFrame
        return pd.DataFrame(arrays_dict)

    # Iterate over each cell group to perform differential expression analysis
    for cell_group in cellgroup_names:
        print(f"cell group: {cell_group}")

        # Subset the AnnData object for the current cell group
        adTmp = adata[adata.obs[cellgrp_obsname] == cell_group].copy()

        # Count the number of cells per sample within the cell group
        vcounts = adTmp.obs[groupby_obsname].value_counts()

        # Check if there are exactly two samples and each has at least ncells_per_sample cells
        if (len(vcounts) == 2) and (vcounts >= ncells_per_sample).all():
            # Perform differential expression analysis using the specified test
            sc.tl.rank_genes_groups(
                adTmp,
                use_raw=False,
                groupby=groupby_obsname,
                groups=[sample_obsvals[0]],
                reference=sample_obsvals[1],
                method=test_name,
                mask_var=mask_var
            )

            # Convert the rank_genes_groups result to a DataFrame and store it in tmp_dict
            tmp_dict[cell_group] = convert_rankGeneGroup_to_df(adTmp.uns['rank_genes_groups'].copy(), subset_keys)
            # Alternative method to get the DataFrame (commented out)
            # tmp_dict[cell_group] = sc.get.rank_genes_groups_df(adTmp, cell_group)

    # Store the differential expression results in the result dictionary
    ans['geneTab_dict'] = tmp_dict

    return ans

discover_cell_cliques

discover_cell_cliques(adata, cluster_cols, k=None, mode='lenient', out_col='core_cluster', min_size=1, allow_missing=False, max_combinations=None, return_details=False)

Define 'core' clusters across multiple clustering runs.

Parameters

adata : AnnData The data object with clustering labels in .obs columns. cluster_cols : list[str] or str One or more .obs columns, each containing a clustering. k : int or None, default None Cells must be in the same cluster in at least k runs to be grouped. If None, uses all runs (k = n_runs), i.e., exact tuple agreement. mode : {'lenient','strict'}, default 'lenient' 'lenient' uses DSU over all k-run combinations (fast, transitive). 'strict' refines lenient components so every pair inside a final cluster agrees in >= k runs (complete-linkage on masked Hamming). out_col : str, default 'core_cluster' Name of the output categorical column added to adata.obs. min_size : int, default 1 Minimum size to keep a core cluster; smaller groups get 'core_-1'. allow_missing : bool, default False If False, raises if any clustering column has missing labels. If True, combinations containing missing labels for a cell are either skipped (lenient) or ignored in distance computations (strict). max_combinations : int or None If set and number of k-run combinations exceeds this, raises with guidance. return_details : bool, default False If True, also returns a dict with bookkeeping info.

Returns

core_labels : pandas.Series (categorical) details : dict (optional)

Source code in src/pySingleCellNet/tools/cluster_cliques.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def discover_cell_cliques(
    adata,
    cluster_cols: Union[List[str], str],
    k: Optional[int] = None,           # number of runs that must agree; None => all runs
    mode: str = "lenient",             # 'lenient' (DSU over k-run signatures) or 'strict'
    out_col: str = "core_cluster",
    min_size: int = 1,
    allow_missing: bool = False,
    max_combinations: Optional[int] = None,   # safety for very large #runs
    return_details: bool = False,
) -> Union[pd.Series, Tuple[pd.Series, Dict]]:
    """
    Define 'core' clusters across multiple clustering runs.

    Parameters
    ----------
    adata : AnnData
        The data object with clustering labels in .obs columns.
    cluster_cols : list[str] or str
        One or more .obs columns, each containing a clustering.
    k : int or None, default None
        Cells must be in the same cluster in at least k runs to be grouped.
        If None, uses all runs (k = n_runs), i.e., exact tuple agreement.
    mode : {'lenient','strict'}, default 'lenient'
        'lenient' uses DSU over all k-run combinations (fast, transitive).
        'strict' refines lenient components so every pair inside a final cluster
        agrees in >= k runs (complete-linkage on masked Hamming).
    out_col : str, default 'core_cluster'
        Name of the output categorical column added to adata.obs.
    min_size : int, default 1
        Minimum size to keep a core cluster; smaller groups get 'core_-1'.
    allow_missing : bool, default False
        If False, raises if any clustering column has missing labels.
        If True, combinations containing missing labels for a cell are either skipped
        (lenient) or ignored in distance computations (strict).
    max_combinations : int or None
        If set and number of k-run combinations exceeds this, raises with guidance.
    return_details : bool, default False
        If True, also returns a dict with bookkeeping info.

    Returns
    -------
    core_labels : pandas.Series (categorical)
    details : dict (optional)
    """
    # ---------------- Validate & prep ----------------
    if isinstance(cluster_cols, str):
        cluster_cols = [cluster_cols]
    if not cluster_cols:
        raise ValueError("Provide at least one column in `cluster_cols`.")
    for c in cluster_cols:
        if c not in adata.obs.columns:
            raise ValueError(f"'{c}' not found in adata.obs")

    n = adata.n_obs
    n_runs = len(cluster_cols)
    if k is None:
        k = n_runs
    if not (1 <= k <= n_runs):
        raise ValueError(f"`k` must be between 1 and {n_runs} (inclusive).")
    if mode not in ("lenient", "strict"):
        raise ValueError("`mode` must be 'lenient' or 'strict'.")

    labels_df = adata.obs[cluster_cols].copy()
    if not allow_missing and labels_df.isna().any().any():
        missing_cols = labels_df.columns[labels_df.isna().any(axis=0)].tolist()
        raise ValueError(
            f"Missing labels detected in columns: {missing_cols}. "
            "Pass allow_missing=True to proceed."
        )
    labels_df = labels_df.astype("category")
    # Codes: (n_cells, n_runs); NaN -> -1 sentinel
    codes = np.vstack([labels_df[c].cat.codes.to_numpy() for c in cluster_cols]).T.astype(np.int64)
    codes[codes < 0] = -1

    # ---- Fast path: exact consensus (k == n_runs) => tuple equality ----
    if k == n_runs:
        key = pd.MultiIndex.from_frame(labels_df.apply(lambda s: s.astype(str))).to_numpy()
        grp_ids, _ = pd.factorize(key, sort=False)
        core = pd.Series([f"core_{i}" for i in grp_ids], index=adata.obs_names, name=out_col)

        sizes = core.value_counts()
        small_mask = core.map(sizes) < min_size
        core = core.mask(small_mask, "core_-1")
        adata.obs[out_col] = pd.Categorical(core)
        details = {
            'n_runs': n_runs, 'k': k, 'mode': mode,
            'component_sizes_before_strict': sizes.to_dict(),
            'component_sizes_after_strict': sizes.to_dict()
        }
        return (adata.obs[out_col], details) if return_details else adata.obs[out_col]

    # --------------- LENIENT (DSU over k-combos) ---------------
    dsu = _DSU(n)
    run_indices = list(range(n_runs))
    combos = list(itertools.combinations(run_indices, k))
    if (max_combinations is not None) and (len(combos) > max_combinations):
        raise RuntimeError(
            f"Number of combinations C({n_runs},{k}) = {len(combos)} exceeds "
            f"max_combinations={max_combinations}. Increase k, reduce runs, or raise max_combinations."
        )

    for comb in combos:
        cols = np.array(comb, dtype=int)
        sub = codes[:, cols]  # (n_cells, k)
        # Valid rows for this combination:
        # - If allow_missing=False: require all k labels present
        # - If allow_missing=True: require at least one present (others may be -1)
        valid = (sub != -1).all(axis=1) if not allow_missing else (sub != -1).any(axis=1)
        if not valid.any():
            continue

        # Build signatures as tuples for hashing (missing stays as -1 if allow_missing=True)
        sigs = np.full(n, None, dtype=object)
        for i, v in enumerate(valid):
            if v:
                sigs[i] = tuple(sub[i].tolist())

        # Group equal signatures and union their members
        bucket: Dict[Tuple[int, ...], List[int]] = {}
        for i, sig in enumerate(sigs):
            if sig is None:
                continue
            lst = bucket.get(sig)
            if lst is None:
                bucket[sig] = [i]
            else:
                lst.append(i)
        for members in bucket.values():
            if len(members) >= 2:
                base = members[0]
                for other in members[1:]:
                    dsu.union(base, other)

    # Extract lenient components
    roots = np.fromiter((dsu.find(i) for i in range(n)), dtype=np.int64, count=n)
    _, comp_codes = np.unique(roots, return_inverse=True)
    core_lenient = pd.Series([f"core_{c}" for c in comp_codes], index=adata.obs_names)

    # --------------- STRICT refinement (fixed block) ---------------
    if mode == "strict":
        # Require every pair inside a final cluster to agree in >= k runs.
        # This is equivalent to complete-linkage with cutoff d_max = 1 - (k/n_runs)
        d_max = 1.0 - (float(k) / float(n_runs))
        refined = pd.Series(index=adata.obs_names, dtype=object)

        comp_sizes_before: Dict[str, int] = {}
        comp_sizes_after: Dict[str, int] = {}

        groups = core_lenient.groupby(core_lenient).groups  # dict: {comp_id: Index(obs_names)}

        for comp_id, obs_keys in groups.items():
            # Map obs-name labels -> integer positions (FIX)
            idxs = adata.obs_names.get_indexer(pd.Index(obs_keys))
            if (idxs < 0).any():
                raise RuntimeError("Encountered unknown obs_names while mapping to positions.")

            comp_sizes_before[comp_id] = idxs.size

            if idxs.size <= 1:
                refined.iloc[idxs] = comp_id + "_0"
                comp_sizes_after[comp_id + "_0"] = idxs.size
                continue

            X = codes[idxs, :]  # (s, n_runs)
            D = _pairwise_masked_hamming(X, missing_val=-1)

            # If already satisfies the strict criterion, keep as one block
            if np.all(D <= d_max):
                refined.iloc[idxs] = comp_id + "_0"
                comp_sizes_after[comp_id + "_0"] = idxs.size
                continue

            Z = linkage(D, method="complete")
            labs = fcluster(Z, t=d_max, criterion="distance")  # 1..K within this component

            for lab in np.unique(labs):
                sel = idxs[labs == lab]  # integer positions
                out_name = f"{comp_id}_{int(lab)-1}"
                refined.iloc[sel] = out_name
                comp_sizes_after[out_name] = sel.size

        core = refined
    else:
        core = core_lenient
        vc = core.value_counts()
        comp_sizes_before = vc.to_dict()
        comp_sizes_after = vc.to_dict()

    # Enforce min_size
    sizes = core.value_counts()
    small_mask = core.map(sizes) < min_size
    core = core.mask(small_mask, "core_-1")
    core.name = out_col

    adata.obs[out_col] = pd.Categorical(core)

    details = {
        'n_runs': n_runs,
        'k': k,
        'mode': mode,
        'component_sizes_before_strict': comp_sizes_before,
        'component_sizes_after_strict': comp_sizes_after,
    }
    return (adata.obs[out_col], details) if return_details else adata.obs[out_col]

find_gene_modules

find_gene_modules(adata, mean_cluster=True, groupby='leiden', mask_var=None, knn=5, leiden_resolution=0.5, prefix='gmod_', metric='euclidean', *, uns_key='knn_modules', layer=None, min_module_size=2, order_genes_by_within_module_connectivity=True, random_state=0)

Find gene modules by building a kNN graph over genes (or cluster-mean profiles) and clustering with Leiden.

Writes a dict {f"{prefix}{cluster_id}": [gene names]} to adata.uns[uns_key] and returns the same dict.

Parameters

mean_cluster If True, aggregate cells by groupby before building the gene kNN graph. groupby Column in adata.obs used for aggregation when mean_cluster=True. mask_var Boolean column in adata.var used to select a subset of genes. If None, use all genes. knn Number of neighbors for the kNN graph on genes. leiden_resolution Resolution for Leiden clustering. prefix Prefix for module names. metric Distance metric for kNN (e.g. 'euclidean', 'manhattan', 'cosine', 'correlation'). NOTE: If metric=='correlation' and the data are sparse, we densify for stability. uns_key Top-level .uns key to store the resulting dict of modules (default 'knn_modules'). layer If provided, use adata.layers[layer] as expression, otherwise adata.X. (Aggregation honors this choice.) min_module_size Remove modules smaller than this size after clustering. order_genes_by_within_module_connectivity If True, sort each module's genes by their within-module connectivity (descending). random_state Random seed passed to Leiden for reproducibility.

Source code in src/pySingleCellNet/tools/gene.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def find_gene_modules(
    adata,
    mean_cluster: bool = True,
    groupby: str = 'leiden',
    mask_var: Optional[str] = None,
    knn: int = 5,
    leiden_resolution: float = 0.5,
    prefix: str = 'gmod_',
    metric: str = 'euclidean',
    *,
    # NEW:
    uns_key: str = 'knn_modules',                # where to store the dict of modules
    layer: Optional[str] = None,                 # use adata.layers[layer] instead of .X
    min_module_size: int = 2,                    # drop tiny modules (set to 1 to keep all)
    order_genes_by_within_module_connectivity: bool = True,
    random_state: Optional[int] = 0,             # for reproducible Leiden
) -> Dict[str, List[str]]:
    """
    Find gene modules by building a kNN graph over genes (or cluster-mean profiles)
    and clustering with Leiden.

    Writes a dict {f"{prefix}{cluster_id}": [gene names]} to `adata.uns[uns_key]`
    and returns the same dict.

    Parameters
    ----------
    mean_cluster
        If True, aggregate cells by `groupby` before building the gene kNN graph.
    groupby
        Column in adata.obs used for aggregation when `mean_cluster=True`.
    mask_var
        Boolean column in adata.var used to select a subset of genes. If None, use all genes.
    knn
        Number of neighbors for the kNN graph on genes.
    leiden_resolution
        Resolution for Leiden clustering.
    prefix
        Prefix for module names.
    metric
        Distance metric for kNN (e.g. 'euclidean', 'manhattan', 'cosine', 'correlation').
        NOTE: If `metric=='correlation'` and the data are sparse, we densify for stability.
    uns_key
        Top-level .uns key to store the resulting dict of modules (default 'knn_modules').
    layer
        If provided, use `adata.layers[layer]` as expression, otherwise `adata.X`.
        (Aggregation honors this choice.)
    min_module_size
        Remove modules smaller than this size after clustering.
    order_genes_by_within_module_connectivity
        If True, sort each module's genes by their within-module connectivity (descending).
    random_state
        Random seed passed to Leiden for reproducibility.
    """
    # ----------------- 1) Choose expression matrix via a copy -----------------
    adata_subset = adata.copy()
    if layer is not None:
        if layer not in adata.layers:
            raise ValueError(f"Layer '{layer}' not found in adata.layers.")
        adata_subset.X = adata.layers[layer].copy()

    # ----------------- 2) Optional gene mask -----------------
    if mask_var is not None:
        if mask_var not in adata_subset.var.columns:
            raise ValueError(f"Column '{mask_var}' not found in adata.var.")
        gene_mask = adata_subset.var[mask_var].astype(bool).to_numpy()
        if gene_mask.sum() == 0:
            raise ValueError(f"No genes where var['{mask_var}'] is True.")
        adata_subset = adata_subset[:, gene_mask].copy()

    # ----------------- 3) Optional per-cluster aggregation -----------------
    if mean_cluster:
        if groupby not in adata_subset.obs.columns:
            raise ValueError(f"Column '{groupby}' not found in adata.obs.")
        # Prefer Scanpy's aggregation helper if available (Scanpy ≥1.10):
        if hasattr(sc.get, "aggregate"):
            ad_agg = sc.get.aggregate(adata_subset, by=groupby, func='mean')
            adata_subset = ad_agg.copy()
            adata_subset.X = ad_agg.layers['mean']  # make means the working matrix
        else:
            # Fallback: manual aggregation (sparse-aware)
            groups = adata_subset.obs[groupby].astype("category")
            cat = groups.cat.codes.to_numpy()
            n_groups = groups.cat.categories.size
            # Build group indicator sparse matrix G (cells x groups), then G^T * X / counts
            rows = np.arange(adata_subset.n_obs)
            G = sparse.csr_matrix((np.ones_like(rows), (rows, cat)), shape=(adata_subset.n_obs, n_groups))
            if sparse.issparse(adata_subset.X):
                sums = G.T @ adata_subset.X
            else:
                sums = (G.T @ sparse.csr_matrix(adata_subset.X)).toarray()
            counts = np.asarray(G.sum(axis=0)).ravel() + 1e-12
            means = sums / counts[:, None]
            # Build a new AnnData with groups as observations and genes as variables
            adata_subset = sc.AnnData(
                X=means,
                obs=pd.DataFrame(index=groups.cat.categories),
                var=adata_subset.var.copy()
            )

    # ----------------- 4) Transpose: genes become observations -----------------
    adt = adata_subset.T.copy()

    # ----------------- 5) Correlation metric stability (densify if needed) ----
    if metric == 'correlation' and sparse.issparse(adt.X):
        adt.X = adt.X.toarray()

    # ----------------- 6) Build kNN on genes (no PCA) ------------------------
    sc.pp.neighbors(
        adt,
        n_neighbors=int(knn),
        metric=metric,
        n_pcs=0,                # work directly in the expression space
        key_added="gene_neighbors"
    )

    # ----------------- 7) Leiden on that graph --------------------------------
    sc.tl.leiden(
        adt,
        resolution=float(leiden_resolution),
        key_added="gene_modules",
        neighbors_key="gene_neighbors",
        random_state=random_state,
    )

    # ----------------- 8) Collect modules (optionally filter & order) ----------
    # Base groups: leiden label -> list of gene names
    base_groups = (
        adt.obs
        .groupby('gene_modules', observed=True)['gene_modules']
        .apply(lambda s: s.index.tolist())
        .to_dict()
    )

    # Filter tiny modules
    if min_module_size > 1:
        base_groups = {k: v for k, v in base_groups.items() if len(v) >= min_module_size}

    # Optionally order by within-module connectivity
    modules: Dict[str, List[str]] = {}
    if order_genes_by_within_module_connectivity and 'gene_neighbors_connectivities' in adt.obsp:
        C = adt.obsp['gene_neighbors_connectivities']  # sparse CSR
        name_to_idx = {g: i for i, g in enumerate(adt.obs_names)}
        for cluster_id, genes in base_groups.items():
            idx = np.array([name_to_idx[g] for g in genes], dtype=int)
            # sum of weights within the subgraph
            w = np.asarray(C[idx, :][:, idx].sum(axis=1)).ravel()
            order = np.argsort(-w)  # descending
            mod_name = f"{prefix}{cluster_id}"
            modules[mod_name] = [genes[i] for i in order]
    else:
        # keep the original (arbitrary) order
        modules = {f"{prefix}{cluster_id}": gene_list for cluster_id, gene_list in base_groups.items()}

    # ----------------- 9) Store results & return ------------------------------
    adata.uns[uns_key] = modules
    # (optional lightweight metadata alongside; keeps backward-compat for adata.uns[uns_key])
    meta_key = f"{uns_key}__meta"
    adata.uns[meta_key] = {
        "mean_cluster": bool(mean_cluster),
        "groupby": groupby,
        "mask_var": mask_var,
        "knn": int(knn),
        "leiden_resolution": float(leiden_resolution),
        "prefix": prefix,
        "metric": metric,
        "layer": layer,
        "min_module_size": int(min_module_size),
        "ordered_by_within_module_connectivity": bool(order_genes_by_within_module_connectivity),
        "random_state": random_state,
        "n_modules": len(modules),
        "module_sizes": {k: len(v) for k, v in modules.items()},
    }

    return modules

graph_from_nodes_and_edges

graph_from_nodes_and_edges(edge_dataframe, node_dataframe, attribution_column_names, directed=True)

Create an iGraph graph from provided node and edge dataframes.

This function constructs an iGraph graph using nodes defined in node_dataframe and edges defined in edge_dataframe. Each vertex is assigned attributes based on specified columns, and edges are created according to 'from' and 'to' columns in the edge dataframe.

Parameters:

  • edge_dataframe (DataFrame) –

    A DataFrame containing edge information with at least 'from' and 'to' columns indicating source and target node identifiers.

  • node_dataframe (DataFrame) –

    A DataFrame containing node information. Must include an 'id' column for vertex identifiers and any other columns specified in attribution_column_names.

  • attribution_column_names (list of str) –

    List of column names from node_dataframe whose values will be assigned as attributes to the corresponding vertices in the graph.

  • directed (bool, default: True ) –

    Whether the graph should be directed. Defaults to True.

Returns:

  • ig.Graph: An iGraph graph constructed from the given nodes and edges,

  • with vertex attributes and labels set according to the provided data.

Source code in src/pySingleCellNet/tools/categorize.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
def graph_from_nodes_and_edges(edge_dataframe, node_dataframe, attribution_column_names, directed=True):
    """Create an iGraph graph from provided node and edge dataframes.

    This function constructs an iGraph graph using nodes defined in 
    `node_dataframe` and edges defined in `edge_dataframe`. Each vertex 
    is assigned attributes based on specified columns, and edges are 
    created according to 'from' and 'to' columns in the edge dataframe.

    Args:
        edge_dataframe (pd.DataFrame): A DataFrame containing edge 
            information with at least 'from' and 'to' columns indicating 
            source and target node identifiers.
        node_dataframe (pd.DataFrame): A DataFrame containing node 
            information. Must include an 'id' column for vertex identifiers 
            and any other columns specified in `attribution_column_names`.
        attribution_column_names (list of str): List of column names from 
            `node_dataframe` whose values will be assigned as attributes 
            to the corresponding vertices in the graph.
        directed (bool, optional): Whether the graph should be directed. 
            Defaults to True.

    Returns:
        ig.Graph: An iGraph graph constructed from the given nodes and edges, 
        with vertex attributes and labels set according to the provided data.
    """
    gra = ig.Graph(directed=directed)
    attr = {}
    for attr_names in attribution_column_names:
        attr[attr_names] = node_dataframe[attr_names].to_numpy()

    gra.add_vertices(n=node_dataframe.id.to_numpy(), attributes=attr)
    for ind in edge_dataframe.index:
        tempsource = edge_dataframe.loc[ind].loc['from']
        temptarget = edge_dataframe.loc[ind].loc['to']
        gra.add_edges([(tempsource, temptarget)])

    gra.vs["label"] = gra.vs["id"]
    return gra

gsea_on_deg

gsea_on_deg(deg_res, genesets_name, genesets, permutation_num=100, threads=4, seed=3, min_size=10, max_size=500)

Perform Gene Set Enrichment Analysis (GSEA) on differential expression results.

Applies GSEA using gseapy.prerank for each group in the differential expression results dictionary against provided gene sets.

Parameters:

  • deg_res (dict) –

    Dictionary mapping cell group names to DataFrames of differential expression results. Each DataFrame must contain columns 'names' (gene names) and 'scores' (ranking scores).

  • genesets_name (str) –

    Name of the gene set collection (not actively used).

  • genesets (dict) –

    Dictionary of gene sets where keys are gene set names and values are lists of genes.

  • permutation_num (int, default: 100 ) –

    Number of permutations for GSEA. Defaults to 100.

  • threads (int, default: 4 ) –

    Number of parallel threads to use. Defaults to 4.

  • seed (int, default: 3 ) –

    Random seed for reproducibility. Defaults to 3.

  • min_size (int, default: 10 ) –

    Minimum gene set size to consider. Defaults to 10.

  • max_size (int, default: 500 ) –

    Maximum gene set size to consider. Defaults to 500.

Returns:

  • dict ( dict ) –

    Dictionary where keys are cell group names and values are GSEA result objects returned by gseapy.prerank.

Example

deg_results = { ... 'Cluster1': pd.DataFrame({'names': ['GeneA', 'GeneB'], 'scores': [2.5, -1.3]}), ... 'Cluster2': pd.DataFrame({'names': ['GeneC', 'GeneD'], 'scores': [1.2, -2.1]}) ... } gene_sets = {'Pathway1': ['GeneA', 'GeneC'], 'Pathway2': ['GeneB', 'GeneD']} results = gsea_on_deg(deg_results, 'ExampleGeneSets', gene_sets)

Source code in src/pySingleCellNet/tools/comparison.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def gsea_on_deg(
    deg_res: dict,
    genesets_name: str,
    genesets: dict,
    permutation_num: int = 100,
    threads: int = 4,
    seed: int = 3,
    min_size: int = 10,
    max_size: int = 500
) -> dict:
    """Perform Gene Set Enrichment Analysis (GSEA) on differential expression results.

    Applies GSEA using `gseapy.prerank` for each group in the differential 
    expression results dictionary against provided gene sets.

    Args:
        deg_res (dict): Dictionary mapping cell group names to DataFrames 
            of differential expression results. Each DataFrame must contain 
            columns 'names' (gene names) and 'scores' (ranking scores).
        genesets_name (str): Name of the gene set collection (not actively used).
        genesets (dict): Dictionary of gene sets where keys are gene set 
            names and values are lists of genes.
        permutation_num (int, optional): Number of permutations for GSEA. 
            Defaults to 100.
        threads (int, optional): Number of parallel threads to use. Defaults to 4.
        seed (int, optional): Random seed for reproducibility. Defaults to 3.
        min_size (int, optional): Minimum gene set size to consider. Defaults to 10.
        max_size (int, optional): Maximum gene set size to consider. Defaults to 500.

    Returns:
        dict: Dictionary where keys are cell group names and values are 
            GSEA result objects returned by `gseapy.prerank`.

    Example:
        >>> deg_results = {
        ...     'Cluster1': pd.DataFrame({'names': ['GeneA', 'GeneB'], 'scores': [2.5, -1.3]}),
        ...     'Cluster2': pd.DataFrame({'names': ['GeneC', 'GeneD'], 'scores': [1.2, -2.1]})
        ... }
        >>> gene_sets = {'Pathway1': ['GeneA', 'GeneC'], 'Pathway2': ['GeneB', 'GeneD']}
        >>> results = gsea_on_deg(deg_results, 'ExampleGeneSets', gene_sets)
    """
    ans = dict()
    diff_gene_tables = deg_res
    cellgrp_vals = list(diff_gene_tables.keys())
    for cellgrp in cellgrp_vals:
        atab = diff_gene_tables[cellgrp]
        atab = atab[['names', 'scores']]
        atab.columns = ['0', '1']
        pre_res = gp.prerank(
            rnk=atab,
            gene_sets=genesets,
            permutation_num=permutation_num,
            ascending=False,
            threads=threads,
            no_plot=True,
            seed=seed,
            min_size=min_size,
            max_size=max_size
        )
        ans[cellgrp] = pre_res
    return ans

paga_connectivities_to_igraph

paga_connectivities_to_igraph(adInput, n_neighbors=10, use_rep='X_pca', n_comps=30, threshold=0.05, paga_key='paga', connectivities_key='connectivities', group_key='auto_cluster')

Convert a PAGA adjacency matrix to an undirected iGraph object and add 'ncells' attribute for each vertex based on the number of cells in each cluster.

This function extracts the PAGA connectivity matrix from adata.uns, thresholds the edges, constructs an undirected iGraph graph, and assigns vertex names and the number of cells in each cluster.

Parameters:

  • adInput (AnnData) –

    The AnnData object containing: - adata.uns[paga_key][connectivities_key]: The PAGA adjacency matrix (CSR format). - adata.obs[group_key].cat.categories: The node labels.

  • n_neighbors (int, default: 10 ) –

    Number of neighbors for computing nearest neighbors. Defaults to 10.

  • use_rep (str, default: 'X_pca' ) –

    The representation to use. Defaults to 'X_pca'.

  • n_comps (int, default: 30 ) –

    Number of principal components. Defaults to 30.

  • threshold (float, default: 0.05 ) –

    Minimum edge weight to include. Defaults to 0.05.

  • paga_key (str, default: 'paga' ) –

    Key in adata.uns for PAGA results. Defaults to "paga".

  • connectivities_key (str, default: 'connectivities' ) –

    Key for connectivity matrix in adata.uns[paga_key]. Defaults to "connectivities".

  • group_key (str, default: 'auto_cluster' ) –

    The .obs column name with cluster labels. Defaults to "auto_cluster".

Returns:

  • ig.Graph: An undirected graph with edges meeting the threshold, edge weights assigned,

  • vertex names set to cluster categories when possible, and each vertex has an 'ncells' attribute.

Source code in src/pySingleCellNet/tools/categorize.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def paga_connectivities_to_igraph(
    adInput,
    n_neighbors=10,
    use_rep='X_pca',
    n_comps=30,
    threshold=0.05, 
    paga_key="paga", 
    connectivities_key="connectivities", 
    group_key="auto_cluster"
):
    """Convert a PAGA adjacency matrix to an undirected iGraph object and add 'ncells' 
    attribute for each vertex based on the number of cells in each cluster.

    This function extracts the PAGA connectivity matrix from `adata.uns`, thresholds 
    the edges, constructs an undirected iGraph graph, and assigns vertex names and 
    the number of cells in each cluster.

    Args:
        adInput (AnnData): The AnnData object containing:
            - `adata.uns[paga_key][connectivities_key]`: The PAGA adjacency matrix (CSR format).
            - `adata.obs[group_key].cat.categories`: The node labels.
        n_neighbors (int, optional): Number of neighbors for computing nearest neighbors. Defaults to 10.
        use_rep (str, optional): The representation to use. Defaults to 'X_pca'.
        n_comps (int, optional): Number of principal components. Defaults to 30.
        threshold (float, optional): Minimum edge weight to include. Defaults to 0.05.
        paga_key (str, optional): Key in `adata.uns` for PAGA results. Defaults to "paga".
        connectivities_key (str, optional): Key for connectivity matrix in `adata.uns[paga_key]`. Defaults to "connectivities".
        group_key (str, optional): The `.obs` column name with cluster labels. Defaults to "auto_cluster".

    Returns:
        ig.Graph: An undirected graph with edges meeting the threshold, edge weights assigned, 
        vertex names set to cluster categories when possible, and each vertex has an 'ncells' attribute.
    """
    # Copy so as to avoid altering the original AnnData object
    adata = adInput.copy()

    # Compute PCA, knn, and PAGA
    sc.tl.pca(adata, n_comps, mask_var='highly_variable')
    sc.pp.neighbors(adata, n_neighbors=n_neighbors, use_rep=use_rep, n_pcs=n_comps)
    sc.tl.paga(adata, groups=group_key)

    # Extract the PAGA connectivity matrix
    adjacency_csr = adata.uns[paga_key][connectivities_key]
    adjacency_coo = adjacency_csr.tocoo()

    # Build edge list based on threshold
    edges = []
    weights = []
    for i, j, val in zip(adjacency_coo.row, adjacency_coo.col, adjacency_coo.data):
        if i < j and val >= threshold:
            edges.append((i, j))
            weights.append(val)

    # Create the graph
    g = ig.Graph(n=adjacency_csr.shape[0], edges=edges, directed=False)
    g.es["weight"] = weights

    # Assign vertex names and 'ncells' attribute if group_key exists in adata.obs
    if group_key in adata.obs:
        # Get cluster categories
        categories = adata.obs[group_key].cat.categories

        # Calculate the number of cells per category
        cell_counts_series = adata.obs[group_key].value_counts().reindex(categories, fill_value=0)
        cell_counts = list(cell_counts_series)

        if len(categories) == adjacency_csr.shape[0]:
            # Assign vertex names and 'ncells' attribute
            g.vs["name"] = list(categories)
            g.vs["label"] = list(categories)
            g.vs["ncells"] = cell_counts
        else:
            print(
                f"Warning: adjacency matrix size ({adjacency_csr.shape[0]}) "
                f"differs from number of categories ({len(categories)}). "
                "Vertex names and 'ncells' will not be fully assigned."
            )
            # Even if the sizes don't match, still assign available 'ncells' for existing categories
            g.vs["ncells"] = cell_counts
    else:
        print(
            f"Warning: {group_key} not found in adata.obs; "
            "vertex names and 'ncells' will not be assigned."
        )

    return g

score_gene_sets

score_gene_sets(adata, gene_sets, *, layer=None, log_transform=False, clip_percentiles=(1.0, 99.0), agg='mean', top_p=0.5, top_k=None, rank_method=None, rank_universe=None, auc_max_rank=0.05, batch_size=2048, use_average_ranks=False, min_genes_per_set=1, case_insensitive=False, obs_prefix=None, return_dataframe=True)

Compute per-cell gene-set scores with both value-based and rank-based (AUCell/UCell) modes.

Value-based pipeline (when rank_method is None): 1) Optional log1p. 2) Per-gene percentile clipping (clip_percentiles). 3) Per-gene min–max scaling to [0, 1]. 4) Aggregate across genes in each set per cell with 'mean' | 'median' | 'sum' | 'nonzero_mean' | 'top_p_mean' | 'top_k_mean' | callable.

Rank-based pipeline (when rank_method in {'auc','ucell'}): • For each cell, rank genes within a chosen universe (rank_universe). • 'auc' : AUCell-style AUC in the top L ranks (L = auc_max_rank). • 'ucell': normalized Mann–Whitney U statistic in [0,1]. • Ranks are computed in batches (batch_size) for memory efficiency.

Parameters:

  • adata

    AnnData object.

  • gene_sets (GeneSetInput) –

    Dict[name -> genes], list of gene lists (auto-named), or name of adata.uns key.

  • layer (Optional[str], default: None ) –

    Use adata.layers[layer] instead of .X.

  • log_transform (bool, default: False ) –

    Apply np.log1p before scoring (safe monotone transform).

  • clip_percentiles (Tuple[float, float], default: (1.0, 99.0) ) –

    (low, high) clipping percentiles for value-based mode.

  • agg (Union[str, Callable[[ndarray], ndarray]], default: 'mean' ) –

    Aggregation for value-based mode or a callable: (cells×genes) -> (cells,).

  • top_p (Optional[float], default: 0.5 ) –

    Fraction for 'top_p_mean' (0<p<=1).

  • top_k (Optional[int], default: None ) –

    Count for 'top_k_mean' (>=1).

  • rank_method (Optional[str], default: None ) –

    None | 'auc' | 'ucell' to switch to rank-based scoring.

  • rank_universe (Optional[Union[str, Sequence[str]]], default: None ) –

    None=all genes; or a boolean var column name (e.g. 'highly_variable'); or an explicit list of gene names defining the ranking universe.

  • auc_max_rank (Union[int, float], default: 0.05 ) –

    AUCell top window (int L) or fraction (0,1].

  • batch_size (int, default: 2048 ) –

    Row batch size for rank computation.

  • use_average_ranks (bool, default: False ) –

    If True, uses average-tie ranks (scipy.stats.rankdata); slower.

  • min_genes_per_set (int, default: 1 ) –

    Require at least this many present genes to score a set (else NaN).

  • case_insensitive (bool, default: False ) –

    Case-insensitive gene matching against var_names.

  • obs_prefix (Optional[str], default: None ) –

    If provided, also writes scores to adata.obs[f"{obs_prefix}{name}"].

  • return_dataframe (bool, default: True ) –

    If True, return a DataFrame; else return ndarray.

Returns:

  • DataFrame

    DataFrame (cells × sets) of scores (and optionally writes to adata.obs).

Notes

• Rank-based scores ignore clipping/min–max (ranks are invariant to monotone transforms). • AUCell output here is normalized to [0,1] within the top-L window. • UCell output is the normalized U statistic in [0,1].

Source code in src/pySingleCellNet/tools/gene.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
def score_gene_sets(
    adata,
    gene_sets: GeneSetInput,
    *,
    layer: Optional[str] = None,
    # ---- value-based (existing) options ----
    log_transform: bool = False,
    clip_percentiles: Tuple[float, float] = (1.0, 99.0),
    agg: Union[str, Callable[[np.ndarray], np.ndarray]] = "mean",
    top_p: Optional[float] = 0.5,           # for agg="top_p_mean" (0<p<=1)
    top_k: Optional[int] = None,            # for agg="top_k_mean"
    # ---- rank-based (new) options ----
    rank_method: Optional[str] = None,      # None | "auc" | "ucell"
    rank_universe: Optional[Union[str, Sequence[str]]] = None,  # None | var column | list of genes
    auc_max_rank: Union[int, float] = 0.05, # AUCell window: int = L, float=(0,1] fraction of universe
    batch_size: int = 2048,                 # batch size for ranking
    use_average_ranks: bool = False,        # use scipy.stats.rankdata (average ties); slower
    # ---- misc ----
    min_genes_per_set: int = 1,
    case_insensitive: bool = False,
    obs_prefix: Optional[str] = None,
    return_dataframe: bool = True,
) -> pd.DataFrame:
    """Compute per-cell gene-set scores with both value-based and rank-based (AUCell/UCell) modes.

    Value-based pipeline (when `rank_method is None`):
      1) Optional log1p.
      2) Per-gene percentile clipping (`clip_percentiles`).
      3) Per-gene min–max scaling to [0, 1].
      4) Aggregate across genes in each set per cell with
         'mean' | 'median' | 'sum' | 'nonzero_mean' | 'top_p_mean' | 'top_k_mean' | callable.

    Rank-based pipeline (when `rank_method in {'auc','ucell'}`):
      • For each cell, rank genes within a chosen universe (`rank_universe`).
      • 'auc'  : AUCell-style AUC in the top L ranks (L = `auc_max_rank`).
      • 'ucell': normalized Mann–Whitney U statistic in [0,1].
      • Ranks are computed in batches (`batch_size`) for memory efficiency.

    Args:
        adata: AnnData object.
        gene_sets: Dict[name -> genes], list of gene lists (auto-named), or name of `adata.uns` key.
        layer: Use `adata.layers[layer]` instead of `.X`.
        log_transform: Apply `np.log1p` before scoring (safe monotone transform).
        clip_percentiles: (low, high) clipping percentiles for value-based mode.
        agg: Aggregation for value-based mode or a callable: (cells×genes) -> (cells,).
        top_p: Fraction for 'top_p_mean' (0<p<=1).
        top_k: Count for 'top_k_mean' (>=1).
        rank_method: None | 'auc' | 'ucell' to switch to rank-based scoring.
        rank_universe: None=all genes; or a boolean var column name (e.g. 'highly_variable');
                       or an explicit list of gene names defining the ranking universe.
        auc_max_rank: AUCell top window (int L) or fraction (0,1].
        batch_size: Row batch size for rank computation.
        use_average_ranks: If True, uses average-tie ranks (scipy.stats.rankdata); slower.
        min_genes_per_set: Require at least this many present genes to score a set (else NaN).
        case_insensitive: Case-insensitive gene matching against `var_names`.
        obs_prefix: If provided, also writes scores to `adata.obs[f"{obs_prefix}{name}"]`.
        return_dataframe: If True, return a DataFrame; else return ndarray.

    Returns:
        DataFrame (cells × sets) of scores (and optionally writes to `adata.obs`).

    Notes:
        • Rank-based scores ignore clipping/min–max (ranks are invariant to monotone transforms).
        • AUCell output here is normalized to [0,1] within the top-L window.
        • UCell output is the normalized U statistic in [0,1].
    """
    # ------- resolve gene_sets -> dict[name] -> list[str] -------
    if isinstance(gene_sets, str):
        if gene_sets not in adata.uns:
            raise ValueError(f"gene_sets='{gene_sets}' not found in adata.uns")
        gs_map = dict(adata.uns[gene_sets])
    elif isinstance(gene_sets, Mapping):
        gs_map = {str(k): list(v) for k, v in gene_sets.items()}
    else:
        gs_map = {f"set_{i+1}": list(v) for i, v in enumerate(gene_sets)}
    if not gs_map:
        raise ValueError("No gene sets provided.")

    X = adata.layers[layer] if layer is not None else adata.X
    n_cells, n_genes = X.shape
    var_names = adata.var_names.astype(str)

    # name lookup
    if case_insensitive:
        lut = {g.lower(): i for i, g in enumerate(var_names)}
        def _loc(g: str) -> int: return lut.get(g.lower(), -1)
    else:
        lut = {g: i for i, g in enumerate(var_names)}
        def _loc(g: str) -> int: return lut.get(g, -1)

    # map each set to present indices (deduped)
    present_idx: Dict[str, np.ndarray] = {}
    for name, genes in gs_map.items():
        idx = sorted({_loc(str(g)) for g in genes if _loc(str(g)) >= 0})
        present_idx[name] = np.array(idx, dtype=int)

    # ======================= RANK-BASED BRANCH =======================
    if rank_method is not None:
        method = rank_method.lower()
        if method not in {"auc", "ucell"}:
            raise ValueError("rank_method must be one of {None, 'auc', 'ucell'}.")

        # pick universe
        if rank_universe is None:
            U_idx = np.arange(n_genes, dtype=int)
        elif isinstance(rank_universe, str) and rank_universe in adata.var.columns:
            mask = adata.var[rank_universe].astype(bool).to_numpy()
            U_idx = np.where(mask)[0]
        else:
            # list-like of gene names
            names = pd.Index(rank_universe)  # raises if not list-like; OK
            U_idx = var_names.get_indexer(names)
            U_idx = U_idx[U_idx >= 0]
        if U_idx.size == 0:
            raise ValueError("rank_universe resolved to 0 genes.")

        # restrict sets to universe; build compact col map
        pos_in_U = {j: k for k, j in enumerate(U_idx)}
        set_cols_in_U: Dict[str, np.ndarray] = {}
        for name, idx in present_idx.items():
            idxU = idx[np.isin(idx, U_idx)]
            if idxU.size < min_genes_per_set:
                set_cols_in_U[name] = np.array([], dtype=int)
            else:
                set_cols_in_U[name] = np.array([pos_in_U[j] for j in idxU], dtype=int)

        # slice universe matrix (cells × |U|)
        Xu = X[:, U_idx].toarray() if sparse.issparse(X) else np.asarray(X)[:, U_idx]
        if log_transform:
            Xu = np.log1p(Xu)  # monotone; safe for ranks

        # AUCell window
        if method == "auc":
            if isinstance(auc_max_rank, float):
                if not (0 < auc_max_rank <= 1):
                    raise ValueError("If auc_max_rank is float, it must be in (0,1].")
                L = max(1, int(np.ceil(auc_max_rank * Xu.shape[1])))
            else:
                L = int(auc_max_rank)
                if L < 1 or L > Xu.shape[1]:
                    raise ValueError("auc_max_rank (int) must be in [1, n_universe].")

        # prepare output
        scores = {name: np.full(n_cells, np.nan, float) for name in gs_map.keys()}

        # optional average-tie ranks
        if use_average_ranks:
            from scipy.stats import rankdata  # local import; slower but exact ties

        # rank batches
        for start in range(0, n_cells, batch_size):
            end = min(n_cells, start + batch_size)
            A = Xu[start:end, :]  # (b × nU)

            if use_average_ranks:
                # ranks ascending: 1..nU (average ties). Loop rows for stability.
                ranks_asc = np.vstack([rankdata(row, method="average") for row in A]).astype(np.float64)
                ranks_desc = A.shape[1] + 1 - ranks_asc
            else:
                # fast ordinal ranks via double argsort (stable)
                order = np.argsort(A, axis=1, kind="mergesort")
                ranks_asc = np.empty_like(order, dtype=np.int32)
                row_indices = np.arange(order.shape[0])[:, None]
                ranks_asc[row_indices, order] = np.arange(1, A.shape[1] + 1, dtype=np.int32)
                ranks_desc = A.shape[1] - ranks_asc + 1

            if method == "ucell":
                nU = A.shape[1]
                for name, cols in set_cols_in_U.items():
                    m = cols.size
                    if m < min_genes_per_set:
                        continue
                    r = ranks_asc[:, cols].astype(np.float64)         # (b × m)
                    U = r.sum(axis=1) - (m * (m + 1) / 2.0)          # Mann–Whitney U
                    denom = m * (nU - m)
                    out = np.zeros(U.shape[0], float)
                    np.divide(U, denom, out=out, where=denom > 0)    # normalized to [0,1]
                    scores[name][start:end] = out

            else:  # AUCell
                Lloc = L
                for name, cols in set_cols_in_U.items():
                    m_all = cols.size
                    if m_all < min_genes_per_set:
                        continue
                    r = ranks_desc[:, cols]                           # (b × m)
                    mask = (r <= Lloc)
                    contrib = (Lloc - r + 1) * mask                   # triangular weights
                    raw = contrib.sum(axis=1)
                    m_prime = min(m_all, Lloc)
                    max_raw = m_prime * Lloc - (m_prime * (m_prime - 1)) / 2.0
                    out = np.zeros(raw.shape[0], float)
                    np.divide(raw, max_raw, out=out, where=max_raw > 0)  # normalize to [0,1]
                    scores[name][start:end] = out

        df = pd.DataFrame(scores, index=adata.obs_names)
        if obs_prefix:
            for k in df.columns:
                adata.obs[f"{obs_prefix}{k}"] = df[k].values
        return df if return_dataframe else df.values

    # ======================= VALUE-BASED BRANCH =======================
    # collect unique indices across all sets
    all_idx: List[int] = []
    for idx in present_idx.values():
        all_idx.extend(idx.tolist())
    uniq_idx = np.array(sorted(set(all_idx)), dtype=int)
    if uniq_idx.size == 0:
        raise ValueError("None of the provided genes are present in adata.var_names.")

    # slice (cells × uniq_genes), densify for percentiles
    Xu = X[:, uniq_idx].toarray() if sparse.issparse(X) else np.asarray(X)[:, uniq_idx]

    # optional log1p
    if log_transform:
        Xu = np.log1p(Xu)

    # per-gene clip + scale to [0,1]
    lo_p, hi_p = float(clip_percentiles[0]), float(clip_percentiles[1])
    if not (0.0 <= lo_p < hi_p <= 100.0):
        raise ValueError("clip_percentiles must satisfy 0 <= low < high <= 100.")
    lo = np.percentile(Xu, lo_p, axis=0)
    hi = np.percentile(Xu, hi_p, axis=0)
    Xu = np.clip(Xu, lo[None, :], hi[None, :])
    denom = (hi - lo)
    denom[denom <= 0] = np.inf
    Xu = (Xu - lo[None, :]) / denom[None, :]
    Xu = np.where(np.isfinite(Xu), Xu, 0.0)

    # compact column map
    compact = {j: k for k, j in enumerate(uniq_idx)}

    # row-wise helpers
    def _row_topk_mean(A: np.ndarray, k: int) -> np.ndarray:
        if k <= 0: return np.zeros(A.shape[0], dtype=float)
        k = min(k, A.shape[1])
        idx = A.shape[1] - k
        part = np.partition(A, idx, axis=1)
        return part[:, -k:].mean(axis=1)

    def _row_nonzero_mean(A: np.ndarray) -> np.ndarray:
        mask = (A > 0)
        num = A.sum(axis=1)
        den = mask.sum(axis=1)
        out = np.zeros(A.shape[0], float)
        np.divide(num, den, out=out, where=den > 0)
        return out

    # pick aggregator
    if isinstance(agg, str):
        agg_l = agg.lower()
        if agg_l == "mean":
            agg_fn = lambda A: A.mean(axis=1)
        elif agg_l == "median":
            agg_fn = lambda A: np.median(A, axis=1)
        elif agg_l == "sum":
            agg_fn = lambda A: A.sum(axis=1)
        elif agg_l == "nonzero_mean":
            agg_fn = _row_nonzero_mean
        elif agg_l == "top_p_mean":
            if top_p is None or not (0 < float(top_p) <= 1):
                raise ValueError("For agg='top_p_mean', provide 0 < top_p <= 1.")
            def agg_fn(A, _p=float(top_p)):
                k = max(1, int(np.ceil(_p * A.shape[1])))
                return _row_topk_mean(A, k)
        elif agg_l == "top_k_mean":
            if top_k is None or int(top_k) < 1:
                raise ValueError("For agg='top_k_mean', provide top_k >= 1.")
            agg_fn = lambda A, _k=int(top_k): _row_topk_mean(A, _k)
        else:
            raise ValueError("agg must be 'mean','median','sum','nonzero_mean','top_p_mean','top_k_mean' or a callable.")
    elif callable(agg):
        agg_fn = lambda A: agg(A)
    else:
        raise ValueError("Invalid 'agg' argument.")

    # aggregate per set
    out = {}
    for name, idx in present_idx.items():
        if idx.size < min_genes_per_set:
            out[name] = np.full(n_cells, np.nan, dtype=float)
            continue
        cols = [compact[j] for j in idx]
        A = Xu[:, cols]  # (cells × genes_in_set)
        out[name] = agg_fn(A)

    df = pd.DataFrame(out, index=adata.obs_names)
    if obs_prefix:
        for k in df.columns:
            adata.obs[f"{obs_prefix}{k}"] = df[k].values
    return df if return_dataframe else df.values

whoare_genes_neighbors

whoare_genes_neighbors(adata, gene, n_neighbors=5, key='gene', use='connectivities')

Retrieve the top n_neighbors nearest genes to gene, using a precomputed gene–gene kNN graph stored in adata.uns (as produced by build_gene_knn_graph).

This version handles both sparse‐CSR matrices and dense NumPy arrays in adata.uns.

Parameters

adata AnnData that has the following keys in adata.uns: - adata.uns[f"{key}_gene_index"] (np.ndarray of gene names, in order) - adata.uns[f"{key}_connectivities"] (CSR sparse matrix or dense ndarray) - adata.uns[f"{key}_distances"] (CSR sparse matrix or dense ndarray) gene Gene name (must appear in adata.uns[f"{key}_gene_index"]). n_neighbors Number of neighbors to return. key Prefix under which the kNN graph was stored. For example, if build_gene_knn_graph(...) was called with key="gene", the function will look for: - adata.uns["gene_gene_index"] - adata.uns["gene_connectivities"] - adata.uns["gene_distances"] use One of {"connectivities", "distances"}.
- If "connectivities", neighbors are ranked by descending connectivity weight.
- If "distances", neighbors are ranked by ascending distance (only among nonzero entries).

Returns

neighbors : List[str] A list of gene names (length ≤ n_neighbors) that are closest to gene.

Source code in src/pySingleCellNet/tools/gene.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def whoare_genes_neighbors(
    adata,
    gene: str,
    n_neighbors: int = 5,
    key: str = "gene",
    use: str = "connectivities"
):
    """
    Retrieve the top `n_neighbors` nearest genes to `gene`, using a precomputed gene–gene kNN graph
    stored in adata.uns (as produced by build_gene_knn_graph).

    This version handles both sparse‐CSR matrices and dense NumPy arrays in adata.uns.

    Parameters
    ----------
    adata
        AnnData that has the following keys in adata.uns:
          - adata.uns[f"{key}_gene_index"]      (np.ndarray of gene names, in order)
          - adata.uns[f"{key}_connectivities"]  (CSR sparse matrix or dense ndarray)
          - adata.uns[f"{key}_distances"]       (CSR sparse matrix or dense ndarray)
    gene
        Gene name (must appear in `adata.uns[f"{key}_gene_index"]`).
    n_neighbors
        Number of neighbors to return.
    key
        Prefix under which the kNN graph was stored. For example, if build_gene_knn_graph(...)
        was called with `key="gene"`, the function will look for:
          - adata.uns["gene_gene_index"]
          - adata.uns["gene_connectivities"]
          - adata.uns["gene_distances"]
    use
        One of {"connectivities", "distances"}.  
        - If "connectivities", neighbors are ranked by descending connectivity weight.  
        - If "distances", neighbors are ranked by ascending distance (only among nonzero entries).

    Returns
    -------
    neighbors : List[str]
        A list of gene names (length ≤ n_neighbors) that are closest to `gene`.
    """
    if use not in ("connectivities", "distances"):
        raise ValueError("`use` must be either 'connectivities' or 'distances'.")

    idx_key = f"{key}_gene_index"
    conn_key = f"{key}_connectivities"
    dist_key = f"{key}_distances"

    if idx_key not in adata.uns:
        raise ValueError(f"Could not find `{idx_key}` in adata.uns.")
    if conn_key not in adata.uns or dist_key not in adata.uns:
        raise ValueError(f"Could not find `{conn_key}` or `{dist_key}` in adata.uns.")

    gene_index = np.array(adata.uns[idx_key])
    if gene not in gene_index:
        raise KeyError(f"Gene '{gene}' not found in {idx_key}.")
    i = int(np.where(gene_index == gene)[0][0])

    # Select the appropriate stored matrix (could be sparse CSR or dense ndarray)
    mat_key = conn_key if use == "connectivities" else dist_key
    stored = adata.uns[mat_key]

    # If stored is a NumPy array, treat it as a dense full matrix:
    if isinstance(stored, np.ndarray):
        row_vec = stored[i].copy()
        # Exclude self
        if use == "connectivities":
            row_vec[i] = -np.inf
            order = np.argsort(-row_vec)  # descending
        else:
            row_vec[i] = np.inf
            order = np.argsort(row_vec)   # ascending
        topk = order[:n_neighbors]
        return [gene_index[j] for j in topk]

    # Otherwise, assume stored is a sparse matrix (CSR or similar):
    if not sparse.issparse(stored):
        raise TypeError(f"Expected CSR or ndarray for `{mat_key}`, got {type(stored)}.")

    row = stored.getrow(i)
    # For connectivities: sort nonzero entries by descending weight
    if use == "connectivities":
        cols = row.indices
        weights = row.data
        mask = cols != i
        cols = cols[mask]
        weights = weights[mask]
        if weights.size == 0:
            return []
        order = np.argsort(-weights)
        topk = cols[order][:n_neighbors]
        return [gene_index[j] for j in topk]

    # For distances: sort nonzero entries by ascending distance
    else:  # use == "distances"
        cols = row.indices
        dists = row.data
        mask = cols != i
        cols = cols[mask]
        dists = dists[mask]
        if dists.size == 0:
            return []
        order = np.argsort(dists)
        topk = cols[order][:n_neighbors]
        return [gene_index[j] for j in topk]