Skip to main content

Documentation Index

Fetch the complete documentation index at: https://wb-21fd5541-style-guide-models-integrations-20260527-015516.mintlify.app/llms.txt

Use this file to discover all available pages before exploring further.

Use W&B Keras callbacks to track experiments, log model checkpoints, and visualize model predictions during training. This integration is for Keras users who want to add experiment tracking and model versioning to their training workflows without rewriting their training loop. Keras callbacks are available in the wandb.integration.keras module with Python SDK versions 0.13.4 and above. The W&B Keras integration provides the following callbacks:
  • WandbMetricsLogger: Use this callback for experiment tracking. It logs your training and validation metrics along with system metrics to W&B.
  • WandbModelCheckpoint: Use this callback to log your model checkpoints to W&B Artifacts.
  • WandbEvalCallback: This base callback logs model predictions to W&B Tables for interactive visualization.

Install and import Keras integration

Install the latest version of W&B.
pip install -U wandb
To use the Keras integration, import required classes from wandb.integration.keras.
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbModelCheckpoint, WandbEvalCallback
The following sections describe each callback in detail with code examples.

Track experiments with WandbMetricsLogger

wandb.integration.keras.WandbMetricsLogger() logs Keras’ logs dictionary that callback methods such as on_epoch_end and on_batch_end take as an argument. The following partial example shows how to use WandbMetricsLogger() in a Keras workflow. First, compile the model with the desired optimizer, loss function, and metrics. Then, initialize a W&B run using wandb.init(). Finally, pass the WandbMetricsLogger() callback to model.fit().
import wandb
from wandb.integration.keras import WandbMetricsLogger
import tensorflow as tf

model.compile(
    optimizer = "adam",
    loss = "categorical_crossentropy",
    metrics = ["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top@5_accuracy')]
)

# Initialize a new W&B Run
with wandb.init(config={"batch_size": 64}) as run:

    # Pass the WandbMetricsLogger to model.fit
    model.fit(
        X_train, y_train, validation_data=(X_test, y_test), callbacks=[WandbMetricsLogger()]
    )
The previous example logs training and validation metrics such as loss, accuracy, and top@5_accuracy to W&B at the end of each epoch.

WandbMetricsLogger reference

ParameterDescription
log_freq(epoch, batch, or an int): if epoch, logs metrics at the end of each epoch. If batch, logs metrics at the end of each batch. If an int, logs metrics at the end of that many batches. Defaults to epoch.
initial_global_step(int): Use this argument to correctly log the learning rate when you resume training from some initial_epoch, and a learning rate scheduler is used. This can be computed as step_size * initial_step. Defaults to 0.

Checkpoint a model using WandbModelCheckpoint

Use the WandbModelCheckpoint callback to periodically save the Keras model (SavedModel format) or model weights and upload them to W&B as a wandb.Artifact for model versioning. This callback subclasses tf.keras.callbacks.ModelCheckpoint(), so the parent callback handles the checkpointing logic. This callback saves:
  • The model that has achieved best performance based on the monitor.
  • The model at the end of every epoch regardless of the performance.
  • The model at the end of the epoch or after a fixed number of training batches.
  • Only model weights or the whole model.
  • The model either in SavedModel format or in .h5 format.
Use this callback in conjunction with WandbMetricsLogger().
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbModelCheckpoint

# Initialize a new W&B Run
with wandb.init(config={"bs": 12}) as run:

    # Pass the WandbModelCheckpoint to model.fit
    model.fit(
        X_train,
        y_train,
        validation_data=(X_test, y_test),
        callbacks=[
            WandbMetricsLogger(),
            WandbModelCheckpoint("models"),
        ],
    )

WandbModelCheckpoint reference

ParameterDescription
filepath(str): path to save the mode file.
monitor(str): The metric name to monitor.
verbose(int): Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1 displays messages when the callback takes an action.
save_best_only(Boolean): if save_best_only=True, it only saves the latest model or the model it considers the best, according to the defined by the monitor and mode attributes.
save_weights_only(Boolean): if True, saves only the model’s weights.
mode(auto, min, or max): For val_acc, set it to max, for val_loss, set it to min, and so on
save_freq(“epoch” or int): When using “epoch”, the callback saves the model after each epoch. When using an integer, the callback saves the model at end of this many batches. When monitoring validation metrics such as val_acc or val_loss, save_freq must be set to “epoch” as those metrics are only available at the end of an epoch.
options(str): Optional tf.train.CheckpointOptions object if save_weights_only is true or optional tf.saved_model.SaveOptions object if save_weights_only is false.
initial_value_threshold(float): Floating point initial “best” value of the metric to be monitored.

Log checkpoints after N epochs

By default (save_freq="epoch"), the callback creates a checkpoint and uploads it as an artifact after each epoch. To create a checkpoint after a specific number of batches, set save_freq to an integer. To checkpoint after N epochs, compute the cardinality of the train dataloader and pass it to save_freq:
WandbModelCheckpoint(
    filepath="models/",
    save_freq=int((trainloader.cardinality()*N).numpy())
)

Log checkpoints efficiently on a TPU architecture

While checkpointing on TPUs, you might encounter the UnimplementedError: File system scheme '[local]' not implemented error message. This happens because the model directory (filepath) must use a cloud storage bucket path (gs://bucket-name/...), and this bucket must be accessible from the TPU server. Instead, W&B uses the local path for checkpointing, which W&B then uploads as an artifact.
checkpoint_options = tf.saved_model.SaveOptions(experimental_io_device="/job:localhost")

WandbModelCheckpoint(
    filepath="models/,
    options=checkpoint_options,
)

Visualize model predictions using WandbEvalCallback

WandbEvalCallback() is an abstract base class for building Keras callbacks, primarily for model prediction and, secondarily, dataset visualization. This abstract callback is independent of the dataset and the task. To use it, inherit from this base WandbEvalCallback() callback class and implement the add_ground_truth and add_model_prediction methods. WandbEvalCallback() is a utility class that provides methods to:
  • Create data and prediction wandb.Table() instances.
  • Log data and prediction Tables as wandb.Artifact().
  • Log the data table on_train_begin.
  • Log the prediction table on_epoch_end.
The following example uses WandbClfEvalCallback for an image classification task. This example callback logs the validation data (data_table) to W&B, performs inference, and logs the prediction (pred_table) to W&B at the end of every epoch.
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbEvalCallback


# Implement your model prediction visualization callback
class WandbClfEvalCallback(WandbEvalCallback):
    def __init__(
        self, validation_data, data_table_columns, pred_table_columns, num_samples=100
    ):
        super().__init__(data_table_columns, pred_table_columns)

        self.x = validation_data[0]
        self.y = validation_data[1]

    def add_ground_truth(self, logs=None):
        for idx, (image, label) in enumerate(zip(self.x, self.y)):
            self.data_table.add_data(idx, wandb.Image(image), label)

    def add_model_predictions(self, epoch, logs=None):
        preds = self.model.predict(self.x, verbose=0)
        preds = tf.argmax(preds, axis=-1)

        table_idxs = self.data_table_ref.get_index()

        for idx in table_idxs:
            pred = preds[idx]
            self.pred_table.add_data(
                epoch,
                self.data_table_ref.data[idx][0],
                self.data_table_ref.data[idx][1],
                self.data_table_ref.data[idx][2],
                pred,
            )


# ...

# Initialize a new W&B Run
with wandb.init(config={"hyper": "parameter"}) as run:

    # Add the Callbacks to Model.fit
    model.fit(
        X_train,
        y_train,
        validation_data=(X_test, y_test),
        callbacks=[
            WandbMetricsLogger(),
            WandbClfEvalCallback(
                validation_data=(X_test, y_test),
                data_table_columns=["idx", "image", "label"],
                pred_table_columns=["epoch", "idx", "image", "label", "pred"],
            ),
        ],
    )

WandbEvalCallback reference

ParameterDescription
data_table_columns(list) List of column names for the data_table
pred_table_columns(list) List of column names for the pred_table

Memory footprint details

W&B logs the data_table when invoking the on_train_begin method. After W&B uploads it as a W&B Artifact, you get a reference to this table, which you can access using the data_table_ref class variable. The data_table_ref is a 2D list that you can index like self.data_table_ref[idx][n], where idx is the row number and n is the column number. See the usage in the following example.

Customize the callback

For more control over when data and predictions are logged, you can override the default callback methods. Override the on_train_begin or on_epoch_end methods to have more fine-grained control. If you want to log the samples after N batches, you can implement the on_train_batch_end method.
If you’re implementing a callback for model prediction visualization by inheriting WandbEvalCallback and something needs to be clarified or fixed, open an issue.

Legacy WandbCallback

WandbCallback is the legacy all-in-one callback. For new projects, use the dedicated callbacks described in the previous sections (WandbMetricsLogger, WandbModelCheckpoint, and WandbEvalCallback). Use the W&B library WandbCallback() class to save all metrics and loss values tracked in model.fit().
import wandb
from wandb.integration.keras import WandbCallback

with wandb.init(config={"hyper": "parameter"}) as run:

    # code to set up your model in Keras

    # Pass the callback to model.fit
    model.fit(
        X_train, y_train, validation_data=(X_test, y_test), callbacks=[WandbCallback()]
    )
You can watch the short video Get Started with Keras and W&B in Less Than a Minute. For a more detailed video, watch Integrate W&B with Keras. You can review the Colab Jupyter Notebook. For additional sample scripts, see the W&B example repo, including a Fashion MNIST example and the W&B Dashboard it generates. The WandbCallback class supports logging configuration options: specifying a metric to monitor, tracking of weights and gradients, logging of predictions on training_data and validation_data, and more. See the reference documentation for keras.WandbCallback for full details. WandbCallback:
  • Logs history data from any metrics collected by Keras: loss and anything passed into keras_model.compile().
  • Sets summary metrics for the run associated with the “best” training step, as defined by the monitor and mode attributes. This defaults to the epoch with the minimum val_loss. By default, WandbCallback saves the model associated with the best epoch.
  • Optionally logs gradient and parameter histograms.
  • Optionally saves training and validation data for wandb to visualize.

WandbCallback reference

Arguments
monitor(str) name of metric to monitor. Defaults to val_loss.
mode(str) one of {auto, min, max}. min - save model when monitor is minimized max - save model when monitor is maximized auto - try to guess when to save the model (default).
save_modelTrue - save a model when monitor beats all previous epochs False - don’t save models
save_graph(boolean) if True save model graph to wandb (default to True).
save_weights_only(boolean) if True, saves only the model’s weights(model.save_weights(filepath)). Otherwise, saves the full model).
log_weights(boolean) if True save histograms of the model’s layer’s weights.
log_gradients(boolean) if True log histograms of the training gradients
training_data(tuple) Same format (X,y) as passed to model.fit. This is needed for calculating gradients - this is mandatory if log_gradients is True.
validation_data(tuple) Same format (X,y) as passed to model.fit. A set of data for wandb to visualize. If you set this field, wandb makes a small number of predictions every epoch and saves the results for later visualization.
generator(generator) a generator that returns validation data for wandb to visualize. This generator should return tuples (X,y). Either validate_data or generator should be set for wandb to visualize specific data examples.
validation_steps(int) if validation_data is a generator, how many steps to run the generator for the full validation set.
labels(list) If you are visualizing your data with wandb this list of labels converts numeric output to understandable string if you are building a classifier with multiple classes. For a binary classifier, you can pass in a list of two labels [label for false, label for true]. If validate_data and generator are both false, this does nothing.
predictions(int) the number of predictions to make for visualization each epoch, max is 100.
input_type(string) type of the model input to help visualization. can be one of: (image, images, segmentation_mask).
output_type(string) type of the model output to help visualization. can be one of: (image, images, segmentation_mask).
log_evaluation(boolean) if True, save a Table containing validation data and the model’s predictions at each epoch. See validation_indexes, validation_row_processor, and output_row_processor for additional details.
class_colors([float, float, float]) if the input or output is a segmentation mask, an array containing an rgb tuple (range 0-1) for each class.
log_batch_frequency(integer) if None, callback logs every epoch. If set to integer, callback logs training metrics every log_batch_frequency batches.
log_best_prefix(string) if None, saves no extra summary metrics. If set to a string, prepends the monitored metric and epoch with the prefix and saves the results as summary metrics.
validation_indexes([wandb.data_types._TableLinkMixin]) an ordered list of index keys to associate with each validation example. If log_evaluation is True and you provide validation_indexes, does not create a Table of validation data. Instead, associates each prediction with the row represented by the TableLinkMixin. To obtain a list of row keys, use Table.get_index() .
validation_row_processor(Callable) a function to apply to the validation data, commonly used to visualize the data. The function receives an ndx (int) and a row (dict). If your model has a single input, then row["input"] contains the input data for the row. Otherwise, it contains the names of the input slots. If your fit function takes a single target, then row["target"] contains the target data for the row. Otherwise, it contains the names of the output slots. For example, if your input data is a single array, to visualize the data as an Image, provide lambda ndx, row: {"img": wandb.Image(row["input"])} as the processor. Ignored if log_evaluation is False or validation_indexes are present.
output_row_processor(Callable) same as validation_row_processor, but applied to the model’s output. row["output"] contains the results of the model output.
infer_missing_processors(Boolean) Determines whether to infer validation_row_processor and output_row_processor if they are missing. Defaults to True. If you provide labels, W&B attempts to infer classification-type processors where appropriate.
log_evaluation_frequency(int) Determines how often to log evaluation results. Defaults to 0 to log only at the end of training. Set to 1 to log every epoch, 2 to log every other epoch, and so on. Has no effect when log_evaluation is False.

Frequently asked questions

Use Keras multiprocessing with wandb

When you set use_multiprocessing=True, this error might occur:
Error("You must call wandb.init() before wandb.config.batch_size")
To work around it:
  1. In the Sequence class construction, add: wandb.init(group='...').
  2. In main, make sure you use if __name__ == "__main__": and put the rest of your script logic inside it.