Introduction

A callback is a powerful tool to customize the behavior of a Keras model during training, evaluation, or inference. Examples include callback_tensorboard() to visualize training progress and results with TensorBoard, or callback_model_checkpoint() to periodically save your model during training.

In this guide, you will learn what a Keras callback is, what it can do, and how you can build your own. We provide a few demos of simple callback applications to get you started.

Setup

library(tensorflow)
library(keras)
envir::import_from(magrittr, `%<>%`)
envir::import_from(dplyr, last)

tf_version()

Keras callbacks overview

All callbacks subclass the keras$callbacks$Callback class, and override a set of methods called at various stages of training, testing, and predicting. Callbacks are useful to get a view on internal states and statistics of the model during training.

You can pass a list of callbacks (as a named argument callbacks) to the following keras model methods:

An overview of callback methods

Global methods

on_(train|test|predict)_begin(self, logs=None)

Called at the beginning of fit/evaluate/predict.

on_(train|test|predict)_end(self, logs=None)

Called at the end of fit/evaluate/predict.

Batch-level methods for training/testing/predicting

on_(train|test|predict)_batch_begin(self, batch, logs=None)

Called right before processing a batch during training/testing/predicting.

on_(train|test|predict)_batch_end(self, batch, logs=None)

Called at the end of training/testing/predicting a batch. Within this method, logs is a dict containing the metrics results.

Epoch-level methods (training only)

on_epoch_begin(self, epoch, logs=None)

Called at the beginning of an epoch during training.

on_epoch_end(self, epoch, logs=None)

Called at the end of an epoch during training.

A basic example

Let’s take a look at a concrete example. To get started, let’s import tensorflow and define a simple Sequential Keras model:

get_model <- function() {
  model <- keras_model_sequential() %>% 
    layer_dense(1, input_shape = 784) %>% 
    compile(
      optimizer = optimizer_rmsprop(learning_rate=0.1),
      loss = "mean_squared_error",
      metrics = "mean_absolute_error"
    )
  model
}

Then, load the MNIST data for training and testing from Keras datasets API:

mnist <- dataset_mnist()

flatten_and_rescale <- function(x) {
  x <- array_reshape(x, c(-1, 784)) 
  x <- x / 255
  x
}

mnist$train$x <- flatten_and_rescale(mnist$train$x) 
mnist$test$x  <- flatten_and_rescale(mnist$test$x) 

# limit to 500 samples
mnist$train$x <- mnist$train$x[1:500,]
mnist$train$y <- mnist$train$y[1:500]
mnist$test$x  <- mnist$test$x[1:500,]
mnist$test$y  <- mnist$test$y[1:500]

Now, define a simple custom callback that logs:

  • When fit/evaluate/predict starts & ends
  • When each epoch starts & ends
  • When each training batch starts & ends
  • When each evaluation (test) batch starts & ends
  • When each inference (prediction) batch starts & ends
show <- function(msg, logs) {
  cat(glue::glue(msg, .envir = parent.frame()),
      "got logs: ", sep = "; ")
  str(logs); cat("\n")
}

CustomCallback(keras$callbacks$Callback) %py_class% {
  on_train_begin <- function(logs = NULL)
    show("Starting training", logs)
  
  on_train_end <- function(logs = NULL)
    show("Stop training", logs)
  
  on_epoch_begin <- function(epoch, logs = NULL)
    show("Start epoch {epoch} of training", logs)
  
  on_epoch_end <- function(epoch, logs = NULL)
    show("End epoch {epoch} of training", logs)
  
  on_test_begin <- function(logs = NULL)
    show("Start testing", logs)
  
  on_test_end <- function(logs = NULL)
    show("Stop testing", logs)
  
  on_predict_begin <- function(logs = NULL)
    show("Start predicting", logs)
  
  on_predict_end <- function(logs = NULL)
    show("Stop predicting", logs)
  
  on_train_batch_begin <- function(batch, logs = NULL)
    show("...Training: start of batch {batch}", logs)
  
  on_train_batch_end <- function(batch, logs = NULL)
    show("...Training: end of batch {batch}",  logs)
  
  on_test_batch_begin <- function(batch, logs = NULL)
    show("...Evaluating: start of batch {batch}", logs)
  
  on_test_batch_end <- function(batch, logs = NULL)
    show("...Evaluating: end of batch {batch}", logs)
  
  on_predict_batch_begin <- function(batch, logs = NULL)
    show("...Predicting: start of batch {batch}", logs)
  
  on_predict_batch_end <- function(batch, logs = NULL)
    show("...Predicting: end of batch {batch}", logs)
}

Let’s try it out:

model <- get_model()
model %>% fit(
  mnist$train$x,
  mnist$train$y,
  batch_size = 128,
  epochs = 2,
  verbose = 0,
  validation_split = 0.5,
  callbacks = list(CustomCallback())
)
res <- model %>%
  evaluate(
    mnist$test$x,
    mnist$test$y,
    batch_size = 128,
    verbose = 0,
    callbacks = list(CustomCallback())
  )
res <- model %>%
  predict(mnist$test$x,
          batch_size = 128,
          callbacks = list(CustomCallback()))

Usage of logs dict

The logs dict contains the loss value, and all the metrics at the end of a batch or epoch. Example includes the loss and mean absolute error.

LossAndErrorPrintingCallback(keras$callbacks$Callback) %py_class% {
  on_train_batch_end <- function(batch, logs = NULL)
    cat(sprintf("Up to batch %i, the average loss is %7.2f.\n",
                batch,  logs$loss))
  
  on_test_batch_end <- function(batch, logs = NULL)
    cat(sprintf("Up to batch %i, the average loss is %7.2f.\n",
                batch, logs$loss))
  
  on_epoch_end <- function(epoch, logs = NULL)
    cat(sprintf(
      "The average loss for epoch %2i is %9.2f and mean absolute error is %7.2f.\n",
      epoch, logs$loss, logs$mean_absolute_error
    ))
}

model <- get_model()
model %>% fit(
  mnist$train$x,
  mnist$train$y,
  batch_size = 128,
  epochs = 2,
  verbose = 0,
  callbacks = list(LossAndErrorPrintingCallback())
)

res = model %>% evaluate(
  mnist$test$x,
  mnist$test$y,
  batch_size = 128,
  verbose = 0,
  callbacks = list(LossAndErrorPrintingCallback())
)

Usage of self$model attribute

In addition to receiving log information when one of their methods is called, callbacks have access to the model associated with the current round of training/evaluation/inference: self$model.

Here are of few of the things you can do with self$model in a callback:

  • Set self$model$stop_training <- TRUE to immediately interrupt training.
  • Mutate hyperparameters of the optimizer (available as self$model$optimizer), such as self$model$optimizer$learning_rate.
  • Save the model at period intervals.
  • Record the output of predict(model) on a few test samples at the end of each epoch, to use as a sanity check during training.
  • Extract visualizations of intermediate features at the end of each epoch, to monitor what the model is learning over time.
  • etc.

Let’s see this in action in a couple of examples.

Examples of Keras callback applications

Early stopping at minimum loss

This first example shows the creation of a Callback that stops training when the minimum of loss has been reached, by setting the attribute self$model$stop_training (boolean). Optionally, you can provide an argument patience to specify how many epochs we should wait before stopping after having reached a local minimum.

keras$callbacks$EarlyStopping provides a more complete and general implementation.

EarlyStoppingAtMinLoss(keras$callbacks$Callback) %py_class% {
  "Stop training when the loss is at its min, i.e. the loss stops decreasing.

  Arguments:
      patience: Number of epochs to wait after min has been hit. After this
        number of no improvement, training stops.
  "
  
  initialize <- function(patience = 0) {
    # call keras$callbacks$Callback$__init__(), so it can setup `self`
    super$initialize()
    self$patience <- patience
    # best_weights to store the weights at which the minimum loss occurs.
    self$best_weights <- NULL
  }
  
  on_train_begin <- function(logs = NULL) {
    # The number of epoch it has waited when loss is no longer minimum.
    self$wait <- 0
    # The epoch the training stops at.
    self$stopped_epoch <- 0
    # Initialize the best as infinity.
    self$best <- Inf
  }
  
  on_epoch_end <- function(epoch, logs = NULL) {
    current <- logs$loss
    if (current < self$best) {
      self$best <- current
      self$wait <- 0
      # Record the best weights if current results is better (less).
      self$best_weights <- self$model$get_weights()
    } else {
      self$wait %<>% `+`(1)
      if (self$wait >= self$patience) {
        self$stopped_epoch <- epoch
        self$model$stop_training <- TRUE
        cat("Restoring model weights from the end of the best epoch.\n")
        self$model$set_weights(self$best_weights)
      }
    }
  }
  
  on_train_end <- function(logs = NULL)
    if (self$stopped_epoch > 0)
      cat(sprintf("Epoch %05d: early stopping\n", self$stopped_epoch + 1))
  
}


model <- get_model()
model %>% fit(
  mnist$train$x,
  mnist$train$y,
  batch_size = 64,
  steps_per_epoch = 5,
  epochs = 30,
  verbose = 0,
  callbacks = list(LossAndErrorPrintingCallback(), 
                   EarlyStoppingAtMinLoss())
)

Learning rate scheduling

In this example, we show how a custom Callback can be used to dynamically change the learning rate of the optimizer during the course of training.

See keras$callbacks$LearningRateScheduler for a more general implementations (in RStudio, press F1 while the cursor is over LearningRateScheduler and a browser will open to this page).

CustomLearningRateScheduler(keras$callbacks$Callback) %py_class% {
  "Learning rate scheduler which sets the learning rate according to schedule.

  Arguments:
      schedule: a function that takes an epoch index
          (integer, indexed from 0) and current learning rate
          as inputs and returns a new learning rate as output (float).
  "
  
  `__init__` <- function(schedule) {
    super()$`__init__`()
    self$schedule <- schedule
  }
  
  on_epoch_begin <- function(epoch, logs = NULL) {
    ## When in doubt about what types of objects are in scope (e.g., self$model)
    ## use a debugger to interact with the actual objects at the console!
    # browser()
    
    if (!"learning_rate" %in% names(self$model$optimizer))
      stop('Optimizer must have a "learning_rate" attribute.')
    
    # # Get the current learning rate from model's optimizer.
    # use as.numeric() to convert the tf.Variable to an R numeric
    lr <- as.numeric(self$model$optimizer$learning_rate)
    # # Call schedule function to get the scheduled learning rate.
    scheduled_lr <- self$schedule(epoch, lr)
    # # Set the value back to the optimizer before this epoch starts
    self$model$optimizer$learning_rate <- scheduled_lr
    cat(sprintf("\nEpoch %05d: Learning rate is %6.4f.\n", epoch, scheduled_lr))
  }
}


LR_SCHEDULE <- tibble::tribble(~ start_epoch, ~ learning_rate,
                               0, .1,
                               3, 0.05,
                               6, 0.01,
                               9, 0.005,
                               12, 0.001)


lr_schedule <- function(epoch, learning_rate) {
  "Helper function to retrieve the scheduled learning rate based on epoch."
  if (epoch <= last(LR_SCHEDULE$start_epoch))
    with(LR_SCHEDULE, learning_rate[which.min(epoch > start_epoch)])
  else
    learning_rate
}


model <- get_model()
model %>% fit(
  mnist$train$x,
  mnist$train$y,
  batch_size = 64,
  steps_per_epoch = 5,
  epochs = 15,
  verbose = 0,
  callbacks = list(
    LossAndErrorPrintingCallback(),
    CustomLearningRateScheduler(lr_schedule)
  )
)

Built-in Keras callbacks

Be sure to check out the existing Keras callbacks by reading the API docs. Applications include logging to CSV, saving the model, visualizing metrics in TensorBoard, and a lot more!