[Keras] 콜백함수 (2) - weight 중간 저장: ModelCheckpoint
Jan 19, 2020

keras의 콜백함수인 ModelCheckpoint는 모델이 학습하면서 정의한 조건을 만족했을 때 Model의 weight 값을 중간 저장해 줍니다. 학습시간이 꽤 오래걸린다면, 모델이 개선된 validation score를 도출해낼 때마다 weight를 중간 저장함으로써, 혹시 중간에 memory overflow나 crash가 나더라도 다시 weight를 불러와서 학습을 이어나갈 수 있기 때문에, 시간을 save해 줄 수 있습니다.

사용법은 매우 간단합니다.

from keras.callbacks import ModelCheckpoint

filename = 'checkpoint-epoch-{}-batch-{}-trial-001.h5'.format(EPOCH, BATCH_SIZE)
checkpoint = ModelCheckpoint(filename,             # file명을 지정합니다
                             monitor='val_loss',   # val_loss 값이 개선되었을때 호출됩니다
                             verbose=1,            # 로그를 출력합니다
                             save_best_only=True,  # 가장 best 값만 저장합니다
                             mode='auto'           # auto는 알아서 best를 찾습니다. min/max
                            )

model.fit(x_train, y_train, 
      validation_data=(x_valid, y_valid),
      epochs=EPOCH, 
      batch_size=BATCH_SIZE, 
      callbacks=[checkpoint], # checkpoint 콜백
     )


관련 글 더보기

- Digit Recognizer (Kaggle) - over 99% accuracy

- 딥러닝(LSTM)을 활용하여 삼성전자 주가 예측을 해보았습니다

- [Keras] 콜백함수 (3) - 조기종료: EarlyStopping

- [Keras] 콜백함수 (1) - 학습률(learning rate): ReduceLROnPlateau

- [Keras] 손실함수(Loss Function)와 평가지표(metric) 커스텀하기

데이터 분석, 머신러닝, 딥러닝의 대중화를 꿈 꿉니다.