🔥알림🔥
① 테디노트 유튜브 -
구경하러 가기!
② LangChain 한국어 튜토리얼
바로가기 👀
③ 랭체인 노트 무료 전자책(wikidocs)
바로가기 🙌
④ RAG 비법노트 LangChain 강의오픈
바로가기 🙌
⑤ 서울대 PyTorch 딥러닝 강의
바로가기 🙌
[pytorch] Seq2Seq 으로 챗봇 만들기 (코드 구현)
이번 튜토리얼에서는 seq2seq
모델을 pytorch로 구현하고 한글 챗봇 데이터를 학습시켜 추론해 보는 단계까지 진행해 보도록 하겠습니다.
[참고]
본 튜토리얼에서 활용한 한글 챗봇 데이터는 songys/Chatbot_data 를 활용하였습니다.
예제 코드
모듈 import & 데이터 로드
import os
import numpy as np
import pandas as pd
import random
data_dir = 'data'
df = pd.read_csv(os.path.join(data_dir, 'ChatbotData.csv'))
df
Q | A | label | |
---|---|---|---|
0 | 12시 땡! | 하루가 또 가네요. | 0 |
1 | 1지망 학교 떨어졌어 | 위로해 드립니다. | 0 |
2 | 3박4일 놀러가고 싶다 | 여행은 언제나 좋죠. | 0 |
3 | 3박4일 정도 놀러가고 싶다 | 여행은 언제나 좋죠. | 0 |
4 | PPL 심하네 | 눈살이 찌푸려지죠. | 0 |
... | ... | ... | ... |
11818 | 훔쳐보는 것도 눈치 보임. | 티가 나니까 눈치가 보이는 거죠! | 2 |
11819 | 훔쳐보는 것도 눈치 보임. | 훔쳐보는 거 티나나봐요. | 2 |
11820 | 흑기사 해주는 짝남. | 설렜겠어요. | 2 |
11821 | 힘든 연애 좋은 연애라는게 무슨 차이일까? | 잘 헤어질 수 있는 사이 여부인 거 같아요. | 2 |
11822 | 힘들어서 결혼할까봐 | 도피성 결혼은 하지 않길 바라요. | 2 |
11823 rows × 3 columns
question = df['Q']
answer = df['A']
question[:5]
0 12시 땡! 1 1지망 학교 떨어졌어 2 3박4일 놀러가고 싶다 3 3박4일 정도 놀러가고 싶다 4 PPL 심하네 Name: Q, dtype: object
answer[:5]
0 하루가 또 가네요. 1 위로해 드립니다. 2 여행은 언제나 좋죠. 3 여행은 언제나 좋죠. 4 눈살이 찌푸려지죠. Name: A, dtype: object
1. 데이터 전처리
1-1. 한글 정규화
import re
# 한글, 영어, 숫자, 공백, ?!.,을 제외한 나머지 문자 제거
korean_pattern = r'[^ ?,.!A-Za-z0-9가-힣+]'
# 패턴 컴파일
normalizer = re.compile(korean_pattern)
normalizer
re.compile(r'[^ ?,.!A-Za-z0-9가-힣+]', re.UNICODE)
print(f'수정 전: {question[10]}')
print(f'수정 후: {normalizer.sub("", question[10])}')
수정 전: SNS보면 나만 빼고 다 행복해보여 수정 후: SNS보면 나만 빼고 다 행복해보여
print(f'수정 전: {answer[10]}')
print(f'수정 후: {normalizer.sub("", answer[10])}')
수정 전: 자랑하는 자리니까요. 수정 후: 자랑하는 자리니까요.
def normalize(sentence):
return normalizer.sub("", sentence)
normalize(question[10])
'SNS보면 나만 빼고 다 행복해보여'
1-2. 한글 형태소 분석기
from konlpy.tag import Mecab, Okt
# 형태소 분석기
mecab = Mecab()
okt = Okt()
# mecab
mecab.morphs(normalize(question[10]))
['SNS', '보', '면', '나', '만', '빼', '고', '다', '행복', '해', '보여']
# okt
okt.morphs(normalize(answer[10]))
['자랑', '하는', '자리', '니까', '요', '.']
# 한글 전처리를 함수화
def clean_text(sentence, tagger):
sentence = normalize(sentence)
sentence = tagger.morphs(sentence)
sentence = ' '.join(sentence)
sentence = sentence.lower()
return sentence
# 한글
clean_text(question[10], okt)
'sns 보면 나 만 빼고 다 행복 해보여'
# 영어
clean_text(answer[10], okt)
'자랑 하는 자리 니까 요 .'
len(question), len(answer)
(11823, 11823)
questions = [clean_text(sent, okt) for sent in question.values[:1000]]
answers = [clean_text(sent, okt) for sent in answer.values[:1000]]
questions[:5]
['12시 땡 !', '1 지망 학교 떨어졌어', '3 박 4일 놀러 가고 싶다', '3 박 4일 정도 놀러 가고 싶다', 'ppl 심하네']
answers[:5]
['하루 가 또 가네요 .', '위로 해 드립니다 .', '여행 은 언제나 좋죠 .', '여행 은 언제나 좋죠 .', '눈살 이 찌푸려지죠 .']
1-3. 단어 사전 생성
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
device(type='cuda', index=0)
PAD_TOKEN = 0
SOS_TOKEN = 1
EOS_TOKEN = 2
class WordVocab():
def __init__(self):
self.word2index = {
'<PAD>': PAD_TOKEN,
'<SOS>': SOS_TOKEN,
'<EOS>': EOS_TOKEN,
}
self.word2count = {}
self.index2word = {
PAD_TOKEN: '<PAD>',
SOS_TOKEN: '<SOS>',
EOS_TOKEN: '<EOS>'
}
self.n_words = 3 # PAD, SOS, EOS 포함
def add_sentence(self, sentence):
for word in sentence.split(' '):
self.add_word(word)
def add_word(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
questions[10]
'sns 보면 나 만 빼고 다 행복 해보여'
print(f'원문: {questions[10]}')
lang = WordVocab()
lang.add_sentence(questions[10])
print('==='*10)
print('단어사전')
print(lang.word2index)
원문: sns 보면 나 만 빼고 다 행복 해보여
==============================
단어사전
{'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, 'sns': 3, '보면': 4, '나': 5, '만': 6, '빼고': 7, '다': 8, '행복': 9, '해보여': 10}
1-4. padding to sequences
-
하나의 배치 구성을 위해서는 문장의 길이가 맞아야 합니다.
-
하지만, 문장 별로 길이가 다르기 때문에 길이를 맞춰 주는 작업을 수행해야 합니다.
-
짧은 문장은 남은 공간에 PAD 토큰을 추가하여 길이를 맞춰 주도록 합니다.
max_length = 10
sentence_length = 6
sentence_tokens = np.random.randint(low=3, high=100, size=(sentence_length,))
sentence_tokens = sentence_tokens.tolist()
print(f'Generated Sentence: {sentence_tokens}')
sentence_tokens = sentence_tokens[:(max_length-1)]
token_length = len(sentence_tokens)
# 문장의 맨 끝부분에 <EOS> 토큰 추가
sentence_tokens.append(2)
for i in range(token_length, max_length-1):
# 나머지 빈 곳에 <PAD> 토큰 추가
sentence_tokens.append(0)
print(f'Output: {sentence_tokens}')
print(f'Total Length: {len(sentence_tokens)}')
Generated Sentence: [92, 29, 34, 32, 10, 26] Output: [92, 29, 34, 32, 10, 26, 2, 0, 0, 0] Total Length: 10
1-5. 전처리 프로세스 클래스화
-
torch.utils.data.Dataset
을 상속 받아TextDataset
클래스를 구현합니다. -
데이터를 로드하고, 정규화 및 전처리, 토큰화를 진행합니다.
-
단어 사전을 생성하고 이에 따라, 시퀀스로 변환합니다.
from konlpy.tag import Mecab, Okt
class TextDataset(Dataset):
def __init__(self, csv_path, min_length=3, max_length=32):
super(TextDataset, self).__init__()
data_dir = 'data'
# TOKEN 정의
self.PAD_TOKEN = 0 # Padding 토큰
self.SOS_TOKEN = 1 # SOS 토큰
self.EOS_TOKEN = 2 # EOS 토큰
self.tagger = Mecab() # 형태소 분석기
self.max_length = max_length # 한 문장의 최대 길이 지정
# CSV 데이터 로드
df = pd.read_csv(os.path.join(data_dir, csv_path))
# 한글 정규화
korean_pattern = r'[^ ?,.!A-Za-z0-9가-힣+]'
self.normalizer = re.compile(korean_pattern)
# src: 질의, tgt: 답변
src_clean = []
tgt_clean = []
# 단어 사전 생성
wordvocab = WordVocab()
for _, row in df.iterrows():
src = row['Q']
tgt = row['A']
# 한글 전처리
src = self.clean_text(src)
tgt = self.clean_text(tgt)
if len(src.split()) > min_length and len(tgt.split()) > min_length:
# 최소 길이를 넘어가는 문장의 단어만 추가
wordvocab.add_sentence(src)
wordvocab.add_sentence(tgt)
src_clean.append(src)
tgt_clean.append(tgt)
self.srcs = src_clean
self.tgts = tgt_clean
self.wordvocab = wordvocab
def normalize(self, sentence):
# 정규표현식에 따른 한글 정규화
return self.normalizer.sub("", sentence)
def clean_text(self, sentence):
# 한글 정규화
sentence = self.normalize(sentence)
# 형태소 처리
sentence = self.tagger.morphs(sentence)
sentence = ' '.join(sentence)
sentence = sentence.lower()
return sentence
def texts_to_sequences(self, sentence):
# 문장 -> 시퀀스로 변환
return [self.wordvocab.word2index[w] for w in sentence.split()]
def pad_sequence(self, sentence_tokens):
# 문장의 맨 끝 토큰은 제거
sentence_tokens = sentence_tokens[:(self.max_length-1)]
token_length = len(sentence_tokens)
# 문장의 맨 끝부분에 <EOS> 토큰 추가
sentence_tokens.append(self.EOS_TOKEN)
for i in range(token_length, (self.max_length-1)):
# 나머지 빈 곳에 <PAD> 토큰 추가
sentence_tokens.append(self.PAD_TOKEN)
return sentence_tokens
def __getitem__(self, idx):
inputs = self.srcs[idx]
inputs_sequences = self.texts_to_sequences(inputs)
inputs_padded = self.pad_sequence(inputs_sequences)
outputs = self.tgts[idx]
outputs_sequences = self.texts_to_sequences(outputs)
outputs_padded = self.pad_sequence(outputs_sequences)
return torch.tensor(inputs_padded), torch.tensor(outputs_padded)
def __len__(self):
return len(self.srcs)
# 한 문장의 최대 단어길이를 25로 설정
MAX_LENGTH = 25
dataset = TextDataset('ChatbotData.csv', min_length=3, max_length=MAX_LENGTH)
# 10번째 데이터 임의 추출
x, y = dataset[10]
-
문장의 맨 끝에는 2번 토큰(EOS 토큰)이 위치합니다.
-
EOS 토큰부터 max_length 까지는 PAD 토큰으로 채워집니다. 여기서 0번 토큰이 PAD 토큰 입니다.
-
x, y 데이터 모두
max_length=25
의 크기를 가집니다.
print(f'x shape: {x.shape}')
print(x)
x shape: torch.Size([25]) tensor([83, 84, 51, 85, 86, 18, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
print(f'y shape: {y.shape}')
print(y)
y shape: torch.Size([25]) tensor([87, 88, 58, 89, 63, 90, 11, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
1-6. train / test 데이터셋 분할
# 80%의 데이터를 train에 할당합니다.
train_size = int(len(dataset) * 0.8)
train_size
8168
# 나머지 20% 데이터를 test에 할당합니다.
test_size = len(dataset) - train_size
test_size
2042
from torch.utils.data import random_split
# 랜덤 스플릿으로 분할을 완료합니다.
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
1-7. DataLoader 생성
-
배치 구성을 쉽게 하기 위해서
torch.utils.data.DataLoader
를 활용합니다. -
train/test 데이터셋 모두
batch_size=16
으로 설정하겠습니다.
from torch.utils.data import DataLoader, SubsetRandomSampler
train_loader = DataLoader(train_dataset,
batch_size=16,
shuffle=True)
test_loader = DataLoader(test_dataset,
batch_size=16,
shuffle=True)
# 1개의 배치 데이터를 추출합니다.
x, y = next(iter(train_loader))
# shape: (batch_size, sequence_length)
x.shape, y.shape
(torch.Size([16, 25]), torch.Size([16, 25]))
2. 모델
2-1. Encoder
class Encoder(nn.Module):
def __init__(self, num_vocabs, hidden_size, embedding_dim, num_layers):
super(Encoder, self).__init__()
# 단어 사전의 개수 지정
self.num_vocabs = num_vocabs
# 임베딩 레이어 정의 (number of vocabs, embedding dimension)
self.embedding = nn.Embedding(num_vocabs, embedding_dim)
# GRU (embedding dimension)
self.gru = nn.GRU(embedding_dim,
hidden_size,
num_layers=num_layers,
bidirectional=False)
def forward(self, x):
x = self.embedding(x).permute(1, 0, 2)
output, hidden = self.gru(x)
return output, hidden
2-1-1. Embedding Layer의 입/출력 shape에 대한 이해
embedding_dim = 64 # 임베딩 차원
embedding = nn.Embedding(dataset.wordvocab.n_words, embedding_dim)
# x의 shape을 변경합니다.
# (batch_size, sequence_length) => (sequence_length, batch_size)
embedded = embedding(x)
print(x.shape)
print(embedded.shape)
# input: (sequence_length, batch_size)
# output: (sequence_length, batch_size, embedding_dim)
torch.Size([16, 25]) torch.Size([16, 25, 64])
embedding 레이어를 통과한 출력을 (batch_size, sequence_length, embedding_dim)
=> (sequence_length, batch_size, embedding_dim)
shape 변환을 위하여 permute(1, 0, 2)
를 수행합니다.
여기서 shape를 변환하는 이유는 GRU 레이어의 입력이 (sequence_length, batch_size, embedding_dim)
을 수용하기 때문입니다.
embedded = embedded.permute(1, 0, 2)
print(embedded.shape)
# (sequence_length, batch_size, embedding_dim)
torch.Size([25, 16, 64])
2-1-2. GRU Layer의 입/출력 shape에 대한 이해
hidden_size = 32
gru = nn.GRU(embedding_dim, # embedding 차원
hidden_size,
num_layers=1,
bidirectional=False)
# input : (sequence_length, batch_size, embedding_dim)
# h0 : (Bidirectional(1) x number of layers(1), batch_size, hidden_size)
o, h = gru(embedded, None)
print(o.shape)
print(h.shape)
# output : (sequence_length, batch_size, hidden_size x bidirectional(1))
# hidden_state: (bidirectional(1) x number of layers(1), batch_size, hidden_size)
torch.Size([25, 16, 32]) torch.Size([1, 16, 32])
2-1-3. Encoder의 입/출력 shape에 대한 이해
NUM_VOCABS = dataset.wordvocab.n_words
print(f'number of vocabs: {NUM_VOCABS}')
number of vocabs: 6417
# Encoder 정의
encoder = Encoder(NUM_VOCABS,
hidden_size=32,
embedding_dim=64,
num_layers=1)
# Encoder에 x 통과 후 output, hidden_size 의 shape 확인
# input(x) : (batch_size, sequence_length)
o, h = encoder(x)
print(o.shape)
print(h.shape)
# output : (sequence_length, batch_size, hidden_size x bidirectional(1))
# hidden_state: (bidirectional(1) x number of layers(1), batch_size, hidden_size)
torch.Size([25, 16, 32]) torch.Size([1, 16, 32])
2-2. Decoder
class Decoder(nn.Module):
def __init__(self, num_vocabs, hidden_size, embedding_dim, num_layers=1, dropout=0.2):
super(Decoder, self).__init__()
# 단어사전 개수
self.num_vocabs = num_vocabs
self.embedding = nn.Embedding(num_vocabs, embedding_dim)
self.dropout = nn.Dropout(dropout)
self.gru = nn.GRU(embedding_dim,
hidden_size,
num_layers=num_layers,
bidirectional=False)
# 최종 출력은 단어사전의 개수
self.fc = nn.Linear(hidden_size, num_vocabs)
def forward(self, x, hidden_state):
x = x.unsqueeze(0) # (1, batch_size) 로 변환
embedded = F.relu(self.embedding(x))
embedded = self.dropout(embedded)
output, hidden = self.gru(embedded, hidden_state)
output = self.fc(output.squeeze(0)) # (sequence_length, batch_size, hidden_size(32) x bidirectional(1))
return output, hidden
2-2-1. Embedding Layer의 입/출력 shape에 대한 이해
x = torch.abs(torch.randn(size=(1, 16)).long())
print(x)
x.shape
# batch_size = 16 이라 가정했을 때,
# (1, batch_size)
# 여기서 batch_size => (1, batch_size) 로 shape 변환을 선행
tensor([[0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 2, 1, 0]])
torch.Size([1, 16])
embedding_dim = 64 # 임베딩 차원
embedding = nn.Embedding(dataset.wordvocab.n_words, embedding_dim)
embedded = embedding(x)
embedded.shape
# embedding 출력
# (1, batch_size, embedding_dim)
torch.Size([1, 16, 64])
2-2-2. GRU Layer의 입/출력 shape에 대한 이해
hidden_size = 32
gru = nn.GRU(embedding_dim,
hidden_size,
num_layers=1,
bidirectional=False,
batch_first=False, # batch_first=False로 지정
)
o, h = gru(embedded)
print(o.shape)
# output shape: (sequence_length, batch_size, hidden_size(32) x bidirectional(1))
print(h.shape)
# hidden_state shape: (Bidirectional(1) x number of layers(1), batch_size, hidden_size(32))
torch.Size([1, 16, 32]) torch.Size([1, 16, 32])
2-2-3. 최종 출력층(FC) shape에 대한 이해
fc = nn.Linear(32, NUM_VOCABS) # 출력은 단어사전의 개수로 가정
output = fc(o[0])
print(o[0].shape)
print(output.shape)
# input : (batch_size, output from GRU)
# output: (batch_size, output dimension)
torch.Size([16, 32]) torch.Size([16, 6417])
2-3. 인코더 -> 디코더 입출력 shape
decoder = Decoder(num_vocabs=dataset.wordvocab.n_words,
hidden_size=32,
embedding_dim=64,
num_layers=1)
디코더에 입력될 인코더의 output
, hidden_state
의 shape을 확인합니다.
-
여기서
hidden_state
만 디코더의 입력 으로 활용합니다. -
x
는 SOS 토큰이 첫 번째 입력으로 들어갑니다.
x, y = next(iter(train_loader))
o, h = encoder(x)
print(o.shape, h.shape)
# output : (batch_size, sequence_length, hidden_size(32) x bidirectional(1))
# hidden_state: (Bidirectional(1) x number of layers(1), batch_size, hidden_size(32))
인코더(Encoder)로부터 생성된 hidden_state(h)와 SOS 토큰을 디코더(Decoder)의 입력으로 넣어줍니다
x = torch.abs(torch.full(size=(16,), fill_value=SOS_TOKEN).long())
print(x)
x.shape
# batch_size = 16 이라 가정(16개의 SOS 토큰)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
torch.Size([16])
decoder_output, decoder_hidden = decoder(x, h)
decoder_output.shape, decoder_hidden.shape
# (batch_size, num_vocabs), (1, batch_size, hidden_size)
(torch.Size([16, 6417]), torch.Size([1, 16, 32]))
-
decoder_output
은(batch_size, num_vocabs)
shape로 출력 -
decoder_hidden
의 shape는 입력으로 넣어준 shape와 동일함을 확인
2-4. Seq2Seq
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, device):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device
def forward(self, inputs, outputs, teacher_forcing_ratio=0.5):
# inputs : (batch_size, sequence_length)
# outputs: (batch_size, sequence_length)
batch_size, output_length = outputs.shape
output_num_vocabs = self.decoder.num_vocabs
# 리턴할 예측된 outputs를 저장할 임시 변수
# (sequence_length, batch_size, num_vocabs)
predicted_outputs = torch.zeros(output_length, batch_size, output_num_vocabs).to(self.device)
# 인코더에 입력 데이터 주입, encoder_output은 버리고 hidden_state 만 살립니다.
# 여기서 hidden_state가 디코더에 주입할 context vector 입니다.
# (Bidirectional(1) x number of layers(1), batch_size, hidden_size)
_, decoder_hidden = self.encoder(inputs)
# (batch_size) shape의 SOS TOKEN으로 채워진 디코더 입력 생성
decoder_input = torch.full((batch_size,), SOS_TOKEN, device=self.device)
# 순회하면서 출력 단어를 생성합니다.
# 0번째는 SOS TOKEN이 위치하므로, 1번째 인덱스부터 순회합니다.
for t in range(0, output_length):
# decoder_input : 디코더 입력 (batch_size) 형태의 SOS TOKEN로 채워진 입력
# decoder_output: (batch_size, num_vocabs)
# decoder_hidden: (Bidirectional(1) x number of layers(1), batch_size, hidden_size), context vector와 동일 shape
decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
# t번째 단어에 디코더의 output 저장
predicted_outputs[t] = decoder_output
# teacher forcing 적용 여부 확률로 결정
# teacher forcing 이란: 정답치를 다음 RNN Cell의 입력으로 넣어주는 경우. 수렴속도가 빠를 수 있으나, 불안정할 수 있음
teacher_force = random.random() < teacher_forcing_ratio
# top1 단어 토큰 예측
top1 = decoder_output.argmax(1)
# teacher forcing 인 경우 ground truth 값을, 그렇지 않은 경우, 예측 값을 다음 input으로 지정
decoder_input = outputs[:, t] if teacher_force else top1
return predicted_outputs.permute(1, 0, 2) # (batch_size, sequence_length, num_vocabs)로 변경
2-4-1. Seq2Seq 입출력 확인
# Encoder 정의
encoder = Encoder(num_vocabs=dataset.wordvocab.n_words,
hidden_size=32,
embedding_dim=64,
num_layers=1)
# Decoder 정의
decoder = Decoder(num_vocabs=dataset.wordvocab.n_words,
hidden_size=32,
embedding_dim=64,
num_layers=1)
# Seq2Seq 정의
seq2seq = Seq2Seq(encoder, decoder, 'cpu')
x, y = next(iter(train_loader))
print(x.shape, y.shape)
# (batch_size, sequence_length), (batch_size, sequence_length)
torch.Size([16, 25]) torch.Size([16, 25])
output = seq2seq(x, y)
print(output.shape)
# (batch_size, sequence_length, num_vocabs)
torch.Size([16, 25, 6417])
3. Training
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
NUM_VOCABS = dataset.wordvocab.n_words
HIDDEN_SIZE = 512
EMBEDDIMG_DIM = 256
print(f'num_vocabs: {NUM_VOCABS}\n======================')
# Encoder 정의
encoder = Encoder(num_vocabs=NUM_VOCABS,
hidden_size=HIDDEN_SIZE,
embedding_dim=EMBEDDIMG_DIM,
num_layers=1)
# Decoder 정의
decoder = Decoder(num_vocabs=NUM_VOCABS,
hidden_size=HIDDEN_SIZE,
embedding_dim=EMBEDDIMG_DIM,
num_layers=1)
# Seq2Seq 생성
# encoder, decoder를 device 모두 지정
model = Seq2Seq(encoder.to(device), decoder.to(device), device)
print(model)
num_vocabs: 6417 ====================== Seq2Seq( (encoder): Encoder( (embedding): Embedding(6417, 256) (gru): GRU(256, 512) ) (decoder): Decoder( (embedding): Embedding(6417, 256) (dropout): Dropout(p=0.2, inplace=False) (gru): GRU(256, 512) (fc): Linear(in_features=512, out_features=6417, bias=True) ) )
3-1. Hyperparamter 정의
class EarlyStopping:
def __init__(self, patience=3, delta=0.0, mode='min', verbose=True):
"""
patience (int): loss or score가 개선된 후 기다리는 기간. default: 3
delta (float): 개선시 인정되는 최소 변화 수치. default: 0.0
mode (str): 개선시 최소/최대값 기준 선정('min' or 'max'). default: 'min'.
verbose (bool): 메시지 출력. default: True
"""
self.early_stop = False
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = np.Inf if mode == 'min' else 0
self.mode = mode
self.delta = delta
def __call__(self, score):
if self.best_score is None:
self.best_score = score
self.counter = 0
elif self.mode == 'min':
if score < (self.best_score - self.delta):
self.counter = 0
self.best_score = score
if self.verbose:
print(f'[EarlyStopping] (Update) Best Score: {self.best_score:.5f}')
else:
self.counter += 1
if self.verbose:
print(f'[EarlyStopping] (Patience) {self.counter}/{self.patience}, ' \
f'Best: {self.best_score:.5f}' \
f', Current: {score:.5f}, Delta: {np.abs(self.best_score - score):.5f}')
elif self.mode == 'max':
if score > (self.best_score + self.delta):
self.counter = 0
self.best_score = score
if self.verbose:
print(f'[EarlyStopping] (Update) Best Score: {self.best_score:.5f}')
else:
self.counter += 1
if self.verbose:
print(f'[EarlyStopping] (Patience) {self.counter}/{self.patience}, ' \
f'Best: {self.best_score:.5f}' \
f', Current: {score:.5f}, Delta: {np.abs(self.best_score - score):.5f}')
if self.counter >= self.patience:
if self.verbose:
print(f'[EarlyStop Triggered] Best Score: {self.best_score:.5f}')
# Early Stop
self.early_stop = True
else:
# Continue
self.early_stop = False
훈련에 적용할 하이퍼파라미터 설정
LR = 1e-3
optimizer = optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss()
es = EarlyStopping(patience=5,
delta=0.001,
mode='min',
verbose=True
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
mode='min',
factor=0.5,
patience=2,
threshold_mode='abs',
min_lr=1e-8,
verbose=True)
3-2. train 함수 정의
def train(model, data_loader, optimizer, loss_fn, device):
model.train()
running_loss = 0
for x, y in data_loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
# output: (batch_size, sequence_length, num_vocabs)
output = model(x, y)
output_dim = output.size(2)
# 1번 index 부터 슬라이싱한 이유는 0번 index가 SOS TOKEN 이기 때문
# (batch_size*sequence_length, num_vocabs) 로 변경
output = output.reshape(-1, output_dim)
# (batch_size*sequence_length) 로 변경
y = y.view(-1)
# Loss 계산
loss = loss_fn(output, y)
loss.backward()
optimizer.step()
running_loss += loss.item() * x.size(0)
return running_loss / len(data_loader)
3-3. evaluation 함수 정의
def evaluate(model, data_loader, loss_fn, device):
model.eval()
eval_loss = 0
with torch.no_grad():
for x, y in data_loader:
x, y = x.to(device), y.to(device)
output = model(x, y)
output_dim = output.size(2)
output = output.reshape(-1, output_dim)
y = y.view(-1)
# Loss 계산
loss = loss_fn(output, y)
eval_loss += loss.item() * x.size(0)
return eval_loss / len(data_loader)
3-4. 랜덤 샘플링 후 결과 추론
def sequence_to_sentence(sequences, index2word):
outputs = []
for p in sequences:
word = index2word[p]
if p not in [SOS_TOKEN, EOS_TOKEN, PAD_TOKEN]:
outputs.append(word)
if word == EOS_TOKEN:
break
return ' '.join(outputs)
sequence를 다시 문장으로 바꾸어 문장 형식으로 출력하기 위한 함수
def random_evaluation(model, dataset, index2word, device, n=10):
n_samples = len(dataset)
indices = list(range(n_samples))
np.random.shuffle(indices) # Shuffle
sampled_indices = indices[:n] # Sampling N indices
# 샘플링한 데이터를 기반으로 DataLoader 생성
sampler = SubsetRandomSampler(sampled_indices)
sampled_dataloader = DataLoader(dataset, batch_size=10, sampler=sampler)
model.eval()
with torch.no_grad():
for x, y in sampled_dataloader:
x, y = x.to(device), y.to(device)
output = model(x, y, teacher_forcing_ratio=0)
# output: (number of samples, sequence_length, num_vocabs)
preds = output.detach().cpu().numpy()
x = x.detach().cpu().numpy()
y = y.detach().cpu().numpy()
for i in range(n):
print(f'질문 : {sequence_to_sentence(x[i], index2word)}')
print(f'답변 : {sequence_to_sentence(y[i], index2word)}')
print(f'예측답변: {sequence_to_sentence(preds[i].argmax(1), index2word)}')
print('==='*10)
3-5. 훈련 시작
NUM_EPOCHS = 20
STATEDICT_PATH = 'models/seq2seq-chatbot-kor.pt'
best_loss = np.inf
for epoch in range(NUM_EPOCHS):
loss = train(model, train_loader, optimizer, loss_fn, device)
val_loss = evaluate(model, test_loader, loss_fn, device)
if val_loss < best_loss:
best_loss = val_loss
torch.save(model.state_dict(), STATEDICT_PATH)
if epoch % 5 == 0:
print(f'epoch: {epoch+1}, loss: {loss:.4f}, val_loss: {val_loss:.4f}')
# Early Stop
es(loss)
if es.early_stop:
break
# Scheduler
scheduler.step(val_loss)
model.load_state_dict(torch.load(STATEDICT_PATH))
torch.save(model.state_dict(), f'models/seq2seq-chatbot-kor-{best_loss:.4f}.pt')
epoch: 1, loss: 32.2713, val_loss: 28.9118 [EarlyStopping] (Update) Best Score: 32.27133 [EarlyStopping] (Update) Best Score: 27.90689 [EarlyStopping] (Update) Best Score: 25.78647 [EarlyStopping] (Update) Best Score: 23.60776 [EarlyStopping] (Update) Best Score: 20.68632 epoch: 6, loss: 17.8025, val_loss: 27.6363 [EarlyStopping] (Update) Best Score: 17.80247 [EarlyStopping] (Update) Best Score: 14.78749 [EarlyStopping] (Update) Best Score: 12.07858 Epoch 8: reducing learning rate of group 0 to 5.0000e-04. [EarlyStopping] (Update) Best Score: 8.72596 [EarlyStopping] (Update) Best Score: 6.87622 epoch: 11, loss: 5.6122, val_loss: 31.1174 [EarlyStopping] (Update) Best Score: 5.61224 Epoch 11: reducing learning rate of group 0 to 2.5000e-04. [EarlyStopping] (Update) Best Score: 4.26539 [EarlyStopping] (Update) Best Score: 3.53884 [EarlyStopping] (Update) Best Score: 3.03033 Epoch 14: reducing learning rate of group 0 to 1.2500e-04. [EarlyStopping] (Update) Best Score: 2.41656 epoch: 16, loss: 2.1939, val_loss: 33.0993 [EarlyStopping] (Update) Best Score: 2.19391 [EarlyStopping] (Update) Best Score: 2.03166 Epoch 17: reducing learning rate of group 0 to 6.2500e-05. [EarlyStopping] (Update) Best Score: 1.75492 [EarlyStopping] (Update) Best Score: 1.62426 [EarlyStopping] (Update) Best Score: 1.50042 Epoch 20: reducing learning rate of group 0 to 3.1250e-05.
4. 결과
model.load_state_dict(torch.load(STATEDICT_PATH))
random_evaluation(model, test_dataset, dataset.wordvocab.index2word, device)
질문 : 대 기업 아니 어도 될까 ? 답변 : 어디 에서 일 하 든 상관 없 어요 . 예측답변: 먼저 고백 해 보 세요 . ============================== 질문 : 술기운 에 연락 했 는데 . 답변 : 후회 하 지 않 길 바라 요 . 예측답변: 이렇게 고민 한 일 이 었 나 봐요 . ============================== 질문 : 별 이 안 보여 답변 : 한적 한 시골 에서 하늘 을 올려 봐 보 세요 . 예측답변: 눈 을 해 보 세요 . ============================== 질문 : 좋 아 하 는 이상형 이 계속 바뀌 어 . 답변 : 성격 도 계속 바뀌 니 걱정 말 아요 . 예측답변: 이상형 을 사랑 하 는 게 좋 겠 어요 . ============================== 질문 : 이 기회 잡 고 싶 다 . 답변 : 행운 을 빌 게 요 ! 예측답변: 저 도 그럴 거 예요 . ============================== 질문 : 건강 이 최고 인 것 같 아 답변 : 가장 중요 한 목표 네요 . 예측답변: 눈 이 는 게 이 죠 . ============================== 질문 : 사무실 에 나왔 지만 손 에 안 잡히 네 답변 : 복잡 한 심경 인가 봐요 . 예측답변: as 에 에 가 고 나 봐요 . ============================== 질문 : 명치 쪽 이 답답 해 답변 : 많이 답답 할 거 라 생각 해요 . 예측답변: 기분 전환 해 보 세요 . ============================== 질문 : 6 년 그리고 남 은 것 들 답변 : 좋 은 기억 들 만 남 았 길 바랄게요 . 예측답변: 시간 이 시간 이 필요 하 겠 지만 잘 이겨낼 수 도 있 어요 . ============================== 질문 : 운동 좀 해야겠다 . 답변 : 건강 생각 해서 꾸준히 하 세요 . 예측답변: 네 말씀 해 보 세요 . ==============================
댓글남기기