amici.interpretation.AMICIAttentionModule#

class amici.interpretation.AMICIAttentionModule(_adata=None, _labels_key=None, _attention_patterns_df=None, _nn_idxs_df=None, _nn_dists_df=None, _compute_kwargs=None, _flavor=None)[source]#

Bases: object

Methods

__init__

compute

Different flavors of attention patterns retrieved from cache.

compute_communication_hubs

Compute hub analysis by clustering cells based on their high-interacting neighbor composition.

plot_attention_summary

Plot a summary of attention patterns for different cell types and heads.

save

Save attention patterns to file

classmethod compute(model, adata=None, indices=None, batch_size=None, flavor='vanilla', prog_bar=True)[source]#

Different flavors of attention patterns retrieved from cache.

Returns a DataFrame if return_nn_idxs_and_dists is False, otherwise returns a tuple of DataFrames (attention_patterns, nn_idxs, nn_dists). Code adapted from: callummcdougall/CircuitsVis. einsum key: - b: n_batch - h: n_heads - n: n_neighbors - g: n_genes - m: n_heads * n_head_size (a.k.a. d_model)

Parameters:
  • adata (AnnData) – AnnData object to get attention patterns from.

  • indices (list[int], optional) – Indices of the cells to get attention patterns from.

  • batch_size (int, optional) – Batch size to use for the data loader.

  • flavor (Literal["vanilla", "value-weighted", "info-weighted", "gene-weighted"], optional) – Flavor of attention patterns to retrieve.

  • prog_bar (bool, optional) – Whether to show a progress bar.

Return type:

AMICIAttentionModule

Stores#

Union[pd.DataFrame, tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]]:
  • attention_patterns: DataFrame of attention patterns with the following columns:
    • “neighbor_{i}”: Attention pattern for the i-th neighbor.

    • “head_idx”: Index of the attention head.

    • “label”: Label of the cell.

    • “cell_idx”: Index of the cell.

  • nn_idxs: DataFrame of nearest neighbor indices (only if return_nn_idxs_and_dists is True) with the following columns:
    • “neighbor_{i}”: Index of the i-th neighbor.

  • nn_dists: DataFrame of nearest neighbor distances (only if return_nn_idxs_and_dists is True) with the following columns:
    • “neighbor_{i}”: Distance to the i-th neighbor.

save(save_path)[source]#

Save attention patterns to file

compute_communication_hubs(attention_quantile_threshold=0.9, n_clusters=None, random_state=42)[source]#

Compute hub analysis by clustering cells based on their high-interacting neighbor composition.

This method: 1. Aggregates attention scores by taking the max across all heads 2. For each receiver cell type, computes a threshold as the specified quantile

of attention scores to receivers of that type

  1. For each cell, counts the cell types of neighbors with attention scores above the threshold

  2. Normalizes these counts to create composition vectors

  3. Clusters the composition vectors using KMeans

Parameters:
  • attention_quantile_threshold (float) – The quantile threshold for classifying high-interacting neighbors. Defaults to 0.9 (90th percentile).

  • n_clusters (int | None) – Number of clusters for KMeans. If None, will use silhouette analysis to find the optimal number of clusters (2-12 range).

  • random_state (int) – Random state for KMeans clustering.

Returns:

A DataFrame indexed by cell obs_names with the following columns:
  • One column per cell type containing the normalized count of high-interacting neighbors of that type

  • ’hub_cluster’: The assigned hub cluster label

Return type:

pd.DataFrame

plot_attention_summary(cell_type_sub=None, sel_head=None, plot_histogram=False, palette=None, max_distance=100, bin_size=5, min_bin_count=50, epoch=None, wandb_log=False, show=True, save_png=False, save_svg=False, save_dir='./figures')[source]#

Plot a summary of attention patterns for different cell types and heads.

Parameters:
  • cell_types (list, optional) – List of cell types to consider. Defaults to None.

  • sel_head (int, optional) – Selected head index. Defaults to None.

  • plot_histogram (bool, optional) – Whether to plot a histogram of attention scores. Defaults to False.

  • flavor (str, optional) – Flavor of attention to plot. Defaults to “vanilla”.

  • palette (str or list, optional) – Color palette for the plot. Defaults to None.

  • max_distance (int, optional) – Maximum distance for binning. Defaults to 100.

  • bin_size (int, optional) – Size of each bin. Defaults to 5.

  • min_bin_count (int, optional) – Minimum count of bins. Defaults to 50.

  • epoch (int, optional) – Epoch number for logging. Defaults to None.

  • wandb_log (bool, optional) – Whether to log the plot to Weights and Biases. Defaults to True.

  • show (bool, optional) – Whether to show the plot. Defaults to False.

  • save_png (bool, optional) – Whether to save the plot as a PNG file. Defaults to False.

  • save_svg (bool, optional) – Whether to save the plot as an SVG file. Defaults to False.

  • save_dir (str, optional) – Directory to save the plot. Defaults to “./figures”.

Returns:

None