Skip to main content

Documentation Index

Fetch the complete documentation index at: https://wb-21fd5541-style-guide-models-integrations-20260527-015516.mintlify.app/llms.txt

Use this file to discover all available pages before exploring further.

PyTorch Lightning provides a lightweight wrapper for organizing your PyTorch code and adding advanced features such as distributed training and 16-bit precision. W&B provides a lightweight wrapper for logging your ML experiments. You don’t need to combine the two yourself: W&B is incorporated directly into the PyTorch Lightning library through the WandbLogger. This page shows you how to use WandbLogger to track metrics, log hyperparameters, save model checkpoints as artifacts, log media, and run multi-GPU training with PyTorch Lightning and W&B.

Integrate with Lightning

The following sections show how to authenticate with W&B, install the wandb library, and attach a WandbLogger to your Lightning Trainer or Fabric instance.
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer

wandb_logger = WandbLogger(log_model="all")
trainer = Trainer(logger=wandb_logger)
Using wandb.log(): The WandbLogger logs to W&B using the Trainer’s global_step. If you make additional calls to wandb.log() directly in your code, don’t use the step argument in wandb.log().Instead, log the Trainer’s global_step like your other metrics:
wandb.log({"accuracy":0.99, "trainer/global_step": step})
Interactive dashboards

Sign up and create an API key

An API key authenticates your machine to W&B. You can generate an API key from your user profile.
For a more streamlined approach, create an API key by going directly to User Settings. Copy the newly created API key immediately and save it in a secure location such as a password manager.
To generate an API key from your user profile:
  1. Click your user profile icon in the upper right corner.
  2. Select User Settings, then scroll to the API Keys section.

Install the wandb library and log in

To install the wandb library locally and log in:
  1. Set the WANDB_API_KEY environment variable to your API key.
    export WANDB_API_KEY=[YOUR-API-KEY]
    
  2. Install the wandb library and log in.
    pip install wandb
    
    wandb login
    

Use PyTorch Lightning’s WandbLogger

PyTorch Lightning has multiple WandbLogger classes to log metrics, model weights, and media. Choose the class that matches your training setup: To integrate with Lightning, instantiate the WandbLogger and pass it to Lightning’s Trainer or Fabric.
trainer = Trainer(logger=wandb_logger)

Common logger arguments

The following table lists common parameters for WandbLogger. Review the PyTorch Lightning documentation for details about all logger arguments.
ParameterDescription
projectDefines which W&B project to log to
nameNames your W&B run
log_modelLogs all models if log_model="all" or at end of training if log_model=True
save_dirPath where data is saved

Log your hyperparameters

Logging hyperparameters with W&B lets you compare runs and reproduce results. Use the method that matches your logger:
class LitModule(LightningModule):
    def __init__(self, *args, **kwarg):
        self.save_hyperparameters()

Log additional config parameters

To capture extra configuration values alongside your hyperparameters, update the run config directly:
# add one parameter
wandb_logger.experiment.config["key"] = value

# add multiple parameters
wandb_logger.experiment.config.update({key1: val1, key2: val2})

# use directly wandb module
wandb.config["key"] = value
wandb.config.update()

Log gradients, parameter histogram and model topology

Pass your model object to wandblogger.watch() to monitor your model’s gradients and parameters as you train. See the PyTorch Lightning WandbLogger documentation.

Log metrics

To log your metrics to W&B when using the WandbLogger, call self.log('my_metric_name', metric_vale) within your LightningModule, such as in your training_step or validation_step methods.The following code snippet shows how to define your LightningModule to log your metrics and your LightningModule hyperparameters. This example uses the torchmetrics library to calculate your metrics.
import torch
from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy
from lightning.pytorch import LightningModule


class My_LitModule(LightningModule):
    def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):
        """method used to define the model parameters"""
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = Linear(28 * 28, n_layer_1)
        self.layer_2 = Linear(n_layer_1, n_layer_2)
        self.layer_3 = Linear(n_layer_2, n_classes)

        self.loss = CrossEntropyLoss()
        self.lr = lr

        # save hyper-parameters to self.hparams (auto-logged by W&B)
        self.save_hyperparameters()

    def forward(self, x):
        """method used for inference input -> output"""

        # (b, 1, 28, 28) -> (b, 1*28*28)
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)

        # apply 3 x (linear + relu)
        x = F.relu(self.layer_1(x))
        x = F.relu(self.layer_2(x))
        x = self.layer_3(x)
        return x

    def training_step(self, batch, batch_idx):
        """needs to return a loss from a single batch"""
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log("train_loss", loss)
        self.log("train_accuracy", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        """used for logging metrics"""
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log("val_loss", loss)
        self.log("val_accuracy", acc)
        return preds

    def configure_optimizers(self):
        """defines model optimizer"""
        return Adam(self.parameters(), lr=self.lr)

    def _get_preds_loss_accuracy(self, batch):
        """convenience function since train/valid/test steps are similar"""
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = accuracy(preds, y)
        return preds, loss, acc

Log the min/max of a metric

Using W&B’s define_metric function, you can define whether your W&B summary metric displays the min, max, mean, or best value for that metric. If define_metric isn’t used, the last value logged appears in your summary metrics. For more information, see the customize logging axes guide. To track the max validation accuracy in the W&B summary metric, call wandb.define_metric() only once, at the beginning of training:
class My_LitModule(LightningModule):
    ...

    def validation_step(self, batch, batch_idx):
        if trainer.global_step == 0:
            wandb.define_metric("val_accuracy", summary="max")

        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log("val_loss", loss)
        self.log("val_accuracy", acc)
        return preds

Checkpoint a model

Saving checkpoints as W&B artifacts gives you versioned model files you can retrieve later by run, alias, or version. To save model checkpoints as W&B Artifacts, use the Lightning ModelCheckpoint callback and set the log_model argument in the WandbLogger.
trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback])
The latest and best aliases are set automatically to make it easier to retrieve a model checkpoint from a W&B Artifact:
# reference can be retrieved in artifacts panel
# "VERSION" can be a version (for example, "v2") or an alias ("latest" or "best")
checkpoint_reference = "[USER]/[PROJECT]/[MODEL-RUN_ID]:[VERSION]"
# download checkpoint locally (if not already cached)
wandb_logger.download_artifact(checkpoint_reference, artifact_type="model")
# load checkpoint
model = LitModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")
The model checkpoints you log are viewable through the W&B Artifacts UI, and include the full model lineage (see an example model checkpoint in the UI). To bookmark your best model checkpoints and centralize them across your team, link them to the W&B Model Registry. In the Registry, you can organize your best models by task, manage model lifecycle, track and audit throughout the ML lifecycle, and automate downstream actions with webhooks or jobs.

Log images, text, and more

The WandbLogger has log_image, log_text, and log_table methods for logging media. You can also call wandb.log() or trainer.logger.experiment.log() directly to log other media types such as Audio, Molecules, Point Clouds, and 3D Objects.
# using tensors, numpy arrays or PIL images
wandb_logger.log_image(key="samples", images=[img1, img2])

# adding captions
wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"])

# using file path
wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])

# using .log in the trainer
trainer.logger.experiment.log(
    {"samples": [wandb.Image(img, caption=caption) for (img, caption) in my_images]},
    step=current_trainer_global_step,
)
Use Lightning’s Callbacks system to control when you log to W&B through the WandbLogger. The following example logs a sample of validation images and predictions:
import torch
import wandb
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger

# or
# from wandb.integration.lightning.fabric import WandbLogger


class LogPredictionSamplesCallback(Callback):
    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
    ):
        """Called when the validation batch ends."""

        # `outputs` comes from `LightningModule.validation_step`
        # which corresponds to our model predictions in this case

        # Log 20 sample image predictions from the first batch
        if batch_idx == 0:
            n = 20
            x, y = batch
            images = [img for img in x[:n]]
            captions = [
                f"Ground Truth: {y_i} - Prediction: {y_pred}"
                for y_i, y_pred in zip(y[:n], outputs[:n])
            ]

            # Option 1: log images with `WandbLogger.log_image`
            wandb_logger.log_image(key="sample_images", images=images, caption=captions)

            # Option 2: log images and predictions as a W&B Table
            columns = ["image", "ground truth", "prediction"]
            data = [
                [wandb.Image(x_i), y_i, y_pred] or x_i,
                y_i,
                y_pred in list(zip(x[:n], y[:n], outputs[:n])),
            ]
            wandb_logger.log_table(key="sample_table", columns=columns, data=data)


trainer = pl.Trainer(callbacks=[LogPredictionSamplesCallback()])

Use multiple GPUs with Lightning and W&B

When you run distributed training, the way you reference wandb.run across ranks can affect whether training proceeds or deadlocks. This section explains the requirements and shows a recommended pattern. PyTorch Lightning supports multi-GPU through its DDP Interface. However, PyTorch Lightning’s design requires you to be careful about how you instantiate your GPUs. Lightning requires each GPU (or rank) in your training loop to be instantiated in exactly the same way, with the same initial conditions. However, only the rank 0 process gets access to the wandb.run object. For non-zero rank processes, wandb.run = None. This can cause your non-zero processes to fail. Such a situation can put you in a deadlock because the rank 0 process waits for the non-zero rank processes to join, which have already crashed. For this reason, be careful about how you set up your training code. The recommended approach is to make your code independent of the wandb.run object.
class MNISTClassifier(pl.LightningModule):
    def __init__(self):
        super(MNISTClassifier, self).__init__()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)

        self.log("train/loss", loss)
        return {"train_loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)

        self.log("val/loss", loss)
        return {"val_loss": loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


def main():
    # Setting all the random seeds to the same value.
    # This is important in a distributed training setting.
    # Each rank will get its own set of initial weights.
    # If they don't match up, the gradients will not match either,
    # leading to training that may not converge.
    pl.seed_everything(1)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

    model = MNISTClassifier()
    wandb_logger = WandbLogger(project="[PROJECT-NAME]")
    callbacks = [
        ModelCheckpoint(
            dirpath="checkpoints",
            every_n_train_steps=100,
        ),
    ]
    trainer = pl.Trainer(
        max_epochs=3, gpus=2, logger=wandb_logger, strategy="ddp", callbacks=callbacks
    )
    trainer.fit(model, train_loader, val_loader)

Examples

For an end-to-end walkthrough, you can follow along in a video tutorial with a Colab notebook.

Frequently asked questions

How does W&B integrate with Lightning?

The core integration is based on the Lightning loggers API, which lets you write much of your logging code in a framework-independent way. Logger instances are passed to the Lightning Trainer and are triggered based on that API’s rich hook-and-callback system. This keeps your research code well separated from engineering and logging code.

What does the integration log without any additional code?

W&B saves your model checkpoints, where you can view them or download them for use in future runs. W&B also captures system metrics, like GPU usage and network I/O. It captures environment information, like hardware and OS information. It captures code state, including Git commit and diff patch, notebook contents, and session history. It also captures anything printed to standard out.

What if I need to use wandb.run in my training setup?

You need to expand the scope of the variable you need to access yourself. In other words, make sure that the initial conditions are the same on all processes.
if os.environ.get("LOCAL_RANK", None) is None:
    os.environ["WANDB_DIR"] = wandb.run.dir
If they are, you can use os.environ["WANDB_DIR"] to set up the model checkpoints directory. This way, any non-zero rank process can access wandb.run.dir.