Passer au contenu principal
Utilisez les callbacks Keras pour suivre les expériences, enregistrer les points de contrôle du modèle et visualiser les prédictions du modèle. Les callbacks Keras sont disponibles dans le module wandb.integration.keras à partir de la version 0.13.4 du SDK Python. L’intégration Keras de W&B fournit les callbacks suivants :
  • WandbMetricsLogger : Utilisez ce callback pour le suivi des expériences. Il enregistre vos métriques d’entraînement et de validation, ainsi que les métriques système, dans W&B.
  • WandbModelCheckpoint : Utilisez ce callback pour enregistrer les points de contrôle de votre modèle dans les Artifacts W&B.
  • WandbEvalCallback: Ce callback de base enregistre les prédictions du modèle dans les Tables W&B pour une visualisation interactive.

Installer et importer l’intégration Keras

Installez la dernière version de W&B.
pip install -U wandb
Pour utiliser l’intégration Keras, importez les classes requises à partir de wandb.integration.keras :
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbModelCheckpoint, WandbEvalCallback
Les sections suivantes décrivent en détail chaque fonction de rappel, avec des exemples de code.

Suivre les expériences avec WandbMetricsLogger

wandb.integration.keras.WandbMetricsLogger() enregistre automatiquement le dictionnaire logs de Keras, passé en argument aux méthodes de callback telles que on_epoch_end, on_batch_end, etc. L’exemple partiel ci-dessous montre comment utiliser WandbMetricsLogger() dans un flux de travail Keras. Commencez par compiler le modèle avec l’optimiseur, la fonction de perte et les métriques souhaités. Ensuite, initialisez un run W&B à l’aide de wandb.init(). Enfin, transmettez le callback WandbMetricsLogger() à model.fit().
import wandb
from wandb.integration.keras import WandbMetricsLogger
import tensorflow as tf

model.compile(
    optimizer = "adam",
    loss = "categorical_crossentropy",
    metrics = ["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top@5_accuracy')]
)

# Initialiser un nouveau run W&B
with wandb.init(config={"batch_size": 64}) as run:

    # Passer le WandbMetricsLogger à model.fit
    model.fit(
        X_train, y_train, validation_data=(X_test, y_test), callbacks=[WandbMetricsLogger()]
    )
L’exemple précédent consigne dans W&B les métriques d’entraînement et de validation, telles que loss, accuracy et top@5_accuracy, à la fin de chaque époque. Il consigne également :

Référence de WandbMetricsLogger

ParamètreDescription
log_freq(epoch, batch ou un int) : si epoch, journalise les métriques à la fin de chaque époque. Si batch, journalise les métriques à la fin de chaque lot. Si c’est un int, journalise les métriques après ce nombre de lots. Valeur par défaut : epoch.
initial_global_step(int) : utilisez cet argument pour journaliser correctement le taux d’apprentissage lorsque vous reprenez l’entraînement à partir d’un initial_epoch et qu’un ordonnanceur de taux d’apprentissage est utilisé. Cette valeur peut être calculée comme step_size * initial_step. La valeur par défaut est 0.

Créer un checkpoint d’un modèle avec WandbModelCheckpoint

Utilisez le callback WandbModelCheckpoint pour enregistrer périodiquement le modèle Keras (au format SavedModel) ou les poids du modèle, puis les envoyer à W&B sous forme de wandb.Artifact pour la gestion des versions du modèle. Ce callback est une sous-classe de tf.keras.callbacks.ModelCheckpoint(). La logique de checkpointing est donc gérée par le callback parent. Ce callback enregistre :
  • Le modèle ayant obtenu les meilleures performances en fonction de la métrique surveillée.
  • Le modèle à la fin de chaque époque, indépendamment des performances.
  • Le modèle à la fin de l’époque ou après un nombre fixe de lots d’entraînement.
  • Soit uniquement les poids du modèle, soit le modèle complet.
  • Le modèle soit au format SavedModel, soit au format .h5.
Utilisez ce callback avec WandbMetricsLogger().
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbModelCheckpoint

# Initialiser un nouveau run W&B
with wandb.init(config={"bs": 12}) as run:

    # Passer le WandbModelCheckpoint à model.fit
    model.fit(
        X_train,
        y_train,
        validation_data=(X_test, y_test),
        callbacks=[
            WandbMetricsLogger(),
            WandbModelCheckpoint("models"),
        ],
    )

Référence WandbModelCheckpoint

ParamètreDescription
filepath(str) : chemin d’enregistrement du fichier de modèle.
monitor(str) : nom de la métrique à surveiller.
verbose(int) : mode de verbosité, 0 ou 1. Le mode 0 est silencieux et le mode 1 affiche des messages lorsque le callback effectue une action.
save_best_only(Boolean) : si save_best_only=True, enregistre uniquement le dernier modèle ou le modèle considéré comme le meilleur, selon les attributs monitor et mode.
save_weights_only(Boolean) : si True, enregistre uniquement les poids du modèle.
mode(auto, min ou max) : pour val_acc, définissez-le sur max ; pour val_loss, définissez-le sur min, et ainsi de suite.
save_freq(“epoch” ou int) : lorsque vous utilisez epoch, le callback enregistre le modèle après chaque époque. Lorsque vous utilisez un entier, le callback enregistre le modèle à la fin de ce nombre de lots. Notez que lors de la surveillance de métriques de validation telles que val_acc ou val_loss, save_freq doit être défini sur “epoch”, car ces métriques ne sont disponibles qu’à la fin d’une époque.
options(str) : objet tf.train.CheckpointOptions facultatif si save_weights_only est true, ou objet tf.saved_model.SaveOptions facultatif si save_weights_only est false.
initial_value_threshold(float) : valeur initiale à virgule flottante du « meilleur » résultat de la métrique à surveiller.

Enregistrer des checkpoints après N époques

Par défaut (save_freq="epoch"), le callback crée un checkpoint et le téléverse comme artifact après chaque époque. Pour créer un checkpoint après un nombre précis de lots, définissez save_freq sur un entier. Pour créer un checkpoint après N époques, calculez la cardinalité du dataloader train et passez-la à save_freq :
WandbModelCheckpoint(
    filepath="models/",
    save_freq=int((trainloader.cardinality()*N).numpy())
)

Journaliser efficacement les checkpoints sur une architecture TPU

Lors de la création de checkpoints sur des TPU, vous pouvez rencontrer le message d’erreur UnimplementedError: File system scheme '[local]' not implemented. Cela se produit parce que le répertoire du modèle (filepath) doit utiliser un chemin de bucket de stockage cloud (gs://bucket-name/...), et que ce bucket doit être accessible depuis le serveur TPU. À la place, W&B utilise le chemin local pour créer les checkpoints, qui sont ensuite téléversés en tant qu’artifact.
checkpoint_options = tf.saved_model.SaveOptions(experimental_io_device="/job:localhost")

WandbModelCheckpoint(
    filepath="models/,
    options=checkpoint_options,
)

Visualiser les prédictions du modèle avec WandbEvalCallback

WandbEvalCallback() est une classe de base abstraite permettant de créer des callbacks Keras, principalement pour la prédiction de modèle et, dans un second temps, pour la visualisation du jeu de données. Ce callback abstrait est indépendant du jeu de données et de la tâche. Pour l’utiliser, héritez de cette classe de base WandbEvalCallback() et implémentez les méthodes add_ground_truth et add_model_prediction. WandbEvalCallback() est une classe utilitaire qui fournit des méthodes pour :
  • Créer des instances wandb.Table() pour les données et les prédictions.
  • Journaliser les Tables de données et de prédictions en tant que wandb.Artifact().
  • Journaliser le tableau de données dans on_train_begin.
  • Journaliser le tableau de prédictions dans on_epoch_end.
L’exemple suivant utilise WandbClfEvalCallback pour une tâche de classification d’images. Ce callback d’exemple journalise les données de validation (data_table) dans W&B, effectue l’inférence, puis journalise les prédictions (pred_table) dans W&B à la fin de chaque époque.
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbEvalCallback


# Implémenter votre callback de visualisation des prédictions du modèle
class WandbClfEvalCallback(WandbEvalCallback):
    def __init__(
        self, validation_data, data_table_columns, pred_table_columns, num_samples=100
    ):
        super().__init__(data_table_columns, pred_table_columns)

        self.x = validation_data[0]
        self.y = validation_data[1]

    def add_ground_truth(self, logs=None):
        for idx, (image, label) in enumerate(zip(self.x, self.y)):
            self.data_table.add_data(idx, wandb.Image(image), label)

    def add_model_predictions(self, epoch, logs=None):
        preds = self.model.predict(self.x, verbose=0)
        preds = tf.argmax(preds, axis=-1)

        table_idxs = self.data_table_ref.get_index()

        for idx in table_idxs:
            pred = preds[idx]
            self.pred_table.add_data(
                epoch,
                self.data_table_ref.data[idx][0],
                self.data_table_ref.data[idx][1],
                self.data_table_ref.data[idx][2],
                pred,
            )


# ...

# Initialiser un nouveau run W&B
with wandb.init(config={"hyper": "parameter"}) as run:

    # Ajouter les callbacks à Model.fit
    model.fit(
        X_train,
        y_train,
        validation_data=(X_test, y_test),
        callbacks=[
            WandbMetricsLogger(),
            WandbClfEvalCallback(
                validation_data=(X_test, y_test),
                data_table_columns=["idx", "image", "label"],
                pred_table_columns=["epoch", "idx", "image", "label", "pred"],
            ),
        ],
    )

Référence de WandbEvalCallback

ParamètreDescription
data_table_columns(liste) Liste des noms de colonnes du data_table
pred_table_columns(liste) Liste des noms de colonnes du pred_table

Détails de l’empreinte mémoire

Nous journalisons data_table dans W&B lorsque la méthode on_train_begin est appelée. Une fois téléversé en tant qu’Artifact W&B, nous obtenons une référence à ce tableau, accessible via la variable de classe data_table_ref. data_table_ref est une liste 2D qui peut être indexée comme self.data_table_ref[idx][n], où idx est le numéro de ligne et n le numéro de colonne. Voyons son utilisation dans l’exemple ci-dessous.

Personnaliser le callback

Vous pouvez redéfinir les méthodes on_train_begin ou on_epoch_end pour bénéficier d’un contrôle plus précis. Si vous souhaitez journaliser les échantillons après N lots, vous pouvez implémenter la méthode on_train_batch_end.
Si vous implémentez un callback pour visualiser les prédictions du modèle en héritant de WandbEvalCallback et que certains points doivent être clarifiés ou corrigés, faites-le-nous savoir en ouvrant une issue.

WandbCallback [obsolète]

Utilisez la classe WandbCallback() de la bibliothèque W&B pour enregistrer automatiquement toutes les métriques et valeurs de perte suivies dans model.fit().
import wandb
from wandb.integration.keras import WandbCallback

with wandb.init(config={"hyper": "parameter"}) as run:

    # code pour configurer votre modèle dans Keras

    # Passer le callback à model.fit
    model.fit(
        X_train, y_train, validation_data=(X_test, y_test), callbacks=[WandbCallback()]
    )
Vous pouvez regarder la courte vidéo Premiers pas avec Keras et W&B en moins d’une minute. Pour une vidéo plus détaillée, regardez Intégrer W&B à Keras. Vous pouvez également consulter le notebook Jupyter Colab.
Voir notre dépôt d’exemples pour des scripts, notamment un exemple Fashion MNIST et le tableau de bord W&B qu’il génère.
La classe WandbCallback prend en charge un large éventail d’options de configuration de journalisation : spécifier une métrique à surveiller, suivre les poids et les gradients, journaliser les prédictions sur training_data et validation_data, entre autres. Consultez la documentation de référence de keras.WandbCallback pour plus de détails. Le WandbCallback
  • Journalise automatiquement les données d’historique de toutes les métriques collectées par Keras : la perte et tout ce qui est transmis à keras_model.compile().
  • Définit les métriques de synthèse pour le run associé à la « meilleure » étape d’entraînement, telle que définie par les attributs monitor et mode. Par défaut, il s’agit de l’époque où val_loss est minimale. Par défaut, WandbCallback enregistre le modèle associé à la meilleure epoch.
  • Journalise de façon facultative l’histogramme des gradients et des paramètres.
  • Enregistre de façon facultative les données d’entraînement et de validation pour que wandb puisse les visualiser.

Référence de WandbCallback

Arguments
monitor(str) nom de la métrique à surveiller. Valeur par défaut : val_loss.
mode(str) l’une des valeurs {auto, min, max}. min - enregistrer le modèle lorsque la métrique surveillée est minimisée max - enregistrer le modèle lorsque la métrique surveillée est maximisée auto - essayer de déterminer quand enregistrer le modèle (par défaut).
save_modelTrue - enregistre un modèle lorsque monitor surpasse toutes les époques précédentes False - n’enregistre pas de modèles
save_graph(booléen) si True, enregistre le graphe du modèle dans wandb (True par défaut).
save_weights_only(booléen) si True, enregistre uniquement les poids du modèle (model.save_weights(filepath)). Sinon, enregistre le modèle complet).
log_weights(booléen) si True, enregistre les histogrammes des poids des couches du modèle.
log_gradients(booléen) si True, journalise les histogrammes des gradients d’entraînement
training_data(tuple) Même format (X,y) que celui transmis à model.fit. Nécessaire pour calculer les gradients ; ce paramètre est obligatoire si log_gradients est True.
validation_data(tuple) Même format (X,y) que celui transmis à model.fit. Un jeu de données à visualiser par wandb. Si vous définissez ce champ, à chaque époque, wandb effectue un petit nombre de prédictions et enregistre les résultats pour une visualisation ultérieure.
generator(générateur) un générateur qui renvoie des données de validation que wandb peut visualiser. Ce générateur doit renvoyer des tuples (X,y). Vous devez définir validate_data ou generator pour que wandb puisse visualiser des exemples de données spécifiques.
validation_steps(int) si validation_data est un générateur, nombre d’étapes nécessaires pour exécuter le générateur sur l’ensemble complet de validation.
labels(liste) Si vous visualisez vos données avec wandb, cette liste d’étiquettes convertit les sorties numériques en chaînes compréhensibles si vous créez un classificateur à plusieurs classes. Pour un classificateur binaire, vous pouvez fournir une liste de deux étiquettes [label for false, label for true]. Si validate_data et generator sont tous deux à false, cela n’a aucun effet.
prédictions(int) le nombre de prédictions à effectuer pour la visualisation à chaque époque, avec un maximum de 100.
input_type(string) type de l’entrée du modèle pour faciliter la visualisation. Peut être l’une des valeurs suivantes : (image, images, segmentation_mask).
output_type(string) type de la sortie du modèle pour faciliter la visualisation. Peut être l’une des valeurs suivantes : (image, images, segmentation_mask).
log_evaluation(booléen) si True, enregistre un tableau contenant les données de validation et les prédictions du modèle à chaque époque. Voir validation_indexes, validation_row_processor et output_row_processor pour plus de détails.
class_colors([float, float, float]) si l’entrée ou la sortie est un masque de segmentation, un tableau contenant un tuple RGB (plage 0-1) pour chaque classe.
log_batch_frequency(entier) si None, le callback journalise chaque époque. S’il est défini sur un entier, le callback journalise les métriques d’entraînement tous les log_batch_frequency lots.
log_best_prefix(chaîne) si None, n’enregistre aucune métrique summary supplémentaire. Si une chaîne est définie, ajoute ce préfixe à la métrique surveillée et à l’époque, puis enregistre les résultats comme métriques summary.
validation_indexes([wandb.data_types._TableLinkMixin]) une liste ordonnée de clés d’index à associer à chaque exemple de validation. Si log_evaluation est défini sur True et que vous fournissez validation_indexes, cela ne crée pas de Table de données de validation. À la place, chaque prédiction est associée à la ligne représentée par le TableLinkMixin. Pour obtenir une liste de clés de ligne, utilisez Table.get_index() .
validation_row_processor(Callable) une fonction à appliquer aux données de validation, généralement utilisée pour visualiser les données. La fonction reçoit un ndx (int) et un row (dict). Si votre modèle n’a qu’une seule entrée, alors row["input"] contient les données d’entrée de la ligne. Sinon, il contient les noms des slots d’entrée. Si votre fonction fit prend une seule cible, alors row["target"] contient les données cibles de la ligne. Sinon, il contient les noms des slots de sortie. Par exemple, si vos données d’entrée correspondent à un seul tableau, pour visualiser les données sous forme d’image, fournissez lambda ndx, row: {"img": wandb.Image(row["input"])} comme fonction de traitement. Ignoré si log_evaluation vaut False ou si validation_indexes sont présents.
output_row_processor(Callable) identique à validation_row_processor, mais appliqué à la sortie du modèle. row["output"] contient les résultats du modèle.
infer_missing_processors(Boolean) Détermine si validation_row_processor et output_row_processor doivent être inférés s’ils sont absents. La valeur par défaut est True. Si vous fournissez labels, W&B essaie d’inférer des processeurs de type Classification lorsque c’est pertinent.
log_evaluation_frequency(int) Détermine à quelle fréquence les résultats d’évaluation sont consignés dans le journal. La valeur par défaut est 0, afin de les consigner uniquement à la fin de l’entraînement. Définissez 1 pour les consigner à chaque époque, 2 pour une époque sur deux, et ainsi de suite. N’a aucun effet lorsque log_evaluation est défini sur False.

Questions fréquentes

Comment puis-je utiliser le mode multiprocessing de Keras avec wandb ?

Lorsque vous définissez use_multiprocessing=True, cette erreur peut se produire :
Error("You must call wandb.init() before wandb.config.batch_size")
Pour contourner ce problème :
  1. Lors de l’instanciation de la classe Sequence, ajoutez : wandb.init(group='...').
  2. Dans main, assurez-vous d’utiliser if __name__ == "__main__": et placez-y le reste de la logique de votre script.