Source code for amici._module

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from scvi import REGISTRY_KEYS
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
from transformer_lens.hook_points import HookedRootModule, HookPoint

from ._components import AttentionBlock, ResNetMLP
from ._constants import NN_REGISTRY_KEYS


[docs] class AMICIModule(HookedRootModule, BaseModuleClass): def __init__( self, n_genes: int, n_labels: int, empirical_ct_means: torch.Tensor, n_label_embed: int = 32, n_kv_dim: int = 256, n_query_embed_hidden: int = 512, n_query_dim: int = 64, n_nn_embed: int = 256, n_nn_embed_hidden: int = 1024, n_pos_coef_mlp_hidden: int = 512, n_head_size: int = 16, n_heads: int = 4, neighbor_dropout: float = 0.1, attention_dummy_score: float = 3.0, attention_penalty_coef: float = 0.0, value_l1_penalty_coef: float = 0.0, pos_coef_offset: float = -2.0, distance_kernel_unit_scale: float = 1.0, ): super().__init__() self.n_genes = n_genes self.n_labels = n_labels self.n_label_embed = n_label_embed self.n_query_embed_hidden = n_query_embed_hidden self.n_query_dim = n_query_dim self.n_kv_dim = n_kv_dim self.n_nn_embed = n_nn_embed self.n_nn_embed_hidden = n_nn_embed_hidden self.n_pos_coef_mlp_hidden = n_pos_coef_mlp_hidden self.attention_dummy_score = attention_dummy_score self.neighbor_dropout = neighbor_dropout self.attention_penalty_coef = attention_penalty_coef self.value_l1_penalty_coef = value_l1_penalty_coef self.distance_kernel_unit_scale = distance_kernel_unit_scale self.pos_coef_offset = pos_coef_offset self.n_head_size = n_head_size self.n_heads = n_heads self.empirical_ct_means = empirical_ct_means self.register_buffer("ct_profiles", self.empirical_ct_means) self.ct_embed = nn.Embedding(self.n_labels, self.n_label_embed) self.query_embed = ResNetMLP( n_input=self.n_label_embed, n_output=self.n_heads * self.n_query_dim, n_layers=2, n_hidden=self.n_query_embed_hidden, dropout=0.0, ) self.nn_embed = ResNetMLP( n_input=self.n_genes, n_output=self.n_nn_embed, n_layers=2, n_hidden=self.n_nn_embed_hidden, dropout=0.0, ) self.pos_coef_mlp = ResNetMLP( n_input=self.n_nn_embed + self.n_label_embed, n_output=self.n_heads, n_layers=2, n_hidden=self.n_pos_coef_mlp_hidden, dropout=0.0, use_final_layer_norm=False, ) self.kv_embed = ResNetMLP( n_input=self.n_nn_embed, n_output=self.n_heads * self.n_kv_dim, n_layers=2, n_hidden=self.n_kv_dim, dropout=0.0, ) self.attention_layer = AttentionBlock( self.n_query_dim, self.n_kv_dim, self.n_head_size, self.n_heads, dummy_attn_score=self.attention_dummy_score, add_res_connection=True, ) self.linear_head = nn.Linear( self.n_heads * self.n_head_size, self.n_genes, bias=False, ) self.hook_label_embed = HookPoint() # [batch, n_label_embed, 1] self.hook_nn_embed = HookPoint() # [batch, n_nn_embed, n_neighbors] self.hook_final_residual = HookPoint() # [batch, n_genes] self.hook_pe_embed = HookPoint() self.setup() # setup hook points def _get_inference_input(self, tensors): labels = tensors[REGISTRY_KEYS.LABELS_KEY] nn_X = tensors[NN_REGISTRY_KEYS.NN_X_KEY] return { "labels": labels, "nn_X": nn_X, }
[docs] def inference(self, labels, nn_X): # convert target labels into cell type embeddings by a lookup table w arb dimension label_embed = self.hook_label_embed(rearrange(self.ct_embed(labels), "b 1 d -> b d")) # batch x n_label_embed # embed neighbor expressions nn_embed = self.hook_nn_embed(self.nn_embed(nn_X)) # batch x n_neighbors x n_nn_embed return { "label_embed": label_embed, "nn_embed": nn_embed, }
def _get_generative_input(self, tensors, inference_outputs): labels = tensors[REGISTRY_KEYS.LABELS_KEY] label_embed = inference_outputs["label_embed"] nn_embed = inference_outputs["nn_embed"] nn_dist = tensors[NN_REGISTRY_KEYS.NN_DIST_KEY] return { "labels": labels, "label_embed": label_embed, "nn_embed": nn_embed, "nn_dist": nn_dist, }
[docs] @auto_move_data def generative( self, labels, label_embed, nn_embed, nn_dist, return_attention_patterns: bool = False, return_attention_scores: bool = False, return_v: bool = False, ): return_attention_patterns = self.attention_penalty_coef > 0.0 or return_attention_patterns return_v = self.value_l1_penalty_coef > 0.0 or return_v query_embed = self.query_embed(label_embed) query_embed = rearrange(query_embed, "b (h d) -> b 1 h d", h=self.n_heads) label_embed_repeated = repeat(label_embed, "b d -> b n d", n=nn_embed.shape[1]) pos_coefs = F.softplus( self.pos_coef_mlp(torch.cat([nn_embed, label_embed_repeated], dim=-1)) + self.pos_coef_offset ) # batch x n_neighbors x n_heads pos_attn_score = -pos_coefs * (nn_dist.unsqueeze(-1) / self.distance_kernel_unit_scale) kv_embed = self.kv_embed(nn_embed) kv_embed = rearrange(kv_embed, "b n (h d) -> b n h d", h=self.n_heads) attention_mask = None if self.training and self.neighbor_dropout > 0.0: attention_mask = ( torch.rand((kv_embed.shape[0], kv_embed.shape[1]), device=kv_embed.device) > self.neighbor_dropout ).int() attn_outs = self.attention_layer( query_embed, kv_embed, kv_embed, attention_mask=attention_mask, pos_attn_score=pos_attn_score, return_base_attn_scores=return_attention_scores, return_attn_patterns=return_attention_patterns, return_v=return_v, ) residual_embed = attn_outs["x"] residual_embed = rearrange(residual_embed, "b 1 d -> b d") # batch x n_genes attention_scores = None if return_attention_scores: attention_scores = attn_outs["base_attn_scores"][:, :, 0, :] attention_patterns = None if return_attention_patterns: attention_patterns = attn_outs["attn_patterns"][:, :, 0, :] attention_v = None if return_v: attention_v = attn_outs["v"] # linear layer output into prediction of gene expression residual residual = self.hook_final_residual(self.linear_head(residual_embed).float()) # batch x n_genes # have another matrix w/ cell type specific gene expression mean vectors # can be learned or start by just making them the empirical means of that cell type in the data # normalize x at all stages to help nn and to make loss simply the mse batch_ct_means = self.ct_profiles[labels.squeeze(-1)].squeeze() prediction = (batch_ct_means + residual).float() gen_outs = { "residual_embed": residual_embed, "residual": residual, "prediction": prediction, "attention_scores": attention_scores, "attention_patterns": attention_patterns, "attention_v": attention_v, "pos_coefs": pos_coefs, } return gen_outs
# HACK: kl_weight argument exists to support VAEMixin get_reconstruction_error
[docs] def loss(self, tensors, inference_outputs, generative_outputs, kl_weight=1.0): """Loss computation.""" true_X = tensors[REGISTRY_KEYS.X_KEY] prediction = generative_outputs["prediction"] reconstruction_loss = F.gaussian_nll_loss( prediction, true_X, var=torch.ones_like(prediction), reduction="none" ).sum(-1) attention_penalty = torch.zeros(true_X.shape[0], device=true_X.device) if self.attention_penalty_coef > 0.0: attention_patterns = generative_outputs["attention_patterns"] # batch x head_index x key_pos eps = torch.finfo(attention_patterns.dtype).eps attention_entropy_terms = ( -1 * attention_patterns * torch.log(torch.clamp(attention_patterns, min=eps, max=1 - eps)) ) attention_penalty = reduce( reduce( attention_entropy_terms, "batch head_index key_pos -> batch head_index", "sum", ), "batch head_index -> batch", "mean", ) value_l1_penalty = torch.zeros(true_X.shape[0], device=true_X.device) if self.value_l1_penalty_coef > 0.0: attention_v = generative_outputs["attention_v"] value_l1_penalty = reduce( reduce( torch.abs(attention_v), "batch key_pos head_index head_size -> batch key_pos", "sum", ), "batch key_pos -> batch", "mean", ) loss = torch.mean( reconstruction_loss + self.attention_penalty_coef * attention_penalty + self.value_l1_penalty_coef * value_l1_penalty ) return LossOutput( loss=loss, reconstruction_loss=reconstruction_loss, kl_local={ "attention_penalty": self.attention_penalty_coef * attention_penalty, "value_l1_penalty": self.value_l1_penalty_coef * value_l1_penalty, }, extra_metrics={"attention_penalty_coef": torch.tensor(self.attention_penalty_coef)}, )