amici.callbacks.AttentionPenaltyMonitor

Contents

amici.callbacks.AttentionPenaltyMonitor#

class amici.callbacks.AttentionPenaltyMonitor(epoch_start=10, epoch_end=30, start_val=0.0001, end_val=0.001, flavor='log')[source]#

Bases: Callback

Methods

__init__

load_state_dict

Called when loading a checkpoint, implement to reload callback state given callback's state_dict.

on_after_backward

Called after loss.backward() and before optimizers are stepped.

on_before_backward

Called before loss.backward().

on_before_optimizer_step

Called before optimizer.step().

on_before_zero_grad

Called before optimizer.zero_grad().

on_exception

Called when any trainer execution is interrupted by an exception.

on_fit_end

Called when fit ends.

on_fit_start

Called when fit begins.

on_load_checkpoint

Called when loading a model checkpoint, use to reload state.

on_predict_batch_end

Called when the predict batch ends.

on_predict_batch_start

Called when the predict batch begins.

on_predict_end

Called when predict ends.

on_predict_epoch_end

Called when the predict epoch ends.

on_predict_epoch_start

Called when the predict epoch begins.

on_predict_start

Called when the predict begins.

on_sanity_check_end

Called when the validation sanity check ends.

on_sanity_check_start

Called when the validation sanity check starts.

on_save_checkpoint

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

on_test_batch_end

Called when the test batch ends.

on_test_batch_start

Called when the test batch begins.

on_test_end

Called when the test ends.

on_test_epoch_end

Called when the test epoch ends.

on_test_epoch_start

Called when the test epoch begins.

on_test_start

Called when the test begins.

on_train_batch_end

Called when the train batch ends.

on_train_batch_start

Called when the train batch begins.

on_train_end

Called when the train ends.

on_train_epoch_end

Called when the train epoch ends.

on_train_epoch_start

Called when the train epoch begins.

on_train_start

Called when the train begins.

on_validation_batch_end

Called when the validation batch ends.

on_validation_batch_start

Called when the validation batch begins.

on_validation_end

Called when the validation loop ends.

on_validation_epoch_end

Called when the val epoch ends.

on_validation_epoch_start

Called when the val epoch begins.

on_validation_start

Called when the validation loop begins.

setup

Called when fit, validate, test, predict, or tune begins.

state_dict

Called when saving a checkpoint, implement to generate callback's state_dict.

teardown

Called when fit, validate, test, predict, or tune ends.

Attributes

state_key

Identifier for the state of the callback.

on_train_epoch_start(trainer, pl_module)[source]#

Called when the train epoch begins.

Return type:

None

load_state_dict(state_dict)#

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict (dict[str, Any]) – the callback state returned by state_dict.

Return type:

None

on_after_backward(trainer, pl_module)#

Called after loss.backward() and before optimizers are stepped.

Return type:

None

on_before_backward(trainer, pl_module, loss)#

Called before loss.backward().

Return type:

None

on_before_optimizer_step(trainer, pl_module, optimizer)#

Called before optimizer.step().

Return type:

None

on_before_zero_grad(trainer, pl_module, optimizer)#

Called before optimizer.zero_grad().

Return type:

None

on_exception(trainer, pl_module, exception)#

Called when any trainer execution is interrupted by an exception.

Return type:

None

on_fit_end(trainer, pl_module)#

Called when fit ends.

Return type:

None

on_fit_start(trainer, pl_module)#

Called when fit begins.

Return type:

None

on_load_checkpoint(trainer, pl_module, checkpoint)#

Called when loading a model checkpoint, use to reload state.

Parameters:
  • trainer (Trainer) – the current Trainer instance.

  • pl_module (LightningModule) – the current LightningModule instance.

  • checkpoint (dict[str, Any]) – the full checkpoint dictionary that got loaded by the Trainer.

Return type:

None

on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)#

Called when the predict batch ends.

Return type:

None

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)#

Called when the predict batch begins.

Return type:

None

on_predict_end(trainer, pl_module)#

Called when predict ends.

Return type:

None

on_predict_epoch_end(trainer, pl_module)#

Called when the predict epoch ends.

Return type:

None

on_predict_epoch_start(trainer, pl_module)#

Called when the predict epoch begins.

Return type:

None

on_predict_start(trainer, pl_module)#

Called when the predict begins.

Return type:

None

on_sanity_check_end(trainer, pl_module)#

Called when the validation sanity check ends.

Return type:

None

on_sanity_check_start(trainer, pl_module)#

Called when the validation sanity check starts.

Return type:

None

on_save_checkpoint(trainer, pl_module, checkpoint)#

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:
  • trainer (Trainer) – the current Trainer instance.

  • pl_module (LightningModule) – the current LightningModule instance.

  • checkpoint (dict[str, Any]) – the checkpoint dictionary that will be saved.

Return type:

None

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)#

Called when the test batch ends.

Return type:

None

on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)#

Called when the test batch begins.

Return type:

None

on_test_end(trainer, pl_module)#

Called when the test ends.

Return type:

None

on_test_epoch_end(trainer, pl_module)#

Called when the test epoch ends.

Return type:

None

on_test_epoch_start(trainer, pl_module)#

Called when the test epoch begins.

Return type:

None

on_test_start(trainer, pl_module)#

Called when the test begins.

Return type:

None

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)#

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

Return type:

None

on_train_batch_start(trainer, pl_module, batch, batch_idx)#

Called when the train batch begins.

Return type:

None

on_train_end(trainer, pl_module)#

Called when the train ends.

Return type:

None

on_train_epoch_end(trainer, pl_module)#

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the lightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
Return type:

None

on_train_start(trainer, pl_module)#

Called when the train begins.

Return type:

None

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)#

Called when the validation batch ends.

Return type:

None

on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)#

Called when the validation batch begins.

Return type:

None

on_validation_end(trainer, pl_module)#

Called when the validation loop ends.

Return type:

None

on_validation_epoch_end(trainer, pl_module)#

Called when the val epoch ends.

Return type:

None

on_validation_epoch_start(trainer, pl_module)#

Called when the val epoch begins.

Return type:

None

on_validation_start(trainer, pl_module)#

Called when the validation loop begins.

Return type:

None

setup(trainer, pl_module, stage)#

Called when fit, validate, test, predict, or tune begins.

Return type:

None

state_dict()#

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type:

dict[str, Any]

Returns:

A dictionary containing callback state.

property state_key: str#

Identifier for the state of the callback.

Used to store and retrieve a callback’s state from the checkpoint dictionary by checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.

teardown(trainer, pl_module, stage)#

Called when fit, validate, test, predict, or tune ends.

Return type:

None