機械学習ワークフローにおいて、モデルアーティファクトをいつ・どのように追跡、共有、管理するかを説明します。このページでは、学習中の実験のロギング、Reports の生成、および各タスクに適した W&B API を用いたログデータへのアクセス方法を扱います。
このチュートリアルでは、次のものを使用します。
マシンをW&Bで認証するには、まず wandb.ai/settings でAPIキーを生成する必要があります。APIキーをコピーし、安全な場所に保管してください。
このチュートリアルで必要になる W&B ライブラリと、その他の必要なパッケージをインストールします。
W&B の Python SDK をインポートします:
次のコードブロックで、チームのエンティティ名を指定します。
TEAM_ENTITY = "<Team_Entity>" # チームエンティティを置き換えてください
PROJECT = "my-awesome-project"
次のコードは、基本的な機械学習ワークフローをシミュレートします。モデルの学習、メトリクスのログへの記録、モデルをアーティファクトとして保存する処理を行います。
学習中に W&B とやり取りするには、W&B Python SDK(wandb.sdk)を使用します。まず wandb.Run.log() を使って損失値(loss)をログに記録し、その後 wandb.Artifact を使って学習済みモデルをアーティファクトとして保存し、最後に Artifact.add_file を使ってモデルファイルを追加します。
import random # データのシミュレーション用
def model(training_data: int) -> int:
"""デモ用のモデルシミュレーション。"""
return training_data * 2 + random.randint(-1, 1)
# 重みとノイズのシミュレーション
weights = random.random() # ランダムな重みを初期化
noise = random.random() / 5 # ノイズをシミュレートするための小さなランダムノイズ
# ハイパーパラメータと設定
config = {
"epochs": 10, # 学習エポック数
"learning_rate": 0.01, # オプティマイザの学習率
}
# コンテキストマネージャを使用してW&B runを初期化・終了する
with wandb.init(project=PROJECT, entity=TEAM_ENTITY, config=config) as run:
# 学習ループのシミュレーション
for epoch in range(config["epochs"]):
xb = weights + noise # シミュレートされた入力学習データ
yb = weights + noise * 2 # シミュレートされた目標出力(入力ノイズの2倍)
y_pred = model(xb) # モデルの予測値
loss = (yb - y_pred) ** 2 # 平均二乗誤差(MSE)損失
print(f"epoch={epoch}, loss={loss}")
# エポックと損失をW&Bに記録
run.log({
"epoch": epoch,
"loss": loss,
})
# モデルアーティファクトの一意の名前
model_artifact_name = f"model-demo"
# シミュレートされたモデルファイルを保存するローカルパス
PATH = "model.txt"
# モデルをローカルに保存
with open(PATH, "w") as f:
f.write(str(weights)) # モデルの重みをファイルに保存
# アーティファクトオブジェクトを作成
# ローカルに保存したモデルをアーティファクトオブジェクトに追加
artifact = wandb.Artifact(name=model_artifact_name, type="model", description="学習済みモデル")
artifact.add_file(local_path=PATH)
artifact.save()
前のコードブロックの主なポイントは次のとおりです:
- 学習中のメトリクスをログするには、
wandb.Run.log() を使用します。
- モデル(データセットなど)をアーティファクトとして W&B のプロジェクトに保存するには、
wandb.Artifact を使用します。
モデルを学習してアーティファクトとして保存したので、それを W&B のレジストリに公開できます。wandb.Run.use_artifact() を使用して、プロジェクトからアーティファクトを取得し、Model Registry で公開する準備をします。wandb.Run.use_artifact() には 2 つの主な役割があります:
https://wandb.ai/login で自分のアカウントにログインします。
Projects の下に、my-awesome-project(または上でプロジェクト名として使用した名前)が表示されているはずです。これをクリックして、そのプロジェクトのワークスペースに入ります。
ここから、これまでに行った各 run の詳細を確認できます。このスクリーンショットでは、コードを複数回再実行しており、そのたびに run が生成されています。各 run にはランダムに生成された名前が付けられています。
組織内の他のユーザーとモデルを共有するには、wandb.Run.link_artifact() を使用して collection に登録します。次のコードは、そのアーティファクトを registry にリンクし、チームで利用できるようにします。
# アーティファクト名はチームのプロジェクト内の特定のアーティファクトバージョンを指定します
artifact_name = f'{TEAM_ENTITY}/{PROJECT}/{model_artifact_name}:v0'
print("Artifact name: ", artifact_name)
REGISTRY_NAME = "Model" # W&B のレジストリ名
COLLECTION_NAME = "DemoModels" # レジストリ内のコレクション名
# レジストリ内のアーティファクトのターゲットパスを作成します
target_path = f"wandb-registry-{REGISTRY_NAME}/{COLLECTION_NAME}"
print("Target path: ", target_path)
with wandb.init(entity=TEAM_ENTITY, project=PROJECT) as run:
model_artifact = run.use_artifact(artifact_or_name=artifact_name, type="model")
run.link_artifact(artifact=model_artifact, target_path=target_path)
wandb.Run.link_artifact() を実行すると、そのモデルのアーティファクトはレジストリ内の DemoModels コレクションに保存されます。そこから、バージョン履歴、lineage map、その他のメタデータ などの詳細を確認できます。
アーティファクトをレジストリにリンクする方法の詳細については、Link artifacts to a registry を参照してください。
推論用にレジストリからモデルアーティファクトを取得する
推論でモデルを使用するには、wandb.Run.use_artifact() を使ってレジストリから公開済みアーティファクトを取得します。これによりアーティファクトオブジェクトが返されるので、wandb.Artifact.download() を使ってそのアーティファクトをローカルファイルとしてダウンロードできます。
REGISTRY_NAME = "Model" # W&B のレジストリ名
COLLECTION_NAME = "DemoModels" # レジストリ内のコレクション名
VERSION = 0 # 取得するアーティファクトのバージョン
model_artifact_name = f"wandb-registry-{REGISTRY_NAME}/{COLLECTION_NAME}:v{VERSION}"
print(f"モデルアーティファクト名: {model_artifact_name}")
with wandb.init(entity=TEAM_ENTITY, project=PROJECT) as run:
registry_model = run.use_artifact(artifact_or_name=model_artifact_name)
local_model_path = registry_model.download()
レジストリからアーティファクトを取得する方法の詳細については、レジストリからアーティファクトをダウンロードするを参照してください。
使用している機械学習フレームワークによっては、重みをロードする前にモデルアーキテクチャを再構築する必要がある場合があります。これは使用する特定のフレームワークやモデルに依存するため、本書では詳細は扱わず、読者の演習として残しています。
W&B Report と Workspace API は現在パブリックプレビュー中です。
作業内容を要約するために、レポート を作成して共有します。レポートをプログラムから作成するには、W&B Report と Workspace API を使用します。
まず、W&B Reports API をインストールします。
pip install wandb wandb-workspaces -qqq
次のコードブロックは、Markdown、パネルグリッドなどを含む複数のブロックからなるレポートを作成します。ブロックを追加したり、既存のブロックの内容を変更したりして、レポートをカスタマイズできます。
このコードブロックの出力では、作成されたレポートの URL へのリンクが表示されます。ブラウザでこのリンクを開いて、レポートを表示できます。
import wandb_workspaces.reports.v2 as wr
experiment_summary = """This is a summary of the experiment conducted to train a simple model using W&B."""
dataset_info = """The dataset used for training consists of synthetic data generated by a simple model."""
model_info = """The model is a simple linear regression model that predicts output based on input data with some noise."""
report = wr.Report(
project=PROJECT,
entity=TEAM_ENTITY,
title="My Awesome Model Training Report",
description=experiment_summary,
blocks= [
wr.TableOfContents(),
wr.H2("Experiment Summary"),
wr.MarkdownBlock(text=experiment_summary),
wr.H2("Dataset Information"),
wr.MarkdownBlock(text=dataset_info),
wr.H2("Model Information"),
wr.MarkdownBlock(text = model_info),
wr.PanelGrid(
panels=[
wr.LinePlot(title="Train Loss", x="Step", y=["loss"], title_x="Step", title_y="Loss")
],
),
]
)
# レポートをW&Bに保存する
report.save()
レポートをプログラムから作成する方法や、W&B アプリを使ってインタラクティブにレポートを作成する方法の詳細については、W&B Docs の Developer ガイドにある Create a report を参照してください。
W&B Public APIs を使用して、W&B の履歴データにクエリを実行し、分析および管理を行います。これは、アーティファクトの系譜や依存関係を追跡したり、異なるバージョンを比較したり、時間経過に伴うモデルのパフォーマンスを分析したりする場合に役立ちます。
次のコードブロックは、特定のコレクション内のすべてのアーティファクトについてモデルレジストリをクエリする方法を示します。コレクションを取得し、その各バージョンをループして、各アーティファクトの名前とバージョンを出力します。
import wandb
# wandb APIを初期化する
api = wandb.Api()
# 文字列 `model` を含み、タグ `text-classification` または
# エイリアス `latest` を持つすべてのアーティファクトバージョンを検索する
registry_filters = {
"name": {"$regex": "model"}
}
# 論理演算子 $or を使用してアーティファクトバージョンをフィルタリングする
version_filters = {
"$or": [
{"tag": "text-classification"},
{"alias": "latest"}
]
}
# フィルターに一致するすべてのアーティファクトバージョンのイテラブルを返す
artifacts = api.registries(filter=registry_filters).collections().versions(filter=version_filters)
# 見つかった各アーティファクトの名前、コレクション、エイリアス、タグ、作成日時を出力する
for art in artifacts:
print(f"artifact name: {art.name}")
print(f"collection artifact belongs to: { art.collection.name}")
print(f"artifact aliases: {art.aliases}")
print(f"tags attached to artifact: {art.tags}")
print(f"artifact created at: {art.created_at}\n")
レジストリに対するクエリの詳細については、Query registry items を参照してください。