amici.interpretation.AMICICounterfactualAttentionModule#
- class amici.interpretation.AMICICounterfactualAttentionModule(_adata=None, _compute_kwargs=None, _counterfactual_attention_df=None, _labels_key=None)[source]#
Bases:
objectMethods
__init__Evaluate counterfactual attention score functions at given distances.
Compute the counterfactual attention patterns for a given cell type and head indices.
Plot a summary of counterfactual attention patterns for the query cell type and neighbors in counterfactual_attention_df.
Plot the distribution of length scales for each head and sender cell types based on counterfactual attention patterns.
Save counterfactual attention scores to file
- classmethod compute(model, cell_type, adata=None, indices=None, head_idxs=None, batch_size=None)[source]#
Compute the counterfactual attention patterns for a given cell type and head indices.
- Parameters:
model (
AMICI) – The trained AMICI model.cell_type (
str) – The index cell type to get counterfactual attention scores for.adata (
AnnData|None) – Optional AnnData object to use. If None, uses model.adata.indices (
list[int] |None) – Optional list of cell indices to get scores for. If None, uses all cells.head_idxs (
list[int] |None) – Optional list of attention head indices to get scores for. If None, uses all heads.batch_size (
int|None) – Optional batch size to use for the data loader.
- Return type:
- Returns:
AMICICounterfactualAttentionModule: The module instance with computed counterfactual attention scores stored in _counterfactual_attention_df with columns:
query_label: Cell type label of the query cell
neighbor_idx: Index of the neighbor cell
neighbor_label: Cell type label of the neighbor cell
head_idx: Attention head index
base_attention_score: Base attention score
pos_coef: Position coefficient of the counterfactual attention score function
dummy_attention_score: Dummy attention score of the model
distance_kernel_unit_scale: Distance kernel unit scale
- calculate_counterfactual_attention_at_distances(head_idx, distances)[source]#
Evaluate counterfactual attention score functions at given distances.
- Parameters:
- Returns:
- pd.DataFrame: DataFrame containing the counterfactual attention scores with columns:
query_label: Cell type label of the query cell
neighbor_idx: Index of the neighbor cell
neighbor_label: Cell type label of the neighbor cell
head_{head_idx}: Attention score for that neighbor for head head_idx
distance: Counterfactual distance of the neighbor
- plot_counterfactual_attention_summary(head_idx, distances, neighbor_ct_sub=None, palette=None, save_dir='./figures', save_svg=False, save_png=False, show=True, wandb_log=False)[source]#
Plot a summary of counterfactual attention patterns for the query cell type and neighbors in counterfactual_attention_df.
- Parameters:
counterfactual_attention_df (pd.DataFrame) – DataFrame containing counterfactual attention patterns as returned by model.get_counterfactual_attention_scores().
head_idx (int) – Head index to plot.
distances (list) – List of distances to plot.
neighbor_ct_sub (list, optional) – List of neighbor cell types to plot.
palette (str or list, optional) – Color palette for the plot.
save_dir (str, optional) – Directory to save the plot.
save_svg (bool, optional) – Whether to save the plot as an SVG file.
save_png (bool, optional) – Whether to save the plot as a PNG file.
show (bool, optional) – Whether to show the plot.
wandb_log (bool, optional) – Whether to log the plot to Weights and Biases.
- plot_length_scale_distribution(head_idxs, sender_types, attention_threshold=0.5, max_length_scale=50, sample_threshold=0.02, palette=None, show=True, save_png=False, save_svg=False, save_dir='./figures', show_ci=False, confidence_level=0.95, n_bootstrap=10000, ci_statistic='mean')[source]#
Plot the distribution of length scales for each head and sender cell types based on counterfactual attention patterns.
- Parameters:
sender_types (list[str]) – The sender cell types to analyze.
attention_threshold (float, optional) – The attention threshold.
max_length_scale (float, optional) – The maximum length scale for the plot.
sample_threshold (float, optional) – The sample threshold for correction.
palette (dict, optional) – Dictionary mapping sender cell types to colors.
show (bool, optional) – Whether to display the plot.
save_png (bool, optional) – Whether to save the plot as a PNG file.
save_svg (bool, optional) – Whether to save the plot as an SVG file.
save_dir (str, optional) – The directory to save the plot files.
show_ci (bool, optional) – Whether to overlay bootstrap confidence intervals.
confidence_level (float, optional) – Confidence level for bootstrap CI.
n_bootstrap (int, optional) – Number of bootstrap resamples.
ci_statistic (str, optional) – Statistic for bootstrap CI.
- Returns:
tuple: (length_scale_df, ci_df) if show_ci=True, else length_scale_df