메인 콘텐츠로 건너뛰기
PyTorch Lightning은 PyTorch 코드를 구성하고 분산 트레이닝, 16비트 정밀도와 같은 고급 기능을 쉽게 추가할 수 있도록 하는 가벼운 래퍼를 제공합니다. W&B는 ML 실험을 기록하기 위한 가벼운 래퍼를 제공합니다. 하지만 두 가지를 직접 결합할 필요는 없습니다. W&B는 WandbLogger를 통해 PyTorch Lightning 라이브러리에 이미 통합되어 있습니다.

Lightning과 연동하기

from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer

wandb_logger = WandbLogger(log_model="all")
trainer = Trainer(logger=wandb_logger)
wandb.log() 사용 시: WandbLogger는 Trainer의 global_step을 사용해 W&B에 로그를 기록합니다. 코드에서 wandb.log를 직접 추가로 호출하는 경우, wandb.log()에서 step 인자를 사용하지 마세요.대신, 다른 메트릭과 마찬가지로 Trainer의 global_step 값을 로그로 기록하세요:
wandb.log({"accuracy":0.99, "trainer/global_step": step})
대화형 대시보드

가입 및 API 키 생성하기

API 키는 사용 중인 머신을 W&B에 인증하는 데 사용됩니다. 사용자 프로필에서 API 키를 생성할 수 있습니다.
더 간편한 방법을 원한다면 User Settings에서 직접 API 키를 생성하세요. 새로 생성된 API 키를 즉시 복사하여 비밀번호 관리자와 같은 안전한 위치에 저장하세요.
  1. 오른쪽 상단에서 사용자 프로필 아이콘을 클릭합니다.
  2. User Settings를 선택한 다음 API Keys 섹션까지 스크롤합니다.

wandb 라이브러리 설치 및 로그인

로컬에 wandb 라이브러리를 설치하고 로그인하려면 다음을 수행하세요.
  1. WANDB_API_KEY 환경 변수를 API 키 값으로 설정합니다.
    export WANDB_API_KEY=<your_api_key>
    
  2. wandb 라이브러리를 설치하고 로그인합니다.
    pip install wandb
    
    wandb login
    

PyTorch Lightning의 WandbLogger 사용

PyTorch Lightning은 메트릭과 모델 가중치, 미디어 등을 로깅하기 위한 여러 종류의 WandbLogger 클래스를 제공합니다. Lightning과 통합하려면 WandbLogger를 인스턴스화하고 이를 Lightning의 Trainer 또는 Fabric에 전달하세요.
trainer = Trainer(logger=wandb_logger)

공통 로거 인자

아래는 WandbLogger에서 가장 많이 사용되는 파라미터입니다. 모든 로거 인자에 대한 자세한 내용은 PyTorch Lightning 문서를 참고하세요.
ParameterDescription
project로그를 남길 wandb Project를 정의합니다
namewandb run에 사용할 이름을 지정합니다
log_modellog_model="all"이면 모든 모델을 로그하고, log_model=True이면 트레이닝 종료 시 모델을 로그합니다
save_dir데이터가 저장될 경로를 지정합니다

하이퍼파라미터 로깅

class LitModule(LightningModule):
    def __init__(self, *args, **kwarg):
        self.save_hyperparameters()

추가 설정 파라미터 기록

# 파라미터 하나 추가
wandb_logger.experiment.config["key"] = value

# 여러 파라미터 추가
wandb_logger.experiment.config.update({key1: val1, key2: val2})

# wandb 모듈 직접 사용
wandb.config["key"] = value
wandb.config.update()

그래디언트, 파라미터 히스토그램 및 모델 토폴로지 로깅

학습하는 동안 모델의 그래디언트와 파라미터를 모니터링하려면 wandblogger.watch()에 모델 객체를 전달하세요. 자세한 내용은 PyTorch Lightning WandbLogger 문서를 참조하세요.

메트릭 로깅

WandbLogger를 사용할 때는 LightningModuletraining_step이나 validation_step 메서드 안에서 self.log('my_metric_name', metric_vale)를 호출하여 메트릭을 W&B에 로깅할 수 있습니다.아래 코드 스니펫은 메트릭과 LightningModule 하이퍼파라미터를 로깅하도록 LightningModule을 정의하는 방법을 보여줍니다. 이 예제는 메트릭을 계산하기 위해 torchmetrics 라이브러리를 사용합니다.
import torch
from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy
from lightning.pytorch import LightningModule


class My_LitModule(LightningModule):
    def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):
        """method used to define the model parameters"""
        super().__init__()

        # MNIST 이미지는 (1, 28, 28) (채널, 너비, 높이)입니다
        self.layer_1 = Linear(28 * 28, n_layer_1)
        self.layer_2 = Linear(n_layer_1, n_layer_2)
        self.layer_3 = Linear(n_layer_2, n_classes)

        self.loss = CrossEntropyLoss()
        self.lr = lr

        # 하이퍼파라미터를 self.hparams에 저장합니다 (W&B가 자동으로 로깅)
        self.save_hyperparameters()

    def forward(self, x):
        """method used for inference input -> output"""

        # (b, 1, 28, 28) -> (b, 1*28*28)
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)

        # (linear + relu)를 3번 수행합니다
        x = F.relu(self.layer_1(x))
        x = F.relu(self.layer_2(x))
        x = self.layer_3(x)
        return x

    def training_step(self, batch, batch_idx):
        """needs to return a loss from a single batch"""
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # 손실과 메트릭 로깅
        self.log("train_loss", loss)
        self.log("train_accuracy", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        """used for logging metrics"""
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # 손실과 메트릭 로깅
        self.log("val_loss", loss)
        self.log("val_accuracy", acc)
        return preds

    def configure_optimizers(self):
        """defines model optimizer"""
        return Adam(self.parameters(), lr=self.lr)

    def _get_preds_loss_accuracy(self, batch):
        """convenience function since train/valid/test steps are similar"""
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = accuracy(preds, y)
        return preds, loss, acc

메트릭의 최소/최대값 로그하기

wandb의 define_metric 함수를 사용하면 W&B summary 메트릭에 대해 해당 메트릭의 최소값, 최대값, 평균값 또는 베스트값 중 무엇을 표시할지 정의할 수 있습니다. define_metric _ 를 사용하지 않으면 마지막으로 로깅된 값이 summary 메트릭에 표시됩니다. 자세한 내용은 define_metric 레퍼런스 문서가이드를 참고하세요. W&B summary 메트릭에서 최대 검증 정확도를 추적하려면, 트레이닝 시작 시 한 번만 wandb.define_metric 을 호출하세요:
class My_LitModule(LightningModule):
    ...

    def validation_step(self, batch, batch_idx):
        if trainer.global_step == 0:
            wandb.define_metric("val_accuracy", summary="max")

        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log("val_loss", loss)
        self.log("val_accuracy", acc)
        return preds

모델 체크포인트 저장하기

모델 체크포인트를 W&B Artifacts로 저장하려면 Lightning ModelCheckpoint 콜백을 사용하고, WandbLoggerlog_model 매개변수를 설정하세요.
trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback])
latestbest 별칭은 W&B Artifact에서 모델 체크포인트를 쉽게 가져올 수 있도록 자동으로 설정됩니다.
# 참조(reference)는 Artifacts 패널에서 가져올 수 있습니다
# "VERSION"은 버전(예: "v2") 또는 별칭("latest" 또는 "best")일 수 있습니다
checkpoint_reference = "USER/PROJECT/MODEL-RUN_ID:VERSION"
# 체크포인트를 로컬로 다운로드합니다 (아직 캐시되지 않은 경우)
wandb_logger.download_artifact(checkpoint_reference, artifact_type="model")
# 체크포인트를 로드합니다
model = LitModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")
기록한 모델 체크포인트는 W&B Artifacts UI에서 확인할 수 있으며, 전체 모델 계보(lineage) 정보를 포함합니다(UI에서 모델 체크포인트 예시는 여기에서 볼 수 있습니다). 팀 전체에서 최고의 모델 체크포인트를 북마크하고 중앙에서 관리하려면, 이를 W&B Model Registry에 연결할 수 있습니다. 여기에서 작업별로 최고의 모델을 구성하고, 모델 라이프사이클을 관리하며, ML 라이프사이클 전반에 걸쳐 손쉬운 추적과 감사를 수행하고, 웹훅이나 잡(jobs)으로 후속 작업을 자동화할 수 있습니다.

이미지, 텍스트 등 다양한 미디어 로깅하기

WandbLogger에는 미디어를 로깅하기 위한 log_image, log_text, log_table 메서드가 있습니다. 또한 wandb.log 또는 trainer.logger.experiment.log를 직접 호출해 오디오, 분자(Molecules), 포인트 클라우드(Point Clouds), 3D 오브젝트(3D Objects) 등의 다른 미디어 타입도 로깅할 수 있습니다.
# tensors, numpy 배열 또는 PIL 이미지를 사용
wandb_logger.log_image(key="samples", images=[img1, img2])

# 캡션 추가
wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"])

# 파일 경로 사용
wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])

# trainer에서 .log 사용
trainer.logger.experiment.log(
    {"samples": [wandb.Image(img, caption=caption) for (img, caption) in my_images]},
    step=current_trainer_global_step,
)
Lightning의 Callbacks 시스템을 사용해 WandbLogger를 통해 W&B에 언제 로깅할지 제어할 수 있으며, 이 예시에서는 검증 이미지와 예측 결과 샘플을 로깅합니다:
import torch
import wandb
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger

# or
# from wandb.integration.lightning.fabric import WandbLogger


class LogPredictionSamplesCallback(Callback):
    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
    ):
        """검증 배치가 끝날 때 호출됩니다."""

        # `outputs`는 `LightningModule.validation_step`에서 옵니다.
        # 이 경우 모델 예측값에 해당합니다.

        # 첫 번째 배치에서 샘플 이미지 예측값 20개를 로깅합니다.
        if batch_idx == 0:
            n = 20
            x, y = batch
            images = [img for img in x[:n]]
            captions = [
                f"Ground Truth: {y_i} - Prediction: {y_pred}"
                for y_i, y_pred in zip(y[:n], outputs[:n])
            ]

            # 옵션 1: `WandbLogger.log_image`로 이미지 로깅
            wandb_logger.log_image(key="sample_images", images=images, caption=captions)

            # 옵션 2: 이미지와 예측값을 W&B Table로 로깅
            columns = ["image", "ground truth", "prediction"]
            data = [
                [wandb.Image(x_i), y_i, y_pred] or x_i,
                y_i,
                y_pred in list(zip(x[:n], y[:n], outputs[:n])),
            ]
            wandb_logger.log_table(key="sample_table", columns=columns, data=data)


trainer = pl.Trainer(callbacks=[LogPredictionSamplesCallback()])

Lightning과 W&B로 여러 개의 GPU 사용하기

PyTorch Lightning은 DDP 인터페이스를 통해 멀티 GPU를 지원합니다. 하지만 PyTorch Lightning의 설계 때문에 GPU를 어떻게 생성/초기화하는지에 주의해야 합니다. Lightning은 트레이닝 루프에서 각 GPU(또는 rank)가 정확히 동일한 방식, 즉 동일한 초기 조건으로 인스턴스화된다고 가정합니다. 그러나 rank 0 프로세스만 wandb.run 객체에 접근할 수 있고, 0이 아닌 rank 프로세스에서는 wandb.run = None입니다. 이로 인해 0이 아닌 rank 프로세스가 실패할 수 있습니다. 이런 상황에서는 rank 0 프로세스가 이미 크래시된 0이 아닌 rank 프로세스가 조인하기를 기다리게 되므로, **교착 상태(데드락)**에 빠질 수 있습니다. 이러한 이유로 트레이닝 코드를 구성하는 방식에 주의해야 합니다. 권장되는 방식은 코드가 wandb.run 객체에 의존하지 않도록 작성하는 것입니다.
class MNISTClassifier(pl.LightningModule):
    def __init__(self):
        super(MNISTClassifier, self).__init__()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)

        self.log("train/loss", loss)
        return {"train_loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)

        self.log("val/loss", loss)
        return {"val_loss": loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


def main():
    # 모든 랜덤 시드를 동일한 값으로 설정합니다.
    # 이는 분산 트레이닝 환경에서 중요합니다.
    # 각 rank는 고유한 초기 가중치 세트를 갖게 됩니다.
    # 이 값들이 일치하지 않으면 그래디언트도 일치하지 않아,
    # 트레이닝이 수렴하지 않을 수 있습니다.
    pl.seed_everything(1)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

    model = MNISTClassifier()
    wandb_logger = WandbLogger(project="<project_name>")
    callbacks = [
        ModelCheckpoint(
            dirpath="checkpoints",
            every_n_train_steps=100,
        ),
    ]
    trainer = pl.Trainer(
        max_epochs=3, gpus=2, logger=wandb_logger, strategy="ddp", callbacks=callbacks
    )
    trainer.fit(model, train_loader, val_loader)

예시

Colab 노트북이 포함된 동영상 튜토리얼을 보면서 함께 따라할 수 있습니다.

자주 묻는 질문

W&B는 Lightning과 어떻게 인테그레이션되나요?

핵심 인테그레이션은 Lightning loggers API를 기반으로 하며, 이를 통해 로그를 남기는 코드를 특정 프레임워크에 의존하지 않는 방식으로 대부분 작성할 수 있습니다. LoggerLightning Trainer에 전달되며, 해당 API의 풍부한 hook-and-callback 시스템에 따라 호출됩니다. 이를 통해 연구 코드를 엔지니어링 및 로깅 코드와 깔끔하게 분리할 수 있습니다.

추가 코드를 작성하지 않아도 이 인테그레이션은 무엇을 기록하나요?

W&B에 모델 체크포인트를 저장하므로, 이후에 이를 확인하거나 다운로드해서 향후 run에서 사용할 수 있습니다. 또한 GPU 사용량과 네트워크 I/O 같은 시스템 메트릭, 하드웨어 및 OS 정보와 같은 환경 정보, git 커밋과 diff 패치, 노트북 내용 및 세션 기록을 포함한 코드 상태, 그리고 표준 출력으로 나가는 모든 내용을 수집합니다.

트레이닝 설정에서 wandb.run을 사용해야 하면 어떻게 하나요?

직접 접근해야 하는 변수의 스코프를 더 넓게 잡아야 합니다. 다시 말해, 모든 프로세스에서 초기 조건이 동일하도록 설정해야 합니다.
if os.environ.get("LOCAL_RANK", None) is None:
    os.environ["WANDB_DIR"] = wandb.run.dir
그 경우에는 os.environ["WANDB_DIR"]를 사용해 모델 체크포인트 디렉터리를 설정할 수 있습니다. 이렇게 하면 rank가 0이 아닌 프로세스도 wandb.run.dir에 접근할 수 있습니다.