Skip to content

plotting functions

Various plotting functions.

bar_classifier_f1

bar_classifier_f1(adata, ground_truth='celltype', class_prediction='SCN_class', bar_height=0.8)

Plots a bar graph of F1 scores per class based on ground truth and predicted classifications.

Parameters:

  • adata (AnnData) –

    Annotated data matrix.

  • ground_truth (str, default: 'celltype' ) –

    The column name in adata.obs containing the true class labels. Defaults to "celltype".

  • class_prediction (str, default: 'SCN_class' ) –

    The column name in adata.obs containing the predicted class labels. Defaults to "SCN_class".

Returns:

  • None

Source code in src/pySingleCellNet/plotting/bar.py
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
def bar_classifier_f1(adata: AnnData, ground_truth: str = "celltype", class_prediction: str = "SCN_class", bar_height=0.8):
    """
    Plots a bar graph of F1 scores per class based on ground truth and predicted classifications.

    Args:
        adata (AnnData): Annotated data matrix.
        ground_truth (str, optional): The column name in `adata.obs` containing the true class labels. Defaults to "celltype".
        class_prediction (str, optional): The column name in `adata.obs` containing the predicted class labels. Defaults to "SCN_class".

    Returns:
        None
    """
    # Calculate F1 scores
    fscore = f1_score(
        adata.obs[ground_truth], 
        adata.obs[class_prediction], 
        average=None, 
        labels=adata.obs[ground_truth].cat.categories
    )

    # Get category names
    cates = list(adata.obs[ground_truth].cat.categories)

    # Create a DataFrame for F1 scores
    f1_scores_df = pd.DataFrame({
        'Class': cates,
        'F1-Score': fscore,
        'Count': adata.obs[ground_truth].value_counts().reindex(cates).values
    })

    # Get colors from the .uns dictionary
    f1_scores_df['Color'] = f1_scores_df['Class'].map(adata.uns['SCN_class_colors'])

    plt.rcParams['figure.constrained_layout.use'] = True
    # sns.set_theme(style="whitegrid")

    # fig, ax = plt.subplots(layout="constrained")
    fig, ax = plt.subplots()

    text_size = max(min(12 - len(cates) // 2, 10), 7) # Adjust text size
    # Plot the F1 scores with colors
    ax = f1_scores_df.plot.barh(
        x='Class', 
        y='F1-Score', 
        color=f1_scores_df['Color'], 
        legend=False,
        width=bar_height
    )

    ax.set_xlabel('F1-Score')
    ax.set_title('F1-Scores per Class')
    ax.set_xlim(0, 1.1)  # Set x-axis limits to ensure visibility of all bars

    # Add the number of observations per class as text within the barplot
    for i, (count, fscore) in enumerate(zip(f1_scores_df['Count'], f1_scores_df['F1-Score'])):
        ax.text(0.03, i, f"n = {count}", ha='left', va='center', color='white' if fscore >= 0.20 else 'black', fontsize=text_size)

    # plt.show()
    return fig

bar_compare_celltype_composition

bar_compare_celltype_composition(adata1, adata2, celltype_col, min_delta, colors=None, metric='log_ratio')

Compare cell type proportions between two AnnData objects and plot either log-ratio or differences for significant changes.

Parameters:

  • adata1 (AnnData) –

    First AnnData object.

  • adata2 (AnnData) –

    Second AnnData object.

  • celltype_col (str) –

    Column name in .obs indicating cell types.

  • min_delta (float) –

    Minimum absolute difference in percentages to include in the plot.

  • colors (dict, default: None ) –

    Dictionary with cell types as keys and colors as values for the bars.

  • metric (str, default: 'log_ratio' ) –

    "log_ratio" (default) or "difference" to specify which metric to plot.

Returns:

  • None

    Displays the bar plot.

Source code in src/pySingleCellNet/plotting/bar.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def bar_compare_celltype_composition(adata1, adata2, celltype_col, min_delta, colors=None, metric="log_ratio"):
    """
    Compare cell type proportions between two AnnData objects and plot either log-ratio or differences for significant changes.

    Parameters:
        adata1 (AnnData): First AnnData object.
        adata2 (AnnData): Second AnnData object.
        celltype_col (str): Column name in `.obs` indicating cell types.
        min_delta (float): Minimum absolute difference in percentages to include in the plot.
        colors (dict, optional): Dictionary with cell types as keys and colors as values for the bars.
        metric (str, optional): "log_ratio" (default) or "difference" to specify which metric to plot.

    Returns:
        None: Displays the bar plot.
    """
    # Compute cell type percentages for both AnnData objects
    def compute_percentages(adata, celltype_col):
        cell_counts = adata.obs[celltype_col].value_counts(normalize=True) * 100
        return cell_counts

    percentages_adata1 = compute_percentages(adata1, celltype_col)
    percentages_adata2 = compute_percentages(adata2, celltype_col)

    # Align indices to ensure comparison
    all_celltypes = percentages_adata1.index.union(percentages_adata2.index)
    percentages_adata1 = percentages_adata1.reindex(all_celltypes, fill_value=0)
    percentages_adata2 = percentages_adata2.reindex(all_celltypes, fill_value=0)

    # Compute the differences and log-ratio
    differences = percentages_adata1 - percentages_adata2
    log_ratios = np.log2((percentages_adata1 + 1e-6) / (percentages_adata2 + 1e-6))  # Avoid division by zero

    # Choose the metric to plot
    if metric == "log_ratio":
        plot_values = log_ratios
        xlabel = "Log2(Percent in A / Percent in B)"
        title = "Log2 Ratio of Cell Type Percentages"
    elif metric == "difference":
        plot_values = differences
        xlabel = "Difference in Percentages (A - B)"
        title = "Difference in Cell Type Percentages"
    else:
        raise ValueError("Invalid metric. Choose either 'log_ratio' or 'difference'.")

    # Filter cell types by the threshold
    significant_celltypes = plot_values[abs(differences) > min_delta].index

    # Prepare data for plotting
    plot_data = plot_values[significant_celltypes].sort_values()

    # Determine colors for the bars (align with sorted data)
    if colors:
        bar_colors = [tuple(map(float, colors[cell_type])) if cell_type in colors else 'skyblue' for cell_type in plot_data.index]
    else:
        bar_colors = 'skyblue'

    # Debugging: Log the bar colors and sorted cell types
    # print("Sorted Cell Types:", plot_data.index.tolist())
    # print("Bar Colors:", bar_colors)

    # Create the horizontal bar plot
    plt.figure(figsize=(10, 6))
    plot_data.plot(kind='barh', color=bar_colors, edgecolor='black')
    # plt.axvline(0, color='gray', linestyle='--', linewidth=1)
    plt.axvline(0, color='black', linewidth=1)
    # plt.title(title)
    plt.xlabel(xlabel)
    # plt.ylabel("Cell Types")
    plt.tight_layout()
    plt.show()

heatmap_clustering_eval

heatmap_clustering_eval(df, index_col='label_col', metrics=('n_clusters', 'unique_strict_genes', 'unique_naive_genes', 'frac_pairs_with_at_least_n_strict'), bar_sum_cols=('unique_strict_genes', 'unique_naive_genes'), cmap_eval='viridis', scale_eval='zscore', linewidth=0.5, value_fmt=None, title='Clustering parameter sweep (select best rows)', render=True, set_default_font=True)

Marsilea heatmap to guide clustering parameter selection.

Left: textual columns for parsed parameters (pc, k, res) Center: eval heatmap with raw numbers printed in cells (includes n_clusters as first column) Right: bar = unique_strict_genes + unique_naive_genes; rows sorted descending by this score

Row names (index strings) are NOT shown.

Source code in src/pySingleCellNet/plotting/heatmap.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def heatmap_clustering_eval(
    df: pd.DataFrame,
    index_col: str = "label_col",
    metrics=("n_clusters", "unique_strict_genes", "unique_naive_genes", "frac_pairs_with_at_least_n_strict"),
    bar_sum_cols=("unique_strict_genes", "unique_naive_genes"),
    cmap_eval: str = "viridis",
    scale_eval: str = "zscore",        # 'zscore' | 'minmax' | 'none' (per column)
    linewidth: float = 0.5,
    value_fmt: dict | None = None,     # e.g., {"frac_pairs_with_at_least_n_strict": "{:.2f}"}
    title: str = "Clustering parameter sweep (select best rows)",
    render: bool = True,
    set_default_font: bool = True,     # avoids 'pc/k/res' being misread as font family
):
    """
    Marsilea heatmap to guide clustering parameter selection.

    Left:   textual columns for parsed parameters (pc, k, res)
    Center: eval heatmap with raw numbers printed in cells (includes n_clusters as first column)
    Right:  bar = unique_strict_genes + unique_naive_genes; rows sorted descending by this score

    Row names (index strings) are NOT shown.
    """

    # Optional: force a sane default font to avoid font-family warnings
    if set_default_font:
        try:
            import matplotlib as mpl
            mpl.rcParams["font.family"] = "DejaVu Sans"
        except Exception:
            pass

    # --- Normalize selections and validate columns
    metrics      = list(metrics)
    bar_sum_cols = list(bar_sum_cols)

    need = {index_col, *metrics, *bar_sum_cols}
    missing = need - set(df.columns)
    if missing:
        raise KeyError(f"Missing columns: {sorted(missing)}; available={sorted(df.columns)}")

    # --- Unique rows by index_col & keep needed cols
    base = (
        df[[index_col, *metrics]]
        .drop_duplicates(subset=[index_col])
        .set_index(index_col)
        .copy()
    )

    # --- selection score = sum of chosen gene-count columns
    score = (
        df[[index_col, *bar_sum_cols]]
        .drop_duplicates(subset=[index_col])
        .set_index(index_col)
        .sum(axis=1)
    )

    # --- order rows by descending score
    order = score.sort_values(ascending=False).index
    base  = base.loc[order]
    score = score.loc[order]

    # --- raw values to print in cells
    X_raw = base.loc[:, metrics].astype(float)

    if value_fmt is None:
        value_fmt = {
            col: "{:.3f}" if "frac" in col or "ratio" in col else "{:.0f}"
            for col in metrics
        }
        if "n_clusters" in X_raw.columns:
            value_fmt["n_clusters"] = "{:.0f}"

    text_matrix = np.array(
        [[value_fmt.get(c, "{:.3f}").format(v) for c, v in zip(X_raw.columns, row)]
         for row in X_raw.values]
    )

    # --- color matrix for evals heatmap (column-wise scaling)
    X_color = X_raw.copy()
    if scale_eval == "zscore":
        X_color = (X_color - X_color.mean(axis=0)) / X_color.std(axis=0).replace(0, np.nan)
        X_color = X_color.fillna(0.0)
    elif scale_eval == "minmax":
        rng = (X_color.max(axis=0) - X_color.min(axis=0)).replace(0, np.nan)
        X_color = (X_color - X_color.min(axis=0)) / rng
        X_color = X_color.fillna(0.0)
    elif scale_eval != "none":
        raise ValueError("scale_eval must be one of {'zscore','minmax','none'}")

    # --- parse pc / k / res from index_col strings like: "autoc_pc20_pct1.00_s01_k10_res0.05"
    def _extract_num(pat, s, cast=float):
        m = re.search(pat, s)
        return cast(m.group(1)) if m else np.nan

    labels_series = order.to_series()
    params_df = pd.DataFrame({
        "pc":  labels_series.map(lambda s: _extract_num(r"pc(\d+)", s, int)),
        "k":   labels_series.map(lambda s: _extract_num(r"k(\d+)", s, int)),
        "res": labels_series.map(lambda s: _extract_num(r"res([0-9]*\.?[0-9]+)", s, float)),
    }, index=order)

    # --- Marsilea plotting
    import marsilea as ma
    import marsilea.plotter as mp

    # Evals heatmap (no row name labels)
    h_eval = ma.Heatmap(
        X_color.values,
        linewidth=linewidth,
        label="Evals",
        cmap=cmap_eval,
    )
    h_eval.add_top(mp.Labels(list(X_color.columns)))      # show metric names, not row labels
    h_eval.add_layer(mp.TextMesh(text_matrix))            # overlay raw numbers

    # Right-side bar (with clear label & padding)
    h_eval.add_right(
        mp.Numbers(score.values, label="unique_strict + unique_naive"),
        size=0.9,
        pad=0.15,   # generous so the label is clear
    )

    # Title & legends
    h_eval.add_legends()
    h_eval.add_title(title)

    # --- LEFT textual parameter columns: pc | k | res (aligned with rows)
    # Use label + label_props so 'pc/k/res' are treated as titles, not font families.
    pc_col = mp.Labels(
        params_df["pc"].astype("Int64").astype(str).replace("<NA>", "NA"),
        label="pc",
        label_loc="top",
        label_props={"family": "DejaVu Sans", "weight": "bold"},
    )
    k_col = mp.Labels(
        params_df["k"].astype("Int64").astype(str).replace("<NA>", "NA"),
        label="k",
        label_loc="top",
        label_props={"family": "DejaVu Sans", "weight": "bold"},
    )
    res_col = mp.Labels(
        params_df["res"].map(lambda x: f"{x:.2f}" if pd.notna(x) else "NA"),
        label="res",
        label_loc="top",
        label_props={"family": "DejaVu Sans", "weight": "bold"},
    )

    # Attach the three text columns on the left (sizes tuned for readability)
    h_eval.add_left(pc_col,  size=0.7, pad=0.05)
    h_eval.add_left(k_col,   size=0.7, pad=0.05)
    h_eval.add_left(res_col, size=0.9, pad=0.10)

    if render:
        h_eval.render()

    return {
        "canvas": h_eval,
        "row_order": order.to_list(),
        "score": score,
        "params": params_df,
    }

heatmap_gsea

heatmap_gsea(gmat, clean_signatures=False, clean_cells=False, column_colors=None, figsize=(8, 6), label_font_size=7, cbar_pos=[0.2, 0.92, 0.6, 0.02], dendro_ratio=(0.3, 0.1), cbar_title='NES', col_cluster=False, row_cluster=False)

Generates a heatmap with hierarchical clustering for gene set enrichment analysis (GSEA) results.

Parameters:

  • gmat (DataFrame) –

    A matrix of GSEA scores with gene sets as rows and samples as columns.

  • clean_signatures (bool, default: False ) –

    If True, removes gene sets with zero enrichment scores across all samples. Defaults to False.

  • clean_cells (bool, default: False ) –

    If True, removes samples with zero enrichment scores across all gene sets. Defaults to False.

  • column_colors (Series or DataFrame, default: None ) –

    Colors to annotate columns, typically representing sample groups. Defaults to None.

  • figsize (tuple, default: (8, 6) ) –

    Figure size in inches (width, height). Defaults to (8, 6).

  • label_font_size (int, default: 7 ) –

    Font size for axis and colorbar labels. Defaults to 7.

  • cbar_pos (list, default: [0.2, 0.92, 0.6, 0.02] ) –

    Position of the colorbar [left, bottom, width, height]. Defaults to [0.2, 0.92, 0.6, 0.02] for a horizontal top placement.

  • dendro_ratio (tuple, default: (0.3, 0.1) ) –

    Proportion of the figure allocated to the row and column dendrograms. Defaults to (0.3, 0.1).

  • cbar_title (str, default: 'NES' ) –

    Title of the colorbar. Defaults to 'NES'.

  • col_cluster (bool, default: False ) –

    If True, performs hierarchical clustering on columns. Defaults to False.

  • row_cluster (bool, default: False ) –

    If True, performs hierarchical clustering on rows. Defaults to False.

Returns:

  • None

Displays

A heatmap with optional hierarchical clustering and a horizontal colorbar at the top.

Source code in src/pySingleCellNet/plotting/heatmap.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
def heatmap_gsea(
    gmat,
    clean_signatures=False,
    clean_cells=False,
    column_colors=None,
    figsize=(8, 6),
    label_font_size=7,
    cbar_pos=[0.2, 0.92, 0.6, 0.02],  # Positioned at the top
    dendro_ratio=(0.3, 0.1),
    cbar_title='NES',
    col_cluster=False,
    row_cluster=False,
):
    """
    Generates a heatmap with hierarchical clustering for gene set enrichment analysis (GSEA) results.

    Args:
        gmat (pd.DataFrame):
            A matrix of GSEA scores with gene sets as rows and samples as columns.
        clean_signatures (bool, optional):
            If True, removes gene sets with zero enrichment scores across all samples. Defaults to False.
        clean_cells (bool, optional):
            If True, removes samples with zero enrichment scores across all gene sets. Defaults to False.
        column_colors (pd.Series or pd.DataFrame, optional):
            Colors to annotate columns, typically representing sample groups. Defaults to None.
        figsize (tuple, optional):
            Figure size in inches (width, height). Defaults to (8, 6).
        label_font_size (int, optional):
            Font size for axis and colorbar labels. Defaults to 7.
        cbar_pos (list, optional):
            Position of the colorbar [left, bottom, width, height]. Defaults to [0.2, 0.92, 0.6, 0.02] for a horizontal top placement.
        dendro_ratio (tuple, optional):
            Proportion of the figure allocated to the row and column dendrograms. Defaults to (0.3, 0.1).
        cbar_title (str, optional):
            Title of the colorbar. Defaults to 'NES'.
        col_cluster (bool, optional):
            If True, performs hierarchical clustering on columns. Defaults to False.
        row_cluster (bool, optional):
            If True, performs hierarchical clustering on rows. Defaults to False.

    Returns:
        None

    Displays:
        A heatmap with optional hierarchical clustering and a horizontal colorbar at the top.
    """
    gsea_matrix = gmat.copy()
    if clean_cells:
        gsea_matrix = gsea_matrix.loc[:, gsea_matrix.sum(0) != 0]
    if clean_signatures:
        gsea_matrix = gsea_matrix.loc[gsea_matrix.sum(1) != 0, :]

    # plt.figure(constrained_layout=True)
    ax = sns.clustermap(
        data=gsea_matrix,
        cmap=Roma_20.mpl_colormap.reversed(),
        center=0,
        yticklabels=1,
        xticklabels=1,
        linewidth=.05,
        linecolor='white',
        method='average',
        metric='euclidean',
        dendrogram_ratio=dendro_ratio,
        col_colors=column_colors,
        figsize=figsize,
        row_cluster=row_cluster,
        col_cluster=col_cluster,
        cbar_pos=cbar_pos,
        cbar_kws={'orientation': 'horizontal'}
    )

    ax.ax_cbar.set_title(cbar_title, fontsize=label_font_size, pad=10)
    ax.ax_cbar.tick_params(labelsize=label_font_size, direction='in')

    # Adjust tick labels and heatmap appearance
    ax.ax_row_dendrogram.set_visible(False)
    ax.ax_col_dendrogram.set_visible(False)
    ax.ax_heatmap.set_yticklabels(ax.ax_heatmap.get_ymajorticklabels(), fontsize=label_font_size)
    ax.ax_heatmap.set_xticklabels(ax.ax_heatmap.get_xmajorticklabels(), rotation=45, ha="right", rotation_mode="anchor", fontsize=label_font_size)

    # plt.subplots_adjust(top=0.85)
    plt.show()

heatmap_scores

heatmap_scores(adata, groupby, vmin=0, vmax=1, obsm_name='SCN_score', order_by=None, figure_subplot_bottom=0.4)

Plots a heatmap of single cell scores, grouping cells according to a specified .obs column and optionally ordering within each group.

Parameters:

  • adata (AnnData) –

    An AnnData object containing the single cell data.

  • groupby (str) –

    The name of the column in .obs used for grouping cells in the heatmap.

  • vmin (float, default: 0 ) –

    Minimum value for color scaling. Defaults to 0.

  • vmax (float, default: 1 ) –

    Maximum value for color scaling. Defaults to 1.

  • obsm_name (str, default: 'SCN_score' ) –

    The key in .obsm to retrieve the matrix for plotting. Defaults to 'SCN_score'.

  • order_by (str, default: None ) –

    The name of the column in .obs used for ordering cells within each group. Defaults to None.

Returns:

  • None

    The function plots a heatmap and does not return any value.

Source code in src/pySingleCellNet/plotting/heatmap.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def heatmap_scores(
    adata: AnnData, 
    groupby: str, 
    vmin: float = 0, 
    vmax: float = 1, 
    obsm_name='SCN_score', 
    order_by: str = None,
    figure_subplot_bottom: float = 0.4
):
    """
    Plots a heatmap of single cell scores, grouping cells according to a specified .obs column and optionally ordering within each group.

    Args:
        adata (AnnData): An AnnData object containing the single cell data.
        groupby (str): The name of the column in .obs used for grouping cells in the heatmap.
        vmin (float, optional): Minimum value for color scaling. Defaults to 0.
        vmax (float, optional): Maximum value for color scaling. Defaults to 1.
        obsm_name (str, optional): The key in .obsm to retrieve the matrix for plotting. Defaults to 'SCN_score'.
        order_by (str, optional): The name of the column in .obs used for ordering cells within each group. Defaults to None.

    Returns:
        None: The function plots a heatmap and does not return any value.
    """
    # Create a temporary AnnData object with the scores matrix and all original observations
    adTemp = AnnData(adata.obsm[obsm_name], obs=adata.obs)

    # Determine sorting criteria
    if order_by is not None:
        sort_criteria = [groupby, order_by]
    else:
        sort_criteria = [groupby]

    # Determine the order of cells by sorting based on the criteria
    sorted_order = adTemp.obs.sort_values(by=sort_criteria).index

    # Reorder adTemp according to the sorted order
    adTemp = adTemp[sorted_order, :]

    # Set figure dimensions and subplot adjustments
    # fsize = [5, 6]
    # plt.rcParams['figure.subplot.bottom'] = figure_subplot_bottom

    # Plot the heatmap with the sorted and grouped data
    with plt.rc_context({'figure.subplot.bottom': figure_subplot_bottom}):
        sc.pl.heatmap(adTemp, adTemp.var_names.values, groupby=groupby, cmap=Batlow_20.mpl_colormap,dendrogram=False, swap_axes=True, vmin=vmin, vmax=vmax)

make_bivariate_cmap

make_bivariate_cmap(c00='#f0f0f0', c10='#e31a1c', c01='#1f78b4', c11='#ffff00', n=128)

Create a bivariate colormap by bilinear‐interpolating four corner colors.

This builds an (n × n) grid of RGB colors, blending smoothly between the specified corner colors: - c00 at (low, low) - c10 at (high, low) - c01 at (low, high) - c11 at (high, high)

Parameters:

  • c00 (str, default: '#f0f0f0' ) –

    Matplotlib color spec (hex, name, or RGB tuple) for the low/low corner.

  • c10 (str, default: '#e31a1c' ) –

    Color for the high/low corner.

  • c01 (str, default: '#1f78b4' ) –

    Color for the low/high corner.

  • c11 (str, default: '#ffff00' ) –

    Color for the high/high corner.

  • n (int, default: 128 ) –

    Resolution per axis. The total length of the returned colormap is n*n.

Returns:

  • ListedColormap ( ListedColormap ) –

    A colormap with n*n entries blending between the four corners.

Source code in src/pySingleCellNet/plotting/helpers.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def make_bivariate_cmap(
    c00: str = "#f0f0f0",
    c10: str = "#e31a1c",
    c01: str = "#1f78b4",
    c11: str = "#ffff00",
    n: int = 128
) -> ListedColormap:
    """Create a bivariate colormap by bilinear‐interpolating four corner colors.

    This builds an (n × n) grid of RGB colors, blending smoothly between
    the specified corner colors:
      - c00 at (low, low)
      - c10 at (high, low)
      - c01 at (low, high)
      - c11 at (high, high)

    Args:
        c00: Matplotlib color spec (hex, name, or RGB tuple) for the low/low corner.
        c10: Color for the high/low corner.
        c01: Color for the low/high corner.
        c11: Color for the high/high corner.
        n:   Resolution per axis. The total length of the returned colormap is n*n.

    Returns:
        ListedColormap: A colormap with n*n entries blending between the four corners.
    """
    # Convert corner colors to RGB arrays
    corners = {
        (0, 0): np.array(to_rgb(c00)),
        (1, 0): np.array(to_rgb(c10)),
        (0, 1): np.array(to_rgb(c01)),
        (1, 1): np.array(to_rgb(c11)),
    }

    # Build an (n, n, 3) grid by bilinear interpolation
    lut = np.zeros((n, n, 3), dtype=float)
    xs = np.linspace(0, 1, n)
    ys = np.linspace(0, 1, n)
    for j, y in enumerate(ys):
        for i, x in enumerate(xs):
            lut[j, i] = (
                corners[(0, 0)] * (1 - x) * (1 - y) +
                corners[(1, 0)] * x       * (1 - y) +
                corners[(0, 1)] * (1 - x) * y       +
                corners[(1, 1)] * x       * y
            )

    # Flatten to (n*n, 3) and return as a ListedColormap
    return ListedColormap(lut.reshape(n * n, 3))

scatter_genes_oneper

scatter_genes_oneper(adata, genes, embedding_key='X_spatial', spot_size=2, alpha=0.9, clip_percentiles=(0, 99.5), log_transform=True, cmap='Reds', figsize=None, panel_width=4.0, n_rows=1)

Plot expression of multiple genes on a 2D embedding arranged in a grid.

Each gene is optionally log-transformed, percentile-clipped, and rescaled to [0,1]. Cells are plotted on the embedding, colored by expression, with highest values drawn on top. A single colorbar is placed to the right of the grid. If figsize is None, each panel has width panel_width and height proportional to the embedding's aspect ratio; total figure dims reflect n_rows and computed columns.

Parameters:

  • adata (AnnData) –

    AnnData containing the embedding in adata.obsm[embedding_key].

  • embedding_key (str, default: 'X_spatial' ) –

    Key in .obsm for an (n_obs, 2) coordinate array.

  • genes (Sequence[str]) –

    List of gene names to plot (must be in adata.var_names).

  • spot_size (float, default: 2 ) –

    Marker size for scatter plots. Default 2.

  • alpha (float, default: 0.9 ) –

    Transparency for markers. Default 0.9.

  • clip_percentiles (tuple, default: (0, 99.5) ) –

    (low_pct, high_pct) to clip expression before rescaling.

  • log_transform (bool, default: True ) –

    If True, apply np.log1p to raw expression.

  • cmap (Union[str, Colormap], default: 'Reds' ) –

    Colormap or name for all plots.

  • figsize (Optional[tuple], default: None ) –

    (width, height) of entire figure. If None, computed from panel_width, n_rows, and embedding aspect ratio.

  • panel_width (float, default: 4.0 ) –

    Width (in inches) of each panel when figsize is None.

  • n_rows (int, default: 1 ) –

    Number of rows in the grid. Default 1.

Raises:

  • ValueError

    If embedding is missing/malformed or genes not found.

Source code in src/pySingleCellNet/plotting/spatial.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def scatter_genes_oneper(
    adata: AnnData,
    genes: Sequence[str],
    embedding_key: str = "X_spatial",
    spot_size: float = 2,
    alpha: float = 0.9,
    clip_percentiles: tuple = (0, 99.5),
    log_transform: bool = True,
    cmap: Union[str, plt.Colormap] = 'Reds',
    figsize: Optional[tuple] = None,
    panel_width: float = 4.0,
    n_rows: int = 1
) -> None:
    """Plot expression of multiple genes on a 2D embedding arranged in a grid.

    Each gene is optionally log-transformed, percentile-clipped, and rescaled to [0,1].
    Cells are plotted on the embedding, colored by expression, with highest values
    drawn on top. A single colorbar is placed to the right of the grid.
    If `figsize` is None, each panel has width `panel_width` and height
    proportional to the embedding's aspect ratio; total figure dims reflect
    `n_rows` and computed columns.

    Args:
        adata: AnnData containing the embedding in `adata.obsm[embedding_key]`.
        embedding_key: Key in `.obsm` for an (n_obs, 2) coordinate array.
        genes: List of gene names to plot (must be in `adata.var_names`).
        spot_size: Marker size for scatter plots. Default 2.
        alpha: Transparency for markers. Default 0.9.
        clip_percentiles: (low_pct, high_pct) to clip expression before rescaling.
        log_transform: If True, apply `np.log1p` to raw expression.
        cmap: Colormap or name for all plots.
        figsize: (width, height) of entire figure. If None, computed from
            `panel_width`, `n_rows`, and embedding aspect ratio.
        panel_width: Width (in inches) of each panel when `figsize` is None.
        n_rows: Number of rows in the grid. Default 1.

    Raises:
        ValueError: If embedding is missing/malformed or genes not found.
    """
    # Helper to extract array
    def _get_array(x):
        return x.toarray().flatten() if hasattr(x, 'toarray') else x.flatten()

    coords = adata.obsm.get(embedding_key)
    if coords is None or coords.ndim != 2 or coords.shape[1] < 2:
        raise ValueError(f"adata.obsm['{embedding_key}'] must be an (n_obs, 2) array.")
    x_vals, y_vals = coords[:, 0], coords[:, 1]

    n_genes = len(genes)
    cols = math.ceil(n_genes / n_rows)
    # Compute figsize if not provided
    if figsize is None:
        x_range = x_vals.max() - x_vals.min()
        y_range = y_vals.max() - y_vals.min()
        aspect = x_range / y_range if y_range > 0 else 1.0
        panel_height = panel_width / aspect
        fig_width = panel_width * cols
        fig_height = panel_height * n_rows
    else:
        fig_width, fig_height = figsize

    fig, axes = plt.subplots(n_rows, cols, figsize=(fig_width, fig_height), squeeze=False)
    axes_flat = axes.flatten()

    scatters = []
    for idx, gene in enumerate(genes):
        ax = axes_flat[idx]
        if gene not in adata.var_names:
            raise ValueError(f"Gene '{gene}' not found in adata.var_names.")
        vals = _get_array(adata[:, gene].X)
        if log_transform:
            vals = np.log1p(vals)
        lo, hi = np.percentile(vals, clip_percentiles)
        clipped = np.clip(vals, lo, hi)
        norm = (clipped - lo) / (hi - lo) if hi > lo else np.zeros_like(clipped)

        order = np.argsort(norm)
        sc = ax.scatter(
            x_vals[order],
            y_vals[order],
            c=norm[order],
            cmap=cmap,
            s=spot_size,
            alpha=alpha,
            vmin=0, vmax=1
        )
        ax.set_title(gene)
        ax.set_xticks([]); ax.set_yticks([])
        scatters.append(sc)

    # Turn off unused axes
    for j in range(len(genes), n_rows*cols):
        axes_flat[j].axis('off')

    # Adjust subplots to make room for colorbar
    fig.subplots_adjust(right=0.85)

    # Colorbar axis on the right, spanning full height (15% margin)
    cbar_ax = fig.add_axes([0.88, 0.05, 0.02, 0.9])
    cb = fig.colorbar(scatters[0], cax=cbar_ax)
    cb.set_label('normalized expression')

    plt.tight_layout(rect=[0, 0, 0.85, 1])
    plt.show()

scatter_qc_adata

scatter_qc_adata(adata, title_suffix='')

Creates a figure with two scatter plot panels for visualizing data from an AnnData object.

The first panel shows 'total_counts' vs 'n_genes_by_counts', colored by 'pct_counts_mt'. The second panel shows 'n_genes_by_counts' vs 'pct_counts_mt'. An optional title suffix can be added to customize the axis titles.

Parameters:

  • adata (AnnData) –

    The AnnData object containing the dataset. Must contain 'total_counts', 'n_genes_by_counts', and 'pct_counts_mt' in adata.obs.

  • title_suffix (str, default: '' ) –

    A string to append to the axis titles, useful for specifying experimental conditions (e.g., "C11 day 2"). Defaults to an empty string.

Returns:

  • None

    The function displays a matplotlib figure with two scatter plots.

Example

plot_scatter_with_contours(adata, title_suffix="C11 day 2")

Source code in src/pySingleCellNet/plotting/scatter.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def scatter_qc_adata(adata, title_suffix=""):
    """
    Creates a figure with two scatter plot panels for visualizing data from an AnnData object.

    The first panel shows 'total_counts' vs 'n_genes_by_counts', colored by 'pct_counts_mt'.
    The second panel shows 'n_genes_by_counts' vs 'pct_counts_mt'. An optional title suffix
    can be added to customize the axis titles.

    Args:
        adata (AnnData): The AnnData object containing the dataset.
                         Must contain 'total_counts', 'n_genes_by_counts', and 'pct_counts_mt' in `adata.obs`.
        title_suffix (str, optional): A string to append to the axis titles, useful for specifying
                                      experimental conditions (e.g., "C11 day 2"). Defaults to an empty string.

    Returns:
        None: The function displays a matplotlib figure with two scatter plots.

    Example:
        >>> plot_scatter_with_contours(adata, title_suffix="C11 day 2")
    """

    # Extract necessary columns from the adata object
    total_counts = adata.obs['total_counts']
    n_genes_by_counts = adata.obs['n_genes_by_counts']
    pct_counts_mt = adata.obs['pct_counts_mt']

    # Create a figure with two subplots (1 row, 2 columns)
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    # First subplot: total_counts vs n_genes_by_counts, colored by pct_counts_mt
    scatter1 = axes[0].scatter(total_counts, n_genes_by_counts, c=pct_counts_mt, cmap='viridis', alpha=0.5, s=1)
    axes[0].set_xlabel(f'Total Counts ({title_suffix})')
    axes[0].set_ylabel(f'Number of Genes by Counts ({title_suffix})')
    axes[0].set_title(f'Total Counts vs Genes ({title_suffix})')
    # Add a colorbar
    fig.colorbar(scatter1, ax=axes[0], label='% Mito')

    # Second subplot: n_genes_by_counts vs pct_counts_mt
    scatter2 = axes[1].scatter(n_genes_by_counts, pct_counts_mt, alpha=0.5, s=1)
    axes[1].set_xlabel(f'Number of Genes by Counts ({title_suffix})')
    axes[1].set_ylabel(f'% Mito ({title_suffix})')
    axes[1].set_title(f'Genes vs % Mito ({title_suffix})')

    # Adjust layout to avoid overlap
    plt.tight_layout()

    # Show the plot
    plt.show()

spatial_contours

spatial_contours(adata, genes, spatial_key='spatial', summary_func=np.mean, spot_size=30, alpha=0.8, log_transform=True, clip_percentiles=(1, 99), cmap='viridis', contour_kwargs=None, scatter_kwargs=None)

Scatter spatial expression of one or more genes with smooth contour overlay.

If multiple genes are provided, each is preprocessed (log1p → clip → normalize), then combined per cell via summary_func (e.g. mean, sum, max) on the normalized values. A smooth contour of the summarized signal is overlaid onto the spatial scatter.

Parameters:

  • adata (AnnData) –

    AnnData with spatial coordinates in adata.obsm[spatial_key].

  • genes (Union[str, Sequence[str]]) –

    Single gene name or list of gene names to plot (must be in adata.var_names).

  • spatial_key (str, default: 'spatial' ) –

    Key in .obsm for an (n_obs, 2) coords array.

  • summary_func (Callable[[ndarray], ndarray], default: mean ) –

    Function to combine multiple normalized gene arrays (takes an (n_obs, n_genes) array, returns length-n_obs array). Defaults to np.mean.

  • spot_size (float, default: 30 ) –

    Scatter marker size.

  • alpha (float, default: 0.8 ) –

    Scatter alpha transparency.

  • log_transform (bool, default: True ) –

    If True, apply np.log1p to raw expression before clipping.

  • clip_percentiles (tuple, default: (1, 99) ) –

    Tuple (low_pct, high_pct) percentiles to clip each gene.

  • cmap (str, default: 'viridis' ) –

    Colormap name for the scatter (e.g. 'viridis').

  • contour_kwargs (dict, default: None ) –

    Dict of parameters for smoothing & contouring: - levels: int or list of levels (default 6) - grid_res: int grid resolution (default 200) - smooth_sigma: float Gaussian blur sigma (default 2) - contour_kwargs: dict of line style kwargs (default {'colors':'k','linewidths':1})

  • scatter_kwargs (dict, default: None ) –

    Extra kwargs passed to ax.scatter.

Raises:

  • ValueError

    If any gene is missing or spatial coords are malformed.

Source code in src/pySingleCellNet/plotting/spatial.py
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
def spatial_contours(
    adata: AnnData,
    genes: Union[str, Sequence[str]],
    spatial_key: str = 'spatial',
    summary_func: Callable[[np.ndarray], np.ndarray] = np.mean,
    spot_size: float = 30,
    alpha: float = 0.8,
    log_transform: bool = True,
    clip_percentiles: tuple = (1, 99),
    cmap: str = 'viridis',
    contour_kwargs: dict = None,
    scatter_kwargs: dict = None
) -> None:
    """Scatter spatial expression of one or more genes with smooth contour overlay.

    If multiple genes are provided, each is preprocessed (log1p → clip
    → normalize), then combined per cell via `summary_func` (e.g. mean, sum,
    max) on the normalized values. A smooth contour of the summarized signal
    is overlaid onto the spatial scatter.

    Args:
        adata: AnnData with spatial coordinates in `adata.obsm[spatial_key]`.
        genes: Single gene name or list of gene names to plot (must be in `adata.var_names`).
        spatial_key: Key in `.obsm` for an (n_obs, 2) coords array.
        summary_func: Function to combine multiple normalized gene arrays
            (takes an (n_obs, n_genes) array, returns length-n_obs array).
            Defaults to `np.mean`.
        spot_size: Scatter marker size.
        alpha: Scatter alpha transparency.
        log_transform: If True, apply `np.log1p` to raw expression before clipping.
        clip_percentiles: Tuple `(low_pct, high_pct)` percentiles to clip each gene.
        cmap: Colormap name for the scatter (e.g. 'viridis').
        contour_kwargs: Dict of parameters for smoothing & contouring:
            - levels: int or list of levels (default 6)
            - grid_res: int grid resolution (default 200)
            - smooth_sigma: float Gaussian blur sigma (default 2)
            - contour_kwargs: dict of line style kwargs (default {'colors':'k','linewidths':1})
        scatter_kwargs: Extra kwargs passed to `ax.scatter`.

    Raises:
        ValueError: If any gene is missing or spatial coords are malformed.
    """
    # ensure genes is list
    gene_list = [genes] if isinstance(genes, str) else list(genes)
    for g in gene_list:
        if g not in adata.var_names:
            raise ValueError(f"Gene '{g}' not found in adata.var_names.")

    # helper to extract numpy
    def _get_array(x):
        return x.toarray().flatten() if hasattr(x, 'toarray') else x.flatten()

    # preprocess each gene: extract, log1p, clip, normalize to [0,1]
    normed = []
    for g in gene_list:
        vals = _get_array(adata[:, g].X)
        if log_transform:
            vals = np.log1p(vals)
        lo, hi = np.percentile(vals, clip_percentiles)
        vals = np.clip(vals, lo, hi)
        normed.append((vals - lo) / (hi - lo) if hi > lo else np.zeros_like(vals))
    # stack into (n_obs, n_genes)
    M = np.column_stack(normed)
    # summarize across genes
    summary = summary_func(M, axis=1)

    # fetch spatial coords
    coords = adata.obsm.get(spatial_key)
    if coords is None or coords.shape[1] < 2:
        raise ValueError(f"adata.obsm['{spatial_key}'] must be an (n_obs, 2) array.")
    x, y = coords[:, 0], coords[:, 1]

    # scatter
    fig, ax = plt.subplots(figsize=(6, 6))
    sc_kw = {"c": summary, "cmap": cmap, "s": spot_size, "alpha": alpha}
    if scatter_kwargs:
        sc_kw.update(scatter_kwargs)
    sc = ax.scatter(x, y, **sc_kw)
    ax.set_aspect('equal')
    title = (
        gene_list[0] if len(gene_list) == 1
        else f"{len(gene_list)} genes ({summary_func.__name__})"
    )
    ax.set_title(f"Spatial expression: {title}")
    ax.set_xlabel('x'); ax.set_ylabel('y')
    fig.colorbar(sc, ax=ax, label="summarized (normalized)")

    # smooth + contour
    # default contour params
    ck = {
        "levels": 6,
        "grid_res": 200,
        "smooth_sigma": 2,
        "contour_kwargs": {"colors": "k", "linewidths": 1}
    }
    if contour_kwargs:
        ck.update(contour_kwargs)
    _smooth_contour(
        x, y, summary,
        levels=ck["levels"],
        grid_res=ck["grid_res"],
        smooth_sigma=ck["smooth_sigma"],
        contour_kwargs=ck["contour_kwargs"]
    )
    plt.tight_layout()
    plt.show()

spatial_two_genes

spatial_two_genes(adata, gene1, gene2, cmap, spot_size=2, alpha=0.9, spatial_key='X_spatial', log_transform=False, clip_percentiles=(0, 99.5), priority_metric='sum', show_xcoords=False, show_ycoords=False, show_bbox=False, show_legend=True, width_ratios=(10, 1))

Plot two‐gene spatial expression with a bivariate colormap.

Parameters:

  • adata (AnnData) –

    AnnData with spatial coords in adata.obsm[spatial_key].

  • gene1 (str) –

    First gene name (must be in adata.var_names).

  • gene2 (str) –

    Second gene name.

  • cmap (ListedColormap) –

    Bivariate colormap from make_bivariate_cmap (n×n LUT).

  • spot_size (float, default: 2 ) –

    Scatter point size.

  • alpha (float, default: 0.9 ) –

    Point alpha transparency.

  • spatial_key (str, default: 'X_spatial' ) –

    Key in adata.obsm for an (n_obs, 2) coords array.

  • log_transform (bool, default: False ) –

    If True, apply np.log1p to raw expression.

  • clip_percentiles (tuple, default: (0, 99.5) ) –

    Tuple (low_pct, high_pct) to clip each gene.

  • priority_metric (str, default: 'sum' ) –

    Which metric to sort drawing order by: - 'sum': u + v (default) - 'gene1': u only - 'gene2': v only

  • show_xcoords (bool, default: False ) –

    Whether to display x-axis ticks and labels.

  • show_ycoords (bool, default: False ) –

    Whether to display y-axis ticks and labels.

  • show_bbox (bool, default: False ) –

    Whether to display the bounding box (spines).

  • show_legend (bool, default: True ) –

    Whether to display the legend/colorbar.

  • width_ratios (Tuple[float, float], default: (10, 1) ) –

    2‐tuple giving the relative widths of (scatter_panel, legend_panel). Defaults to (3,1).

Raises:

  • ValueError

    If spatial coords are missing/malformed or if priority_metric is invalid.

Source code in src/pySingleCellNet/plotting/spatial.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def spatial_two_genes(
    adata: AnnData,
    gene1: str,
    gene2: str,
    cmap: ListedColormap,
    spot_size: float = 2,
    alpha: float = 0.9,
    spatial_key: str = 'X_spatial',
    log_transform: bool = False,
    clip_percentiles: tuple = (0, 99.5),
    priority_metric: str = 'sum',
    show_xcoords: bool = False,
    show_ycoords: bool = False,
    show_bbox: bool = False,
    show_legend: bool = True,
    width_ratios: Tuple[float, float] = (10, 1)
) -> None:
    """Plot two‐gene spatial expression with a bivariate colormap.

    Args:
        adata: AnnData with spatial coords in `adata.obsm[spatial_key]`.
        gene1: First gene name (must be in `adata.var_names`).
        gene2: Second gene name.
        cmap: Bivariate colormap from `make_bivariate_cmap` (n×n LUT).
        spot_size: Scatter point size.
        alpha: Point alpha transparency.
        spatial_key: Key in `adata.obsm` for an (n_obs, 2) coords array.
        log_transform: If True, apply `np.log1p` to raw expression.
        clip_percentiles: Tuple `(low_pct, high_pct)` to clip each gene.
        priority_metric: Which metric to sort drawing order by:
            - 'sum': u + v (default)
            - 'gene1': u only
            - 'gene2': v only
        show_xcoords: Whether to display x-axis ticks and labels.
        show_ycoords: Whether to display y-axis ticks and labels.
        show_bbox: Whether to display the bounding box (spines).
        show_legend: Whether to display the legend/colorbar.
        width_ratios: 2‐tuple giving the relative widths of
                  (scatter_panel, legend_panel). Defaults to (3,1).

    Raises:
        ValueError: If spatial coords are missing/malformed or
                    if `priority_metric` is invalid.
    """
    # 1) extract raw arrays
    def _get_array(x):
        return x.toarray().flatten() if hasattr(x, 'toarray') else x.flatten()
    X1 = _get_array(adata[:, gene1].X)
    X2 = _get_array(adata[:, gene2].X)

    # 2) optional log1p
    if log_transform:
        X1 = np.log1p(X1)
        X2 = np.log1p(X2)

    # 3) percentile‐clip
    lo1, hi1 = np.percentile(X1, clip_percentiles)
    lo2, hi2 = np.percentile(X2, clip_percentiles)
    X1 = np.clip(X1, lo1, hi1)
    X2 = np.clip(X2, lo2, hi2)

    # 4) normalize to [0,1]
    u = (X1 - lo1) / (hi1 - lo1) if hi1 > lo1 else np.zeros_like(X1)
    v = (X2 - lo2) / (hi2 - lo2) if hi2 > lo2 else np.zeros_like(X2)

    # 5) prepare LUT
    m = len(cmap.colors)
    n = int(np.sqrt(m))
    C = np.array(cmap.colors).reshape(n, n, 3)

    # 6) bilinear interpolate per‐cell
    gu = u * (n - 1); gv = v * (n - 1)
    i0 = np.floor(gu).astype(int); j0 = np.floor(gv).astype(int)
    i1 = np.minimum(i0 + 1, n - 1); j1 = np.minimum(j0 + 1, n - 1)
    du = gu - i0; dv = gv - j0

    wa = (1 - du) * (1 - dv)
    wb = du * (1 - dv)
    wc = (1 - du) * dv
    wd = du * dv

    c00 = C[j0, i0]; c10 = C[j0, i1]
    c01 = C[j1, i0]; c11 = C[j1, i1]

    cols_rgb = (
        c00 * wa[:, None] +
        c10 * wb[:, None] +
        c01 * wc[:, None] +
        c11 * wd[:, None]
    )
    hex_colors = [to_hex(c) for c in cols_rgb]

    # 7) determine draw order
    if priority_metric == 'sum':
        priority = u + v
    elif priority_metric == 'gene1':
        priority = u
    elif priority_metric == 'gene2':
        priority = v
    else:
        raise ValueError("priority_metric must be 'sum', 'gene1', or 'gene2'")
    order = np.argsort(priority)

    # 8) fetch and sort coords/colors
    coords = adata.obsm.get(spatial_key)
    if coords is None or coords.shape[1] < 2:
        raise ValueError(f"adata.obsm['{spatial_key}'] must be an (n_obs, 2) array")
    coords_sorted = coords[order]
    colors_sorted = [hex_colors[i] for i in order]

    # 9) plot scatter + optional legend
    fig, (ax_sc, ax_cb) = plt.subplots(
        1, 2,
        figsize=(8, 4),
        gridspec_kw={'width_ratios': width_ratios, 'wspace': 0.3}
    )
    ax_sc.scatter(
        coords_sorted[:, 0],
        coords_sorted[:, 1],
        c=colors_sorted,
        s=spot_size,
        alpha=alpha
    )
    ax_sc.set_aspect('equal')
    ax_sc.set_title(f"{gene1} :: {gene2}")

    # axis display options
    if not show_xcoords:
        ax_sc.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    if not show_ycoords:
        ax_sc.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
    if not show_bbox:
        for spine in ax_sc.spines.values():
            spine.set_visible(False)

    # legend/colorbar
    if show_legend:
        lut_img = C  # shape (n,n,3)
        ax_cb.imshow(lut_img, origin='lower', extent=[0, 1, 0, 1])
        # ax_cb.set_xlabel(f"{gene1}\nlow → high")
        # ax_cb.set_ylabel(f"{gene2}\nlow → high")
        ax_cb.set_xlabel(f"{gene1}")
        ax_cb.set_ylabel(f"{gene2}")
        ax_cb.set_xticks([0, 1]); ax_cb.set_yticks([0, 1])
        ax_cb.set_aspect('equal')
    else:
        ax_cb.axis('off')

    plt.show()

stackedbar_composition

stackedbar_composition(adata, groupby, obs_column='SCN_class', labels=None, bar_width=0.75, color_dict=None, ax=None, order_by_similarity=False, similarity_metric='correlation', include_legend=True, legend_rows=10)

Plots a stacked bar chart of cell type proportions for a single AnnData object grouped by a specified column.

Parameters:

  • adata (AnnData) –

    An AnnData object.

  • groupby (str) –

    The column in .obs to group by.

  • obs_column (str, default: 'SCN_class' ) –

    The name of the .obs column to use for categories. Defaults to 'SCN_class'.

  • labels (List[str], default: None ) –

    Custom labels for each group to be displayed on the x-axis. If not provided, the unique values of the groupby column will be used. The length of labels must match the number of unique groups.

  • bar_width (float, default: 0.75 ) –

    The width of the bars in the plot. Defaults to 0.75.

  • color_dict (Dict[str, str], default: None ) –

    A dictionary mapping categories to specific colors. If not provided, default colors will be used.

  • ax (Axes, default: None ) –

    The axis to plot on. If not provided, a new figure and axis will be created.

  • order_by_similarity (bool, default: False ) –

    Whether to order the bars by similarity in composition. Defaults to False.

  • similarity_metric (str, default: 'correlation' ) –

    The metric to use for similarity ordering. Defaults to 'correlation'.

  • include_legend (bool, default: True ) –

    Whether to include a legend in the plot. Defaults to True.

  • legend_rows (int, default: 10 ) –

    The number of rows in the legend. Defaults to 10.

Raises:

  • ValueError

    If the length of labels does not match the number of unique groups.

Examples:

>>> stackedbar_composition(adata, groupby='sample', obs_column='your_column_name')
>>> fig, ax = plt.subplots()
>>> stackedbar_composition(adata, groupby='sample', obs_column='your_column_name', ax=ax, include_legend=False, legend_rows=5)
Source code in src/pySingleCellNet/plotting/bar.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def stackedbar_composition(
    adata: AnnData, 
    groupby: str, 
    obs_column='SCN_class', 
    labels=None, 
    bar_width: float = 0.75, 
    color_dict=None, 
    ax=None,
    order_by_similarity: bool = False,
    similarity_metric: str = 'correlation',
    include_legend: bool = True,
    legend_rows: int = 10
):
    """
    Plots a stacked bar chart of cell type proportions for a single AnnData object grouped by a specified column.

    Args:
        adata (anndata.AnnData): An AnnData object.
        groupby (str): The column in `.obs` to group by.
        obs_column (str, optional): The name of the `.obs` column to use for categories. Defaults to 'SCN_class'.
        labels (List[str], optional): Custom labels for each group to be displayed on the x-axis.
            If not provided, the unique values of the groupby column will be used. The length of `labels` must match
            the number of unique groups.
        bar_width (float, optional): The width of the bars in the plot. Defaults to 0.75.
        color_dict (Dict[str, str], optional): A dictionary mapping categories to specific colors. If not provided,
            default colors will be used.
        ax (matplotlib.axes.Axes, optional): The axis to plot on. If not provided, a new figure and axis will be created.
        order_by_similarity (bool, optional): Whether to order the bars by similarity in composition. Defaults to False.
        similarity_metric (str, optional): The metric to use for similarity ordering. Defaults to 'correlation'.
        include_legend (bool, optional): Whether to include a legend in the plot. Defaults to True.
        legend_rows (int, optional): The number of rows in the legend. Defaults to 10.

    Raises:
        ValueError: If the length of `labels` does not match the number of unique groups.

    Examples:
        >>> stackedbar_composition(adata, groupby='sample', obs_column='your_column_name')
        >>> fig, ax = plt.subplots()
        >>> stackedbar_composition(adata, groupby='sample', obs_column='your_column_name', ax=ax, include_legend=False, legend_rows=5)
    """
    # Ensure the groupby column exists in .obs
    if groupby not in adata.obs.columns:
        raise ValueError(f"The groupby column '{groupby}' does not exist in the .obs attribute.")

    # Check if groupby column is categorical or not
    if pd.api.types.is_categorical_dtype(adata.obs[groupby]):
        unique_groups = adata.obs[groupby].cat.categories.to_list()
    else:
        unique_groups = adata.obs[groupby].unique().tolist()

    # Extract unique groups and ensure labels are provided or create default ones
    if labels is None:
        labels = unique_groups
    elif len(labels) != len(unique_groups):
        raise ValueError("Length of 'labels' must match the number of unique groups.")

    if color_dict is None:
        color_dict = adata.uns.get('SCN_class_colors', {})

    # Extracting category proportions per group
    category_counts = []
    categories = set()
    for group in unique_groups:
        subset = adata[adata.obs[groupby] == group]
        counts = subset.obs[obs_column].value_counts(normalize=True)
        category_counts.append(counts)
        categories.update(counts.index)

    categories = sorted(categories)

    # Preparing the data for plotting
    proportions = np.zeros((len(categories), len(unique_groups)))
    for i, counts in enumerate(category_counts):
        for category in counts.index:
            j = categories.index(category)
            proportions[j, i] = counts[category]

    # Ordering groups by similarity if requested
    if order_by_similarity:
        dist_matrix = pdist(proportions.T, metric=similarity_metric)
        linkage_matrix = linkage(dist_matrix, method='average')
        order = leaves_list(linkage_matrix)
        proportions = proportions[:, order]
        unique_groups = [unique_groups[i] for i in order]
        labels = [labels[i] for i in order]

    # Plotting
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.figure

    bottom = np.zeros(len(unique_groups))
    for i, category in enumerate(categories):
        color = color_dict.get(category, None)
        ax.bar(
            range(len(unique_groups)), 
            proportions[i], 
            bottom=bottom, 
            label=category, 
            width=bar_width, 
            edgecolor='white', 
            linewidth=.5,
            color=color
        )
        bottom += proportions[i]

    ax.set_xticks(range(len(unique_groups)))
    ax.set_xticklabels(labels, rotation=45, ha="right")
    ax.set_ylabel('Proportion')
    ax.set_title(f'{obs_column} proportions by {groupby}')

    if include_legend:
        num_columns = int(np.ceil(len(categories) / legend_rows))
        ax.legend(title='Classes', bbox_to_anchor=(1.05, 1), loc='upper left', ncol=num_columns)

    if ax is None:
        plt.tight_layout()
        plt.show()
    else:
        return ax

stackedbar_composition_list

stackedbar_composition_list(adata_list, obs_column='SCN_class', labels=None, bar_width=0.75, color_dict=None, legend_loc='outside center right')

Plots a stacked bar chart of category proportions for a list of AnnData objects.

This function takes a list of AnnData objects, and for a specified column in the .obs attribute, it plots a stacked bar chart. Each bar represents an AnnData object with segments showing the proportion of each category within that object.

Parameters:

  • adata_list (List[AnnData]) –

    A list of AnnData objects.

  • obs_column (str, default: 'SCN_class' ) –

    The name of the .obs column to use for categories. Defaults to 'SCN_class'.

  • labels (List[str], default: None ) –

    Custom labels for each AnnData object to be displayed on the x-axis. If not provided, defaults to 'AnnData {i+1}' for each object. The length of labels must match the number of AnnData objects provided.

  • bar_width (float, default: 0.75 ) –

    The width of the bars in the plot. Defaults to 0.75.

  • color_dict (Dict[str, str], default: None ) –

    A dictionary mapping categories to specific colors. If not provided, default colors will be used.

Raises:

  • ValueError

    If the length of labels does not match the number of AnnData objects.

Examples:

>>> plot_cell_type_proportions([adata1, adata2], obs_column='your_column_name', labels=['Sample 1', 'Sample 2'])
Source code in src/pySingleCellNet/plotting/bar.py
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
def stackedbar_composition_list(
    adata_list, 
    obs_column = 'SCN_class', 
    labels = None, 
    bar_width: float = 0.75, 
    color_dict = None,
    legend_loc = "outside center right"
):
    """
    Plots a stacked bar chart of category proportions for a list of AnnData objects.

    This function takes a list of AnnData objects, and for a specified column in the `.obs` attribute,
    it plots a stacked bar chart. Each bar represents an AnnData object with segments showing the proportion
    of each category within that object.

    Args:
        adata_list (List[anndata.AnnData]): A list of AnnData objects.
        obs_column (str, optional): The name of the `.obs` column to use for categories. Defaults to 'SCN_class'.
        labels (List[str], optional): Custom labels for each AnnData object to be displayed on the x-axis.
            If not provided, defaults to 'AnnData {i+1}' for each object. The length of `labels` must match
            the number of AnnData objects provided.
        bar_width (float, optional): The width of the bars in the plot. Defaults to 0.75.
        color_dict (Dict[str, str], optional): A dictionary mapping categories to specific colors. If not provided,
            default colors will be used.

    Raises:
        ValueError: If the length of `labels` does not match the number of AnnData objects.

    Examples:
        >>> plot_cell_type_proportions([adata1, adata2], obs_column='your_column_name', labels=['Sample 1', 'Sample 2'])
    """

    # Ensure labels are provided, or create default ones
    if labels is None:
        labels = [f'AnnData {i+1}' for i in range(len(adata_list))]
    elif len(labels) != len(adata_list):
        raise ValueError("Length of 'labels' must match the number of AnnData objects provided.")

    # Extracting category proportions
    category_counts = []
    categories = set()
    for adata in adata_list:
        counts = adata.obs[obs_column].value_counts(normalize=True)
        category_counts.append(counts)
        categories.update(counts.index)

    categories = sorted(categories)

    # Preparing the data for plotting
    proportions = np.zeros((len(categories), len(adata_list)))
    for i, counts in enumerate(category_counts):
        for category in counts.index:
            j = categories.index(category)
            proportions[j, i] = counts[category]

    if color_dict is None:
        color_dict = adata_list[0].uns['SCN_class_colors'] # should parameterize this

    # Plotting
    #### fig, ax = plt.subplots()
    fig, ax = plt.subplots(constrained_layout=True)
    bottom = np.zeros(len(adata_list))
    for i, category in enumerate(categories):
        color = color_dict[category] if color_dict and category in color_dict else None
        ax.bar(
            range(len(adata_list)), 
            proportions[i], 
            bottom=bottom, 
            label=category, 
            width=bar_width, 
            edgecolor='white', 
            linewidth=.5,
            color=color
        )
        bottom += proportions[i]

    ax.set_xticks(range(len(adata_list)))
    ax.set_xticklabels(labels, rotation=45, ha="right")
    ax.set_ylabel('Proportion')
    ax.set_title(f'{obs_column} proportions')
    # ax.legend(title='Classes', bbox_to_anchor=(1.05, 1), loc='upper left')
    ## legend = fig.legend(title='Classes', loc="outside right upper", frameon=False)#, bbox_to_anchor=(1.05, 1), loc='upper left')
    ## legend_height = legend.get_window_extent().height / fig.dpi  # in inches

    # Add legend
    legend_handles = [mpatches.Patch(color=color, label=label) for label, color in color_dict.items()]
    # legend = ax.legend(handles=legend_handles, bbox_to_anchor=(1.05, 1), loc='upper left', frameon=False)
    ##### legend = fig.legend(title='Classes', loc="outside right upper", frameon=False)
    fig.legend(handles=legend_handles, loc=legend_loc, frameon=False)
    ##### legend_height = legend.get_window_extent().height / fig.dpi  # in inches

    # fig_height = fig.get_size_inches()[1]  # current height in inches
    #### fig.set_size_inches(fig.get_size_inches()[0], legend_height )
    # plt.tight_layout()
    # plt.show()
    return fig

umap_scores

umap_scores(adata, scn_classes, obsm_name='SCN_score', alpha=0.75, s=10, display=True)

Plots UMAP projections of scRNA-seq data with specified scores.

Parameters:

  • adata (AnnData) –

    The AnnData object containing the scRNA-seq data.

  • scn_classes (list) –

    A list of SCN classes to visualize on the UMAP.

  • obsm_name (str, default: 'SCN_score' ) –

    The name of the obsm key containing the SCN scores. Defaults to 'SCN_score'.

  • alpha (float, default: 0.75 ) –

    The transparency level of the points on the UMAP plot. Defaults to 0.75.

  • s (int, default: 10 ) –

    The size of the points on the UMAP plot. Defaults to 10.

  • display (bool, default: True ) –

    If True, the plot is displayed immediately. If False, the axis object is returned. Defaults to True.

Returns:

  • matplotlib.axes.Axes or None: If display is False, returns the matplotlib axes object. Otherwise, returns None.

Source code in src/pySingleCellNet/plotting/dot.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def umap_scores(
    adata: AnnData,
    scn_classes: list,
    obsm_name='SCN_score',
    alpha=0.75,
    s=10,
    display=True
):
    """
    Plots UMAP projections of scRNA-seq data with specified scores.

    Args:
        adata (AnnData): 
            The AnnData object containing the scRNA-seq data.
        scn_classes (list): 
            A list of SCN classes to visualize on the UMAP.
        obsm_name (str, optional): 
            The name of the obsm key containing the SCN scores. Defaults to 'SCN_score'.
        alpha (float, optional): 
            The transparency level of the points on the UMAP plot. Defaults to 0.75.
        s (int, optional): 
            The size of the points on the UMAP plot. Defaults to 10.
        display (bool, optional): 
            If True, the plot is displayed immediately. If False, the axis object is returned. Defaults to True.

    Returns:
        matplotlib.axes.Axes or None: 
            If `display` is False, returns the matplotlib axes object. Otherwise, returns None.
    """
    # Create a temporary AnnData object with the desired obsm
    adTemp = AnnData(adata.obsm[obsm_name], obs=adata.obs)
    adTemp.obsm['X_umap'] = adata.obsm['X_umap'].copy()

    # Create the UMAP plot
    ax = sc.pl.umap(adTemp, color=scn_classes, alpha=alpha, s=s, vmin=0, vmax=1, show=False)

    # Display or return the axis
    if display:
        plt.show()
    else:
        return ax

umi_counts_ranked

umi_counts_ranked(adata, total_counts_column='total_counts')

Identifies and plors the knee point of the UMI count distribution in an AnnData object.

Parameters:

  • adata (AnnData) –

    The input AnnData object.

  • total_counts_column (str, default: 'total_counts' ) –

    Column in adata.obs containing total UMI counts. Default is "total_counts".

  • show (bool) –

    If True, displays a log-log plot with the knee point. Default is True.

Returns:

  • float

    The UMI count value at the knee point.

Source code in src/pySingleCellNet/plotting/dot.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def umi_counts_ranked(adata, total_counts_column="total_counts"):
    """
    Identifies and plors the knee point of the UMI count distribution in an AnnData object.

    Parameters:
        adata (AnnData): The input AnnData object.
        total_counts_column (str): Column in `adata.obs` containing total UMI counts. Default is "total_counts".
        show (bool): If True, displays a log-log plot with the knee point. Default is True.

    Returns:
        float: The UMI count value at the knee point.
    """
    # Extract total UMI counts
    umi_counts = adata.obs[total_counts_column]

    # Sort UMI counts in descending order
    sorted_umi_counts = np.sort(umi_counts)[::-1]

    # Compute cumulative UMI counts (normalized to a fraction)
    cumulative_counts = np.cumsum(sorted_umi_counts)
    cumulative_fraction = cumulative_counts / cumulative_counts[-1]

    # Compute derivatives to identify the knee point
    first_derivative = np.gradient(cumulative_fraction)
    second_derivative = np.gradient(first_derivative)

    # Find the index of the maximum curvature (knee point)
    knee_idx = np.argmax(second_derivative)
    knee_point_value = sorted_umi_counts[knee_idx]

    # Generate log-log plot
    cell_ranks = np.arange(1, len(sorted_umi_counts) + 1)
    plt.figure(figsize=(10, 6))
    plt.plot(cell_ranks, sorted_umi_counts, marker='o', markersize=2, linestyle='-', linewidth=0.5, label="UMI Counts")
    plt.axvline(cell_ranks[knee_idx], color="red", linestyle="--", label=f"Knee Point: {knee_point_value}")
    plt.title('UMI Counts Per Cell (Log-Log Scale)', fontsize=14)
    plt.xlabel('Cell Rank (Descending)', fontsize=12)
    plt.ylabel('Total UMI Counts', fontsize=12)
    plt.xscale('log')
    plt.yscale('log')
    plt.grid(True, linestyle='--', linewidth=0.5)
    plt.legend()
    plt.tight_layout()
    plt.show()