Passer au contenu principal
PyTorch Lightning fournit un wrapper léger pour organiser votre code PyTorch et ajouter facilement des fonctionnalités avancées telles que l’entraînement distribué et la précision en 16 bits. W&B fournit un wrapper léger pour consigner vos expériences de ML. Mais vous n’avez pas besoin de combiner les deux vous-même : W&B est directement intégré à la bibliothèque PyTorch Lightning via le WandbLogger.

Intégrer Lightning

from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer

wandb_logger = WandbLogger(log_model="all")
trainer = Trainer(logger=wandb_logger)
Utilisation de wandb.log() : Le WandbLogger envoie les journaux vers W&B en utilisant le global_step du Trainer. Si vous effectuez des appels supplémentaires à wandb.log directement dans votre code, n’utilisez pas l’argument step dans wandb.log().À la place, enregistrez le global_step du Trainer comme vos autres métriques :
wandb.log({"accuracy":0.99, "trainer/global_step": step})
Tableaux de bord interactifs

Inscrivez-vous et créez une clé API

Une clé API permet d’authentifier votre machine auprès de W&B. Vous pouvez générer une clé API depuis votre profil.
Pour une méthode plus directe, créez une clé API en accédant directement aux Paramètres utilisateur. Copiez immédiatement la clé API nouvellement créée et conservez-la dans un endroit sûr, par exemple dans un gestionnaire de mots de passe.
  1. Cliquez sur l’icône de votre profil en haut à droite.
  2. Sélectionnez Paramètres utilisateur, puis faites défiler jusqu’à la section Clés API.

Installez la bibliothèque wandb et connectez-vous

Pour installer la bibliothèque wandb localement et vous connecter :
  1. Définissez la variable d’environnement WANDB_API_KEY sur votre clé API.
    export WANDB_API_KEY=<your_api_key>
    
  2. Installez la bibliothèque wandb et connectez-vous.
    pip install wandb
    
    wandb login
    

Utiliser le WandbLogger de PyTorch Lightning

PyTorch Lightning propose plusieurs classes WandbLogger pour consigner des métriques, ainsi que les poids du modèle, des médias, et bien plus encore. Pour l’utiliser avec Lightning, instanciez WandbLogger, puis transmettez-le à Trainer ou à Fabric de Lightning.
trainer = Trainer(logger=wandb_logger)

Arguments courants du logger

Vous trouverez ci-dessous quelques-uns des paramètres les plus utilisés dans WandbLogger. Consultez la documentation de PyTorch Lightning pour en savoir plus sur tous les arguments du logger.
ParamètreDescription
projectDéfinit dans quel projet wandb journaliser
nameDonne un nom à votre run wandb
log_modelJournalise tous les modèles si log_model="all", ou à la fin de l’entraînement si log_model=True
save_dirChemin où les données sont enregistrées

Journalisez vos hyperparamètres

class LitModule(LightningModule):
    def __init__(self, *args, **kwarg):
        self.save_hyperparameters()

Journalisez des paramètres de configuration supplémentaires

# ajouter un paramètre
wandb_logger.experiment.config["key"] = value

# ajouter plusieurs paramètres
wandb_logger.experiment.config.update({key1: val1, key2: val2})

# utiliser directement le module wandb
wandb.config["key"] = value
wandb.config.update()

Journaliser les gradients, l’histogramme des paramètres et la topologie du modèle

Vous pouvez passer l’objet de votre modèle à wandblogger.watch() pour surveiller les gradients et les paramètres de votre modèle pendant l’entraînement. Voir la documentation de WandbLogger pour PyTorch Lightning

Journaliser des métriques

Vous pouvez journaliser vos métriques dans W&B lorsque vous utilisez WandbLogger en appelant self.log('my_metric_name', metric_vale) dans votre LightningModule, par exemple dans les méthodes training_step ou validation_step.L’extrait de code ci-dessous montre comment définir votre LightningModule pour journaliser vos métriques ainsi que les hyperparamètres de votre LightningModule. Cet exemple utilise la bibliothèque torchmetrics pour calculer vos métriques.
import torch
from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy
from lightning.pytorch import LightningModule


class My_LitModule(LightningModule):
    def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):
        """méthode utilisée pour définir les paramètres du modèle"""
        super().__init__()

        # les images mnist ont la forme (1, 28, 28) (canaux, largeur, hauteur)
        self.layer_1 = Linear(28 * 28, n_layer_1)
        self.layer_2 = Linear(n_layer_1, n_layer_2)
        self.layer_3 = Linear(n_layer_2, n_classes)

        self.loss = CrossEntropyLoss()
        self.lr = lr

        # enregistre les hyperparamètres dans self.hparams (journalisés automatiquement par W&B)
        self.save_hyperparameters()

    def forward(self, x):
        """méthode utilisée pour l'inférence, de l'entrée vers la sortie"""

        # (b, 1, 28, 28) -> (b, 1*28*28)
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)

        # effectuons 3 x (linear + relu)
        x = F.relu(self.layer_1(x))
        x = F.relu(self.layer_2(x))
        x = self.layer_3(x)
        return x

    def training_step(self, batch, batch_idx):
        """doit renvoyer une perte pour un seul lot"""
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Journaliser la perte et la métrique
        self.log("train_loss", loss)
        self.log("train_accuracy", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        """utilisée pour journaliser des métriques"""
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # Journaliser la perte et la métrique
        self.log("val_loss", loss)
        self.log("val_accuracy", acc)
        return preds

    def configure_optimizers(self):
        """définit l'optimiseur du modèle"""
        return Adam(self.parameters(), lr=self.lr)

    def _get_preds_loss_accuracy(self, batch):
        """fonction utilitaire, car les étapes train/valid/test sont similaires"""
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = accuracy(preds, y)
        return preds, loss, acc

Journaliser les valeurs min/max d’une métrique

Avec la fonction define_metric de wandb, vous pouvez définir si vous souhaitez que votre métrique de synthèse W&B affiche la valeur minimale, maximale, moyenne ou la meilleure valeur pour cette métrique. Si define_metric _ n’est pas utilisé, la dernière valeur enregistrée apparaîtra dans vos métriques de synthèse. Voir la documentation de référence de define_metric ici et le guide ici pour en savoir plus. Pour indiquer à W&B de suivre la précision de validation maximale dans la métrique de synthèse W&B, appelez wandb.define_metric une seule fois, au début de l’entraînement :
class My_LitModule(LightningModule):
    ...

    def validation_step(self, batch, batch_idx):
        if trainer.global_step == 0:
            wandb.define_metric("val_accuracy", summary="max")

        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # Journaliser la perte et la métrique
        self.log("val_loss", loss)
        self.log("val_accuracy", acc)
        return preds

Créer un point de contrôle d’un modèle

Pour enregistrer des points de contrôle du modèle en tant qu’Artifacts W&B Artifacts, utilisez le callback Lightning ModelCheckpoint et définissez l’argument log_model dans WandbLogger.
trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback])
Les alias latest et best sont automatiquement définis pour vous permettre de récupérer facilement un point de contrôle du modèle depuis un Artifact W&B :
# la référence peut être récupérée dans le panneau des Artifacts
# "VERSION" peut être une version (ex : "v2") ou un alias ("latest" ou "best")
checkpoint_reference = "USER/PROJECT/MODEL-RUN_ID:VERSION"
# télécharger le point de contrôle en local (s'il n'est pas déjà en cache)
wandb_logger.download_artifact(checkpoint_reference, artifact_type="model")
# charger le point de contrôle
model = LitModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")
Les points de contrôle de modèle que vous enregistrez sont visibles dans l’UI W&B Artifacts et incluent la traçabilité complète du modèle (voir un exemple de point de contrôle de modèle dans l’UI ici). Pour mettre vos meilleurs points de contrôle de modèle en favoris et les centraliser pour votre équipe, vous pouvez les lier au registre de modèles W&B. Vous pouvez y organiser vos meilleurs modèles par tâche, gérer le cycle de vie des modèles, faciliter leur suivi et leur audit tout au long du cycle de vie du ML, et automatiser les actions en aval avec des webhooks ou des jobs.

Journaliser des images, du texte et plus encore

WandbLogger fournit les méthodes log_image, log_text et log_table pour journaliser des médias. Vous pouvez aussi appeler directement wandb.log ou trainer.logger.experiment.log pour journaliser d’autres types de médias comme Audio, les molécules, les nuages de points, les objets 3D, etc.
# utiliser des tenseurs, des tableaux numpy ou des images PIL
wandb_logger.log_image(key="samples", images=[img1, img2])

# ajouter des légendes
wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"])

# utiliser un chemin de fichier
wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])

# utiliser .log dans le trainer
trainer.logger.experiment.log(
    {"samples": [wandb.Image(img, caption=caption) for (img, caption) in my_images]},
    step=current_trainer_global_step,
)
Vous pouvez utiliser le système de callbacks de Lightning pour contrôler quand vous journalisez dans W&B via WandbLogger. Dans cet exemple, nous journalisons un échantillon de nos images de validation et de nos prédictions :
import torch
import wandb
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger

# or
# from wandb.integration.lightning.fabric import WandbLogger


class LogPredictionSamplesCallback(Callback):
    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
    ):
        """Called when the validation batch ends."""

        # `outputs` provient de `LightningModule.validation_step`
        # ce qui correspond ici aux prédictions de notre modèle

        # Let's log 20 sample image predictions from the first batch
        if batch_idx == 0:
            n = 20
            x, y = batch
            images = [img for img in x[:n]]
            captions = [
                f"Ground Truth: {y_i} - Prediction: {y_pred}"
                for y_i, y_pred in zip(y[:n], outputs[:n])
            ]

            # Option 1: log images with `WandbLogger.log_image`
            wandb_logger.log_image(key="sample_images", images=images, caption=captions)

            # Option 2: log images and predictions as a W&B Table
            columns = ["image", "ground truth", "prediction"]
            data = [
                [wandb.Image(x_i), y_i, y_pred] or x_i,
                y_i,
                y_pred in list(zip(x[:n], y[:n], outputs[:n])),
            ]
            wandb_logger.log_table(key="sample_table", columns=columns, data=data)


trainer = pl.Trainer(callbacks=[LogPredictionSamplesCallback()])

Utiliser plusieurs GPU avec Lightning et W&B

PyTorch Lightning prend en charge le multi-GPU via son interface DDP. Cependant, la conception de PyTorch Lightning exige que vous fassiez attention à la façon dont vous instanciez vos GPU. Lightning part du principe que chaque GPU (ou rang) dans votre boucle d’entraînement doit être instancié exactement de la même manière, avec les mêmes conditions initiales. Cependant, seul le processus de rang 0 a accès à l’objet wandb.run, et pour les processus de rang non nul : wandb.run = None. Cela peut faire échouer vos processus non nuls. Une telle situation peut vous placer dans un interblocage, car le processus de rang 0 attendra que les processus de rang non nul le rejoignent, alors qu’ils ont déjà planté. Pour cette raison, faites attention à la façon dont vous configurez votre code d’entraînement. La méthode recommandée consiste à rendre votre code indépendant de l’objet wandb.run.
class MNISTClassifier(pl.LightningModule):
    def __init__(self):
        super(MNISTClassifier, self).__init__()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)

        self.log("train/loss", loss)
        return {"train_loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)

        self.log("val/loss", loss)
        return {"val_loss": loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


def main():
    # Définir toutes les graines aléatoires sur la même valeur.
    # C'est important dans un contexte d'entraînement distribué.
    # Chaque rang obtiendra son propre ensemble de poids initiaux.
    # S'ils ne correspondent pas, les gradients ne correspondront pas non plus,
    # ce qui risque de mener à un entraînement qui ne converge pas.
    pl.seed_everything(1)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

    model = MNISTClassifier()
    wandb_logger = WandbLogger(project="<project_name>")
    callbacks = [
        ModelCheckpoint(
            dirpath="checkpoints",
            every_n_train_steps=100,
        ),
    ]
    trainer = pl.Trainer(
        max_epochs=3, gpus=2, logger=wandb_logger, strategy="ddp", callbacks=callbacks
    )
    trainer.fit(model, train_loader, val_loader)

Exemples

Vous pouvez suivre ce tutoriel dans une vidéo avec un notebook Colab.

Foire aux questions

Comment W&B s’intègre-t-il à Lightning ?

L’intégration principale s’appuie sur l’API Lightning loggers, qui vous permet d’écrire une grande partie de votre code de journalisation de manière indépendante du framework. Les Logger sont transmis au Lightning Trainer et sont déclenchés par le riche système de hooks et de callbacks de cette API. Cela permet de bien séparer votre code de recherche du code d’ingénierie et de journalisation.

Que journalise l’intégration sans code supplémentaire ?

Nous enregistrerons les points de contrôle du modèle dans W&B, où vous pourrez les consulter ou les télécharger pour les utiliser dans de futurs runs. Nous capturerons également les métriques système, comme l’utilisation du GPU et les entrées/sorties réseau, des informations sur l’environnement, comme les caractéristiques du matériel et du système d’exploitation, l’état du code (y compris le commit git et le patch diff, le contenu du notebook et l’historique de la session), ainsi que tout ce qui est affiché dans la sortie standard.

Que faire si j’ai besoin d’utiliser wandb.run dans ma configuration d’entraînement ?

Vous devez vous-même élargir la portée de la variable à laquelle vous souhaitez accéder. En d’autres termes, veillez à ce que les conditions initiales soient identiques dans tous les processus.
if os.environ.get("LOCAL_RANK", None) is None:
    os.environ["WANDB_DIR"] = wandb.run.dir
Si c’est le cas, vous pouvez utiliser os.environ["WANDB_DIR"] pour configurer le répertoire des points de contrôle du modèle. Ainsi, tout processus de rang non nul peut accéder à wandb.run.dir.