amici.interpretation.AMICICounterfactualAttentionModule#

class amici.interpretation.AMICICounterfactualAttentionModule(_adata=None, _compute_kwargs=None, _counterfactual_attention_df=None, _labels_key=None)[source]#

Bases: object

Methods

__init__

calculate_counterfactual_attention_at_distances

Evaluate counterfactual attention score functions at given distances.

compute

Compute the counterfactual attention patterns for a given cell type and head indices.

plot_counterfactual_attention_summary

Plot a summary of counterfactual attention patterns for the query cell type and neighbors in counterfactual_attention_df.

plot_length_scale_distribution

Plot the distribution of length scales for each head and sender cell types based on counterfactual attention patterns.

save

Save counterfactual attention scores to file

classmethod compute(model, cell_type, adata=None, indices=None, head_idxs=None, batch_size=None)[source]#

Compute the counterfactual attention patterns for a given cell type and head indices.

Parameters:
  • model (AMICI) – The trained AMICI model.

  • cell_type (str) – The index cell type to get counterfactual attention scores for.

  • adata (AnnData | None) – Optional AnnData object to use. If None, uses model.adata.

  • indices (list[int] | None) – Optional list of cell indices to get scores for. If None, uses all cells.

  • head_idxs (list[int] | None) – Optional list of attention head indices to get scores for. If None, uses all heads.

  • batch_size (int | None) – Optional batch size to use for the data loader.

Return type:

AMICICounterfactualAttentionModule

Returns:

AMICICounterfactualAttentionModule: The module instance with computed counterfactual attention scores stored in _counterfactual_attention_df with columns:

  • query_label: Cell type label of the query cell

  • neighbor_idx: Index of the neighbor cell

  • neighbor_label: Cell type label of the neighbor cell

  • head_idx: Attention head index

  • base_attention_score: Base attention score

  • pos_coef: Position coefficient of the counterfactual attention score function

  • dummy_attention_score: Dummy attention score of the model

  • distance_kernel_unit_scale: Distance kernel unit scale

save(save_path)[source]#

Save counterfactual attention scores to file

calculate_counterfactual_attention_at_distances(head_idx, distances)[source]#

Evaluate counterfactual attention score functions at given distances.

Parameters:
  • head_idx (int) – Head index to evaluate.

  • distances (list[float]) – List of distances to evaluate.

Returns:

pd.DataFrame: DataFrame containing the counterfactual attention scores with columns:
  • query_label: Cell type label of the query cell

  • neighbor_idx: Index of the neighbor cell

  • neighbor_label: Cell type label of the neighbor cell

  • head_{head_idx}: Attention score for that neighbor for head head_idx

  • distance: Counterfactual distance of the neighbor

plot_counterfactual_attention_summary(head_idx, distances, neighbor_ct_sub=None, palette=None, save_dir='./figures', save_svg=False, save_png=False, show=True, wandb_log=False)[source]#

Plot a summary of counterfactual attention patterns for the query cell type and neighbors in counterfactual_attention_df.

Parameters:
  • counterfactual_attention_df (pd.DataFrame) – DataFrame containing counterfactual attention patterns as returned by model.get_counterfactual_attention_scores().

  • head_idx (int) – Head index to plot.

  • distances (list) – List of distances to plot.

  • neighbor_ct_sub (list, optional) – List of neighbor cell types to plot.

  • palette (str or list, optional) – Color palette for the plot.

  • save_dir (str, optional) – Directory to save the plot.

  • save_svg (bool, optional) – Whether to save the plot as an SVG file.

  • save_png (bool, optional) – Whether to save the plot as a PNG file.

  • show (bool, optional) – Whether to show the plot.

  • wandb_log (bool, optional) – Whether to log the plot to Weights and Biases.

plot_length_scale_distribution(head_idxs, sender_types, attention_threshold=0.5, max_length_scale=50, sample_threshold=0.02, palette=None, show=True, save_png=False, save_svg=False, save_dir='./figures', show_ci=False, confidence_level=0.95, n_bootstrap=10000, ci_statistic='mean')[source]#

Plot the distribution of length scales for each head and sender cell types based on counterfactual attention patterns.

Parameters:
  • head_idxs (list[int]) – The head indices to analyze.

  • sender_types (list[str]) – The sender cell types to analyze.

  • attention_threshold (float, optional) – The attention threshold.

  • max_length_scale (float, optional) – The maximum length scale for the plot.

  • sample_threshold (float, optional) – The sample threshold for correction.

  • palette (dict, optional) – Dictionary mapping sender cell types to colors.

  • show (bool, optional) – Whether to display the plot.

  • save_png (bool, optional) – Whether to save the plot as a PNG file.

  • save_svg (bool, optional) – Whether to save the plot as an SVG file.

  • save_dir (str, optional) – The directory to save the plot files.

  • show_ci (bool, optional) – Whether to overlay bootstrap confidence intervals.

  • confidence_level (float, optional) – Confidence level for bootstrap CI.

  • n_bootstrap (int, optional) – Number of bootstrap resamples.

  • ci_statistic (str, optional) – Statistic for bootstrap CI.

Returns:

tuple: (length_scale_df, ci_df) if show_ci=True, else length_scale_df