amici.AMICI#
- class amici.AMICI(adata, **model_kwargs)[source]#
Bases:
VAEMixin,WandbUnsupervisedTrainingMixin,BaseModelClassMethods
__init__Converts a legacy saved model (<v0.15.0) to the updated save format.
Deregisters the
AnnDataManagerinstance associated with adata.Retrieves the
AnnDataManagerfor a given AnnData object.Different flavors of attention patterns retrieved from cache.
Compute the counterfactual attention patterns for a given cell type and head indices.
Compute the evidence lower bound (ELBO) on the data.
Creates a DataFrame with the explained variance scores.
Returns the object in AnnData associated with the key in the data registry.
Get the gene residual contributions for each cell at full attention.
Compute the latent representation of the data.
Compute the marginal log-likehood of the data.
Difference in gene expression prediction error for cell type of interest when ablating neighbor cell types for a specific head of interest.
get_nn_embedget_predictionsCompute the reconstruction error on the data.
Instantiate a model from the saved output.
Return the full registry saved with the model.
Registers an
AnnDataManagerinstance with this model class.Save the state of the model.
Sets up the
AnnDataobject for this model.Move model to device.
Train the model.
Print summary of the setup for the initial AnnData or a given AnnData object.
Print args used to setup a saved model.
Attributes
Data attached to model instance.
Manager instance associated with self.adata.
The current device that the module's params are on.
Returns computed metrics during training.
Whether the model has been trained.
Summary string of the model.
Observations that are in test set.
Observations that are in train set.
Observations that are in validation set.
- classmethod setup_anndata(adata, layer=None, labels_key=None, coord_obsm_key=None, nn_dist_key='_nn_dist', nn_idx_key='_nn_idx', cell_radius_key=None, n_neighbors=None, **kwargs)[source]#
Sets up the
AnnDataobject for this model.A mapping will be created between data fields used by this model to their respective locations in adata. None of the data in adata are modified. Only adds fields to adata.
Each model class deriving from this class provides parameters to this method according to its needs. To operate correctly with the model initialization, the implementation must call
register_manager()on a model-specific instance ofAnnDataManager.
- get_attention_patterns(model, adata=None, indices=None, batch_size=None, flavor='vanilla', prog_bar=True)#
Different flavors of attention patterns retrieved from cache.
Returns a DataFrame if return_nn_idxs_and_dists is False, otherwise returns a tuple of DataFrames (attention_patterns, nn_idxs, nn_dists). Code adapted from: callummcdougall/CircuitsVis. einsum key: - b: n_batch - h: n_heads - n: n_neighbors - g: n_genes - m: n_heads * n_head_size (a.k.a. d_model)
- Parameters:
adata (AnnData) – AnnData object to get attention patterns from.
indices (list[int], optional) – Indices of the cells to get attention patterns from.
batch_size (int, optional) – Batch size to use for the data loader.
flavor (Literal["vanilla", "value-weighted", "info-weighted", "gene-weighted"], optional) – Flavor of attention patterns to retrieve.
prog_bar (bool, optional) – Whether to show a progress bar.
- Return type:
Stores#
- Union[pd.DataFrame, tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]]:
- attention_patterns: DataFrame of attention patterns with the following columns:
“neighbor_{i}”: Attention pattern for the i-th neighbor.
“head_idx”: Index of the attention head.
“label”: Label of the cell.
“cell_idx”: Index of the cell.
- nn_idxs: DataFrame of nearest neighbor indices (only if return_nn_idxs_and_dists is True) with the following columns:
“neighbor_{i}”: Index of the i-th neighbor.
- nn_dists: DataFrame of nearest neighbor distances (only if return_nn_idxs_and_dists is True) with the following columns:
“neighbor_{i}”: Distance to the i-th neighbor.
- get_counterfactual_attention_patterns(model, cell_type, adata=None, indices=None, head_idxs=None, batch_size=None)#
Compute the counterfactual attention patterns for a given cell type and head indices.
- Parameters:
model (
AMICI) – The trained AMICI model.cell_type (
str) – The index cell type to get counterfactual attention scores for.adata (
AnnData|None) – Optional AnnData object to use. If None, uses model.adata.indices (
list[int] |None) – Optional list of cell indices to get scores for. If None, uses all cells.head_idxs (
list[int] |None) – Optional list of attention head indices to get scores for. If None, uses all heads.batch_size (
int|None) – Optional batch size to use for the data loader.
- Return type:
- Returns:
AMICICounterfactualAttentionModule: The module instance with computed counterfactual attention scores stored in _counterfactual_attention_df with columns:
query_label: Cell type label of the query cell
neighbor_idx: Index of the neighbor cell
neighbor_label: Cell type label of the neighbor cell
head_idx: Attention head index
base_attention_score: Base attention score
pos_coef: Position coefficient of the counterfactual attention score function
dummy_attention_score: Dummy attention score of the model
distance_kernel_unit_scale: Distance kernel unit scale
- get_expl_variance_scores(model, adata=None, alpha=0.05, run_permutation_test=True, n_permutations=None)#
Creates a DataFrame with the explained variance scores.
DataFrame contains the explained variance scores for each head across cells for all cell types or a subset of cell types.
- Parameters:
adata (AnnData, optional) – The AnnData object.
alpha (float, optional) – The desired significance level for the explained variance of each head-cell type pair via permutation testing. Defaults to 0.05.
run_permutation_test (bool, optional) – Whether to run the permutation test. Defaults to True.
n_permutations (int, optional) – The number of permutations to use for the permutation test. If not provided, the minimum number of permutations required to achieve a significance level of alpha will be used.
- Return type:
- Returns:
AMICIExplainedVarianceModule: The module instance with computed explained variance scores stored in _explained_variance_df with columns:
head: The index of the head.
ct_name: The name of the cell type.
gene: The name of the gene.
expl_var_head_gene: The explained variance score when all but head is ablated for the gene.
expl_var_head: The explained variance score when all but head aggregated across all genes.
- if run_permutation_test is True, the following columns are also added:
p_value_head: The p-value of the explained variance score when all but head is ablated.
p_value_adj_head: The adjusted p-value of the explained variance score when all but head is ablated.
- get_neighbor_ablation_scores(model, cell_type, head_idx, adata=None, ablated_neighbor_ct_sub=None, ablated_neighbor_indices=None, compute_z_value=False)#
Difference in gene expression prediction error for cell type of interest when ablating neighbor cell types for a specific head of interest.
- Parameters:
cell_type (str) – The cell type to compute the residuals for. None indicates using all cell types.
head_idx (int, optional) – The index of the head to test. None indicates using all heads.
adata (AnnData, optional) – The AnnData object.
ablated_neighbor_ct_sub (list[str], optional) – The neighbor cell types to ablate.
ablated_neighbor_indices (list[int], optional) – The indices of the neighbor cells to ablate.
compute_z_value (bool, optional) – Whether to save the z-value using the correlation coefficient or not.
- Return type:
- Returns:
pd.DataFrame: A DataFrame with the difference in gene expression prediction error for cell type of interest when ablating neighbor cell types. The DataFrame has the columns:
[ablated_neighbor_ct]: The difference in gene expression prediction error when ablating the neighbor cell type ablated_neighbor_ct.
- get_gene_residual_contributions(adata=None, indices=None, batch_size=None, head_idxs=None)[source]#
Get the gene residual contributions for each cell at full attention.
Compute the gene-wise residual contributions for each cell irrespective of the attention score for the head. As the value vectors do not depend on the distances or the index cell, we only need to provide the neighbor gene expressions.
- Parameters:
- Return type:
DataFrame- Returns:
pd.DataFrame: A DataFrame with the gene residual contributions for each cell at full attention. The DataFrame has the columns:
neighbor: The index of the neighbor cell.
head: The index of the head.
{gene}: The gene residual contribution for the gene.
- property adata: AnnOrMuData#
Data attached to model instance.
- property adata_manager: AnnDataManager#
Manager instance associated with self.adata.
- classmethod convert_legacy_save(dir_path, output_dir_path, overwrite=False, prefix=None, **save_kwargs)#
Converts a legacy saved model (<v0.15.0) to the updated save format.
- Parameters:
dir_path (
str) – Path to directory where legacy model is saved.output_dir_path (
str) – Path to save converted save files.overwrite (
bool) – Overwrite existing data or not. IfFalseand directory already exists atoutput_dir_path, error will be raised.**save_kwargs – Keyword arguments passed into
save().
- Return type:
- deregister_manager(adata=None)#
Deregisters the
AnnDataManagerinstance associated with adata.If adata is None, deregisters all
AnnDataManagerinstances in both the class and instance-specific manager stores, except for the one associated with this model instance.
- get_anndata_manager(adata, required=False)#
Retrieves the
AnnDataManagerfor a given AnnData object.Requires
self.idhas been set. Checks for anAnnDataManagerspecific to this model instance.
- get_elbo(adata=None, indices=None, batch_size=None, dataloader=None, return_mean=True, **kwargs)#
Compute the evidence lower bound (ELBO) on the data.
The ELBO is the reconstruction error plus the Kullback-Leibler (KL) divergences between the variational distributions and the priors. It is different from the marginal log-likelihood; specifically, it is a lower bound on the marginal log-likelihood plus a term that is constant with respect to the variational distribution. It still gives good insights on the modeling of the data and is fast to compute.
- Parameters:
adata (
AnnData|None) –AnnDataobject withvar_namesin the same order as the ones used to train the model. IfNoneanddataloaderis alsoNone, it defaults to the object used to initialize the model.indices (
Sequence[int] |None) – Indices of observations inadatato use. IfNone, defaults to all observations. Ignored ifdataloaderis notNone.batch_size (
int|None) – Minibatch size for the forward pass. IfNone, defaults toscvi.settings.batch_size. Ignored ifdataloaderis notNone.dataloader (
Iterator[dict[str,Tensor|None]] |None) – An iterator over minibatches of data on which to compute the metric. The minibatches should be formatted as a dictionary ofTensorwith keys as expected by the model. IfNone, a dataloader is created fromadata.return_mean (
bool) – Whether to return the mean of the ELBO or the ELBO for each observation.**kwargs – Additional keyword arguments to pass into the forward method of the module.
- Return type:
Evidence lower bound (ELBO) of the data.
Notes
This is not the negative ELBO, so higher is better.
- get_from_registry(adata, registry_key)#
Returns the object in AnnData associated with the key in the data registry.
AnnData object should be registered with the model prior to calling this function via the
self._validate_anndatamethod.
- get_latent_representation(adata=None, indices=None, give_mean=True, mc_samples=5000, batch_size=None, return_dist=False, dataloader=None)#
Compute the latent representation of the data.
This is typically denoted as \(z_n\).
- Parameters:
adata (
AnnData|None) –AnnDataobject withvar_namesin the same order as the ones used to train the model. IfNoneanddataloaderis alsoNone, it defaults to the object used to initialize the model.indices (
Sequence[int] |None) – Indices of observations inadatato use. IfNone, defaults to all observations. Ignored ifdataloaderis notNonegive_mean (
bool) – IfTrue, returns the mean of the latent distribution. IfFalse, returns an estimate of the mean usingmc_samplesMonte Carlo samples.mc_samples (
int) – Number of Monte Carlo samples to use for the estimator for distributions with no closed-form mean (e.g., the logistic normal distribution). Not used ifgive_meanisTrueor ifreturn_distisTrue.batch_size (
int|None) – Minibatch size for the forward pass. IfNone, defaults toscvi.settings.batch_size. Ignored ifdataloaderis notNonereturn_dist (
bool) – IfTrue, returns the mean and variance of the latent distribution. Otherwise, returns the mean of the latent distribution.dataloader (
Iterator[dict[str,Tensor|None]]) – An iterator over minibatches of data on which to compute the metric. The minibatches should be formatted as a dictionary ofTensorwith keys as expected by the model. IfNone, a dataloader is created fromadata.
- Return type:
ndarray[Any,dtype[TypeVar(_ScalarType_co, bound=generic, covariant=True)]] |tuple[ndarray[Any,dtype[TypeVar(_ScalarType_co, bound=generic, covariant=True)]],ndarray[Any,dtype[TypeVar(_ScalarType_co, bound=generic, covariant=True)]]]- Returns:
An array of shape
(n_obs, n_latent)ifreturn_distisFalse. Otherwise, returnsa tuple of arrays
(n_obs, n_latent)with the mean and variance of the latentdistribution.
- get_marginal_ll(adata=None, indices=None, n_mc_samples=1000, batch_size=None, return_mean=True, dataloader=None, **kwargs)#
Compute the marginal log-likehood of the data.
The computation here is a biased estimator of the marginal log-likelihood of the data.
- Parameters:
adata (
AnnData|None) –AnnDataobject withvar_namesin the same order as the ones used to train the model. IfNoneanddataloaderis alsoNone, it defaults to the object used to initialize the model.indices (
Sequence[int] |None) – Indices of observations inadatato use. IfNone, defaults to all observations. Ignored ifdataloaderis notNone.n_mc_samples (
int) – Number of Monte Carlo samples to use for the estimator. Passed into the module’smarginal_llmethod.batch_size (
int|None) – Minibatch size for the forward pass. IfNone, defaults toscvi.settings.batch_size. Ignored ifdataloaderis notNone.return_mean (
bool) – Whether to return the mean of the marginal log-likelihood or the marginal-log likelihood for each observation.dataloader (
Iterator[dict[str,Tensor|None]]) – An iterator over minibatches of data on which to compute the metric. The minibatches should be formatted as a dictionary ofTensorwith keys as expected by the model. IfNone, a dataloader is created fromadata.**kwargs – Additional keyword arguments to pass into the module’s
marginal_llmethod.
- Return type:
- Returns:
If
True, returns the mean marginal log-likelihood. Otherwise returns a tensor of shape(n_obs,)with the marginal log-likelihood for each observation.
Notes
This is not the negative log-likelihood, so higher is better.
- get_reconstruction_error(adata=None, indices=None, batch_size=None, dataloader=None, return_mean=True, **kwargs)#
Compute the reconstruction error on the data.
The reconstruction error is the negative log likelihood of the data given the latent variables. It is different from the marginal log-likelihood, but still gives good insights on the modeling of the data and is fast to compute. This is typically written as \(p(x \mid z)\), the likelihood term given one posterior sample.
- Parameters:
adata (
AnnData|None) –AnnDataobject withvar_namesin the same order as the ones used to train the model. IfNoneanddataloaderis alsoNone, it defaults to the object used to initialize the model.indices (
Sequence[int] |None) – Indices of observations inadatato use. IfNone, defaults to all observations. Ignored ifdataloaderis notNonebatch_size (
int|None) – Minibatch size for the forward pass. IfNone, defaults toscvi.settings.batch_size. Ignored ifdataloaderis notNonedataloader (
Iterator[dict[str,Tensor|None]] |None) – An iterator over minibatches of data on which to compute the metric. The minibatches should be formatted as a dictionary ofTensorwith keys as expected by the model. IfNone, a dataloader is created fromadata.return_mean (
bool) – Whether to return the mean reconstruction loss or the reconstruction loss for each observation.**kwargs – Additional keyword arguments to pass into the forward method of the module.
- Return type:
Reconstruction error for the data.
Notes
This is not the negative reconstruction error, so higher is better.
- property history#
Returns computed metrics during training.
- classmethod load(dir_path, adata=None, accelerator='auto', device='auto', prefix=None, backup_url=None)#
Instantiate a model from the saved output.
- Parameters:
dir_path (
str) – Path to saved outputs.adata (
AnnData|MuData|None) – AnnData organized in the same way as data used to train model. It is not necessary to run setup_anndata, as AnnData is validated against the saved scvi setup dictionary. If None, will check for and load anndata saved with the model.accelerator (
str) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”) as well as custom accelerator instances.device (
int|str) – The device to use. Can be set to a non-negative index (int or str) or “auto” for automatic selection based on the chosen accelerator. If set to “auto” and accelerator is not determined to be “cpu”, then device will be set to the first available device.backup_url (
str|None) – URL to retrieve saved outputs from if not present on disk.
- Return type:
Model with loaded state dictionaries.
Examples
>>> model = ModelClass.load(save_path, adata) >>> model.get_....
- static load_registry(dir_path, prefix=None)#
Return the full registry saved with the model.
- classmethod register_manager(adata_manager)#
Registers an
AnnDataManagerinstance with this model class.Stores the
AnnDataManagerreference in a class-specific manager store. Intended for use in thesetup_anndata()class method followed up by retrieval of theAnnDataManagervia the_get_most_recent_anndata_manager()method in the model init method.Notes
Subsequent calls to this method with an
AnnDataManagerinstance referring to the same underlying AnnData object will overwrite the reference to previousAnnDataManager.
- save(dir_path, prefix=None, overwrite=False, save_anndata=False, save_kwargs=None, legacy_mudata_format=False, **anndata_write_kwargs)#
Save the state of the model.
Neither the trainer optimizer state nor the trainer history are saved. Model files are not expected to be reproducibly saved and loaded across versions until we reach version 1.0.
- Parameters:
dir_path (
str) – Path to a directory.prefix (
str|None) – Prefix to prepend to saved file names.overwrite (
bool) – Overwrite existing data or not. If False and directory already exists at dir_path, error will be raised.save_anndata (
bool) – If True, also saves the anndatasave_kwargs (
dict|None) – Keyword arguments passed intosave().legacy_mudata_format (
bool) – IfTrue, saves the modelvar_namesin the legacy format if the model was trained with aMuDataobject. The legacy format is a flat array with variable names across all modalities concatenated, while the new format is a dictionary with keys corresponding to the modality names and values corresponding to the variable names for each modality.anndata_write_kwargs – Kwargs for
write()
- property summary_string#
Summary string of the model.
- to_device(device)#
Move model to device.
- Parameters:
device (
str|int) – Device to move model to. Options: ‘cpu’ for CPU, integer GPU index (eg. 0), or ‘cuda:X’ where X is the GPU index (eg. ‘cuda:0’). See torch.device for more info.
Examples
>>> adata = scvi.data.synthetic_iid() >>> model = scvi.model.SCVI(adata) >>> model.to_device("cpu") # moves model to CPU >>> model.to_device("cuda:0") # moves model to GPU 0 >>> model.to_device(0) # also moves model to GPU 0
- train(max_epochs=None, accelerator='auto', devices='auto', train_size=0.9, validation_size=None, external_indexing=None, shuffle_set_split=True, batch_size=128, early_stopping=False, plan_kwargs=None, use_wandb=False, wandb_project=None, wandb_entity=None, wandb_run_name=None, **trainer_kwargs)#
Train the model.
- Parameters:
max_epochs (
Optional[int]) – Number of passes through the dataset. If None, defaults to np.min([round((20000 / n_cells) * 400), 400])accelerator (
str) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”) as well as custom accelerator instances.devices (
Union[int,list[int],str]) – The devices to use. Can be set to a non-negative index (int or str), a sequence of device indices (list or comma-separated str), the value -1 to indicate all available devices, or “auto” for automatic selection based on the chosen accelerator. If set to “auto” and accelerator is not determined to be “cpu”, then devices will be set to the first available device.train_size (
float) – Size of training set in the range [0.0, 1.0].validation_size (
Optional[float]) – Size of the test set. If None, defaults to 1 - train_size. If train_size + validation_size < 1, the remaining cells belong to a test set.shuffle_set_split (
bool) – Whether to shuffle indices before splitting. If False, the val, train, and test set are split in the sequential order of the data according to validation_size and train_size percentages.external_indexing (
Optional[list[array,array,array]]) – A list of data split indices in the order of training, validation, and test sets. Validation and test set are not required and can be left empty.batch_size (
int) – Minibatch size to use during training.early_stopping (
bool) – Perform early stopping. Additional arguments can be passed in **kwargs. SeeTrainerfor further options.plan_kwargs (
Optional[dict]) – Keyword args forTrainingPlan. Keyword arguments passed to train() will overwrite values present in plan_kwargs, when appropriate.**trainer_kwargs – Other keyword args for
Trainer.
- view_anndata_setup(adata=None, hide_state_registries=False)#
Print summary of the setup for the initial AnnData or a given AnnData object.