🔥알림🔥
① 테디노트 유튜브 - 구경하러 가기!
② LangChain 한국어 튜토리얼 바로가기 👀
③ 랭체인 노트 무료 전자책(wikidocs) 바로가기 🙌
④ RAG 비법노트 LangChain 강의오픈 바로가기 🙌
⑤ 서울대 PyTorch 딥러닝 강의 바로가기 🙌

6 분 소요

이번 포스팅에서는 Facebook Prophet을 활용하여 시계열데이터 예측 튜토리얼을 진행해 보겠습니다.

Facebook Prophet을 활용한 시계열 데이터 예측은 Kaggle에서도 종종 노트북을 찾아보실 수 있을 정도로 캐글에서는 이미 꽤 많이 알려져 있습니다.

쉬운 사용성과 시각화, 그리고 인터랙티브 시각화인 plotly도 지원합니다.

부스팅 계열의 알고리즘, 딥러닝 모델과 함께 앙상블하여 예측한다면 꽤 좋은 성능을 기대해볼 수 있을 것 같습니다.

코드

Colab으로 열기 Colab으로 열기

GitHub GitHub에서 소스보기



Facebook Prophet을 활용한 주가 예측 모델

이번 튜토리얼 에서는 다음과 같은 프로세스 파이프라인으로 주가 예측을 진행합니다.

  • FinanceDataReader를 활용하여 주가 데이터 받아오기
  • Facebook Prophet을 활용하여 주가 예측

필요한 모듈 import

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import os

%matplotlib inline
warnings.filterwarnings('ignore')

데이터 (FinanceDataReader)

FinanceDataReader는 주가 데이터를 편리하게 가져올 수 있는 파이썬 패키지입니다.

FinanceDataReader가 아직 설치 되지 않으신 분들은 아래의 주석을 해제한 후 명령어로 설치해 주시기 바랍니다.

# !pip install finance-datareader
import FinanceDataReader as fdr
# 삼성전자 종목코드: 005930
samsung = fdr.DataReader('005930')

매우 편리하게 삼성전자 주가 데이터를 DataFrame형식으로 받아옵니다.

기본 오름차순 정렬이 된 데이터임을 알 수 있습니다.

컬럼 설명

  • Open: 시가
  • High: 고가
  • Low: 저가
  • Close: 종가
  • Volume: 거래량
  • Change: 대비
samsung.tail()
Open High Low Close Volume Change
Date
2020-12-23 72400 74000 72300 73900 19411326 0.022130
2020-12-24 74100 78800 74000 77800 32502870 0.052774
2020-12-28 79000 80100 78200 78700 40085044 0.011568
2020-12-29 78800 78900 77300 78300 30339449 -0.005083
2020-12-30 77400 81300 77300 81000 29122199 0.034483

미국 주식 데이터도 가져올 수 있습니다.

# Apple(AAPL), 애플
apple = fdr.DataReader('AAPL')
apple.tail()
Close Open High Low Volume Change
Date
2020-12-24 131.97 131.32 133.46 131.10 54930000.0 0.0077
2020-12-28 136.69 133.99 137.34 133.52 124490000.0 0.0358
2020-12-29 134.87 138.05 138.78 134.36 121050000.0 -0.0133
2020-12-30 133.72 135.71 135.96 133.40 96450000.0 -0.0085
2020-12-31 132.69 134.09 134.70 131.77 99120000.0 -0.0077

비트코인 시세

btc = fdr.DataReader('BTC/KRW', '2016-01-01')
btc
Close Open High Low Volume Change
Date
2017-05-23 3206000 3104000 3281000 3081000 21580.0 0.0329
2017-05-24 4175000 3206000 4314000 3206000 34680.0 0.3022
2017-05-25 4199000 4175000 4840000 3102000 35910.0 0.0057
2017-05-26 3227000 4199000 4200000 2900000 36650.0 -0.2315
2017-05-27 3152000 3227000 3288000 2460000 33750.0 -0.0232
... ... ... ... ... ... ...
2020-12-28 30199000 29356000 30199000 29356000 1760.0 0.0288
2020-12-29 30450000 30199000 30450000 30199000 2110.0 0.0083
2020-12-30 31886000 30445000 31886000 30445000 2500.0 0.0472
2020-12-31 32026000 31886000 32026000 31886000 1410.0 0.0044
2021-01-01 32247000 32023000 32247000 32023000 3140.0 0.0069

1318 rows × 6 columns

plt.figure(figsize=(16, 9))
sns.lineplot(x=btc.index, y='Close', data=btc)
plt.show()

시작 날짜를 지정하여 범위 데이터를 가져올 수 있습니다.

# 비트코인 시세
btc = fdr.DataReader('BTC/KRW', '2019-01-01', '2020-12-01')
btc
Close Open High Low Volume Change
Date
2019-01-01 4289000 4199000 4300000 4137000 3230.0 0.0214
2019-01-02 4345000 4294000 4360000 4244000 3860.0 0.0131
2019-01-03 4282000 4352000 4367000 4259000 15370.0 -0.0145
2019-01-04 4309000 4286000 4334000 4243000 19200.0 0.0063
2019-01-05 4297000 4309000 4354000 4278000 24870.0 -0.0028
... ... ... ... ... ... ...
2020-11-27 19119000 19236000 19236000 19119000 1440.0 -0.0067
2020-11-28 19480000 19118000 19480000 19118000 1610.0 0.0189
2020-11-29 20002000 19480000 20002000 19281000 830.0 0.0268
2020-11-30 21302000 19996000 21302000 19996000 2520.0 0.0650
2020-12-01 20848000 21302000 21302000 20848000 1830.0 -0.0213

699 rows × 6 columns

plt.figure(figsize=(16, 9))
sns.lineplot(x=btc.index, y='Close', data=btc)
plt.show()

그 밖에 금, 은과 같은 현물, 달러와 같은 화폐 데이터도 가져올 수 있습니다.

더욱 자세한 내용은 GitHub 페이지 링크를 참고해 보시기 바랍니다.

주가데이터 가져오기

# 삼성전자 주식코드: 005930
STOCK_CODE = '005930'
stock = fdr.DataReader(STOCK_CODE)
stock.head()
Open High Low Close Volume Change
Date
1997-01-23 786 798 770 776 74200 NaN
1997-01-24 745 793 745 783 98260 0.009021
1997-01-25 787 812 787 795 47620 0.015326
1997-01-27 793 793 764 765 41010 -0.037736
1997-01-28 764 825 757 809 181900 0.057516
stock.index
DatetimeIndex(['1997-01-23', '1997-01-24', '1997-01-25', '1997-01-27',
               '1997-01-28', '1997-01-29', '1997-01-30', '1997-01-31',
               '1997-02-01', '1997-02-03',
               ...
               '2020-12-16', '2020-12-17', '2020-12-18', '2020-12-21',
               '2020-12-22', '2020-12-23', '2020-12-24', '2020-12-28',
               '2020-12-29', '2020-12-30'],
              dtype='datetime64[ns]', name='Date', length=6000, freq=None)
stock.head()
Open High Low Close Volume Change
Date
1997-01-23 786 798 770 776 74200 NaN
1997-01-24 745 793 745 783 98260 0.009021
1997-01-25 787 812 787 795 47620 0.015326
1997-01-27 793 793 764 765 41010 -0.037736
1997-01-28 764 825 757 809 181900 0.057516

시각화

plt.figure(figsize=(16, 9))
sns.lineplot(y=stock['Close'], x=stock.index)
plt.xlabel('time')
plt.ylabel('price')
Text(0, 0.5, 'price')
time_steps = [['1990', '2000'], 
              ['2000', '2010'], 
              ['2010', '2015'], 
              ['2015', '2020']]

fig, axes = plt.subplots(2, 2)
fig.set_size_inches(16, 9)
for i in range(4):
    ax = axes[i//2, i%2]
    df = stock.loc[(stock.index > time_steps[i][0]) & (stock.index < time_steps[i][1])]
    sns.lineplot(y=df['Close'], x=df.index, ax=ax)
    ax.set_title(f'{time_steps[i][0]}~{time_steps[i][1]}')
    ax.set_xlabel('time')
    ax.set_ylabel('price')
plt.tight_layout()
plt.show()
stock = fdr.DataReader(STOCK_CODE, '2019')

Prophet

모듈 import

from fbprophet import Prophet
from fbprophet.plot import plot_plotly, plot_components_plotly

컬럼

  • 반드시 y 컬럼과 ds 컬럼이 존재해야합니다.
  • 예측 값은 y, 시계열 데이터는 ds에 지정합니다.
stock['y'] = stock['Close']
stock['ds'] = stock.index
stock.head()
Open High Low Close Volume Change y ds
Date
2019-01-02 39400 39400 38550 38750 7847664 0.001292 38750 2019-01-02
2019-01-03 38300 38550 37450 37600 12471493 -0.029677 37600 2019-01-03
2019-01-04 37450 37600 36850 37450 14108958 -0.003989 37450 2019-01-04
2019-01-07 38000 38900 37800 38750 12748997 0.034713 38750 2019-01-07
2019-01-08 38000 39200 37950 38100 12756554 -0.016774 38100 2019-01-08

prophet 객체 선언 및 학습

m = Prophet()
m.fit(stock)
INFO:fbprophet:Disabling yearly seasonality. Run prophet with yearly_seasonality=True to override this.
INFO:fbprophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
<fbprophet.forecaster.Prophet at 0x7f9d0e758630>

period에 예측 하고 싶은 기간을 입력 합니다.

예측할 시계열 row가 추가 됩니다.

future = m.make_future_dataframe(periods=30)
future.tail()
ds
519 2021-01-25
520 2021-01-26
521 2021-01-27
522 2021-01-28
523 2021-01-29
  • predict로 예측을 진행합니다.
  • predict 안에는 이전 단계에서 만들어준 future 데이터프레임을 입력합니다.
forecast = m.predict(future)
forecast.tail()
ds trend yhat_lower yhat_upper trend_lower trend_upper additive_terms additive_terms_lower additive_terms_upper weekly weekly_lower weekly_upper multiplicative_terms multiplicative_terms_lower multiplicative_terms_upper yhat
519 2021-01-25 73285.088050 70702.567694 77023.796560 72940.877469 73616.041398 469.630675 469.630675 469.630675 469.630675 469.630675 469.630675 0.0 0.0 0.0 73754.718725
520 2021-01-26 73395.948372 70873.498076 77187.365071 73022.093425 73748.649883 726.870827 726.870827 726.870827 726.870827 726.870827 726.870827 0.0 0.0 0.0 74122.819199
521 2021-01-27 73506.808693 71092.107683 77501.089933 73104.528236 73886.709624 792.284285 792.284285 792.284285 792.284285 792.284285 792.284285 0.0 0.0 0.0 74299.092978
522 2021-01-28 73617.669015 71171.065890 77267.618146 73187.564846 74030.914574 649.781483 649.781483 649.781483 649.781483 649.781483 649.781483 0.0 0.0 0.0 74267.450498
523 2021-01-29 73728.529336 71078.903010 77757.016683 73279.870362 74174.946452 557.043551 557.043551 557.043551 557.043551 557.043551 557.043551 0.0 0.0 0.0 74285.572887
forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].iloc[-40:-20]
ds yhat yhat_lower yhat_upper
484 2020-12-16 69642.959477 66396.388852 72835.727750
485 2020-12-17 69611.316997 66617.065846 72552.880007
486 2020-12-18 69629.439387 66399.419405 72497.558388
487 2020-12-21 69874.607474 66601.737433 73178.316089
488 2020-12-22 70242.707948 66976.536342 73272.156857
489 2020-12-23 70418.981727 67486.273376 73490.031885
490 2020-12-24 70387.339247 67253.097407 73446.621386
491 2020-12-28 70650.629725 67246.943834 73781.407399
492 2020-12-29 71018.730198 67901.892703 74238.234211
493 2020-12-30 71195.003978 67828.564333 74347.140582
494 2020-12-31 71163.361497 67954.424397 74181.009132
495 2021-01-01 71181.483887 67749.560871 74398.696899
496 2021-01-02 69137.495246 66134.418892 72411.005515
497 2021-01-03 69248.355568 66100.302212 72451.292939
498 2021-01-04 71426.651975 67967.931199 74681.654384
499 2021-01-05 71794.752448 68560.213468 74960.044652
500 2021-01-06 71971.026228 68841.422101 75149.782270
501 2021-01-07 71939.383748 68858.951212 75378.688742
502 2021-01-08 71957.506137 68635.150888 75290.507898
503 2021-01-09 69913.517496 66570.436828 73067.466928

시각화

plot은 트렌드와 함께 예측된 결과물을 시각화하여 보여줍니다.

fig = m.plot(forecast)

plotly 활용

fig = plot_plotly(m, forecast)
fig

컴포넌트 별 시각화

컴포넌트 별 시각화에서는 seasonality 별 시각화를 진행해 볼 수 있습니다.

trend, yearly, weekly 데이터를 시각화하여 보여 줍니다.

plot_components_plotly(m, forecast)

change points

By default, Prophet specifies 25 potential changepoints which are uniformly placed in the first 80% of the time series. The vertical lines in this figure indicate where the potential changepoints were placed:

처음 80%의 시계열 데이터에 대하여 잠재적인 25개의 changepoints를 만들고, 그 중 선별하여 최종 changepoints를 그래프에서 vertical line으로 그려주게 됩니다.

from fbprophet.plot import add_changepoints_to_plot
fig = m.plot(forecast)
a = add_changepoints_to_plot(fig.gca(), m, forecast)

flexibility 조절

If the trend changes are being overfit (too much flexibility) or underfit (not enough flexibility), you can adjust the strength of the sparse prior using the input argument changepoint_prior_scale. By default, this parameter is set to 0.05. Increasing it will make the trend more flexible:

flexibility 계수가 낮으면 과소적합, 높으면 과대적합하여 예측하게 됩니다.

중요한 hyperparameter 입니다.

m = Prophet(changepoint_prior_scale=0.8)
forecast = m.fit(stock).predict(future)
fig = m.plot(forecast)
INFO:fbprophet:Disabling yearly seasonality. Run prophet with yearly_seasonality=True to override this.
INFO:fbprophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
m = Prophet(changepoint_prior_scale=0.01)
forecast = m.fit(stock).predict(future)
fig = m.plot(forecast)
INFO:fbprophet:Disabling yearly seasonality. Run prophet with yearly_seasonality=True to override this.
INFO:fbprophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.

댓글남기기