メインコンテンツへスキップ
GitHub ソース

function confusion_matrix

confusion_matrix(
    probs: 'Sequence[Sequence[float]] | None' = None,
    y_true: 'Sequence[T] | None' = None,
    preds: 'Sequence[T] | None' = None,
    class_names: 'Sequence[str] | None' = None,
    title: 'str' = 'Confusion Matrix Curve',
    split_table: 'bool' = False
) → CustomChart
確率または予測のシーケンスから混同行列を構築します。 Args:
  • probs: 各クラスに対する予測確率のシーケンス。シーケンスの形状は (N, K) で、N はサンプル数、K はクラス数です。probs を指定した場合は、preds は指定しないでください。
  • y_true: 正解ラベルのシーケンス。
  • preds: 予測クラスラベルのシーケンス。preds を指定した場合は、probs は指定しないでください。
  • class_names: クラス名のシーケンス。指定しない場合、クラス名は “Class_1”、“Class_2” などとして定義されます。
  • title: 混同行列チャートのタイトル。
  • split_table: テーブルを W&B UI 内の別セクションに分割して表示するかどうか。True の場合、テーブルは “Custom Chart Tables” という名前のセクションに表示されます。デフォルトは False です。
Returns:
  • CustomChart: W&B にログできるカスタムチャートオブジェクト。チャートをログするには、wandb.log() に渡します。
Raises:
  • ValueError: probspreds の両方が指定されている場合、または予測と正解ラベルの数が等しくない場合に発生します。さらに、一意な予測クラスの数がクラス名の数を上回る場合、または一意な正解ラベルの数がクラス名の数を上回る場合にも発生します。
  • wandb.Error: NumPy がインストールされていない場合に発生します。
Examples: 野生動物分類のためにランダムな確率値を用いて混同行列をログする例:
import numpy as np
import wandb

# 野生動物のクラス名を定義する
wildlife_class_names = ["Lion", "Tiger", "Elephant", "Zebra"]

# ランダムな正解ラベルを生成する(10サンプルに対して0〜3)
wildlife_y_true = np.random.randint(0, 4, size=10)

# 各クラスのランダムな確率を生成する(10サンプル × 4クラス)
wildlife_probs = np.random.rand(10, 4)
wildlife_probs = np.exp(wildlife_probs) / np.sum(
    np.exp(wildlife_probs),
    axis=1,
    keepdims=True,
)

# W&B の run を初期化して混同行列をログする
with wandb.init(project="wildlife_classification") as run:
    confusion_matrix = wandb.plot.confusion_matrix(
         probs=wildlife_probs,
         y_true=wildlife_y_true,
         class_names=wildlife_class_names,
         title="Wildlife Classification Confusion Matrix",
    )
    run.log({"wildlife_confusion_matrix": confusion_matrix})
この例では、ランダムな確率を使用して混同行列を生成します。 シミュレートしたモデル予測と 85% の正解率を用いて混同行列をログする例です:
import numpy as np
import wandb

# 野生動物のクラス名を定義する
wildlife_class_names = ["Lion", "Tiger", "Elephant", "Zebra"]

# 200枚の動物画像の正解ラベルをシミュレートする(不均衡分布)
wildlife_y_true = np.random.choice(
    [0, 1, 2, 3],
    size=200,
    p=[0.2, 0.3, 0.25, 0.25],
)

# 精度85%でモデルの予測をシミュレートする
wildlife_preds = [
    y_t
    if np.random.rand() < 0.85
    else np.random.choice([x for x in range(4) if x != y_t])
    for y_t in wildlife_y_true
]

# W&B の run を初期化して混同行列をログに記録する
with wandb.init(project="wildlife_classification") as run:
    confusion_matrix = wandb.plot.confusion_matrix(
         preds=wildlife_preds,
         y_true=wildlife_y_true,
         class_names=wildlife_class_names,
         title="Simulated Wildlife Classification Confusion Matrix",
    )
    run.log({"wildlife_confusion_matrix": confusion_matrix})
この例では、予測が精度 85% になるようにシミュレートし、混同行列を生成します。