メインコンテンツへスキップ
Keras のコールバックを使って Experiments をトラッキングし、モデルのチェックポイントをログし、モデルの予測を可視化します。Keras のコールバックは、Python SDK バージョン 0.13.4 以降で wandb.integration.keras モジュールから利用できます。 W&B Keras インテグレーションは次のコールバックを提供します:
  • WandbMetricsLogger : このコールバックを Experiment Tracking に使用します。学習および検証メトリクスに加えて、システムメトリクスを W&B にログします。
  • WandbModelCheckpoint : このコールバックを使用して、モデルのチェックポイントを W&B の Artifacts にログします。
  • WandbEvalCallback: このベースコールバックは、インタラクティブな可視化のためにモデルの予測を W&B の Tables にログします。

Keras インテグレーションをインストールしてインポートする

最新バージョンの W&B をインストールします。
pip install -U wandb
Keras インテグレーションを使用するには、wandb.integration.keras から必要なクラスをインポートします。
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbModelCheckpoint, WandbEvalCallback
以降のセクションでは、各コールバックについてコード例とともに詳しく説明します。

WandbMetricsLogger を使って実験をトラッキングする

Colab で試す wandb.integration.keras.WandbMetricsLogger() は、on_epoch_endon_batch_end などのコールバックメソッドが引数として受け取る、Keras の logs 辞書を自動的に記録します。 以下のコード例は、Keras のワークフローで WandbMetricsLogger() を使用する方法を示しています。まず、使用したいオプティマイザ、損失関数、およびメトリクスでモデルをコンパイルします。次に、wandb.init() を使って W&B run を初期化します。最後に、WandbMetricsLogger() コールバックを model.fit() に渡します。
import wandb
from wandb.integration.keras import WandbMetricsLogger
import tensorflow as tf

model.compile(
    optimizer = "adam",
    loss = "categorical_crossentropy",
    metrics = ["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top@5_accuracy')]
)

# 新しい W&B Run を初期化する
with wandb.init(config={"batch_size": 64}) as run:

    # WandbMetricsLogger を model.fit に渡す
    model.fit(
        X_train, y_train, validation_data=(X_test, y_test), callbacks=[WandbMetricsLogger()]
    )
先ほどの例では、各エポックの終了時に lossaccuracytop@5_accuracy などの学習および検証メトリクスを W&B に記録します。また、次の情報も記録します:

WandbMetricsLogger リファレンス

ParameterDescription
log_freq(epochbatch、または int): epoch の場合、各エポックの終了時にメトリクスをログします。batch の場合、各バッチの終了時にメトリクスをログします。int の場合、その数のバッチごとにメトリクスをログします。デフォルトは epoch です。
initial_global_step(int): initial_epoch から学習を再開し、かつ学習率スケジューラを使用している場合に、学習率を正しくログするために使用します。これは step_size * initial_step として計算されます。デフォルトは 0 です。

WandbModelCheckpoint を使ってモデルをチェックポイントする

Colab で試す WandbModelCheckpoint コールバックを使用すると、Keras モデル(SavedModel 形式)またはモデル重みを定期的に保存し、モデルのバージョニングのためにそれらを wandb.Artifact として W&B にアップロードできます。 このコールバックは tf.keras.callbacks.ModelCheckpoint() を継承しているため、チェックポイント処理のロジックは親コールバックが担当します。 このコールバックは次のものを保存します:
  • monitor で指定した指標に基づき、最高の性能を達成したモデル。
  • 性能に関わらず、各エポックの終了時点のモデル。
  • 各エポックの終了時、または一定数の学習バッチごとに処理したあとのモデル。
  • モデル重みのみ、またはモデル全体。
  • SavedModel 形式または .h5 形式のいずれかのモデル。
このコールバックは WandbMetricsLogger() と併用してください。
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbModelCheckpoint

# 新しい W&B run を初期化する
with wandb.init(config={"bs": 12}) as run:

    # WandbModelCheckpoint を model.fit に渡す
    model.fit(
        X_train,
        y_train,
        validation_data=(X_test, y_test),
        callbacks=[
            WandbMetricsLogger(),
            WandbModelCheckpoint("models"),
        ],
    )

WandbModelCheckpoint リファレンス

ParameterDescription
filepath(str): モデルファイルを保存するパス。\
monitor(str): 監視するメトリクス名。
verbose(int): 詳細度レベル。0 または 1。0 の場合は出力なし、1 の場合はコールバックがアクションを実行したときにメッセージを表示。
save_best_only(Boolean): save_best_only=True の場合、monitormode 属性で定義された条件に従い、最新のモデルまたは最良とみなしたモデルのみを保存。
save_weights_only(Boolean): True の場合、モデルの重みのみを保存。
mode(auto, min, または max): val_acc の場合は maxval_loss の場合は min に設定し、それ以外も同様。
save_freq(“epoch” または int): "epoch" を使用する場合、コールバックは各エポック後にモデルを保存。整数を使用する場合、そのバッチ数ぶんの処理が終了するたびにモデルを保存。val_accval_loss などの検証メトリクスを監視する場合、これらのメトリクスはエポックの最後にのみ利用可能なため、save_freq は必ず "epoch" に設定する必要がある。
options(str): save_weights_onlytrue の場合はオプションの tf.train.CheckpointOptions オブジェクト、save_weights_onlyfalse の場合はオプションの tf.saved_model.SaveOptions オブジェクト。
initial_value_threshold(float): 監視対象メトリクスに対する浮動小数点数の初期「ベスト」値。

N エポックごとにチェックポイントを記録する

デフォルト(save_freq="epoch")では、コールバックは各エポック終了後にチェックポイントを作成し、それをアーティファクトとしてアップロードします。特定のバッチ数ごとにチェックポイントを作成するには、save_freq を整数に設定します。N エポックごとにチェックポイントを作成するには、train データローダーの要素数(cardinality)を計算し、その値を save_freq に渡します:
WandbModelCheckpoint(
    filepath="models/",
    save_freq=int((trainloader.cardinality()*N).numpy())
)

TPU アーキテクチャでチェックポイントを効率的に記録する

TPU 上でチェックポイントを保存する際、UnimplementedError: File system scheme '[local]' not implemented というエラーメッセージが表示される場合があります。これは、モデルディレクトリ(filepath)にはクラウドストレージバケットのパス(gs://bucket-name/...)を指定する必要があり、かつそのバケットが TPU サーバーからアクセス可能でなければならないために発生します。一方で、W&B はチェックポイントの保存にローカルパスを使用し、そのローカルパス上のデータをアーティファクトとしてアップロードします。
checkpoint_options = tf.saved_model.SaveOptions(experimental_io_device="/job:localhost")

WandbModelCheckpoint(
    filepath="models/,
    options=checkpoint_options,
)

WandbEvalCallback を使ってモデルの予測を可視化する

Colab で試す WandbEvalCallback() は、主にモデルの予測用、次にデータセットの可視化用の Keras コールバックを構築するための抽象基底クラスです。 この抽象コールバックは、データセットやタスクに依存しない設計になっています。これを使うには、この WandbEvalCallback() 基底コールバッククラスを継承し、add_ground_truthadd_model_prediction メソッドを実装します。 WandbEvalCallback() は、次のメソッドを提供するユーティリティクラスです:
  • データと予測の wandb.Table() インスタンスを作成する。
  • データと予測のテーブルを wandb.Artifact() としてログする。
  • on_train_begin でデータテーブルをログする。
  • on_epoch_end で予測テーブルをログする。
次の例では、画像分類タスクに WandbClfEvalCallback を使用します。このサンプルコールバックは、検証データ(data_table)を W&B にログし、推論を実行し、各エポックの最後に予測(pred_table)を W&B にログします。
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbEvalCallback


# モデル予測の可視化コールバックを実装する
class WandbClfEvalCallback(WandbEvalCallback):
    def __init__(
        self, validation_data, data_table_columns, pred_table_columns, num_samples=100
    ):
        super().__init__(data_table_columns, pred_table_columns)

        self.x = validation_data[0]
        self.y = validation_data[1]

    def add_ground_truth(self, logs=None):
        for idx, (image, label) in enumerate(zip(self.x, self.y)):
            self.data_table.add_data(idx, wandb.Image(image), label)

    def add_model_predictions(self, epoch, logs=None):
        preds = self.model.predict(self.x, verbose=0)
        preds = tf.argmax(preds, axis=-1)

        table_idxs = self.data_table_ref.get_index()

        for idx in table_idxs:
            pred = preds[idx]
            self.pred_table.add_data(
                epoch,
                self.data_table_ref.data[idx][0],
                self.data_table_ref.data[idx][1],
                self.data_table_ref.data[idx][2],
                pred,
            )


# ...

# 新しい W&B Run を初期化する
with wandb.init(config={"hyper": "parameter"}) as run:

    # Model.fit にコールバックを追加する
    model.fit(
        X_train,
        y_train,
        validation_data=(X_test, y_test),
        callbacks=[
            WandbMetricsLogger(),
            WandbClfEvalCallback(
                validation_data=(X_test, y_test),
                data_table_columns=["idx", "image", "label"],
                pred_table_columns=["epoch", "idx", "image", "label", "pred"],
            ),
        ],
    )

WandbEvalCallback リファレンス

ParameterDescription
data_table_columns(list) data_table の列名リスト
pred_table_columns(list) pred_table の列名リスト

メモリフットプリントの詳細

on_train_begin メソッドが呼び出されたタイミングで、data_table を W&B にログします。これが W&B のアーティファクトとしてアップロードされると、そのテーブルを参照するオブジェクトを取得でき、クラス変数 data_table_ref を使ってアクセスできます。data_table_ref は 2 次元リストであり、self.data_table_ref[idx][n] のようにインデックス指定できます。ここで、idx は行番号、n は列番号です。以下の例でその使い方を見てみましょう。

コールバックをカスタマイズする

より細かく制御したい場合は、on_train_beginon_epoch_end メソッドをオーバーライドできます。N 個のバッチごとにサンプルを記録したい場合は、on_train_batch_end メソッドを実装してください。
WandbEvalCallback を継承してモデル予測の可視化用コールバックを実装しており、明確化や修正が必要な点があれば、issue を作成してお知らせください。

WandbCallback [レガシー]

W&B ライブラリの WandbCallback() クラスを使用して、model.fit() で追跡されるすべてのメトリクスや損失値を自動的に保存します。
import wandb
from wandb.integration.keras import WandbCallback

with wandb.init(config={"hyper": "parameter"}) as run:

    # Kerasでモデルをセットアップするコード

    # コールバックをmodel.fitに渡す
    model.fit(
        X_train, y_train, validation_data=(X_test, y_test), callbacks=[WandbCallback()]
    )
1 分以内で始められる短い動画 Get Started with Keras and W&B in Less Than a Minute を視聴できます。 より詳しい内容については、Integrate W&B with Keras の動画を視聴してください。Colab Jupyter Notebook も参照できます。
スクリプトについては、example repo を参照してください。Fashion MNIST example や、それによって生成される W&B Dashboard などが含まれています。
WandbCallback クラスは、多様なロギング設定オプションをサポートします。監視するメトリクスの指定、weight と gradient のトラッキング、training_data および validation_data 上での予測のログ取得などが可能です。 詳細については、keras.WandbCallback のリファレンスドキュメントを参照してください。 WandbCallback は次のことを行います。
  • Keras が収集したメトリクス(loss と keras_model.compile() に渡されたもの)から履歴データを自動的にログします。
  • monitormode 属性で定義される「最適」な学習ステップに対応する run に対して、サマリーメトリクスを設定します。デフォルトでは、val_loss が最小のエポックになります。WandbCallback は、デフォルトで最良のエポックに対応するモデルを保存します。
  • オプションで、gradient とパラメータのヒストグラムをログします。
  • オプションで、wandb が可視化できるように学習データと検証データを保存します。

WandbCallback リファレンス

引数
monitor(str) 監視対象とする指標名。デフォルトは val_loss
mode(str) {auto, min, max} のいずれか。min - 監視対象の値が最小になったときにモデルを保存します max - 監視対象の値が最大になったときにモデルを保存します auto - モデルをいつ保存するかを自動判別しようとします(デフォルト)。
save_modelTrue - 監視対象がこれまでのすべてのエポックでの値を上回ったときにモデルを保存します。False - モデルを保存しません。
save_graph(boolean) True の場合、モデルのグラフを wandb に保存します (デフォルトは True)。
save_weights_only(boolean) True の場合は、model.save_weights(filepath) を使用してモデルの重みのみを保存します。それ以外の場合はモデル全体を保存します)。
log_weights(boolean) True の場合、モデルの各層の重みのヒストグラムを保存します。
log_gradients(boolean) True の場合、学習時の勾配のヒストグラムを記録します。
training_data(tuple) model.fit に渡される (X,y) と同じ形式です。勾配を計算するために必要で、log_gradientsTrue の場合は必須です。
validation_data(tuple) model.fit に渡される (X,y) と同じ形式。wandb が可視化するためのデータセット。このフィールドを設定すると、wandb は各エポックで少数の予測を行い、その結果を後で可視化できるように保存します。
generator(generator) wandb が可視化するための検証データを返す generator。この generator はタプル (X, y) を返す必要があります。wandb が特定のデータ例を可視化できるようにするには、validate_data または generator のいずれかを設定する必要があります。
validation_steps(int) validation_data が generator の場合、検証セット全体を処理するために generator を何ステップ実行するか。
labels(list) wandb でデータを可視化する場合、多クラス分類器を構築しているときには、このラベルのリストによって数値出力が理解しやすい文字列に変換されます。二値分類器の場合は、2 つのラベル [label for false, label for true] を含むリストを渡すことができます。validate_datagenerator がどちらも false の場合、これは何もしません。
predictions(int) 可視化のために各エポックで実行する予測の数。最大は 100。
input_type(string) 可視化のためのモデル入力の型。次のいずれかです: (image, images, segmentation_mask)。
output_type(string) 可視化に役立てるためのモデル出力の型。指定可能な値は次のいずれかです: (image, images, segmentation_mask).
log_evaluation(boolean) True の場合、検証データと各エポックにおけるモデルの予測を含む Table を保存します。詳細については、validation_indexesvalidation_row_processoroutput_row_processor を参照してください。
class_colors([float, float, float]) 入力または出力がセグメンテーションマスクの場合、各クラスに対する RGB タプル (0〜1 の範囲) を含む配列。
log_batch_frequency(integer) None の場合、コールバックは毎エポックでログを記録します。整数を指定した場合、コールバックは log_batch_frequency バッチごとに学習メトリクスを記録します。
log_best_prefix(string) None の場合、追加のサマリーメトリクスは保存しません。文字列を指定した場合、その接頭辞を監視対象のメトリクスとエポックに付けて、結果をサマリーメトリクスとして保存します。
validation_indexes([wandb.data_types._TableLinkMixin]) 各検証サンプルに関連付けるインデックスキーの順序付きリスト。log_evaluation が True で validation_indexes を指定した場合、検証データ用の Table は作成されません。代わりに、各予測を TableLinkMixin で表される行に関連付けます。行キーのリストを取得するには、Table.get_index() を使用します。
validation_row_processor(Callable) 検証データに適用する関数で、主にデータの可視化に使用します。関数は ndx (int) と row (dict) を受け取ります。モデルが単一の入力を持つ場合、row["input"] にはその行の入力データが含まれます。そうでない場合は、入力スロットの名前が含まれます。fit 関数が単一のターゲットを取る場合、row["target"] にはその行のターゲットデータが含まれます。そうでない場合は、出力スロットの名前が含まれます。たとえば、入力データが単一の配列であり、それを Image として可視化したい場合は、プロセッサとして lambda ndx, row: {"img": wandb.Image(row["input"])} を指定します。log_evaluation が False の場合、または validation_indexes が指定されている場合は無視されます。
output_row_processor(Callable) validation_row_processor と同様ですが、モデルの出力に対して適用されます。row["output"] にはモデルの出力結果が含まれます。
infer_missing_processors(Boolean) 欠落している場合に validation_row_processoroutput_row_processor を自動的に推論するかどうかを決定します。デフォルトは True です。labels を指定すると、該当する場合には W&B が分類用の processor を推論しようとします。
log_evaluation_frequency(int) 評価結果をどれくらいの頻度でログするかを指定します。デフォルトは 0 で、学習の終了時にのみログを記録します。1 に設定すると毎エポック、2 に設定すると 1 エポックおきにログします。log_evaluation が False の場合、この設定は無視されます。

よくある質問

Keras のマルチプロセッシングを wandb と併用するにはどうすればよいですか?

use_multiprocessing=True を設定したときに、次のエラーが発生する場合があります。
Error("You must call wandb.init() before wandb.config.batch_size")
回避策は次のとおりです。
  1. Sequence クラスのコンストラクタでは wandb.init(group='...') を追加します。
  2. main では if __name__ == "__main__": を必ず使用し、その条件分岐の中にスクリプトの残りのロジックを記述します。