Source code for amici.interpretation._explained_variance_module

import os
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import wandb
from anndata import AnnData
from scipy.stats import false_discovery_control
from scvi import REGISTRY_KEYS

from ._utils import _get_compute_method_kwargs

if TYPE_CHECKING:
    from amici._model import AMICI


[docs] @dataclass class AMICIExplainedVarianceModule: _adata: AnnData | None = None _explained_variance_df: pd.DataFrame | None = None _compute_kwargs: dict | None = None _n_permutations: int | None = None
[docs] @classmethod def compute( cls, model: "AMICI", adata: AnnData | None = None, alpha: float = 0.05, run_permutation_test: bool = True, n_permutations: int | None = None, ) -> "AMICIExplainedVarianceModule": """Creates a DataFrame with the explained variance scores. DataFrame contains the explained variance scores for each head across cells for all cell types or a subset of cell types. Args: adata (AnnData, optional): The AnnData object. alpha (float, optional): The desired significance level for the explained variance of each head-cell type pair via permutation testing. Defaults to 0.05. run_permutation_test (bool, optional): Whether to run the permutation test. Defaults to True. n_permutations (int, optional): The number of permutations to use for the permutation test. If not provided, the minimum number of permutations required to achieve a significance level of alpha will be used. Returns ------- AMICIExplainedVarianceModule: The module instance with computed explained variance scores stored in _explained_variance_df with columns: - head: The index of the head. - ct_name: The name of the cell type. - gene: The name of the gene. - expl_var_head_gene: The explained variance score when all but head is ablated for the gene. - expl_var_head: The explained variance score when all but head aggregated across all genes. if run_permutation_test is True, the following columns are also added: - p_value_head: The p-value of the explained variance score when all but head is ablated. - p_value_adj_head: The adjusted p-value of the explained variance score when all but head is ablated. """ _compute_kwargs = _get_compute_method_kwargs(**locals()) model._check_if_trained(warn=True) adata = model._validate_anndata(adata) labels_key = model.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key n_permutations = ( cls._min_n_permutations(model.module.n_heads, len(adata.obs[labels_key].unique()), alpha) if n_permutations is None else n_permutations ) expl_variance_scores = [] p_values = [] for head_idx in range(model.module.n_heads): for ct in list(adata.obs[labels_key].unique()): predictions_head, gene_exp = cls._get_ct_predictions_for_ablated_head( model, adata=adata, head_idx=head_idx, cell_type=ct, ) expl_var_head = cls._compute_expl_variance(predictions_head - gene_exp, gene_exp) expl_var_head_agg = cls._compute_expl_variance( predictions_head - gene_exp, gene_exp, aggregate_across_genes=True ) if run_permutation_test: p_value_head = cls._run_permutation_test_expl_variance(predictions_head, gene_exp, n_permutations) p_values.append(p_value_head) for i, gene in enumerate(adata.var_names): explained_variance_score_row = { "head": head_idx, "ct_name": ct, "gene": gene, "expl_var_head_gene": expl_var_head[i], "expl_var_head": expl_var_head_agg, } if run_permutation_test: explained_variance_score_row["p_value_head"] = p_value_head expl_variance_scores.append(explained_variance_score_row) expl_variance_scores_df = pd.DataFrame(expl_variance_scores) # adjust p-values via B-H correction if run_permutation_test: p_values_adj = false_discovery_control(p_values) p_value_adj_map = dict(zip(p_values, p_values_adj)) expl_variance_scores_df["p_value_adj_head"] = expl_variance_scores_df["p_value_head"].map(p_value_adj_map) return cls( _adata=adata, _explained_variance_df=expl_variance_scores_df, _compute_kwargs=_compute_kwargs, _n_permutations=n_permutations, )
@staticmethod def _min_n_permutations(n_heads: int, n_ct: int, alpha: float) -> int: """Compute the minimum number of permutations required to achieve a significance level of alpha for the explained variance scores. Args: n_heads (int): The number of heads. n_ct (int): The number of cell types. alpha (float): The desired significance level. Returns ------- int: The minimum number of permutations required to achieve a significance level of alpha for the explained variance scores. """ return int(np.ceil((n_heads * n_ct) / alpha)) @staticmethod def _compute_expl_variance( residuals: pd.DataFrame, gene_exp: pd.DataFrame, aggregate_across_genes: bool = False, ): """Compute the explained variance scores for the given cell type residuals and ground truth gene expression for cell type of interest. Args: residuals (pd.DataFrame): The residuals for the cell type of interest when ablating all heads except one. gene_exp (pd.DataFrame): The ground truth gene expression for the cell type of interest. aggregate_across_genes (bool, optional): Whether to aggregate the explained variance across all genes. Returns ------- np.ndarray: The explained variance scores for the cell type and head of interest. """ if aggregate_across_genes: expl_var = ( 1 - torch.var(torch.tensor(residuals.values), dim=0).sum() / torch.var(torch.tensor(gene_exp.values), dim=0).sum() ).item() else: expl_var = ( 1 - torch.var(torch.tensor(residuals.values), dim=0) / torch.var(torch.tensor(gene_exp.values), dim=0) ).numpy() expl_var = np.where(torch.var(torch.tensor(gene_exp.values), dim=0) == 0, 0, expl_var) return expl_var @staticmethod def _run_permutation_test_expl_variance( predictions: pd.DataFrame, # n_ct_cells x n_genes gene_exp: pd.DataFrame, # n_ct_cells x n_genes n_permutations: int, batch_size: int = 32, ): """Run a permutation test to compute the p-value of the explained variance scores. Args: predictions (torch.Tensor): The predictions tensor of shape n_ct_cells x n_genes. gene_exp (torch.Tensor): The ground truth gene expression tensor of shape n_ct_cells x n_genes. n_permutations (int): The number of permutations to run. batch_size (int, optional): The batch size to use for the permutation test. Returns ------- float: The one-sided p-value computed from the permutation test. """ n_ct_cells = predictions.shape[0] predictions = predictions.values gene_exp = gene_exp.values original_mse = np.mean((predictions - gene_exp) ** 2).item() lower_mse_count = 0 for batch_idx in range(0, n_permutations, batch_size): batch_end = min(batch_idx + batch_size, n_permutations) batch_size_actual = batch_end - batch_idx perm_indices = np.stack( [np.random.permutation(n_ct_cells) for _ in range(batch_size_actual)] ) # batch_size x n_ct_cells permuted_preds = predictions[perm_indices] # batch_size x n_ct_cells x n_genes expanded_gene_exp = gene_exp[None, :] # 1 x n_ct_cells x n_genes batch_mses = np.mean((permuted_preds - expanded_gene_exp) ** 2, axis=(1, 2)) lower_mse_count += np.sum(batch_mses <= original_mse) p_value = lower_mse_count / n_permutations return p_value @staticmethod def _get_ct_predictions_for_ablated_head( model: "AMICI", adata: AnnData, head_idx: int = -1, cell_type: str | None = None, ) -> tuple[pd.DataFrame, pd.DataFrame]: """Gene expression residuals and ground truth gene expression for head_idx and cell_type. Args: adata (AnnData, optional): The AnnData object. head_idx (int, optional): The index of the head to ablate or to not ablate. Defaults to -1. cell_type (str, optional): The cell type to compute the residuals for. Returns ------- Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing: - predictions: A DataFrame of predictions of shape n_ct_cells x n_genes. - gene_exp: A DataFrame of ground truth gene expression of shape n_ct_cells x n_genes. """ def _mt_heads_ablation_hook( attn_result, hook, head_idxs_to_ablate, neighbor_mask=None, ): """Hook function to ablate heads. Args: attn_result (torch.Tensor): The attention result. head_idxs_to_ablate (list): The indices of the heads to zero out. neighbor_mask (torch.Tensor, optional): The neighbor mask. Returns ------- torch.Tensor: The modified attention result with specified heads ablated. """ for head_idx in head_idxs_to_ablate: attn_result[:, head_idx, :, :] = 0.0 if neighbor_mask is None: neighbor_mask = torch.ones_like(attn_result) else: # append dummy neighbor col neighbor_mask = torch.cat( [ neighbor_mask, torch.ones((neighbor_mask.shape[0], 1)).to(neighbor_mask.device), ], dim=1, ) neighbor_mask = neighbor_mask.unsqueeze(1).unsqueeze(2).expand_as(attn_result) attn_result *= neighbor_mask return attn_result assert cell_type is not None, "cell_type must be specified" labels_key = model.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key ct_indices = np.arange(len(adata))[adata.obs[labels_key] == cell_type] scdl = model._make_data_loader(adata=adata, indices=ct_indices) predictions = [] gene_expressions = [] for tensors in scdl: true_X = tensors[REGISTRY_KEYS.X_KEY].cpu() tensors = {k: v.to(model.device) for k, v in tensors.items()} model.module.reset_hooks() head_idxs_to_ablate = [i for i in range(model.module.n_heads) if i != head_idx] head_hook_fn = partial( _mt_heads_ablation_hook, head_idxs_to_ablate=head_idxs_to_ablate, ) prediction = ( model.module.run_with_hooks( tensors, fwd_hooks=[("attention_layer.hook_pattern", head_hook_fn)], )[1]["prediction"] .detach() .cpu() ) predictions.append(prediction) gene_expressions.append(true_X) all_predictions = torch.cat(predictions) gene_exp = torch.cat(gene_expressions) predictions_df = pd.DataFrame(all_predictions.numpy(), columns=adata.var_names) gene_exp_df = pd.DataFrame(gene_exp.numpy(), columns=adata.var_names) return predictions_df, gene_exp_df
[docs] def save(self, save_path: str): """Save explained variance scores to file""" self._explained_variance_df.to_csv(save_path) return self
[docs] def compute_max_explained_variance_head( self, cell_type: str, ) -> int: """ Compute the head with the maximum explained variance for the given cell type. Args: cell_type (str): The cell type to compute the maximum explained variance head for. Returns ------- int: The head with the maximum explained variance for the given cell type. """ return ( self._explained_variance_df[self._explained_variance_df["ct_name"] == cell_type] .groupby("head")["expl_var_head_gene"] .max() .idxmax() )
[docs] def plot_explained_variance_barplot( self, palette=None, cell_type_sub=None, wandb_log=False, show=True, save_png=False, save_svg=False, save_dir="./figures", ): """ Plots a barplot of the maximum explained variance score for each head and each cell type. Args: expl_variance_df (pd.DataFrame): A DataFrame containing the explained variance scores for each head and each cell type as returned by `model.get_expl_variance_scores()`. palette (str, optional): The color palette to use for the barplot. Defaults to None. cell_type_sub (list[str], optional): A list of cell types to include in the barplot. Defaults to None. wandb_log (bool, optional): Whether to log the plot to Weights and Biases. Defaults to True. show (bool, optional): Whether to display the plot. Defaults to False. save_png (bool, optional): Whether to save the plot as a PNG file. Defaults to False. save_svg (bool, optional): Whether to save the plot as an SVG file. Defaults to False. save_dir (str, optional): The directory to save the plot files. Defaults to "./figures". """ if cell_type_sub is not None: expl_variance_sub_df = self._explained_variance_df[ self._explained_variance_df["ct_name"].isin(cell_type_sub) ] else: expl_variance_sub_df = self._explained_variance_df max_var_per_head = expl_variance_sub_df.groupby(["head", "ct_name"])["expl_var_head_gene"].max().reset_index() plt.figure(figsize=(10, 6)) sns.barplot( data=max_var_per_head, x="head", y="expl_var_head_gene", hue="ct_name", palette=palette or "tab10", ) plt.xlabel("Head Index") plt.ylabel("Max Variance Score") plt.title("Maximum Variance Score Across Genes for Each Head Colored by Cell Type") plt.legend(title="Cell Type", bbox_to_anchor=(1, 1)) if wandb_log: wandb.log( { "Explained Variance Barplot per Head per Cell Type": wandb.Image(plt), } ) if save_png: plt.savefig(os.path.join(save_dir, "expl_variance_barplot_per_head.png")) if save_svg: plt.savefig(os.path.join(save_dir, "expl_variance_barplot_per_head.svg")) if show: plt.show() plt.close()
[docs] def plot_featurewise_explained_variance_heatmap( self, cell_type_sub=None, n_top_genes=20, wandb_log=False, show=True, save_png=False, save_svg=False, save_dir="./figures", ): """ Plots a heatmap of the explained variance scores for each head for n_top_genes genes. Args: expl_variance_df (pd.DataFrame): A DataFrame containing the explained variance scores for each head and each cell type as returned by `model.get_expl_variance_scores()" cell_type_sub (list[str], optional): A list of cell types for which to plot the heatmap. n_top_genes (int, optional): The number of top genes to include in the heatmap. wandb_log (bool, optional): Whether to log the plot to Weights and Biases. show (bool, optional): Whether to display the plot. save_png (bool, optional): Whether to save the plot as a PNG file. save_svg (bool, optional): Whether to save the plot as an SVG file. save_dir (str, optional): The directory to save the plot files. Defaults to "./figures". Saves the filename as "expl_variance_heatmap_{cell_type}.png" or "expl_variance_heatmap_{cell_type}.svg". """ if cell_type_sub is not None: expl_variance_sub_df = self._explained_variance_df[ self._explained_variance_df["ct_name"].isin(cell_type_sub) ] else: expl_variance_sub_df = self._explained_variance_df for ct in expl_variance_sub_df["ct_name"].unique(): expl_variance_ct_df = expl_variance_sub_df[expl_variance_sub_df["ct_name"] == ct] top_gene_names = ( expl_variance_ct_df.groupby("gene")["expl_var_head_gene"].mean().nlargest(n_top_genes).index ) expl_variance_ct_df = expl_variance_ct_df[expl_variance_ct_df["gene"].isin(top_gene_names)] pivot_table = expl_variance_ct_df.pivot_table( index="gene", columns="head", values="expl_var_head_gene", aggfunc="mean", sort=False, ) plt.figure(figsize=(10, 8)) sns.heatmap( pivot_table, annot=False, cmap="RdBu_r", cbar_kws={"label": "Explained Variance"}, center=0, xticklabels=True, yticklabels=True, ) plt.title(f"Explained Variance Heatmap by Gene and Head for Cell Type {ct}") plt.xlabel("Head") plt.ylabel("Gene") if wandb_log: wandb.log( { f"Featurewise Explained Variance for Cell Type {ct}": wandb.Image(plt), } ) if save_png: plt.savefig(os.path.join(save_dir, f"expl_variance_heatmap_{ct}.png")) if save_svg: plt.savefig(os.path.join(save_dir, f"expl_variance_heatmap_{ct}.svg")) if show: plt.show() plt.close()