Source code for amici.interpretation._attention_module

import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import wandb
from anndata import AnnData
from einops import einsum, rearrange
from scvi import REGISTRY_KEYS
from sklearn.cluster import KMeans
from tqdm import tqdm

from amici._constants import NN_REGISTRY_KEYS

from ._utils import _get_compute_method_kwargs

if TYPE_CHECKING:
    from amici._model import AMICI


[docs] @dataclass class AMICIAttentionModule: _adata: AnnData | None = None _labels_key: str | None = None _attention_patterns_df: pd.DataFrame | None = None _nn_idxs_df: pd.DataFrame | None = None _nn_dists_df: pd.DataFrame | None = None _compute_kwargs: dict | None = None _flavor: (Literal["vanilla", "value-weighted", "info-weighted", "gene-weighted"] | None) = None
[docs] @classmethod def compute( cls, model: "AMICI", adata: AnnData | None = None, indices: list[int] | None = None, batch_size: int | None = None, flavor: Literal["vanilla", "value-weighted", "info-weighted", "gene-weighted"] = "vanilla", prog_bar: bool = True, ) -> "AMICIAttentionModule": """Different flavors of attention patterns retrieved from cache. Returns a DataFrame if return_nn_idxs_and_dists is False, otherwise returns a tuple of DataFrames (attention_patterns, nn_idxs, nn_dists). Code adapted from: https://github.com/callummcdougall/CircuitsVis/blob/main/python/circuitsvis/attention.py#L203-L211. einsum key: - b: n_batch - h: n_heads - n: n_neighbors - g: n_genes - m: n_heads * n_head_size (a.k.a. d_model) Args: adata (AnnData): AnnData object to get attention patterns from. indices (list[int], optional): Indices of the cells to get attention patterns from. batch_size (int, optional): Batch size to use for the data loader. flavor (Literal["vanilla", "value-weighted", "info-weighted", "gene-weighted"], optional): Flavor of attention patterns to retrieve. prog_bar (bool, optional): Whether to show a progress bar. Stores ------- Union[pd.DataFrame, tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]]: - attention_patterns: DataFrame of attention patterns with the following columns: - "neighbor_{i}": Attention pattern for the i-th neighbor. - "head_idx": Index of the attention head. - "label": Label of the cell. - "cell_idx": Index of the cell. - nn_idxs: DataFrame of nearest neighbor indices (only if return_nn_idxs_and_dists is True) with the following columns: - "neighbor_{i}": Index of the i-th neighbor. - nn_dists: DataFrame of nearest neighbor distances (only if return_nn_idxs_and_dists is True) with the following columns: - "neighbor_{i}": Distance to the i-th neighbor. """ _compute_kwargs = _get_compute_method_kwargs(**locals()) model._check_if_trained(warn=True) adata = model._validate_anndata(adata) if indices is None: indices = np.arange(len(adata)) scdl = model._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) attention_patterns = [] nn_idxs = [] nn_dists = [] for tensors in tqdm(scdl, disable=not prog_bar): tensors = {k: v.to(model.device) for k, v in tensors.items()} _, gen_outputs = model.module( tensors, generative_kwargs={ "return_attention_patterns": True, "return_v": True, }, compute_loss=False, ) batch_attention_patterns = (gen_outputs["attention_patterns"].detach().cpu().numpy())[ :, :, :-1 ] # remove query_len dim and dummy dim nn_idxs.append(tensors[NN_REGISTRY_KEYS.NN_IDX_KEY].detach().cpu().numpy()) nn_dists.append(tensors[NN_REGISTRY_KEYS.NN_DIST_KEY].detach().cpu().numpy()) if flavor in ("value-weighted", "info-weighted", "gene-weighted"): batch_v = gen_outputs["attention_v"].detach().cpu().numpy() # batch x n_nns x n_heads x n_head_size W_O = model.module.attention_layer.W_O.detach().cpu().numpy() # n_heads x n_head_size x d_model info = einsum(batch_v, W_O, "b n h d, m h d -> b h n m") if flavor == "gene-weighted": proj_linear_W = ( model.module.linear_head.weight.detach().cpu().numpy() ) # n_genes x (d_model * n_ct_embed) if flavor == "value-weighted": v_norms = rearrange(np.linalg.norm(batch_v, axis=-1), "b n h -> b h n") batch_attention_patterns = batch_attention_patterns * v_norms if flavor == "info-weighted": info_norms = np.linalg.norm(info, axis=-1) batch_attention_patterns = batch_attention_patterns * info_norms if flavor == "gene-weighted": gene_info = einsum(info, proj_linear_W, "b h n m, g m -> b h n g") gene_info_norms = np.linalg.norm(gene_info, axis=-1) batch_attention_patterns = batch_attention_patterns * gene_info_norms attention_patterns.append(batch_attention_patterns) attention_patterns_head = [] labels_key = model.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key for head_idx in range(model.module.n_heads): attention_patterns_head_idx = np.vstack(attention_patterns)[:, head_idx, :] head_idx_labels = np.repeat(head_idx, attention_patterns_head_idx.shape[0]).reshape(-1, 1) cell_type_labels = adata[indices].obs[labels_key].values.reshape(-1, 1) cell_idxs = adata[indices].obs_names.to_numpy().reshape(-1, 1) attention_patterns_head_df = pd.DataFrame( np.concatenate( ( attention_patterns_head_idx, head_idx_labels, cell_type_labels, cell_idxs, ), axis=1, ), columns=[f"neighbor_{i}" for i in range(model.n_neighbors)] + ["head", "label", "cell_idx"], ) attention_patterns_head.append(attention_patterns_head_df) attention_patterns_df = pd.concat(attention_patterns_head, ignore_index=True) nn_idxs_df = pd.DataFrame( adata.obs_names.to_numpy()[np.vstack(nn_idxs)], columns=[f"neighbor_{i}" for i in range(model.n_neighbors)], index=adata.obs_names[indices], ) nn_dists_df = pd.DataFrame( np.vstack(nn_dists), columns=[f"neighbor_{i}" for i in range(model.n_neighbors)], index=adata.obs_names[indices], ) return cls( _adata=adata, _attention_patterns_df=attention_patterns_df, _nn_idxs_df=nn_idxs_df, _nn_dists_df=nn_dists_df, _labels_key=labels_key, _flavor=flavor, _compute_kwargs=_compute_kwargs, )
[docs] def save(self, save_path: str): """Save attention patterns to file""" self._attention_patterns_df.to_csv(save_path) return self
[docs] def compute_communication_hubs( self, attention_quantile_threshold: float = 0.9, n_clusters: int | None = None, random_state: int = 42, ) -> pd.DataFrame: """ Compute hub analysis by clustering cells based on their high-interacting neighbor composition. This method: 1. Aggregates attention scores by taking the max across all heads 2. For each receiver cell type, computes a threshold as the specified quantile of attention scores to receivers of that type 3. For each cell, counts the cell types of neighbors with attention scores above the threshold 4. Normalizes these counts to create composition vectors 5. Clusters the composition vectors using KMeans Args: attention_quantile_threshold: The quantile threshold for classifying high-interacting neighbors. Defaults to 0.9 (90th percentile). n_clusters: Number of clusters for KMeans. If None, will use silhouette analysis to find the optimal number of clusters (2-12 range). random_state: Random state for KMeans clustering. Returns: pd.DataFrame: A DataFrame indexed by cell obs_names with the following columns: - One column per cell type containing the normalized count of high-interacting neighbors of that type - 'hub_cluster': The assigned hub cluster label """ # Get number of neighbors from column names neighbor_cols = [col for col in self._attention_patterns_df.columns if col.startswith('neighbor_')] n_neighbors = len(neighbor_cols) # Aggregate attention scores by max across all heads attention_scores_df = self._attention_patterns_df.groupby(["cell_idx"]).max() attention_scores_df = attention_scores_df.drop(columns=["head"]).reset_index().set_index(["cell_idx"]) # Get cell types cell_types = self._adata.obs[self._labels_key].unique() # Initialize result dataframe for high interacting counts high_interacting_counts = pd.DataFrame( 0.0, index=self._adata.obs_names, columns=cell_types ) # Compute interaction thresholds per receiver cell type interaction_thresholds = {} for cell_type in cell_types: receiver_cell_type_idxs = self._adata[self._adata.obs[self._labels_key] == cell_type].obs_names # Extract attention scores to neighbors of this cell type from all senders attention_to_receiver = attention_scores_df.loc[ attention_scores_df.index.isin(receiver_cell_type_idxs) ] attention_scores = attention_to_receiver.drop(columns=["label"]).values.flatten() # Compute quantile threshold (excluding zeros) non_zero_scores = attention_scores[attention_scores > 0] if len(non_zero_scores) > 0: interaction_threshold = np.quantile(non_zero_scores, q=attention_quantile_threshold) else: interaction_threshold = 0.0 interaction_thresholds[cell_type] = interaction_threshold # For each cell type, compute high interacting neighbor counts for cell_type in cell_types: receiver_cell_type_idxs = self._adata[self._adata.obs[self._labels_key] == cell_type].obs_names attention_to_receiver = attention_scores_df.loc[ attention_scores_df.index.isin(receiver_cell_type_idxs) ] receiver_nn_obs_names = self._nn_idxs_df.loc[ self._nn_idxs_df.index.isin(receiver_cell_type_idxs) ] # Get neighbor labels receiver_nn_labels = pd.DataFrame( self._adata.obs[self._labels_key].loc[np.array(receiver_nn_obs_names).flatten()] ).rename(columns={self._labels_key: "neighbor_label"}) # Melt attention scores attention_to_receiver_melted = pd.melt( attention_to_receiver.reset_index(), id_vars=["cell_idx", "label"], value_vars=[f"neighbor_{i}" for i in range(n_neighbors)], var_name="neighbor_col", value_name="attention_score", ) # Melt neighbor obs names melted_nn_obs_names = pd.melt( receiver_nn_obs_names.reset_index(), id_vars="index", value_vars=[f"neighbor_{i}" for i in range(n_neighbors)], var_name="neighbor_col", value_name="neighbor_idx", ) # Merge attention scores with neighbor info merged_attention_scores = pd.merge( attention_to_receiver_melted, melted_nn_obs_names, left_on=["neighbor_col", "cell_idx"], right_on=["neighbor_col", "index"], how="inner" ).drop(columns=["neighbor_col", "index"]).rename( columns={"cell_idx": "receiver_idx"} ).merge( receiver_nn_labels.reset_index(), right_on="index", left_on="neighbor_idx", how="left" ) # Filter by threshold and count threshold = interaction_thresholds[cell_type] high_interacting_scores = merged_attention_scores[ merged_attention_scores["attention_score"] > threshold ] if len(high_interacting_scores) > 0: high_interacting_counts_cell_type = ( high_interacting_scores[["receiver_idx", "neighbor_label"]] .groupby(["receiver_idx"]) .value_counts() ) high_interacting_counts_cell_type = ( high_interacting_counts_cell_type .reset_index() .pivot(columns="neighbor_label", index="receiver_idx", values="count") ) # Update the high_interacting_counts DataFrame for col in high_interacting_counts_cell_type.columns: if col in high_interacting_counts.columns: high_interacting_counts.loc[ high_interacting_counts_cell_type.index, col ] = high_interacting_counts_cell_type[col] # Normalize to get composition vectors row_sums = high_interacting_counts.sum(axis=1) high_interacting_counts_norm = high_interacting_counts.div(row_sums, axis=0) high_interacting_counts_norm = high_interacting_counts_norm.fillna(0) # Determine number of clusters if not specified if n_clusters is None: from sklearn.metrics import silhouette_score best_score = -1 best_k = 2 for k in range(2, 13): kmeans = KMeans(n_clusters=k, random_state=random_state) cluster_labels = kmeans.fit_predict(high_interacting_counts_norm) score = silhouette_score(high_interacting_counts_norm, cluster_labels) if score > best_score: best_score = score best_k = k n_clusters = best_k # Perform KMeans clustering kmeans = KMeans(n_clusters=n_clusters, random_state=random_state) kmeans.fit(high_interacting_counts_norm) high_interacting_counts_norm["hub_cluster"] = kmeans.labels_ return high_interacting_counts_norm
def _bin_attention_scores( self, cell_types=None, max_distance=100, bin_size=5, min_bin_count=50, prog_bar=True, ): """ Get binned attention scores for each cell type and head. Args: cell_types (list, optional): List of cell types to consider. Defaults to all cell types in `attention_patterns_df`. max_distance (int, optional): Maximum distance for binning. bin_size (int, optional): Size of each bin. min_bin_count (int, optional): Minimum count of bins. prog_bar (bool, optional): Whether to show a progress bar. binned_attention (bool, optional): If False, returns a pivoted attention dataframe with no binning. Returns ------- pd.DataFrame: DataFrame containing binned attention scores with the columns: - cell_i: index of the cell - cell_j: index of the neighbor - attention: attention score - head: attention head - cluster_i: index cell type - cluster_j: neighbor cell type - distance: distance between cell_i and cell_j - distance_bin: distance bin """ cell_types = cell_types or list(set(self._attention_patterns_df["label"].unique())) head_idxs = sorted(self._attention_patterns_df["head"].unique()) filtered_attention_df_list = [] for ct in tqdm(cell_types, desc="Cell type", disable=not prog_bar): attention_df_list = [] for head_idx in tqdm(head_idxs, desc="Head index", disable=not prog_bar): attention_patterns_head_idx = self._attention_patterns_df[ (self._attention_patterns_df["head"] == head_idx) & (self._attention_patterns_df["label"] == ct) ] nn_indices_flat = self._nn_idxs_df.loc[attention_patterns_head_idx["cell_idx"]].to_numpy().flatten() distances_flat = self._nn_dists_df.loc[attention_patterns_head_idx["cell_idx"]].to_numpy().flatten() attention_values_flat = ( attention_patterns_head_idx.drop(["head", "label", "cell_idx"], axis=1).to_numpy().flatten() ) cell_i_repeat = np.repeat(attention_patterns_head_idx["cell_idx"], self._nn_idxs_df.shape[1]) cell_i_labels = attention_patterns_head_idx["label"].to_numpy() cluster_i_repeat = np.repeat(cell_i_labels, self._nn_idxs_df.shape[1]) index_label_map = self._adata.obs[self._labels_key] cluster_j_flat = index_label_map.loc[nn_indices_flat].to_numpy() batch_df = pd.DataFrame( { "cell_i": cell_i_repeat, "cell_j": nn_indices_flat, "attention": attention_values_flat, "distance": distances_flat, "cluster_i": cluster_i_repeat, "cluster_j": cluster_j_flat, } ) batch_df = batch_df.dropna(subset=["cell_j"]) batch_df["head"] = head_idx attention_df_list.append(batch_df) if not attention_df_list: continue attention_df = pd.concat(attention_df_list, ignore_index=True) xticklabels = np.linspace(0, max_distance - bin_size, num=int(max_distance // bin_size)) attention_df["distance_bin"] = pd.cut( attention_df["distance"], bins=np.linspace(0, max_distance, num=int(max_distance // bin_size) + 1), right=False, include_lowest=True, labels=xticklabels, ).astype(float) # filter out distance bins with low attention values group_counts = attention_df.groupby(["distance_bin", "cluster_j"]).size().reset_index(name="count") groups_to_remove = group_counts[group_counts["count"] < min_bin_count] mask = ~attention_df.set_index(["distance_bin", "cluster_j"]).index.isin( groups_to_remove.set_index(["distance_bin", "cluster_j"]).index ) filtered_attention_df = attention_df[mask].reset_index(drop=True) filtered_attention_df_list.append(filtered_attention_df) return pd.concat(filtered_attention_df_list, ignore_index=True)
[docs] def plot_attention_summary( self, cell_type_sub=None, sel_head=None, plot_histogram=False, palette=None, max_distance=100, bin_size=5, min_bin_count=50, epoch=None, wandb_log=False, show=True, save_png=False, save_svg=False, save_dir="./figures", ): """ Plot a summary of attention patterns for different cell types and heads. Args: cell_types (list, optional): List of cell types to consider. Defaults to None. sel_head (int, optional): Selected head index. Defaults to None. plot_histogram (bool, optional): Whether to plot a histogram of attention scores. Defaults to False. flavor (str, optional): Flavor of attention to plot. Defaults to "vanilla". palette (str or list, optional): Color palette for the plot. Defaults to None. max_distance (int, optional): Maximum distance for binning. Defaults to 100. bin_size (int, optional): Size of each bin. Defaults to 5. min_bin_count (int, optional): Minimum count of bins. Defaults to 50. epoch (int, optional): Epoch number for logging. Defaults to None. wandb_log (bool, optional): Whether to log the plot to Weights and Biases. Defaults to True. show (bool, optional): Whether to show the plot. Defaults to False. save_png (bool, optional): Whether to save the plot as a PNG file. Defaults to False. save_svg (bool, optional): Whether to save the plot as an SVG file. Defaults to False. save_dir (str, optional): Directory to save the plot. Defaults to "./figures". Returns ------- None """ cell_types = cell_type_sub or list(set(self._attention_patterns_df["label"].unique())) binned_attention_df = self._bin_attention_scores( cell_types=cell_types, max_distance=max_distance, bin_size=bin_size, min_bin_count=min_bin_count, prog_bar=True, ) max_attention_value = binned_attention_df["attention"].max() for ct in cell_types: binned_attention_ct_df = binned_attention_df[binned_attention_df["cluster_i"] == ct] xticklabels = np.linspace(0, max_distance - bin_size, num=int(max_distance // bin_size)) if sel_head is not None: # plot single head binned_attention_ct_head_df = binned_attention_ct_df[binned_attention_ct_df["head"] == sel_head] sns.set_theme(style="whitegrid") if plot_histogram: fig, (ax1, ax2) = plt.subplots( 2, 1, figsize=(10, 6), sharex=True, gridspec_kw={"height_ratios": [3, 1]}, ) else: fig, (ax1) = plt.subplots( 1, 1, figsize=(10, 6), ) # Attention plot g = sns.pointplot( data=binned_attention_ct_head_df, x="distance_bin", y="attention", hue="cluster_j", errorbar="sd", native_scale=True, palette=palette or "tab10", ax=ax1, ) g.set(xlabel="Distance Bin", ylabel="Attention") g.set_xticklabels( [f"{x:.2f}".rstrip("0").rstrip(".") for x in xticklabels], rotation=45, ha="right", ) ax1.set_title( f"{self._flavor.capitalize()} Attention Patterns for Index Cell Type {ct} for Head {sel_head}, Binned by Distance", size=16, pad=10, ) ax1.legend(title="Cluster", bbox_to_anchor=(1.05, 1), loc="upper left") if self._flavor != "vanilla": ax1.set_ylim(0, max_attention_value) else: ax1.set_ylim(0, 1) if plot_histogram: # Histogram plot num_bins = int(max_distance // (bin_size / 4)) sns.histplot( data=binned_attention_ct_head_df[binned_attention_ct_head_df["distance"] < max_distance], x="distance", bins=num_bins, hue="cluster_j", palette=palette or "tab10", multiple="stack", ax=ax2, legend=False, ) ax1.set_xlabel("") ax2.set(xlabel="Distance Bin", ylabel="Number of Edges") ax2.set_xticks(xticklabels) ax2.set_xticklabels( [f"{x:.2f}".rstrip("0").rstrip(".") for x in xticklabels], rotation=45, ha="right", ) plt.tight_layout() if wandb_log: wandb.log( { "epoch": epoch, f"{self._flavor.capitalize()} Attention vs. Distance for Index Cell {ct}": wandb.Image(plt), } ) if save_svg: plt.savefig(os.path.join(save_dir, f"attn_{ct}_head_{sel_head}.svg")) if save_png: plt.savefig(os.path.join(save_dir, f"attn_{ct}_head_{sel_head}.png")) if show: plt.show() plt.close() else: # plot facetgrid for all heads sns.set_theme(style="whitegrid") g = sns.FacetGrid( binned_attention_ct_df, col="head", col_wrap=4, height=4, aspect=1, sharex=True, sharey=True, hue="cluster_j", palette=palette or "tab10", ) g.map_dataframe( sns.pointplot, "distance_bin", "attention", errorbar="sd", native_scale=True, ) g.set_titles("Head {col_name}") g.set_axis_labels("Distance Bin", "Attention") for ax in g.axes.flat: ax.set_xticks(xticklabels) ax.set_xticklabels( [f"{x:.2f}".rstrip("0").rstrip(".") for x in xticklabels], rotation=45, ha="right", ) g.figure.suptitle( f"{self._flavor.capitalize()} Attention Patterns for Index Cell Type {ct} per Head, Binned by Distance", size=16, ) g.figure.subplots_adjust(top=0.85) g.add_legend(title="Cluster") if self._flavor != "vanilla": plt.ylim(0, max_attention_value) else: plt.ylim(0, 1) plt.tight_layout() if wandb_log: wandb.log( { "epoch": epoch, f"{self._flavor.capitalize()} Attention vs. Distance for Index Cell {ct}": wandb.Image(plt), } ) if save_svg: plt.savefig(os.path.join(save_dir, f"attn_{ct}.svg")) if save_png: plt.savefig(os.path.join(save_dir, f"attn_{ct}.png")) if show: plt.show() plt.close()