Source code for amici.callbacks._callbacks
import os
import warnings
from contextlib import nullcontext, redirect_stdout
import numpy as np
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
def _log_attention_penalty_coef_update(epoch, epoch_start, epoch_end, start_val, end_val):
if epoch < epoch_start:
return 0
elif epoch_start <= epoch < epoch_end:
log_space = np.logspace(
np.log10(start_val),
np.log10(end_val),
num=(epoch_end - epoch_start),
endpoint=True,
base=10.0,
)
return log_space[epoch - epoch_start]
else:
return end_val
def _linear_attention_penalty_coef_update(epoch, epoch_start, epoch_end, start_val, end_val):
if epoch < epoch_start:
return 0
elif epoch_start <= epoch < epoch_end:
return start_val + (end_val - start_val) * (epoch - epoch_start) / (epoch_end - epoch_start)
else:
return end_val
[docs]
class AttentionPenaltyMonitor(Callback):
def __init__(
self,
epoch_start=10,
epoch_end=30,
start_val=1e-4,
end_val=1e-3,
flavor="log",
):
self.epoch_start = epoch_start
self.epoch_end = epoch_end
self.start_val = start_val
self.end_val = end_val
self.flavor = flavor
[docs]
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
if self.flavor == "log":
attention_penalty_coef = _log_attention_penalty_coef_update(
pl_module.current_epoch,
self.epoch_start,
self.epoch_end,
self.start_val,
self.end_val,
)
elif self.flavor == "linear":
attention_penalty_coef = _linear_attention_penalty_coef_update(
pl_module.current_epoch,
self.epoch_start,
self.epoch_end,
self.start_val,
self.end_val,
)
pl_module.module.attention_penalty_coef = attention_penalty_coef
[docs]
class ModelInterpretationLogging(Callback):
def __init__(self, n_epochs_plot: int = 1, verbose: bool = False):
self.epoch = 0
self.n_epochs_plot = n_epochs_plot
self.verbose = verbose
[docs]
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
self.epoch += 1
if self.epoch % self.n_epochs_plot == 0:
model = trainer._model
with (
open(os.devnull, "w") as f,
redirect_stdout(f) if self.verbose else nullcontext(),
warnings.catch_warnings() if self.verbose else nullcontext(),
):
warnings.simplefilter("ignore")
attention_patterns = model.get_attention_patterns(model.adata, epoch=self.epoch, wandb_log=True)
attention_patterns.plot_attention_summary(wandb_log=True)
explained_variance_scores = model.get_expl_variance_scores(
model.adata, epoch=self.epoch, wandb_log=True
)
explained_variance_scores.plot_explained_variance_barplot(wandb_log=True)