🔥알림🔥
① 테디노트 유튜브 -
구경하러 가기!
② LangChain 한국어 튜토리얼
바로가기 👀
③ 랭체인 노트 무료 전자책(wikidocs)
바로가기 🙌
④ RAG 비법노트 LangChain 강의오픈
바로가기 🙌
⑤ 서울대 PyTorch 딥러닝 강의
바로가기 🙌
[huggingface] BERT를 활용한 텍스트 분류(bbc news 데이터셋)
본 튜토리얼에서는 HuggingFace 의 transformers
라이브러리를 활용한 튜토리얼 입니다.
BERT Text Classification 사전학습(pre-trained) 모델과 토크나이저(Tokenizer)를 다운로드 후, BBC뉴스 데이터셋의 뉴스기사 카테고리 분류기를 학습 및 예측하는 튜토리얼 을 진행하겠습니다.
시드 고정
import os
import random
import numpy as np
import torch
import warnings
warnings.filterwarnings('ignore')
# 시드설정
SEED = 123
def seed_everything(seed=SEED):
random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
seed_everything(SEED)
샘플 예제파일 다운로드
import urllib
# bbc-text.csv 데이터셋 다운로드
url = 'https://storage.googleapis.com/download.tensorflow.org/data/bbc-text.csv'
urllib.request.urlretrieve(url, 'bbc-text.csv')
('bbc-text.csv', <http.client.HTTPMessage at 0x7fa4187bb460>)
데이터 로드
import json
from tqdm import tqdm
import numpy as np
import pandas as pd
# 데이터프레임을 로드 합니다.
df = pd.read_csv('bbc-text.csv')
df.head()
category | text | |
---|---|---|
0 | tech | tv future in the hands of viewers with home th... |
1 | business | worldcom boss left books alone former worldc... |
2 | sport | tigers wary of farrell gamble leicester say ... |
3 | sport | yeading face newcastle in fa cup premiership s... |
4 | entertainment | ocean s twelve raids box office ocean s twelve... |
HugginFace
-
Model & Tokenizer 다운로드
-
링크: https://huggingface.co/abhishek/autonlp-bbc-news-classification-37229289
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("abhishek/autonlp-bbc-news-classification-37229289")
model = AutoModelForSequenceClassification.from_pretrained("abhishek/autonlp-bbc-news-classification-37229289")
df.head()
category | text | |
---|---|---|
0 | tech | tv future in the hands of viewers with home th... |
1 | business | worldcom boss left books alone former worldc... |
2 | sport | tigers wary of farrell gamble leicester say ... |
3 | sport | yeading face newcastle in fa cup premiership s... |
4 | entertainment | ocean s twelve raids box office ocean s twelve... |
df['text'].iloc[0]
'tv future in the hands of viewers with home theatre systems plasma high-definition tvs and digital video recorders moving into the living room the way people watch tv will be radically different in five years time. that is according to an expert panel which gathered at the annual consumer electronics show in las vegas to discuss how these new technologies will impact one of our favourite pastimes. with the us leading the trend programmes and other content will be delivered to viewers via home networks through cable satellite telecoms companies and broadband service providers to front rooms and portable devices. one of the most talked-about technologies of ces has been digital and personal video recorders (dvr and pvr). these set-top boxes like the us s tivo and the uk s sky+ system allow people to record store play pause and forward wind tv programmes when they want. essentially the technology allows for much more personalised tv. they are also being built-in to high-definition tv sets which are big business in japan and the us but slower to take off in europe because of the lack of high-definition programming. not only can people forward wind through adverts they can also forget about abiding by network and channel schedules putting together their own a-la-carte entertainment. but some us networks and cable and satellite companies are worried about what it means for them in terms of advertising revenues as well as brand identity and viewer loyalty to channels. although the us leads in this technology at the moment it is also a concern that is being raised in europe particularly with the growing uptake of services like sky+. what happens here today we will see in nine months to a years time in the uk adam hume the bbc broadcast s futurologist told the bbc news website. for the likes of the bbc there are no issues of lost advertising revenue yet. it is a more pressing issue at the moment for commercial uk broadcasters but brand loyalty is important for everyone. we will be talking more about content brands rather than network brands said tim hanlon from brand communications firm starcom mediavest. the reality is that with broadband connections anybody can be the producer of content. he added: the challenge now is that it is hard to promote a programme with so much choice. what this means said stacey jolna senior vice president of tv guide tv group is that the way people find the content they want to watch has to be simplified for tv viewers. it means that networks in us terms or channels could take a leaf out of google s book and be the search engine of the future instead of the scheduler to help people find what they want to watch. this kind of channel model might work for the younger ipod generation which is used to taking control of their gadgets and what they play on them. but it might not suit everyone the panel recognised. older generations are more comfortable with familiar schedules and channel brands because they know what they are getting. they perhaps do not want so much of the choice put into their hands mr hanlon suggested. on the other end you have the kids just out of diapers who are pushing buttons already - everything is possible and available to them said mr hanlon. ultimately the consumer will tell the market they want. of the 50 000 new gadgets and technologies being showcased at ces many of them are about enhancing the tv-watching experience. high-definition tv sets are everywhere and many new models of lcd (liquid crystal display) tvs have been launched with dvr capability built into them instead of being external boxes. one such example launched at the show is humax s 26-inch lcd tv with an 80-hour tivo dvr and dvd recorder. one of the us s biggest satellite tv companies directtv has even launched its own branded dvr at the show with 100-hours of recording capability instant replay and a search function. the set can pause and rewind tv for up to 90 hours. and microsoft chief bill gates announced in his pre-show keynote speech a partnership with tivo called tivotogo which means people can play recorded programmes on windows pcs and mobile devices. all these reflect the increasing trend of freeing up multimedia so that people can watch what they want when they want.'
토큰화된 결과 확인
-
input_ids
,token_type_ids
,attention_mask
로 구성되어 있음 -
input_ids
는 변환된 id 토큰을 의미 -
token_type_ids
는 2개 이상의 문장이 이어지는 문장인지 아닌지를 판단하는 task를 수행
tokenized = tokenizer(df['text'].iloc[0], padding=True, truncation=True)
tokenized
{'input_ids': [101, 2694, 2925, 1999, 1996, 2398, 1997, 7193, 2007, 2188, 3004, 3001, 12123, 2152, 1011, 6210, 2694, 2015, 1998, 3617, 2678, 14520, 2015, 3048, 2046, 1996, 2542, 2282, 1996, 2126, 2111, 3422, 2694, 2097, 2022, 25796, 2367, 1999, 2274, 2086, 2051, 1012, 2008, 2003, 2429, 2000, 2019, 6739, 5997, 2029, 5935, 2012, 1996, 3296, 7325, 8139, 2265, 1999, 5869, 7136, 2000, 6848, 2129, 2122, 2047, 6786, 2097, 4254, 2028, 1997, 2256, 8837, 2627, 14428, 2015, 1012, 2007, 1996, 2149, 2877, 1996, 9874, 8497, 1998, 2060, 4180, 2097, 2022, 5359, 2000, 7193, 3081, 2188, 6125, 2083, 5830, 5871, 18126, 2015, 3316, 1998, 19595, 2326, 11670, 2000, 2392, 4734, 1998, 12109, 5733, 1012, 2028, 1997, 1996, 2087, 5720, 1011, 2055, 6786, 1997, 8292, 2015, 2038, 2042, 3617, 1998, 3167, 2678, 14520, 2015, 1006, 1040, 19716, 1998, 26189, 2099, 1007, 1012, 2122, 2275, 1011, 2327, 8378, 2066, 1996, 2149, 1055, 14841, 6767, 1998, 1996, 2866, 1055, 3712, 1009, 2291, 3499, 2111, 2000, 2501, 3573, 2377, 8724, 1998, 2830, 3612, 2694, 8497, 2043, 2027, 2215, 1012, 7687, 1996, 2974, 4473, 2005, 2172, 2062, 3167, 5084, 2694, 1012, 2027, 2024, 2036, 2108, 2328, 1011, 1999, 2000, 2152, 1011, 6210, 2694, 4520, 2029, 2024, 2502, 2449, 1999, 2900, 1998, 1996, 2149, 2021, 12430, 2000, 2202, 2125, 1999, 2885, 2138, 1997, 1996, 3768, 1997, 2152, 1011, 6210, 4730, 1012, 2025, 2069, 2064, 2111, 2830, 3612, 2083, 4748, 16874, 2015, 2027, 2064, 2036, 5293, 2055, 11113, 28173, 3070, 2011, 2897, 1998, 3149, 20283, 5128, 2362, 2037, 2219, 1037, 1011, 2474, 1011, 11122, 2063, 4024, 1012, 2021, 2070, 2149, 6125, 1998, 5830, 1998, 5871, 3316, 2024, 5191, 2055, 2054, 2009, 2965, 2005, 2068, 1999, 3408, 1997, 6475, 12594, 2004, 2092, 2004, 4435, 4767, 1998, 13972, 9721, 2000, 6833, 1012, 2348, 1996, 2149, 5260, 1999, 2023, 2974, 2012, 1996, 2617, 2009, 2003, 2036, 1037, 5142, 2008, 2003, 2108, 2992, 1999, 2885, 3391, 2007, 1996, 3652, 2039, 15166, 1997, 2578, 2066, 3712, 1009, 1012, 2054, 6433, 2182, 2651, 2057, 2097, 2156, 1999, 3157, 2706, 2000, 1037, 2086, 2051, 1999, 1996, 2866, 4205, 20368, 1996, 4035, 3743, 1055, 11865, 20689, 8662, 2409, 1996, 4035, 2739, 4037, 1012, 2005, 1996, 7777, 1997, 1996, 4035, 2045, 2024, 2053, 3314, 1997, 2439, 6475, 6599, 2664, 1012, 2009, 2003, 1037, 2062, 7827, 3277, 2012, 1996, 2617, 2005, 3293, 2866, 18706, 2021, 4435, 9721, 2003, 2590, 2005, 3071, 1012, 2057, 2097, 2022, 3331, 2062, 2055, 4180, 9639, 2738, 2084, 2897, 9639, 2056, 5199, 7658, 7811, 2013, 4435, 4806, 3813, 2732, 9006, 2865, 6961, 2102, 1012, 1996, 4507, 2003, 2008, 2007, 19595, 7264, 10334, 2064, 2022, 1996, 3135, 1997, 4180, 1012, 2002, 2794, 1024, 1996, 4119, 2085, 2003, 2008, 2009, 2003, 2524, 2000, 5326, 1037, 4746, 2007, 2061, 2172, 3601, 1012, 2054, 2023, 2965, 2056, 19997, 8183, 19666, 2050, 3026, 3580, 2343, 1997, 2694, 5009, 2694, 2177, 2003, 2008, 1996, 2126, 2111, 2424, 1996, 4180, 2027, 2215, 2000, 3422, 2038, 2000, 2022, 11038, 2005, 2694, 7193, 1012, 2009, 2965, 2008, 6125, 1999, 2149, 3408, 2030, 6833, 2071, 2202, 1037, 7053, 2041, 1997, 8224, 1055, 2338, 1998, 2022, 1996, 3945, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
input_ids
를 토큰으로 변환
# input_ids 10개만 출력
input_ids = tokenized['input_ids']
input_ids[:10]
[101, 2694, 2925, 1999, 1996, 2398, 1997, 7193, 2007, 2188]
- 시작 토큰은
[CLS]
, 끝나는 토큰은[SEP]
토큰이 위치합니다.
# 변환된 결과 확인
# 처음 20개의 토큰
print(tokenizer.convert_ids_to_tokens(input_ids)[:20])
['[CLS]', 'tv', 'future', 'in', 'the', 'hands', 'of', 'viewers', 'with', 'home', 'theatre', 'systems', 'plasma', 'high', '-', 'definition', 'tv', '##s', 'and', 'digital']
# 끝부분의 20개의 토큰
print(tokenizer.convert_ids_to_tokens(input_ids)[-20:])
['networks', 'in', 'us', 'terms', 'or', 'channels', 'could', 'take', 'a', 'leaf', 'out', 'of', 'google', 's', 'book', 'and', 'be', 'the', 'search', '[SEP]']
Label Map 생성
label_map = {
'sport': 0,
'business': 1,
'politics': 2,
'tech': 3,
'entertainment': 4
}
# 영문 Label을 숫자로 인코딩 변환
df['category_num'] = df['category'].map(label_map)
Dataset 분할
- 분할 비율: 0.8: 0.2
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(df['text'], df['category_num'],
stratify=df['category_num'],
test_size=0.2,
random_state=SEED
)
Batch Tokenization
- 배치(batch) 단위로 묶어서 토큰 처리
# truncation 시 최대길이 확인
tokenizer.model_max_length
512
-
truncation
:model_max_length
길이에 맞춰 잘라냄 -
padding
:model_max_length
길이보다 짧으면 패딩(padding_index=0)으로 채움
# 배치 단위의 사이즈를 맞춰 일괄 처리
batch_tokenized = tokenizer(df['text'].iloc[:10].tolist(), padding=True, truncation=True)
# (batch_size, model_max_length)
np.array(batch_tokenized['input_ids']).shape
(10, 512)
Dataset 생성
from torch.utils.data import DataLoader, Dataset
from torchtext.vocab import build_vocab_from_iterator
class CustomDataset(Dataset):
def __init__(self, texts, labels):
super().__init__()
self.texts = texts
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
text = self.texts.iloc[idx]
label = self.labels.iloc[idx]
return text, label
# Custom Dataset 생성
train_ds = CustomDataset(x_train, y_train)
valid_ds = CustomDataset(x_test, y_test)
# 1개의 데이터 추출
text, label = next(iter(train_ds))
len(text), label
(1665, 2)
DataLoader 생성
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
# torch 디바이스 지정 ('cpu', 'cuda:0' 혹은 cuda:1)
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print(device)
cuda:1
def collate_batch(batch, tokenizer):
text_list, label_list = [], []
for text, label in batch:
text_list.append(text)
label_list.append(label)
label_list = torch.tensor(label_list, dtype=torch.int64)
# padding을 주어 짧은 문장에 대한 길이를 맞춥니다.
text_tokenized = tokenizer(text_list, padding=True, truncation=True, return_tensors='pt')
return text_tokenized, label_list
train_loader = DataLoader(train_ds,
batch_size=8,
shuffle=True,
collate_fn=lambda x: collate_batch(x, tokenizer))
valid_loader = DataLoader(valid_ds,
batch_size=8,
shuffle=False,
collate_fn=lambda x: collate_batch(x, tokenizer))
x, y = next(iter(train_loader))
x = x.to(device)
y = y.to(device)
x['input_ids'].shape
torch.Size([8, 512])
x['input_ids']
tensor([[ 101, 7513, 3084, ..., 4358, 2012, 102], [ 101, 4977, 1011, ..., 0, 0, 0], [ 101, 19267, 7016, ..., 0, 0, 0], ..., [ 101, 3945, 5233, ..., 4471, 18288, 102], [ 101, 2203, 5747, ..., 7806, 2075, 102], [ 101, 24829, 2278, ..., 0, 0, 0]], device='cuda:1')
모델
from tqdm import tqdm # Progress Bar 출력
import numpy as np
import torch.nn as nn
import torch.optim as optim
# 사전학습 모델의 구조 확인
print(model)
BertForSequenceClassification( (bert): BertModel( (embeddings): BertEmbeddings( (word_embeddings): Embedding(30522, 1024, padding_idx=0) (position_embeddings): Embedding(512, 1024) (token_type_embeddings): Embedding(2, 1024) (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): BertEncoder( (layer): ModuleList( (0-23): 24 x BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=1024, out_features=1024, bias=True) (key): Linear(in_features=1024, out_features=1024, bias=True) (value): Linear(in_features=1024, out_features=1024, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=1024, out_features=1024, bias=True) (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=1024, out_features=4096, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=4096, out_features=1024, bias=True) (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) (pooler): BertPooler( (dense): Linear(in_features=1024, out_features=1024, bias=True) (activation): Tanh() ) ) (dropout): Dropout(p=0.1, inplace=False) (classifier): Linear(in_features=1024, out_features=5, bias=True) )
# 가중치 Freeze
for param in model.parameters():
param.requires_grad = False
# 모델의 classifier 부분 변경
model.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, 32),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.Linear(32, 5)
)
# 변경된 classifier 가중치 업데이트 가능 여부 확인
for param in model.classifier.parameters():
print(param.requires_grad)
True True True True True True True True True True
모델의 추론 결과
- logits 키 값으로 확률 값만 추출
# 입력의 각 키별(input_ids, token_type_ids, attention_mask) device 에 로드
inputs = {k: v.to(device) for k, v in x.items()}
inputs
{'input_ids': tensor([[ 101, 7513, 3084, ..., 4358, 2012, 102], [ 101, 4977, 1011, ..., 0, 0, 0], [ 101, 19267, 7016, ..., 0, 0, 0], ..., [ 101, 3945, 5233, ..., 4471, 18288, 102], [ 101, 2203, 5747, ..., 7806, 2075, 102], [ 101, 24829, 2278, ..., 0, 0, 0]], device='cuda:1'), 'token_type_ids': tensor([[0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], ..., [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0]], device='cuda:1'), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1], [1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0], ..., [1, 1, 1, ..., 1, 1, 1], [1, 1, 1, ..., 1, 1, 1], [1, 1, 1, ..., 0, 0, 0]], device='cuda:1')}
# 모델을 device 에 로드
model.to(device)
# inputs를 입력으로 추론
output = model(**inputs)
# SequenceClassifierOutput 결과 값 확인
output
SequenceClassifierOutput(loss=None, logits=tensor([[-0.2875, 0.3375, -0.1978, -0.3443, -0.0309], [ 0.9405, 0.1261, 0.2335, 0.4588, 0.7446], [-0.0254, -0.6984, -0.3783, 0.4920, 0.4153], [ 0.5402, 1.0535, -0.5411, -0.0864, -0.1885], [-0.1618, 0.4428, -0.1006, -0.4557, 0.1331], [-0.2652, 0.4402, -0.1474, -0.0528, -0.0064], [-0.1604, -0.3877, -0.6900, 0.4542, 0.3729], [ 1.1010, 0.3947, 0.4235, 0.3663, 1.0580]], device='cuda:1', ), hidden_states=None, attentions=None)
# 확률 값인 logits 추출
output.logits
tensor([[-0.2875, 0.3375, -0.1978, -0.3443, -0.0309], [ 0.9405, 0.1261, 0.2335, 0.4588, 0.7446], [-0.0254, -0.6984, -0.3783, 0.4920, 0.4153], [ 0.5402, 1.0535, -0.5411, -0.0864, -0.1885], [-0.1618, 0.4428, -0.1006, -0.4557, 0.1331], [-0.2652, 0.4402, -0.1474, -0.0528, -0.0064], [-0.1604, -0.3877, -0.6900, 0.4542, 0.3729], [ 1.1010, 0.3947, 0.4235, 0.3663, 1.0580]], device='cuda:1')
손실함수 및 옵티마이저 정의
# 모델을 device에 로드
model.to(device)
# loss 정의: CrossEntropyLoss
loss_fn = nn.CrossEntropyLoss()
# 옵티마이저 정의: bert.paramters()와 learning_rate 설정
optimizer = optim.Adam(model.parameters(), lr=0.00005)
def model_train(model, data_loader, loss_fn, optimizer, device):
# 모델을 훈련모드로 설정합니다. training mode 일 때 Gradient 가 업데이트 됩니다. 반드시 train()으로 모드 변경을 해야 합니다.
model.train()
# loss와 accuracy 계산을 위한 임시 변수 입니다. 0으로 초기화합니다.
running_loss = 0
corr = 0
counts = 0
# 예쁘게 Progress Bar를 출력하면서 훈련 상태를 모니터링 하기 위하여 tqdm으로 래핑합니다.
prograss_bar = tqdm(data_loader, unit='batch', total=len(data_loader), mininterval=1)
# mini-batch 학습을 시작합니다.
for idx, (txt, lbl) in enumerate(prograss_bar):
# txt, lbl 데이터를 device 에 올립니다. (cuda:0 혹은 cpu)
inputs = {k:v.to(device) for k, v in txt.items()}
lbl = lbl.to(device)
# 누적 Gradient를 초기화 합니다.
optimizer.zero_grad()
# Forward Propagation을 진행하여 결과를 얻습니다.
output = model(**inputs)
# 예측 값인 logits 만 추출합니다.
output = output.logits
# 손실함수에 output, lbl 값을 대입하여 손실을 계산합니다.
loss = loss_fn(output, lbl)
# 오차역전파(Back Propagation)을 진행하여 미분 값을 계산합니다.
loss.backward()
# 계산된 Gradient를 업데이트 합니다.
optimizer.step()
# Probability Max index 를 구합니다.
output = output.argmax(dim=1)
# 정답 개수를 구합니다.
corr += (output == lbl).sum().item()
counts += len(lbl)
# batch 별 loss 계산하여 누적합을 구합니다.
running_loss += loss.item()
# 프로그레스바에 학습 상황 업데이트
prograss_bar.set_description(f"training loss: {running_loss/(idx+1):.5f}, training accuracy: {corr / counts:.5f}")
# 누적된 정답수를 전체 개수로 나누어 주면 정확도가 산출됩니다.
acc = corr / len(data_loader.dataset)
# 평균 손실(loss)과 정확도를 반환합니다.
# train_loss, train_acc
return running_loss / len(data_loader), acc
def model_evaluate(model, data_loader, loss_fn, device):
# model.eval()은 모델을 평가모드로 설정을 바꾸어 줍니다.
# dropout과 같은 layer의 역할 변경을 위하여 evaluation 진행시 꼭 필요한 절차 입니다.
model.eval()
# Gradient가 업데이트 되는 것을 방지 하기 위하여 반드시 필요합니다.
with torch.no_grad():
# loss와 accuracy 계산을 위한 임시 변수 입니다. 0으로 초기화합니다.
corr = 0
running_loss = 0
# 배치별 evaluation을 진행합니다.
for txt, lbl in data_loader:
# txt, lbl 데이터를 device 에 올립니다. (cuda:0 혹은 cpu)
inputs = {k:v.to(device) for k, v in txt.items()}
lbl = lbl.to(device)
# 모델에 Forward Propagation을 하여 결과를 도출합니다.
output = model(**inputs)
# 예측 값인 logits 만 추출합니다.
output = output.logits
# 검증 손실을 구합니다.
loss = loss_fn(output, lbl)
# Probability Max index 를 구합니다.
output = output.argmax(dim=1)
# 정답 개수를 구합니다.
corr += (output == lbl).sum().item()
# batch 별 loss 계산하여 누적합을 구합니다.
running_loss += loss.item()
# validation 정확도를 계산합니다.
# 누적한 정답숫자를 전체 데이터셋의 숫자로 나누어 최종 accuracy를 산출합니다.
acc = corr / len(data_loader.dataset)
# 결과를 반환합니다.
# val_loss, val_acc
return running_loss / len(data_loader), acc
# 최대 Epoch을 지정합니다.
num_epochs = 5
# checkpoint로 저장할 모델의 이름을 정의 합니다.
model_name = 'BBC-Text-CLF-BERT'
min_loss = np.inf
# Epoch 별 훈련 및 검증을 수행합니다.
for epoch in range(num_epochs):
# Model Training
# 훈련 손실과 정확도를 반환 받습니다.
train_loss, train_acc = model_train(model, train_loader, loss_fn, optimizer, device)
# 검증 손실과 검증 정확도를 반환 받습니다.
val_loss, val_acc = model_evaluate(model, valid_loader, loss_fn, device)
# val_loss 가 개선되었다면 min_loss를 갱신하고 model의 가중치(weights)를 저장합니다.
if val_loss < min_loss:
print(f'[INFO] val_loss has been improved from {min_loss:.5f} to {val_loss:.5f}. Saving Model!')
min_loss = val_loss
torch.save(model.state_dict(), f'{model_name}.pth')
# Epoch 별 결과를 출력합니다.
print(f'epoch {epoch+1:02d}, loss: {train_loss:.5f}, acc: {train_acc:.5f}, val_loss: {val_loss:.5f}, val_accuracy: {val_acc:.5f}')
training loss: 0.67771, training accuracy: 0.94213: 100%|█| 223/223 [00:46<00:00
[INFO] val_loss has been improved from inf to 0.36258. Saving Model! epoch 01, loss: 0.67771, acc: 0.94213, val_loss: 0.36258, val_accuracy: 0.99326
training loss: 0.48028, training accuracy: 0.98764: 100%|█| 223/223 [00:47<00:00
[INFO] val_loss has been improved from 0.36258 to 0.31261. Saving Model! epoch 02, loss: 0.48028, acc: 0.98764, val_loss: 0.31261, val_accuracy: 0.99326
training loss: 0.42029, training accuracy: 0.99101: 100%|█| 223/223 [00:47<00:00
[INFO] val_loss has been improved from 0.31261 to 0.27140. Saving Model! epoch 03, loss: 0.42029, acc: 0.99101, val_loss: 0.27140, val_accuracy: 0.99551
training loss: 0.37186, training accuracy: 0.98820: 100%|█| 223/223 [00:47<00:00
[INFO] val_loss has been improved from 0.27140 to 0.24776. Saving Model! epoch 04, loss: 0.37186, acc: 0.98820, val_loss: 0.24776, val_accuracy: 0.99326
training loss: 0.34230, training accuracy: 0.98820: 100%|█| 223/223 [00:47<00:00
[INFO] val_loss has been improved from 0.24776 to 0.19589. Saving Model! epoch 05, loss: 0.34230, acc: 0.98820, val_loss: 0.19589, val_accuracy: 0.99326
저장한 가중치 로드
# 가중치 로드
model.load_state_dict(torch.load(f'{model_name}.pth'))
최종 검증손실 및 정확도 확인
model.eval()
with torch.no_grad():
val_loss, val_acc = model_evaluate(model, valid_loader, loss_fn, device)
print(f'loss: {val_loss:.5f}, accuracy: {val_acc:.5f}')
loss: 0.19589, accuracy: 0.99326
댓글남기기