AMICI Quick Start Tutorial#

AMICI is an interpretable attention framework that can be applied to single-cell spatial transcriptomics data that jointly estimates interaction length scales, adaptively resolves sender-receiver subpopulations, and links communication to downstream gene programs.

import warnings; warnings.filterwarnings("ignore") # remove scanpy warnings for the tutorial.
!pip install -q amici-st
import os
import torch
import scanpy as sc
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import seaborn as sns
import matplotlib.pyplot as plt

from amici import AMICI
from amici.callbacks import AttentionPenaltyMonitor
from amici.interpretation import (
    AMICICounterfactualAttentionModule,
    AMICIAttentionModule,
    AMICIAblationModule,
)

Load the and split the data#

Load the AnnData located here. Perform any filtering of cells but retain the raw counts. Split the dataset into a train and test set – we generally do a train-test split of 90%-10%.

# Load data
adata_path = "./data/mouse_cortex_tutorial.h5ad"
adata = sc.read(adata_path, backup_url="https://figshare.com/ndownloader/files/58303438")

# Saving the spatial coordinates in the adata.obsm["spatial"] key
adata.obsm["spatial"] = adata.obs[["centroid_x", "centroid_y"]].values

adata_train = adata[adata.obs['in_test'] == False].copy()
adata_test = adata[adata.obs['in_test'] == True].copy()

print("Train set size: ", adata_train.shape)
print("Test set size: ", adata_test.shape)

# Create the cell type palette
labels_key = "subclass"
CELL_TYPE_PALETTE = {
    # Excitatory Neurons
    "L2/3 IT": "#e41a1c",
    "L4/5 IT": "#ff7f00",
    "L5 IT": "#fdbf6f",
    "L5 ET": "#e31a1c",
    "L6 IT": "#6a3d9a",
    "L6 IT Car3": "#cab2d6",
    "L6 CT": "#fb9a99",
    "L5/6 NP": "#a6cee3",
    "L6b": "#1f78b4",
    # Inhibitory Neurons
    "Pvalb": "#8dd3c7",
    "Sst": "#80b1d3",
    "Lamp5": "#33a02c",
    "Vip": "#b2df8a",
    "Sncg": "#bc80bd",
    # Glial Cells
    "Astro": "#bebada",
    "Oligo": "#fb8072",
    "OPC": "#b3de69",
    "Micro": "#fccde5",
    "VLMC": "#d9d9d9",
    # Vascular Cells
    "Endo": "#ffff33",
    "Peri": "#ffffb3",
    "PVM": "#fdb462",
    "SMC": "#8dd3c7",
    # Other
    "other": "#999999",
}
Train set size:  (27567, 254)
Test set size:  (6264, 254)

Visualize the spatial distribution and the train-test split of the dataset

def visualize_spatial_distribution(adata, labels_key="subclass", x_lim=None, y_lim=None):
    plot_df = pd.DataFrame(adata.obsm["spatial"].copy(), columns=["X", "Y"])
    plot_df[labels_key] = adata.obs[labels_key].values
    plot_df["in_test"] = adata.obs["in_test"].values
    plot_df["slice_id"] = adata.obs["slice_id"].values

    plt.figure(figsize=(8, 6))
    sns.scatterplot(
        plot_df, x="X", y="Y", hue=labels_key, alpha=0.7, s=8, palette=CELL_TYPE_PALETTE
    )

    test_df = plot_df[plot_df["in_test"] == True]
    if len(test_df) > 0:
        min_x, max_x = test_df["X"].min(), test_df["X"].max()
        min_y, max_y = test_df["Y"].min(), test_df["Y"].max()
        width = max_x - min_x
        height = max_y - min_y

        padding = 20
        rect = plt.Rectangle(
            (min_x - padding, min_y - padding),
            width + 2*padding,
            height + 2*padding,
        fill=False,
        color='black',
        linestyle='--',
        linewidth=2,
        label=f'Test Region'
    )
    plt.gca().add_patch(rect)

    plt.xlabel("X")
    plt.ylabel("Y")
    plt.title(f"Spatial distribution of cells in the dataset")

    handles, labels = plt.gca().get_legend_handles_labels()
    plt.legend(
        handles=handles,
        labels=labels,
        bbox_to_anchor=(1.05, 1),
        loc="upper left",
        borderaxespad=0.0,
        markerscale=2
    )

    if x_lim is not None:
        plt.xlim(0, x_lim)
    if y_lim is not None:
        plt.ylim(0, y_lim)
    plt.tight_layout()
    plt.show()

visualize_spatial_distribution(adata)
../_images/30bacf291f1fc9609bed38bf89bc94fff504e5e91ec0f5a6bbd139f3e1372977.png

Setup and train AMICI#

Setup the hyperparameters for training AMICI. For your custom dataset, we suggest performing a sweep over a small range of values for the following hyperparameters:

  • end_attention_penalty

  • epoch_start and epoch_end

  • value_l1_penalty_coef

# Set up the seed for reproducibility
seed = 18
pl.seed_everything(seed)

penalty_schedule_params = {
    "start_val": 1e-6,
    "end_val": 1e-3,
    "epoch_start": 10,
    "epoch_end": 30,
    "flavor": "linear",
}
model_params = {
    "n_heads": 8,
    "n_query_dim": 128,
    "n_head_size": 32,
    "n_nn_embed": 256,
    "n_nn_embed_hidden": 512,
    "attention_dummy_score": 3.0,
    "neighbor_dropout": 0.1,
    "attention_penalty_coef": penalty_schedule_params[
        "start_val"
    ],
    "value_l1_penalty_coef": 1e-5,
}
exp_params = {
    "lr": 1e-3,
    "epochs": 400,
    "batch_size": 512,
    "early_stopping": True,
    "early_stopping_monitor": "elbo_validation",
    "early_stopping_patience": 20,
    "learning_rate_monitor": True,
    "n_neighbors": 50,
}
Seed set to 18

Define the model and setup the AnnData with the model parameters.

AMICI.setup_anndata(
    adata_train,
    labels_key=labels_key,
    coord_obsm_key="spatial",
    n_neighbors=exp_params["n_neighbors"],
)
model = AMICI(adata_train, **model_params)
INFO     Generating sequential column names                                                                        
INFO     Generating sequential column names                                                                        
INFO     Generating sequential column names                                                                        
model_path = os.path.join(
    "./saved_models",
    f"cortex_{seed}_params",
)

plan_kwargs = {}
if "lr" in exp_params:
    plan_kwargs["lr"] = exp_params["lr"]

Train the model using the above defined parameters.

model.train(
    max_epochs=int(exp_params.get("epochs")),
    batch_size=int(exp_params.get("batch_size")),
    plan_kwargs=plan_kwargs,
    early_stopping=exp_params.get("early_stopping"),
    early_stopping_monitor=exp_params.get("early_stopping_monitor"),
    early_stopping_patience=exp_params.get("early_stopping_patience"),
    check_val_every_n_epoch=1,
    callbacks=[
        AttentionPenaltyMonitor(
            **penalty_schedule_params
        ),
    ],
)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Epoch 44/400:  11%|█         | 44/400 [06:24<51:51,  8.74s/it, v_num=1, train_loss_step=55.8, train_loss_epoch=56]  
Monitored metric elbo_validation did not improve in the last 20 records. Best score: 59.582. Signaling Trainer to stop.
model.save(model_path, overwrite=True)

Evaluate on the test set. It’s important to include all the data when setting up the AnnData to ensure that neighbors that would not have been included in the test set are included.

AMICI.setup_anndata(
    adata,
    labels_key=labels_key,
    coord_obsm_key="spatial",
    n_neighbors=exp_params["n_neighbors"],
)

# Get test set metrics
test_elbo = model.get_elbo(
    adata, indices=np.where(adata.obs["in_test"])[0], batch_size=128
).item()
test_reconstruction_loss = model.get_reconstruction_error(
    adata, indices=np.where(adata.obs["in_test"])[0], batch_size=128
)["reconstruction_loss"]

print(f"Test ELBO: {test_elbo}")
print(f"Test Reconstruction Loss: {test_reconstruction_loss}")
INFO     Generating sequential column names                                                                        
INFO     Generating sequential column names                                                                        
INFO     Generating sequential column names                                                                        
INFO     Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             
Test ELBO: -59.968727111816406
Test Reconstruction Loss: 59.964866638183594

AMICI’s downstream interpretation#

Load the model if not already saved and setup the AnnData.

model = AMICI.load(
    model_path,
    adata=adata,
)
AMICI.setup_anndata(
    adata,
    labels_key=labels_key,
    coord_obsm_key="spatial",
    n_neighbors=exp_params["n_neighbors"],
)
INFO     File ./saved_models/cortex_18_params/model.pt already downloaded                                          
INFO     Generating sequential column names                                                                        
INFO     Generating sequential column names                                                                        
INFO     Generating sequential column names                                                                        

High-level interaction scores#

ablation_residuals_path = "./data/cortex_ablation_residuals.pkl"
if os.path.exists(ablation_residuals_path):
    ablation_residuals = AMICIAblationModule.load_object(ablation_residuals_path)
else:
    ablation_residuals = model.get_neighbor_ablation_scores(
        adata=adata,
        compute_z_value=True,
    )
    ablation_residuals.save_object(ablation_residuals_path) 

Plot the high-level interaction map showing all interactions between cell types. We generally filter by a weight threshold to filter out interactions with relatively low interaction strengths.

interaction_weight_matrix_df = ablation_residuals._get_interaction_weight_matrix()
interaction_weight_matrix = interaction_weight_matrix_df.values.flatten()
quantile = 0.86
weight_threshold = np.quantile(interaction_weight_matrix, quantile)
print(f"{quantile} quantile threshold: {weight_threshold:.2f}")

sns.kdeplot(
    x=interaction_weight_matrix
)
plt.title("Distribution of interaction weights")
plt.xlabel("Interaction weight")
plt.ylabel("Density")
plt.axvline(weight_threshold, color='r', linestyle='--', label=f'{quantile} quantile threshold: {weight_threshold:.2f}')
plt.legend()
plt.show()
0.86 quantile threshold: 24.00
../_images/ccc0c43408853b988b153775d5856bc8c35dabda9a8ef6b47c1cc1e210377bc1.png
ablation_residuals.plot_interaction_directed_graph(
    significance_threshold=0.05,
    weight_threshold=weight_threshold,
    node_size=500,
    palette=CELL_TYPE_PALETTE,
)
../_images/4f63eefb54f74ba1d47d1996f1f75afda446807587991cbb473941541719ef22.png
<Figure size 640x480 with 0 Axes>
<Figure size 640x480 with 0 Axes>

Per-gene ablation scores#

For a specific receiver of interest, for example Astrocytes, we can plot the ablation scores as a dotplot ranked by an arbitrary number of top genes.

target_ct = "Astro"

ablation_residuals.plot_featurewise_contributions_dotplot(
    cell_type=target_ct,
    color_by="diff",
    size_by="z_value",
    min_size_by=10,
    step=10,
    n_top_genes=5,
)
../_images/35d9c9357e117763bfb9ab7f1d79fe068b82a375dde84192e4916e46eda87e0d.png

Counterfactual attention scores#

counterfactual_attention_patterns = model.get_counterfactual_attention_patterns(
    cell_type=target_ct,
    adata=adata,
)

We can visualize the length scales for different pairs of interactions between the target cell type and other senders

sender_types = ["L4/5 IT", "L2/3 IT", "Oligo"]

length_scale_df = counterfactual_attention_patterns.plot_length_scale_distribution(
    head_idxs=range(model.module.n_heads),
    sender_types=sender_types,
    attention_threshold=0.1,
    plot_kde=True,
    sample_threshold=0.02,
    max_length_scale=100,
    palette=CELL_TYPE_PALETTE
)
<Figure size 1200x600 with 0 Axes>
../_images/5b336342ad68fc5e8a97b18c7753dd1f634570d5d23d167757652e002d5972f5.png

By default we plot the ablation scores across all heads but we can also visualize mediating genes learned by a specific head, corresponding to different length scales learned across attention heads.

ablation_residuals_sub = model.get_neighbor_ablation_scores(
    adata=adata,
    ablated_neighbor_ct_sub=["L4/5 IT", "L2/3 IT", "Oligo"],
    compute_z_value=True,
    head_idx=7,
)
ablation_residuals_sub.plot_featurewise_contributions_dotplot(
    cell_type=target_ct,
    color_by="diff",
    size_by="z_value",
    min_size_by=2,
    step=5,
    n_top_genes=5,
    vrange=0.1,
)
../_images/1b16e4a8e6ad830c7fdb2ad3cfb043e434317cd8b3f1813c8ea9dce6236adab9.png

Empirical attention scores#

attention_patterns = model.get_attention_patterns(
    adata,
    batch_size=32,
)
100%|██████████| 1058/1058 [00:10<00:00, 101.31it/s]

We can plot a general attention summary of the attention scores for any set of cell types of interest. The attention patterns module contains the attention scores between every cell type and it’s neighborhood, which can be used to spatially visualize populations with high attention for different cells.

attention_patterns.plot_attention_summary(
    cell_type_sub=["Astro"],
    palette=CELL_TYPE_PALETTE,
)
Head index: 100%|██████████| 8/8 [00:01<00:00,  7.81it/s]
Cell type: 100%|██████████| 1/1 [00:01<00:00,  1.76s/it]
../_images/05a3c64f986a8c221b701e1a4c3c4dd47ad88ade871ede17ab5333d9ff14e68e.png