amici.interpretation.AMICIAttentionModule#
- class amici.interpretation.AMICIAttentionModule(_adata=None, _labels_key=None, _attention_patterns_df=None, _nn_idxs_df=None, _nn_dists_df=None, _compute_kwargs=None, _flavor=None)[source]#
Bases:
objectMethods
__init__Different flavors of attention patterns retrieved from cache.
Compute hub analysis by clustering cells based on their high-interacting neighbor composition.
Plot a summary of attention patterns for different cell types and heads.
Save attention patterns to file
- classmethod compute(model, adata=None, indices=None, batch_size=None, flavor='vanilla', prog_bar=True)[source]#
Different flavors of attention patterns retrieved from cache.
Returns a DataFrame if return_nn_idxs_and_dists is False, otherwise returns a tuple of DataFrames (attention_patterns, nn_idxs, nn_dists). Code adapted from: callummcdougall/CircuitsVis. einsum key: - b: n_batch - h: n_heads - n: n_neighbors - g: n_genes - m: n_heads * n_head_size (a.k.a. d_model)
- Parameters:
adata (AnnData) – AnnData object to get attention patterns from.
indices (list[int], optional) – Indices of the cells to get attention patterns from.
batch_size (int, optional) – Batch size to use for the data loader.
flavor (Literal["vanilla", "value-weighted", "info-weighted", "gene-weighted"], optional) – Flavor of attention patterns to retrieve.
prog_bar (bool, optional) – Whether to show a progress bar.
- Return type:
Stores#
- Union[pd.DataFrame, tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]]:
- attention_patterns: DataFrame of attention patterns with the following columns:
“neighbor_{i}”: Attention pattern for the i-th neighbor.
“head_idx”: Index of the attention head.
“label”: Label of the cell.
“cell_idx”: Index of the cell.
- nn_idxs: DataFrame of nearest neighbor indices (only if return_nn_idxs_and_dists is True) with the following columns:
“neighbor_{i}”: Index of the i-th neighbor.
- nn_dists: DataFrame of nearest neighbor distances (only if return_nn_idxs_and_dists is True) with the following columns:
“neighbor_{i}”: Distance to the i-th neighbor.
- compute_communication_hubs(attention_quantile_threshold=0.9, n_clusters=None, random_state=42)[source]#
Compute hub analysis by clustering cells based on their high-interacting neighbor composition.
This method: 1. Aggregates attention scores by taking the max across all heads 2. For each receiver cell type, computes a threshold as the specified quantile
of attention scores to receivers of that type
For each cell, counts the cell types of neighbors with attention scores above the threshold
Normalizes these counts to create composition vectors
Clusters the composition vectors using KMeans
- Parameters:
attention_quantile_threshold (
float) – The quantile threshold for classifying high-interacting neighbors. Defaults to 0.9 (90th percentile).n_clusters (
int|None) – Number of clusters for KMeans. If None, will use silhouette analysis to find the optimal number of clusters (2-12 range).random_state (
int) – Random state for KMeans clustering.
- Returns:
- A DataFrame indexed by cell obs_names with the following columns:
One column per cell type containing the normalized count of high-interacting neighbors of that type
’hub_cluster’: The assigned hub cluster label
- Return type:
pd.DataFrame
- plot_attention_summary(cell_type_sub=None, sel_head=None, plot_histogram=False, palette=None, max_distance=100, bin_size=5, min_bin_count=50, epoch=None, wandb_log=False, show=True, save_png=False, save_svg=False, save_dir='./figures')[source]#
Plot a summary of attention patterns for different cell types and heads.
- Parameters:
cell_types (list, optional) – List of cell types to consider. Defaults to None.
sel_head (int, optional) – Selected head index. Defaults to None.
plot_histogram (bool, optional) – Whether to plot a histogram of attention scores. Defaults to False.
flavor (str, optional) – Flavor of attention to plot. Defaults to “vanilla”.
palette (str or list, optional) – Color palette for the plot. Defaults to None.
max_distance (int, optional) – Maximum distance for binning. Defaults to 100.
bin_size (int, optional) – Size of each bin. Defaults to 5.
min_bin_count (int, optional) – Minimum count of bins. Defaults to 50.
epoch (int, optional) – Epoch number for logging. Defaults to None.
wandb_log (bool, optional) – Whether to log the plot to Weights and Biases. Defaults to True.
show (bool, optional) – Whether to show the plot. Defaults to False.
save_png (bool, optional) – Whether to save the plot as a PNG file. Defaults to False.
save_svg (bool, optional) – Whether to save the plot as an SVG file. Defaults to False.
save_dir (str, optional) – Directory to save the plot. Defaults to “./figures”.
- Returns:
None