import os
from dataclasses import dataclass
from typing import TYPE_CHECKING
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import wandb
from anndata import AnnData
from einops import rearrange, repeat
from scvi import REGISTRY_KEYS
from amici._constants import NN_REGISTRY_KEYS
from ._utils import _get_compute_method_kwargs
if TYPE_CHECKING:
from amici._model import AMICI
[docs]
@dataclass
class AMICICounterfactualAttentionModule:
_adata: AnnData | None = None
_compute_kwargs: dict | None = None
_counterfactual_attention_df: pd.DataFrame | None = None
_labels_key: str | None = None
[docs]
@classmethod
def compute(
cls,
model: "AMICI",
cell_type: str,
adata: AnnData | None = None,
indices: list[int] | None = None,
head_idxs: list[int] | None = None,
batch_size: int | None = None,
) -> "AMICICounterfactualAttentionModule":
"""
Compute the counterfactual attention patterns for a given cell type and head indices.
Args:
model: The trained AMICI model.
cell_type: The index cell type to get counterfactual attention scores for.
adata: Optional AnnData object to use. If None, uses model.adata.
indices: Optional list of cell indices to get scores for. If None, uses all cells.
head_idxs: Optional list of attention head indices to get scores for. If None, uses all heads.
batch_size: Optional batch size to use for the data loader.
Returns
-------
AMICICounterfactualAttentionModule: The module instance with computed counterfactual attention scores stored
in _counterfactual_attention_df with columns:
- query_label: Cell type label of the query cell
- neighbor_idx: Index of the neighbor cell
- neighbor_label: Cell type label of the neighbor cell
- head_idx: Attention head index
- base_attention_score: Base attention score
- pos_coef: Position coefficient of the counterfactual attention score function
- dummy_attention_score: Dummy attention score of the model
- distance_kernel_unit_scale: Distance kernel unit scale
"""
_compute_kwargs = _get_compute_method_kwargs(**locals())
_labels_key = model.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
model._check_if_trained(warn=True)
head_idxs = head_idxs if head_idxs is not None else list(range(model.module.n_heads))
adata = model._validate_anndata(adata)
if indices is None:
indices = list(range(adata.n_obs))
# filter down to only neighbor indices that show up in at least one neighborhood of the query cell type
cell_type_indices = np.arange(adata.n_obs)[adata.obs[_labels_key] == cell_type]
filter_scdl = model._make_data_loader(adata=adata, indices=cell_type_indices, batch_size=batch_size)
neighbor_indices = set()
for tensors in filter_scdl:
batch_neighbor_indices = tensors[NN_REGISTRY_KEYS.NN_IDX_KEY].cpu().detach().numpy().flatten()
neighbor_indices.update(batch_neighbor_indices)
indices = [idx for idx in indices if idx in neighbor_indices]
scdl = model._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
labels_cat_list = model.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).categorical_mapping.tolist()
if cell_type not in labels_cat_list:
raise ValueError(f"Cell type {cell_type} not found in adata")
labels_cat_val = labels_cat_list.index(cell_type)
labels_tensor = torch.tensor(labels_cat_val).unsqueeze(0).unsqueeze(0).to(model.device)
counterfactual_attention_dfs = []
batch_start_idx = 0
for neighbor_tensors in scdl:
batch_neighbor_labels = neighbor_tensors[REGISTRY_KEYS.LABELS_KEY].cpu().detach().numpy().flatten()
batch_size = len(batch_neighbor_labels)
nn_X = neighbor_tensors[REGISTRY_KEYS.X_KEY].unsqueeze(0).to(model.device)
inf_outputs = model.module.inference(
labels_tensor,
nn_X,
)
gen_outputs = model.module.generative(
labels_tensor,
inf_outputs["label_embed"],
inf_outputs["nn_embed"],
torch.full((1, nn_X.shape[1]), 0).to(model.device),
return_attention_scores=True,
)
gen_attn_scores = gen_outputs["attention_scores"][0, :, :-1].T # n_cells x n_heads
gen_pos_coefs = gen_outputs["pos_coefs"][0, :] # n_cells x n_heads
batch_attn_scores = gen_attn_scores.cpu().detach().numpy()
batch_pos_coefs = gen_pos_coefs.cpu().detach().numpy()
n_heads = batch_attn_scores.shape[1]
# Get the indices for this batch
batch_indices = indices[batch_start_idx : batch_start_idx + batch_size]
batch_base_attn_df = pd.DataFrame(
batch_attn_scores,
columns=range(n_heads),
).melt(var_name="head_idx", value_name="base_attention_score")
batch_pos_coef_df = pd.DataFrame(batch_pos_coefs, columns=range(n_heads)).melt(
var_name="head_idx", value_name="position_coefficient"
)
batch_counterfactual_attention_df = pd.DataFrame(
{
"query_label": cell_type,
"neighbor_idx": repeat(np.array(batch_indices), "n -> (h n)", h=n_heads),
"neighbor_label": repeat(
np.array(labels_cat_list)[batch_neighbor_labels],
"n -> (h n)",
h=n_heads,
),
"head_idx": batch_base_attn_df["head_idx"],
"base_attention_score": batch_base_attn_df["base_attention_score"],
"position_coefficient": batch_pos_coef_df["position_coefficient"],
"dummy_attention_score": model.module.attention_dummy_score,
"distance_kernel_unit_scale": model.module.distance_kernel_unit_scale,
}
)
counterfactual_attention_dfs.append(batch_counterfactual_attention_df)
batch_start_idx += batch_size
counterfactual_attention_df = pd.concat(counterfactual_attention_dfs, axis=0, ignore_index=True)
counterfactual_attention_df = counterfactual_attention_df[
counterfactual_attention_df["neighbor_label"] != cell_type
]
return cls(
_adata=adata,
_labels_key=_labels_key,
_counterfactual_attention_df=counterfactual_attention_df,
_compute_kwargs=_compute_kwargs,
)
[docs]
def save(self, save_path: str):
"""Save counterfactual attention scores to file"""
self._counterfactual_attention_df.to_csv(save_path)
return self
[docs]
def calculate_counterfactual_attention_at_distances(
self,
head_idx: int,
distances: list[float],
):
"""
Evaluate counterfactual attention score functions at given distances.
Args:
head_idx (int): Head index to evaluate.
distances (list[float]): List of distances to evaluate.
Returns
-------
pd.DataFrame: DataFrame containing the counterfactual attention scores with columns:
- query_label: Cell type label of the query cell
- neighbor_idx: Index of the neighbor cell
- neighbor_label: Cell type label of the neighbor cell
- head_{head_idx}: Attention score for that neighbor for head head_idx
- distance: Counterfactual distance of the neighbor
"""
counterfactual_attention_eval_dfs = []
head_counterfactual_attention_df = self._counterfactual_attention_df.loc[
self._counterfactual_attention_df["head_idx"] == head_idx
]
base_attention_score = head_counterfactual_attention_df["base_attention_score"].to_numpy()
pos_coef = head_counterfactual_attention_df["position_coefficient"].to_numpy()
dummy_attention_score = head_counterfactual_attention_df["dummy_attention_score"].to_numpy()
distance_kernel_unit_scale = head_counterfactual_attention_df["distance_kernel_unit_scale"].to_numpy()
for distance in distances:
attention_score = base_attention_score - pos_coef * (distance / distance_kernel_unit_scale)
attention_pattern = np.exp(attention_score) / (np.exp(attention_score) + np.exp(dummy_attention_score))
counterfactual_attention_eval_dfs.append(
pd.DataFrame.from_dict(
{
"query_label": head_counterfactual_attention_df["query_label"],
"neighbor_idx": head_counterfactual_attention_df["neighbor_idx"],
"neighbor_label": head_counterfactual_attention_df["neighbor_label"],
f"head_{head_idx}": attention_pattern,
"distance": distance,
}
)
)
return pd.concat(counterfactual_attention_eval_dfs, axis=0)
[docs]
def plot_counterfactual_attention_summary(
self,
head_idx,
distances,
neighbor_ct_sub=None,
palette=None,
save_dir="./figures",
save_svg=False,
save_png=False,
show=True,
wandb_log=False,
):
"""
Plot a summary of counterfactual attention patterns for the query cell type and neighbors in counterfactual_attention_df.
Args:
counterfactual_attention_df (pd.DataFrame): DataFrame containing counterfactual attention patterns
as returned by `model.get_counterfactual_attention_scores()`.
head_idx (int): Head index to plot.
distances (list): List of distances to plot.
neighbor_ct_sub (list, optional): List of neighbor cell types to plot.
palette (str or list, optional): Color palette for the plot.
save_dir (str, optional): Directory to save the plot.
save_svg (bool, optional): Whether to save the plot as an SVG file.
save_png (bool, optional): Whether to save the plot as a PNG file.
show (bool, optional): Whether to show the plot.
wandb_log (bool, optional): Whether to log the plot to Weights and Biases.
"""
counterfactual_attention_eval_df = self.calculate_counterfactual_attention_at_distances(head_idx, distances)
plt.figure(figsize=(10, 6))
legend_elements = []
if neighbor_ct_sub is None:
neighbor_ct_sub = counterfactual_attention_eval_df["neighbor_label"].unique()
if palette is None:
palette = {neighbor_ct: sns.color_palette("tab10")[i] for i, neighbor_ct in enumerate(neighbor_ct_sub)}
for neighbor_ct in neighbor_ct_sub:
query_label = counterfactual_attention_eval_df["query_label"].unique()[0]
neighbor_df = counterfactual_attention_eval_df[
counterfactual_attention_eval_df["neighbor_label"] == neighbor_ct
]
attention_col = f"head_{head_idx}"
# Calculate mean and standard deviation
grouped = neighbor_df.groupby("distance")[attention_col].agg(["mean", "std"]).reset_index()
# Sample neighbors for background traces
neighbor_idx_sample = np.random.choice(
neighbor_df["neighbor_idx"].unique(),
size=min(100, len(neighbor_df["neighbor_idx"].unique())),
replace=False,
)
# Plot individual traces with the neighbor type color
for idx in neighbor_idx_sample:
subset = neighbor_df[neighbor_df["neighbor_idx"] == idx]
plt.plot(
subset["distance"],
subset[attention_col],
color=palette[neighbor_ct],
alpha=0.1,
linewidth=0.8,
)
# Plot mean and confidence interval with the same color
plt.plot(
grouped["distance"],
grouped["mean"],
color=palette[neighbor_ct],
linewidth=2,
label="Mean",
)
plt.fill_between(
grouped["distance"],
grouped["mean"] - grouped["std"],
grouped["mean"] + grouped["std"],
color=palette[neighbor_ct],
alpha=0.3,
label="±1 std",
)
plt.xlabel("Distance")
plt.ylabel("Attention Score")
plt.title(f"Counterfactual Attention Patterns for {query_label} (Head {head_idx})")
# Create legend with matching colors
legend_elements.append(
mpatches.Patch(facecolor=palette[neighbor_ct], label=neighbor_ct),
)
legend_elements.append(mpatches.Patch(facecolor=palette[neighbor_ct], alpha=0.3, label="±1 std"))
plt.legend(
title="Neighbor Cell Type",
bbox_to_anchor=(1.05, 1),
loc="upper left",
handles=legend_elements,
)
plt.ylim(0, 1)
plt.tight_layout()
if wandb_log:
wandb.log(
{
f"Counterfactual Attention Patterns for {query_label}": wandb.Image(plt),
}
)
if save_svg:
plt.savefig(
os.path.join(
save_dir,
f"counterfactual_attention_patterns_head_{head_idx}_celltype_{query_label}.svg",
)
)
if save_png:
plt.savefig(
os.path.join(
save_dir,
f"counterfactual_attention_patterns_head_{head_idx}_celltype_{query_label}.png",
)
)
if show:
plt.show()
plt.close()
def _correct_length_scale_artifacts(
self,
length_scale_df: pd.DataFrame,
sample_threshold: float,
) -> pd.DataFrame:
"""
Check for length scales that may be artifacts due to a small fraction of receivers with senders within the length scale in the sample data.
Args:
length_scale_df (pd.DataFrame): The length scale distribution to correct.
sample_threshold (float): The threshold for the number of samples to consider for the correction.
Returns
-------
pd.DataFrame: The length scales with the artifacts corrected by setting length scales to 0 if there are not enough samples.
"""
# Get the median length scales and the query label
median_length_scale_df = pd.DataFrame(
length_scale_df.groupby(["head_idx", "sender_type"])["length_scale"].median()
).reset_index()
query_label = self._counterfactual_attention_df["query_label"].unique()[0]
for sender_type in median_length_scale_df["sender_type"].unique():
# Look at the sender type of interest only
median_length_sender_df = median_length_scale_df[median_length_scale_df["sender_type"] == sender_type]
for length_scale, head_idx in np.array(median_length_sender_df[["length_scale", "head_idx"]]):
# Get receiver cells, indices and distances to nearest neighbors
receiver_idxs = np.where(self._adata.obs[self._labels_key] == query_label)[0]
nn_idxs = self._adata.obsm["_nn_idx"][receiver_idxs]
nn_dists = self._adata.obsm["_nn_dist"][receiver_idxs]
# Check how many nearest neighbors are there of the sender type within the length scale
nn_labels = self._adata.obs[self._labels_key].values[rearrange(nn_idxs, "b n -> (b n)")]
nn_labels_sender_dist = (
rearrange(np.array(nn_labels), "(b n) -> b n", b=nn_idxs.shape[0]) == sender_type
) & (nn_dists < length_scale)
count_lt_d = nn_labels_sender_dist.sum(-1)
count_receivers = (count_lt_d >= 1).sum()
all_receivers = (self._adata.obs[self._labels_key] == query_label).sum()
# Use the threshold to correct the length scale to 0 if the number of receivers is too low
if count_receivers / all_receivers < sample_threshold:
length_scale_df.loc[
(length_scale_df["head_idx"] == head_idx) & (length_scale_df["sender_type"] == sender_type),
"length_scale",
] = 0
return length_scale_df
def _calculate_length_scales(
self,
head_idxs: list[int],
sender_types: list[str],
attention_threshold: float = 0.5,
sample_threshold: float = 0.02,
):
"""
Compute the length scales for each neighbor of the query cell type for the given head indices and a set of given sender cell types.
Args:
head_idxs (list[int]): The head indices to analyze.
sender_types (list[str]): The sender cell types to analyze.
attention_threshold (float, optional): The attention threshold below which we consider the length scale.
sample_threshold (float, optional): The threshold for the number of samples to consider for the correction.
Returns
-------
pd.DataFrame: A DataFrame containing the length scales for each head for the given sender cell types.
"""
assert (
self._counterfactual_attention_df["query_label"].unique() not in sender_types
), "Sender type cannot be the same as the query label"
length_scales_per_head = []
for head_idx in head_idxs:
for sender_type in sender_types:
# Filter for head and sender type
head_counterfactual_attention_df = self._counterfactual_attention_df.loc[
(self._counterfactual_attention_df["head_idx"] == head_idx)
& (self._counterfactual_attention_df["neighbor_label"] == sender_type)
]
base_attention_score = head_counterfactual_attention_df["base_attention_score"].to_numpy()
dummy_attention_score = head_counterfactual_attention_df["dummy_attention_score"].to_numpy()
pos_coef = head_counterfactual_attention_df["position_coefficient"].to_numpy()
distance_kernel_unit_scale = head_counterfactual_attention_df["distance_kernel_unit_scale"].to_numpy()
# Calculate the length scales with attention threshold
length_scales = (distance_kernel_unit_scale / pos_coef) * (
np.log((1 - attention_threshold) / attention_threshold)
+ base_attention_score
- dummy_attention_score
)
length_scale_per_head = pd.DataFrame(
{
"head_idx": np.repeat(head_idx, len(length_scales)),
"sender_type": np.repeat(sender_type, len(length_scales)),
"length_scale": length_scales,
"neighbor_idx": head_counterfactual_attention_df["neighbor_idx"].to_numpy(),
}
)
length_scales_per_head.append(length_scale_per_head)
length_scale_df = pd.concat(length_scales_per_head, axis=0)
length_scale_df["length_scale"].clip(lower=0, inplace=True)
# Before returning the length scales, correct for the artifact containing false length scales due to spurious counterfactual attention
length_scale_df = self._correct_length_scale_artifacts(length_scale_df, sample_threshold)
return length_scale_df
def _compute_length_scale_bootstrap_ci(
self,
length_scale_df: pd.DataFrame,
confidence_level: float = 0.95,
n_bootstrap: int = 10000,
random_state: int | None = 42,
statistic: str = "mean",
) -> pd.DataFrame:
"""
Compute bootstrap confidence intervals for length scales per head and sender type.
Args:
length_scale_df: DataFrame containing length scales.
confidence_level: The confidence level for the interval.
n_bootstrap: Number of bootstrap resamples.
random_state: Random seed for reproducibility.
statistic: The statistic to bootstrap, either "mean" or "median".
Returns
-------
pd.DataFrame: DataFrame with columns:
- head_idx: The attention head index
- sender_type: The sender cell type
- statistic: The statistic used ("mean" or "median")
- point_estimate: The point estimate (mean or median)
- ci_lower: Lower bound of the bootstrap confidence interval
- ci_upper: Upper bound of the bootstrap confidence interval
- n_samples: Number of samples used
"""
if statistic not in ("mean", "median"):
raise ValueError(f"statistic must be 'mean' or 'median', got '{statistic}'")
stat_func = np.mean if statistic == "mean" else np.median
rng = np.random.default_rng(random_state)
results = []
for (head_idx, sender_type), group in length_scale_df.groupby(["head_idx", "sender_type"]):
length_scales = group["length_scale"].to_numpy()
n = len(length_scales)
if n == 0:
continue
# Bootstrap resampling
bootstrap_stats = np.empty(n_bootstrap)
for i in range(n_bootstrap):
resample_idx = rng.integers(0, n, size=n)
bootstrap_stats[i] = stat_func(length_scales[resample_idx])
# Compute percentile confidence interval
alpha = 1 - confidence_level
ci_lower = np.percentile(bootstrap_stats, 100 * alpha / 2)
ci_upper = np.percentile(bootstrap_stats, 100 * (1 - alpha / 2))
results.append(
{
"head_idx": head_idx,
"sender_type": sender_type,
"statistic": statistic,
"point_estimate": stat_func(length_scales),
"ci_lower": ci_lower,
"ci_upper": ci_upper,
"n_samples": n,
}
)
return pd.DataFrame(results)
[docs]
def plot_length_scale_distribution(
self,
head_idxs: list[int],
sender_types: list[str],
attention_threshold: float = 0.5,
max_length_scale: float = 50,
sample_threshold: float = 0.02,
palette: dict | None = None,
show: bool = True,
save_png: bool = False,
save_svg: bool = False,
save_dir: str = "./figures",
show_ci: bool = False,
confidence_level: float = 0.95,
n_bootstrap: int = 10000,
ci_statistic: str = "mean",
):
"""Plot the distribution of length scales for each head and sender cell types based on counterfactual attention patterns.
Args:
head_idxs (list[int]): The head indices to analyze.
sender_types (list[str]): The sender cell types to analyze.
attention_threshold (float, optional): The attention threshold.
max_length_scale (float, optional): The maximum length scale for the plot.
sample_threshold (float, optional): The sample threshold for correction.
palette (dict, optional): Dictionary mapping sender cell types to colors.
show (bool, optional): Whether to display the plot.
save_png (bool, optional): Whether to save the plot as a PNG file.
save_svg (bool, optional): Whether to save the plot as an SVG file.
save_dir (str, optional): The directory to save the plot files.
show_ci (bool, optional): Whether to overlay bootstrap confidence intervals.
confidence_level (float, optional): Confidence level for bootstrap CI.
n_bootstrap (int, optional): Number of bootstrap resamples.
ci_statistic (str, optional): Statistic for bootstrap CI.
Returns
-------
tuple: (length_scale_df, ci_df) if show_ci=True, else length_scale_df
"""
cell_type = self._counterfactual_attention_df["query_label"].unique()[0]
length_scale_df = self._calculate_length_scales(head_idxs, sender_types, attention_threshold, sample_threshold)
# Compute bootstrap CIs if requested
ci_df = None
if show_ci:
ci_df = self._compute_length_scale_bootstrap_ci(
length_scale_df=length_scale_df,
confidence_level=confidence_level,
n_bootstrap=n_bootstrap,
statistic=ci_statistic,
)
fig, ax = plt.subplots(figsize=(12, 6))
median_length_scale = length_scale_df.groupby(["sender_type", "head_idx"])["length_scale"].median()
max_per_sender = median_length_scale.groupby("sender_type").max()
sender_types_order = list(max_per_sender.sort_values(ascending=False).index)
if palette is None:
color_palette = sns.color_palette("tab10", n_colors=len(sender_types_order))
palette = {st: color_palette[i] for i, st in enumerate(sender_types_order)}
sns.boxplot(
data=length_scale_df,
x="head_idx",
y="length_scale",
hue="sender_type",
hue_order=sender_types_order,
palette=palette,
dodge=True,
fliersize=0.05,
ax=ax,
)
if show_ci and ci_df is not None:
n_sender_types = len(sender_types_order)
box_width = 0.8 / n_sender_types
bracket_offset = (box_width / 2) + 0.08
for i, head_idx in enumerate(head_idxs):
for j, sender_type in enumerate(sender_types_order):
ci_row = ci_df[(ci_df["head_idx"] == head_idx) & (ci_df["sender_type"] == sender_type)]
if len(ci_row) == 0:
continue
ci_lower = ci_row["ci_lower"].values[0]
ci_upper = ci_row["ci_upper"].values[0]
x_center = i + (j - (n_sender_types - 1) / 2) * box_width
x_bracket = x_center + bracket_offset
bracket_cap_width = 0.06
line_width = 1.0
bracket_color = "black"
ax.plot(
[x_bracket, x_bracket],
[ci_lower, ci_upper],
color=bracket_color,
linewidth=line_width,
solid_capstyle="butt",
zorder=10,
)
ax.plot(
[x_bracket - bracket_cap_width, x_bracket],
[ci_lower, ci_lower],
color=bracket_color,
linewidth=line_width,
solid_capstyle="butt",
zorder=10,
)
ax.plot(
[x_bracket - bracket_cap_width, x_bracket],
[ci_upper, ci_upper],
color=bracket_color,
linewidth=line_width,
solid_capstyle="butt",
zorder=10,
)
# Set labels based on the provided head_idxs list
ax.set_xticks(range(len(head_idxs)))
ax.set_xticklabels([f"Head {h}" for h in head_idxs])
ax.set_xlabel("Attention Head")
ylabel = f"Length Scale (distance where attention ≤ {attention_threshold})"
if show_ci:
ylabel += f"\n(brackets: {ci_statistic} {int(confidence_level*100)}% CI)"
ax.set_ylabel(ylabel)
ax.set_ylim(-0.5, max_length_scale)
ax.set_title(f"Length Scale Distribution by Head and Sender Type for {cell_type}")
ax.grid(True, alpha=0.3)
# Adjust legend position
ax.legend(title="Sender Type", bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()
if save_png:
plt.savefig(f"{save_dir}/length_scale_distribution_{cell_type}.png")
if save_svg:
plt.savefig(f"{save_dir}/length_scale_distribution_{cell_type}.svg")
if show:
plt.show()
plt.close()
if show_ci:
return length_scale_df, ci_df
return length_scale_df