🔥알림🔥
① 테디노트 유튜브 - 구경하러 가기!
② LangChain 한국어 튜토리얼 바로가기 👀
③ 랭체인 노트 무료 전자책(wikidocs) 바로가기 🙌

2 분 소요

WandBweights and biases 의 약어입니다. 머신러닝을 하시는 분들은 weights & biases 와 굉장히 친숙할텐데요. WandB의 네이밍에서 알 수 있듯이 모델이 학습할 때 실험 결과를 저장 및 시각화, 하이퍼파라미터를 저장, 모델 뿐만아니라 시스템(GPU, CPU, 메모리 등)을 모니터링하고 트래킹 해주는 도구 입니다. 대표적인 모델의 학습 결과 & 트래킹 도구인 TensorBoard와 비슷한 역할을 하는데, 클라우드 공간에 하이퍼파라미터 및 모델의 학습(실험) 결과를 저장하고 싱크해준다는 점 그리고 사용성이 매우 편이하다는 점이 TensorBoard 대비 WandB가 가지는 더 큰 매력이라고 생각합니다.

WandB home

wandb home

설치

WandB는 PYPI 패키지로 쉽게 설치할 수 있습니다.

설치

pip install wandb

사용법

로그인

WandB 설치 후 다음의 명령어를 터미널에서 실행하여 WandB 계정에 로그인 합니다.

따라서, 사전에 WandB 웹사이트에서 회원가입이 되어 있어야 합니다.

터미널에서 다음의 명령어 입력

wandb login

init()

초기 설정

import wandb

wandb.init(p)
# init
wandb.init(project='프로젝트명', 
           group='그룹명',
           name='모델의 이름',
           notes='메모(batch_size, seed, epoch 등)',
           tags='태그',
           entity='teddynote')

project 에 기입한 프로젝트명이 대시보드 상에 표기되며, 보통 1개의 가장 큰 범주의 프로젝트 명을 설정합니다. 아래의 예시에서는 Dacon 대회에서 모델 학습시 지정한 프로젝트명인 “Dacon Car Crash”가 표기되는 것을 확인할 수 있습니다.

wandb project

group 에 기입한 내용 별로 대시보드 상에서 그룹화하여 시각화할 수 있습니다. 아래의 예시에서는 총 4개의 그룹이 생성되었음을 확인할 수 있습니다. 저는 보통 Label 별로 혹은 주요 Feature의 범주 별로 그룹화하여 학습결과를 시각화할 때 사용하고 있습니다.

wandb group

name에 표기한 내용은 group 안에 name 별로 구분되어 보여집니다.

wandb name

위에 표기된 모델을 클릭하면 상세내용을 확인할 수 있습니다. 상세 내용에 표기되는 내용은 아래와 같습니다.

wandb summary

특히, tags에는 여러 개의 tag를 달아 나중에 tag 별로 모델을 필터할 수 있습니다. 예를 들면 cnn, rnn, attention 등과 같은 태그를 달아놓고 나중에 cnn tag가 달린 모델만 필터하여 성능 차이를 확인할 수 있습니다.

마지막으로 entity는 대시보드에서 설정한 entity를 그대로 지정하면 됩니다. 보통 팀이름이나 개인으로 이용하는 분들은 id를 입력하면 됩니다.

wandb.config

config 에는 hyper parameter 를 지정하면 추후 학습이 완료된 후 hyper parameter 가 metric에 표기됩니다. Hyper parameter 튜닝시 세부 옵션별로 성능을 비교하는 것도 가능합니다.

# config
wandb.config = {
  "learning_rate": 5e5,
  "epochs": 20,
  "batch_size": 16, 
  "seed": 123
}

등록한 config 값은 hyper-parameter 추적이 가능합니다. 나중에 sweep 이란 기능을 활용하여 hyper parameter optimization을 수행하는데 매번 trial 마다 어떠한 hyper-parameter가 셋팅되어있었는지 추적하기에 용이합니다.

wandb config

아래 그림은 모델별 / hyper parameter 변화에 따른 성능 측정 결과표입니다.

wandb hyperparameters

wandb.define_metric

대시보드에 모델의 평가(evaluation) 결과를 추적하기 위해서는 평가지표를 정의합니다.

다음은 몇 개의 평가지표를 정의하고, 추적하기 위한 코드 예시입니다.

예시

# loss 추적
wandb.define_metric('train_loss', summary='min')
wandb.define_metric('val_loss', summary='min')
# f1 score 추적
wandb.define_metric('train_f1', summary='max')
wandb.define_metric('val_f1', summary='max')

대시보드에서 추적 예시

wandb metric

wandb.watch

모델의 학습/검증, 평가 결과 추적을 위해서는 모델이 학습을 시작하기 전에 wandb.watch(model) 코드를 실행하여 추적을 시작하는 것을 명시적으로 알립니다.

예시

# watch model
wandb.watch(model)

# model training
...

wandb.log

매 epoch 마다 학습이 완료되면 loss, val_loss 그 밖에 평가 지표를 wandb.log() 를 호출하여 등록합니다.

예시

wandb.log({'fold': idx, 'loss': t_loss, 'val_loss': v_loss, 'train_f1': t_score, 'val_f1': v_score}, step=epoch)

여기서 step에 epoch를 지정하면 매 epoch과 함께 지표를 저장하게 됩니다.

wandb.alert

코드가 돌아가다가 exception이 발생하거나 중요한 이벤트가 발생하여 User 에게 알림을 주어야 하는 경우에는 wandb.alert()를 호출하여 알릴 수 있습니다.

예시

# exception 알림
wandb.alert('Exception occured at', f"Epoch: {epoch+1}")

# traning 이 완료되었음을 알림
wandb.alert('Training Task Finished', f"VAL_LOSS: {val_loss:.5f}, F1 SCORE: {val_f1:.5f}")

설정에서 Slack 에 연결하여 wandb.alert() 호출시 Slack으로 알림을 받을 수 있습니다.

wandb slack

댓글남기기