🔥알림🔥
① 테디노트 유튜브 -
구경하러 가기!
② LangChain 한국어 튜토리얼
바로가기 👀
③ 랭체인 노트 무료 전자책(wikidocs)
바로가기 🙌
④ RAG 비법노트 LangChain 강의오픈
바로가기 🙌
⑤ 서울대 PyTorch 딥러닝 강의
바로가기 🙌
WandB 를 활용하여 모델의 학습을 추적하는 방법
WandB는 weights and biases 의 약어입니다. 머신러닝을 하시는 분들은 weights
& biases
와 굉장히 친숙할텐데요. WandB의 네이밍에서 알 수 있듯이 모델이 학습할 때 실험 결과를 저장 및 시각화, 하이퍼파라미터를 저장, 모델 뿐만아니라 시스템(GPU, CPU, 메모리 등)을 모니터링하고 트래킹 해주는 도구 입니다. 대표적인 모델의 학습 결과 & 트래킹 도구인 TensorBoard와 비슷한 역할을 하는데, 클라우드 공간에 하이퍼파라미터 및 모델의 학습(실험) 결과를 저장하고 싱크해준다는 점 그리고 사용성이 매우 편이하다는 점이 TensorBoard 대비 WandB가 가지는 더 큰 매력이라고 생각합니다.
WandB home
설치
WandB는 PYPI 패키지로 쉽게 설치할 수 있습니다.
설치
pip install wandb
사용법
로그인
WandB 설치 후 다음의 명령어를 터미널에서 실행하여 WandB 계정에 로그인 합니다.
따라서, 사전에 WandB 웹사이트에서 회원가입이 되어 있어야 합니다.
터미널에서 다음의 명령어 입력
wandb login
init()
초기 설정
import wandb
wandb.init(p)
# init
wandb.init(project='프로젝트명',
group='그룹명',
name='모델의 이름',
notes='메모(batch_size, seed, epoch 등)',
tags='태그',
entity='teddynote')
project
에 기입한 프로젝트명이 대시보드 상에 표기되며, 보통 1개의 가장 큰 범주의 프로젝트 명을 설정합니다. 아래의 예시에서는 Dacon 대회에서 모델 학습시 지정한 프로젝트명인 “Dacon Car Crash”가 표기되는 것을 확인할 수 있습니다.
group
에 기입한 내용 별로 대시보드 상에서 그룹화하여 시각화할 수 있습니다. 아래의 예시에서는 총 4개의 그룹이 생성되었음을 확인할 수 있습니다. 저는 보통 Label 별로 혹은 주요 Feature의 범주 별로 그룹화하여 학습결과를 시각화할 때 사용하고 있습니다.
name
에 표기한 내용은 group
안에 name
별로 구분되어 보여집니다.
위에 표기된 모델을 클릭하면 상세내용을 확인할 수 있습니다. 상세 내용에 표기되는 내용은 아래와 같습니다.
특히, tags
에는 여러 개의 tag를 달아 나중에 tag 별로 모델을 필터할 수 있습니다. 예를 들면 cnn
, rnn
, attention
등과 같은 태그를 달아놓고 나중에 cnn
tag가 달린 모델만 필터하여 성능 차이를 확인할 수 있습니다.
마지막으로 entity
는 대시보드에서 설정한 entity를 그대로 지정하면 됩니다. 보통 팀이름이나 개인으로 이용하는 분들은 id를 입력하면 됩니다.
wandb.config
config 에는 hyper parameter 를 지정하면 추후 학습이 완료된 후 hyper parameter 가 metric에 표기됩니다. Hyper parameter 튜닝시 세부 옵션별로 성능을 비교하는 것도 가능합니다.
# config
wandb.config = {
"learning_rate": 5e5,
"epochs": 20,
"batch_size": 16,
"seed": 123
}
등록한 config 값은 hyper-parameter 추적이 가능합니다. 나중에 sweep
이란 기능을 활용하여 hyper parameter optimization을 수행하는데 매번 trial 마다 어떠한 hyper-parameter가 셋팅되어있었는지 추적하기에 용이합니다.
아래 그림은 모델별 / hyper parameter 변화에 따른 성능 측정 결과표입니다.
wandb.define_metric
대시보드에 모델의 평가(evaluation) 결과를 추적하기 위해서는 평가지표를 정의합니다.
다음은 몇 개의 평가지표를 정의하고, 추적하기 위한 코드 예시입니다.
예시
# loss 추적
wandb.define_metric('train_loss', summary='min')
wandb.define_metric('val_loss', summary='min')
# f1 score 추적
wandb.define_metric('train_f1', summary='max')
wandb.define_metric('val_f1', summary='max')
대시보드에서 추적 예시
wandb.watch
모델의 학습/검증, 평가 결과 추적을 위해서는 모델이 학습을 시작하기 전에 wandb.watch(model)
코드를 실행하여 추적을 시작하는 것을 명시적으로 알립니다.
예시
# watch model
wandb.watch(model)
# model training
...
wandb.log
매 epoch 마다 학습이 완료되면 loss
, val_loss
그 밖에 평가 지표를 wandb.log()
를 호출하여 등록합니다.
예시
wandb.log({'fold': idx, 'loss': t_loss, 'val_loss': v_loss, 'train_f1': t_score, 'val_f1': v_score}, step=epoch)
여기서 step에 epoch를 지정하면 매 epoch과 함께 지표를 저장하게 됩니다.
wandb.alert
코드가 돌아가다가 exception이 발생하거나 중요한 이벤트가 발생하여 User 에게 알림을 주어야 하는 경우에는 wandb.alert()
를 호출하여 알릴 수 있습니다.
예시
# exception 알림
wandb.alert('Exception occured at', f"Epoch: {epoch+1}")
# traning 이 완료되었음을 알림
wandb.alert('Training Task Finished', f"VAL_LOSS: {val_loss:.5f}, F1 SCORE: {val_f1:.5f}")
설정에서 Slack 에 연결하여 wandb.alert()
호출시 Slack으로 알림을 받을 수 있습니다.
댓글남기기