amici.AMICI#

class amici.AMICI(adata, **model_kwargs)[source]#

Bases: VAEMixin, WandbUnsupervisedTrainingMixin, BaseModelClass

Methods

__init__

convert_legacy_save

Converts a legacy saved model (<v0.15.0) to the updated save format.

deregister_manager

Deregisters the AnnDataManager instance associated with adata.

get_anndata_manager

Retrieves the AnnDataManager for a given AnnData object.

get_attention_patterns

Different flavors of attention patterns retrieved from cache.

get_counterfactual_attention_patterns

Compute the counterfactual attention patterns for a given cell type and head indices.

get_elbo

Compute the evidence lower bound (ELBO) on the data.

get_expl_variance_scores

Creates a DataFrame with the explained variance scores.

get_from_registry

Returns the object in AnnData associated with the key in the data registry.

get_gene_residual_contributions

Get the gene residual contributions for each cell at full attention.

get_latent_representation

Compute the latent representation of the data.

get_marginal_ll

Compute the marginal log-likehood of the data.

get_neighbor_ablation_scores

Difference in gene expression prediction error for cell type of interest when ablating neighbor cell types for a specific head of interest.

get_nn_embed

get_predictions

get_reconstruction_error

Compute the reconstruction error on the data.

load

Instantiate a model from the saved output.

load_registry

Return the full registry saved with the model.

register_manager

Registers an AnnDataManager instance with this model class.

save

Save the state of the model.

setup_anndata

Sets up the AnnData object for this model.

to_device

Move model to device.

train

Train the model.

view_anndata_setup

Print summary of the setup for the initial AnnData or a given AnnData object.

view_setup_args

Print args used to setup a saved model.

Attributes

adata

Data attached to model instance.

adata_manager

Manager instance associated with self.adata.

device

The current device that the module's params are on.

history

Returns computed metrics during training.

is_trained

Whether the model has been trained.

summary_string

Summary string of the model.

test_indices

Observations that are in test set.

train_indices

Observations that are in train set.

validation_indices

Observations that are in validation set.

classmethod setup_anndata(adata, layer=None, labels_key=None, coord_obsm_key=None, nn_dist_key='_nn_dist', nn_idx_key='_nn_idx', cell_radius_key=None, n_neighbors=None, **kwargs)[source]#

Sets up the AnnData object for this model.

A mapping will be created between data fields used by this model to their respective locations in adata. None of the data in adata are modified. Only adds fields to adata.

Each model class deriving from this class provides parameters to this method according to its needs. To operate correctly with the model initialization, the implementation must call register_manager() on a model-specific instance of AnnDataManager.

get_attention_patterns(model, adata=None, indices=None, batch_size=None, flavor='vanilla', prog_bar=True)#

Different flavors of attention patterns retrieved from cache.

Returns a DataFrame if return_nn_idxs_and_dists is False, otherwise returns a tuple of DataFrames (attention_patterns, nn_idxs, nn_dists). Code adapted from: callummcdougall/CircuitsVis. einsum key: - b: n_batch - h: n_heads - n: n_neighbors - g: n_genes - m: n_heads * n_head_size (a.k.a. d_model)

Parameters:
  • adata (AnnData) – AnnData object to get attention patterns from.

  • indices (list[int], optional) – Indices of the cells to get attention patterns from.

  • batch_size (int, optional) – Batch size to use for the data loader.

  • flavor (Literal["vanilla", "value-weighted", "info-weighted", "gene-weighted"], optional) – Flavor of attention patterns to retrieve.

  • prog_bar (bool, optional) – Whether to show a progress bar.

Return type:

AMICIAttentionModule

Stores#

Union[pd.DataFrame, tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]]:
  • attention_patterns: DataFrame of attention patterns with the following columns:
    • “neighbor_{i}”: Attention pattern for the i-th neighbor.

    • “head_idx”: Index of the attention head.

    • “label”: Label of the cell.

    • “cell_idx”: Index of the cell.

  • nn_idxs: DataFrame of nearest neighbor indices (only if return_nn_idxs_and_dists is True) with the following columns:
    • “neighbor_{i}”: Index of the i-th neighbor.

  • nn_dists: DataFrame of nearest neighbor distances (only if return_nn_idxs_and_dists is True) with the following columns:
    • “neighbor_{i}”: Distance to the i-th neighbor.

get_counterfactual_attention_patterns(model, cell_type, adata=None, indices=None, head_idxs=None, batch_size=None)#

Compute the counterfactual attention patterns for a given cell type and head indices.

Parameters:
  • model (AMICI) – The trained AMICI model.

  • cell_type (str) – The index cell type to get counterfactual attention scores for.

  • adata (AnnData | None) – Optional AnnData object to use. If None, uses model.adata.

  • indices (list[int] | None) – Optional list of cell indices to get scores for. If None, uses all cells.

  • head_idxs (list[int] | None) – Optional list of attention head indices to get scores for. If None, uses all heads.

  • batch_size (int | None) – Optional batch size to use for the data loader.

Return type:

AMICICounterfactualAttentionModule

Returns:

AMICICounterfactualAttentionModule: The module instance with computed counterfactual attention scores stored in _counterfactual_attention_df with columns:

  • query_label: Cell type label of the query cell

  • neighbor_idx: Index of the neighbor cell

  • neighbor_label: Cell type label of the neighbor cell

  • head_idx: Attention head index

  • base_attention_score: Base attention score

  • pos_coef: Position coefficient of the counterfactual attention score function

  • dummy_attention_score: Dummy attention score of the model

  • distance_kernel_unit_scale: Distance kernel unit scale

get_expl_variance_scores(model, adata=None, alpha=0.05, run_permutation_test=True, n_permutations=None)#

Creates a DataFrame with the explained variance scores.

DataFrame contains the explained variance scores for each head across cells for all cell types or a subset of cell types.

Parameters:
  • adata (AnnData, optional) – The AnnData object.

  • alpha (float, optional) – The desired significance level for the explained variance of each head-cell type pair via permutation testing. Defaults to 0.05.

  • run_permutation_test (bool, optional) – Whether to run the permutation test. Defaults to True.

  • n_permutations (int, optional) – The number of permutations to use for the permutation test. If not provided, the minimum number of permutations required to achieve a significance level of alpha will be used.

Return type:

AMICIExplainedVarianceModule

Returns:

AMICIExplainedVarianceModule: The module instance with computed explained variance scores stored in _explained_variance_df with columns:

  • head: The index of the head.

  • ct_name: The name of the cell type.

  • gene: The name of the gene.

  • expl_var_head_gene: The explained variance score when all but head is ablated for the gene.

  • expl_var_head: The explained variance score when all but head aggregated across all genes.

if run_permutation_test is True, the following columns are also added:
  • p_value_head: The p-value of the explained variance score when all but head is ablated.

  • p_value_adj_head: The adjusted p-value of the explained variance score when all but head is ablated.

get_neighbor_ablation_scores(model, cell_type, head_idx, adata=None, ablated_neighbor_ct_sub=None, ablated_neighbor_indices=None, compute_z_value=False)#

Difference in gene expression prediction error for cell type of interest when ablating neighbor cell types for a specific head of interest.

Parameters:
  • cell_type (str) – The cell type to compute the residuals for. None indicates using all cell types.

  • head_idx (int, optional) – The index of the head to test. None indicates using all heads.

  • adata (AnnData, optional) – The AnnData object.

  • ablated_neighbor_ct_sub (list[str], optional) – The neighbor cell types to ablate.

  • ablated_neighbor_indices (list[int], optional) – The indices of the neighbor cells to ablate.

  • compute_z_value (bool, optional) – Whether to save the z-value using the correlation coefficient or not.

Return type:

AMICIAblationModule

Returns:

pd.DataFrame: A DataFrame with the difference in gene expression prediction error for cell type of interest when ablating neighbor cell types. The DataFrame has the columns:

  • [ablated_neighbor_ct]: The difference in gene expression prediction error when ablating the neighbor cell type ablated_neighbor_ct.

get_gene_residual_contributions(adata=None, indices=None, batch_size=None, head_idxs=None)[source]#

Get the gene residual contributions for each cell at full attention.

Compute the gene-wise residual contributions for each cell irrespective of the attention score for the head. As the value vectors do not depend on the distances or the index cell, we only need to provide the neighbor gene expressions.

Parameters:
  • adata (AnnData | None) – The AnnData object.

  • indices (list[int] | None) – The indices of the cells to get gene residual contributions for.

  • batch_size (int | None) – The batch size.

  • head_idxs (list[int] | None) – The indices of the heads to get gene residual contributions for.

Return type:

DataFrame

Returns:

pd.DataFrame: A DataFrame with the gene residual contributions for each cell at full attention. The DataFrame has the columns:

  • neighbor: The index of the neighbor cell.

  • head: The index of the head.

  • {gene}: The gene residual contribution for the gene.

property adata: AnnOrMuData#

Data attached to model instance.

property adata_manager: AnnDataManager#

Manager instance associated with self.adata.

classmethod convert_legacy_save(dir_path, output_dir_path, overwrite=False, prefix=None, **save_kwargs)#

Converts a legacy saved model (<v0.15.0) to the updated save format.

Parameters:
  • dir_path (str) – Path to directory where legacy model is saved.

  • output_dir_path (str) – Path to save converted save files.

  • overwrite (bool) – Overwrite existing data or not. If False and directory already exists at output_dir_path, error will be raised.

  • prefix (str | None) – Prefix of saved file names.

  • **save_kwargs – Keyword arguments passed into save().

Return type:

None

deregister_manager(adata=None)#

Deregisters the AnnDataManager instance associated with adata.

If adata is None, deregisters all AnnDataManager instances in both the class and instance-specific manager stores, except for the one associated with this model instance.

property device: str#

The current device that the module’s params are on.

get_anndata_manager(adata, required=False)#

Retrieves the AnnDataManager for a given AnnData object.

Requires self.id has been set. Checks for an AnnDataManager specific to this model instance.

Parameters:
  • adata (AnnData | MuData) – AnnData object to find manager instance for.

  • required (bool) – If True, errors on missing manager. Otherwise, returns None when manager is missing.

Return type:

AnnDataManager | None

get_elbo(adata=None, indices=None, batch_size=None, dataloader=None, return_mean=True, **kwargs)#

Compute the evidence lower bound (ELBO) on the data.

The ELBO is the reconstruction error plus the Kullback-Leibler (KL) divergences between the variational distributions and the priors. It is different from the marginal log-likelihood; specifically, it is a lower bound on the marginal log-likelihood plus a term that is constant with respect to the variational distribution. It still gives good insights on the modeling of the data and is fast to compute.

Parameters:
  • adata (AnnData | None) – AnnData object with var_names in the same order as the ones used to train the model. If None and dataloader is also None, it defaults to the object used to initialize the model.

  • indices (Sequence[int] | None) – Indices of observations in adata to use. If None, defaults to all observations. Ignored if dataloader is not None.

  • batch_size (int | None) – Minibatch size for the forward pass. If None, defaults to scvi.settings.batch_size. Ignored if dataloader is not None.

  • dataloader (Iterator[dict[str, Tensor | None]] | None) – An iterator over minibatches of data on which to compute the metric. The minibatches should be formatted as a dictionary of Tensor with keys as expected by the model. If None, a dataloader is created from adata.

  • return_mean (bool) – Whether to return the mean of the ELBO or the ELBO for each observation.

  • **kwargs – Additional keyword arguments to pass into the forward method of the module.

Return type:

Evidence lower bound (ELBO) of the data.

Notes

This is not the negative ELBO, so higher is better.

get_from_registry(adata, registry_key)#

Returns the object in AnnData associated with the key in the data registry.

AnnData object should be registered with the model prior to calling this function via the self._validate_anndata method.

Parameters:
  • registry_key (str) – key of object to get from data registry.

  • adata (AnnData | MuData) – AnnData to pull data from.

Return type:

The requested data as a NumPy array.

get_latent_representation(adata=None, indices=None, give_mean=True, mc_samples=5000, batch_size=None, return_dist=False, dataloader=None)#

Compute the latent representation of the data.

This is typically denoted as \(z_n\).

Parameters:
  • adata (AnnData | None) – AnnData object with var_names in the same order as the ones used to train the model. If None and dataloader is also None, it defaults to the object used to initialize the model.

  • indices (Sequence[int] | None) – Indices of observations in adata to use. If None, defaults to all observations. Ignored if dataloader is not None

  • give_mean (bool) – If True, returns the mean of the latent distribution. If False, returns an estimate of the mean using mc_samples Monte Carlo samples.

  • mc_samples (int) – Number of Monte Carlo samples to use for the estimator for distributions with no closed-form mean (e.g., the logistic normal distribution). Not used if give_mean is True or if return_dist is True.

  • batch_size (int | None) – Minibatch size for the forward pass. If None, defaults to scvi.settings.batch_size. Ignored if dataloader is not None

  • return_dist (bool) – If True, returns the mean and variance of the latent distribution. Otherwise, returns the mean of the latent distribution.

  • dataloader (Iterator[dict[str, Tensor | None]]) – An iterator over minibatches of data on which to compute the metric. The minibatches should be formatted as a dictionary of Tensor with keys as expected by the model. If None, a dataloader is created from adata.

Return type:

ndarray[Any, dtype[TypeVar(_ScalarType_co, bound= generic, covariant=True)]] | tuple[ndarray[Any, dtype[TypeVar(_ScalarType_co, bound= generic, covariant=True)]], ndarray[Any, dtype[TypeVar(_ScalarType_co, bound= generic, covariant=True)]]]

Returns:

  • An array of shape (n_obs, n_latent) if return_dist is False. Otherwise, returns

  • a tuple of arrays (n_obs, n_latent) with the mean and variance of the latent

  • distribution.

get_marginal_ll(adata=None, indices=None, n_mc_samples=1000, batch_size=None, return_mean=True, dataloader=None, **kwargs)#

Compute the marginal log-likehood of the data.

The computation here is a biased estimator of the marginal log-likelihood of the data.

Parameters:
  • adata (AnnData | None) – AnnData object with var_names in the same order as the ones used to train the model. If None and dataloader is also None, it defaults to the object used to initialize the model.

  • indices (Sequence[int] | None) – Indices of observations in adata to use. If None, defaults to all observations. Ignored if dataloader is not None.

  • n_mc_samples (int) – Number of Monte Carlo samples to use for the estimator. Passed into the module’s marginal_ll method.

  • batch_size (int | None) – Minibatch size for the forward pass. If None, defaults to scvi.settings.batch_size. Ignored if dataloader is not None.

  • return_mean (bool) – Whether to return the mean of the marginal log-likelihood or the marginal-log likelihood for each observation.

  • dataloader (Iterator[dict[str, Tensor | None]]) – An iterator over minibatches of data on which to compute the metric. The minibatches should be formatted as a dictionary of Tensor with keys as expected by the model. If None, a dataloader is created from adata.

  • **kwargs – Additional keyword arguments to pass into the module’s marginal_ll method.

Return type:

float | Tensor

Returns:

  • If True, returns the mean marginal log-likelihood. Otherwise returns a tensor of shape

  • (n_obs,) with the marginal log-likelihood for each observation.

Notes

This is not the negative log-likelihood, so higher is better.

get_reconstruction_error(adata=None, indices=None, batch_size=None, dataloader=None, return_mean=True, **kwargs)#

Compute the reconstruction error on the data.

The reconstruction error is the negative log likelihood of the data given the latent variables. It is different from the marginal log-likelihood, but still gives good insights on the modeling of the data and is fast to compute. This is typically written as \(p(x \mid z)\), the likelihood term given one posterior sample.

Parameters:
  • adata (AnnData | None) – AnnData object with var_names in the same order as the ones used to train the model. If None and dataloader is also None, it defaults to the object used to initialize the model.

  • indices (Sequence[int] | None) – Indices of observations in adata to use. If None, defaults to all observations. Ignored if dataloader is not None

  • batch_size (int | None) – Minibatch size for the forward pass. If None, defaults to scvi.settings.batch_size. Ignored if dataloader is not None

  • dataloader (Iterator[dict[str, Tensor | None]] | None) – An iterator over minibatches of data on which to compute the metric. The minibatches should be formatted as a dictionary of Tensor with keys as expected by the model. If None, a dataloader is created from adata.

  • return_mean (bool) – Whether to return the mean reconstruction loss or the reconstruction loss for each observation.

  • **kwargs – Additional keyword arguments to pass into the forward method of the module.

Return type:

Reconstruction error for the data.

Notes

This is not the negative reconstruction error, so higher is better.

property history#

Returns computed metrics during training.

property is_trained: bool#

Whether the model has been trained.

classmethod load(dir_path, adata=None, accelerator='auto', device='auto', prefix=None, backup_url=None)#

Instantiate a model from the saved output.

Parameters:
  • dir_path (str) – Path to saved outputs.

  • adata (AnnData | MuData | None) – AnnData organized in the same way as data used to train model. It is not necessary to run setup_anndata, as AnnData is validated against the saved scvi setup dictionary. If None, will check for and load anndata saved with the model.

  • accelerator (str) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”) as well as custom accelerator instances.

  • device (int | str) – The device to use. Can be set to a non-negative index (int or str) or “auto” for automatic selection based on the chosen accelerator. If set to “auto” and accelerator is not determined to be “cpu”, then device will be set to the first available device.

  • prefix (str | None) – Prefix of saved file names.

  • backup_url (str | None) – URL to retrieve saved outputs from if not present on disk.

Return type:

Model with loaded state dictionaries.

Examples

>>> model = ModelClass.load(save_path, adata)
>>> model.get_....
static load_registry(dir_path, prefix=None)#

Return the full registry saved with the model.

Parameters:
  • dir_path (str) – Path to saved outputs.

  • prefix (str | None) – Prefix of saved file names.

Return type:

The full registry saved with the model

classmethod register_manager(adata_manager)#

Registers an AnnDataManager instance with this model class.

Stores the AnnDataManager reference in a class-specific manager store. Intended for use in the setup_anndata() class method followed up by retrieval of the AnnDataManager via the _get_most_recent_anndata_manager() method in the model init method.

Notes

Subsequent calls to this method with an AnnDataManager instance referring to the same underlying AnnData object will overwrite the reference to previous AnnDataManager.

save(dir_path, prefix=None, overwrite=False, save_anndata=False, save_kwargs=None, legacy_mudata_format=False, **anndata_write_kwargs)#

Save the state of the model.

Neither the trainer optimizer state nor the trainer history are saved. Model files are not expected to be reproducibly saved and loaded across versions until we reach version 1.0.

Parameters:
  • dir_path (str) – Path to a directory.

  • prefix (str | None) – Prefix to prepend to saved file names.

  • overwrite (bool) – Overwrite existing data or not. If False and directory already exists at dir_path, error will be raised.

  • save_anndata (bool) – If True, also saves the anndata

  • save_kwargs (dict | None) – Keyword arguments passed into save().

  • legacy_mudata_format (bool) – If True, saves the model var_names in the legacy format if the model was trained with a MuData object. The legacy format is a flat array with variable names across all modalities concatenated, while the new format is a dictionary with keys corresponding to the modality names and values corresponding to the variable names for each modality.

  • anndata_write_kwargs – Kwargs for write()

property summary_string#

Summary string of the model.

property test_indices: ndarray#

Observations that are in test set.

to_device(device)#

Move model to device.

Parameters:

device (str | int) – Device to move model to. Options: ‘cpu’ for CPU, integer GPU index (eg. 0), or ‘cuda:X’ where X is the GPU index (eg. ‘cuda:0’). See torch.device for more info.

Examples

>>> adata = scvi.data.synthetic_iid()
>>> model = scvi.model.SCVI(adata)
>>> model.to_device("cpu")  # moves model to CPU
>>> model.to_device("cuda:0")  # moves model to GPU 0
>>> model.to_device(0)  # also moves model to GPU 0
train(max_epochs=None, accelerator='auto', devices='auto', train_size=0.9, validation_size=None, external_indexing=None, shuffle_set_split=True, batch_size=128, early_stopping=False, plan_kwargs=None, use_wandb=False, wandb_project=None, wandb_entity=None, wandb_run_name=None, **trainer_kwargs)#

Train the model.

Parameters:
  • max_epochs (Optional[int]) – Number of passes through the dataset. If None, defaults to np.min([round((20000 / n_cells) * 400), 400])

  • accelerator (str) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”) as well as custom accelerator instances.

  • devices (Union[int, list[int], str]) – The devices to use. Can be set to a non-negative index (int or str), a sequence of device indices (list or comma-separated str), the value -1 to indicate all available devices, or “auto” for automatic selection based on the chosen accelerator. If set to “auto” and accelerator is not determined to be “cpu”, then devices will be set to the first available device.

  • train_size (float) – Size of training set in the range [0.0, 1.0].

  • validation_size (Optional[float]) – Size of the test set. If None, defaults to 1 - train_size. If train_size + validation_size < 1, the remaining cells belong to a test set.

  • shuffle_set_split (bool) – Whether to shuffle indices before splitting. If False, the val, train, and test set are split in the sequential order of the data according to validation_size and train_size percentages.

  • external_indexing (Optional[list[array, array, array]]) – A list of data split indices in the order of training, validation, and test sets. Validation and test set are not required and can be left empty.

  • batch_size (int) – Minibatch size to use during training.

  • early_stopping (bool) – Perform early stopping. Additional arguments can be passed in **kwargs. See Trainer for further options.

  • plan_kwargs (Optional[dict]) – Keyword args for TrainingPlan. Keyword arguments passed to train() will overwrite values present in plan_kwargs, when appropriate.

  • **trainer_kwargs – Other keyword args for Trainer.

property train_indices: ndarray#

Observations that are in train set.

property validation_indices: ndarray#

Observations that are in validation set.

view_anndata_setup(adata=None, hide_state_registries=False)#

Print summary of the setup for the initial AnnData or a given AnnData object.

Parameters:
  • adata (AnnData | MuData | None) – AnnData object setup with setup_anndata or transfer_fields().

  • hide_state_registries (bool) – If True, prints a shortened summary without details of each state registry.

Return type:

None

static view_setup_args(dir_path, prefix=None)#

Print args used to setup a saved model.

Parameters:
  • dir_path (str) – Path to saved outputs.

  • prefix (str | None) – Prefix of saved file names.

Return type:

None