🔥알림🔥
① 테디노트 유튜브 - 구경하러 가기!
② LangChain 한국어 튜토리얼 바로가기 👀
③ 랭체인 노트 무료 전자책(wikidocs) 바로가기 🙌
④ RAG 비법노트 LangChain 강의오픈 바로가기 🙌
⑤ 서울대 PyTorch 딥러닝 강의 바로가기 🙌

20 분 소요

본 튜토리얼에서는 HuggingFace 의 transformers 라이브러리를 활용한 튜토리얼 입니다.

BERT Text Classification 사전학습(pre-trained) 모델과 토크나이저(Tokenizer)를 다운로드 후, BBC뉴스 데이터셋의 뉴스기사 카테고리 분류기를 학습 및 예측하는 튜토리얼 을 진행하겠습니다.

시드 고정

import os
import random
import numpy as np
import torch
import warnings

warnings.filterwarnings('ignore')

# 시드설정
SEED = 123

def seed_everything(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(SEED)

샘플 예제파일 다운로드

import urllib

# bbc-text.csv 데이터셋 다운로드
url = 'https://storage.googleapis.com/download.tensorflow.org/data/bbc-text.csv'
urllib.request.urlretrieve(url, 'bbc-text.csv')
('bbc-text.csv', <http.client.HTTPMessage at 0x7fa4187bb460>)

데이터 로드

import json
from tqdm import tqdm
import numpy as np
import pandas as pd
    
# 데이터프레임을 로드 합니다.
df = pd.read_csv('bbc-text.csv')
df.head()
category text
0 tech tv future in the hands of viewers with home th...
1 business worldcom boss left books alone former worldc...
2 sport tigers wary of farrell gamble leicester say ...
3 sport yeading face newcastle in fa cup premiership s...
4 entertainment ocean s twelve raids box office ocean s twelve...

HugginFace

  • Model & Tokenizer 다운로드

  • 링크: https://huggingface.co/abhishek/autonlp-bbc-news-classification-37229289

from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("abhishek/autonlp-bbc-news-classification-37229289")

model = AutoModelForSequenceClassification.from_pretrained("abhishek/autonlp-bbc-news-classification-37229289")
df.head()
category text
0 tech tv future in the hands of viewers with home th...
1 business worldcom boss left books alone former worldc...
2 sport tigers wary of farrell gamble leicester say ...
3 sport yeading face newcastle in fa cup premiership s...
4 entertainment ocean s twelve raids box office ocean s twelve...
df['text'].iloc[0]
'tv future in the hands of viewers with home theatre systems  plasma high-definition tvs  and digital video recorders moving into the living room  the way people watch tv will be radically different in five years  time.  that is according to an expert panel which gathered at the annual consumer electronics show in las vegas to discuss how these new technologies will impact one of our favourite pastimes. with the us leading the trend  programmes and other content will be delivered to viewers via home networks  through cable  satellite  telecoms companies  and broadband service providers to front rooms and portable devices.  one of the most talked-about technologies of ces has been digital and personal video recorders (dvr and pvr). these set-top boxes  like the us s tivo and the uk s sky+ system  allow people to record  store  play  pause and forward wind tv programmes when they want.  essentially  the technology allows for much more personalised tv. they are also being built-in to high-definition tv sets  which are big business in japan and the us  but slower to take off in europe because of the lack of high-definition programming. not only can people forward wind through adverts  they can also forget about abiding by network and channel schedules  putting together their own a-la-carte entertainment. but some us networks and cable and satellite companies are worried about what it means for them in terms of advertising revenues as well as  brand identity  and viewer loyalty to channels. although the us leads in this technology at the moment  it is also a concern that is being raised in europe  particularly with the growing uptake of services like sky+.  what happens here today  we will see in nine months to a years  time in the uk   adam hume  the bbc broadcast s futurologist told the bbc news website. for the likes of the bbc  there are no issues of lost advertising revenue yet. it is a more pressing issue at the moment for commercial uk broadcasters  but brand loyalty is important for everyone.  we will be talking more about content brands rather than network brands   said tim hanlon  from brand communications firm starcom mediavest.  the reality is that with broadband connections  anybody can be the producer of content.  he added:  the challenge now is that it is hard to promote a programme with so much choice.   what this means  said stacey jolna  senior vice president of tv guide tv group  is that the way people find the content they want to watch has to be simplified for tv viewers. it means that networks  in us terms  or channels could take a leaf out of google s book and be the search engine of the future  instead of the scheduler to help people find what they want to watch. this kind of channel model might work for the younger ipod generation which is used to taking control of their gadgets and what they play on them. but it might not suit everyone  the panel recognised. older generations are more comfortable with familiar schedules and channel brands because they know what they are getting. they perhaps do not want so much of the choice put into their hands  mr hanlon suggested.  on the other end  you have the kids just out of diapers who are pushing buttons already - everything is possible and available to them   said mr hanlon.  ultimately  the consumer will tell the market they want.   of the 50 000 new gadgets and technologies being showcased at ces  many of them are about enhancing the tv-watching experience. high-definition tv sets are everywhere and many new models of lcd (liquid crystal display) tvs have been launched with dvr capability built into them  instead of being external boxes. one such example launched at the show is humax s 26-inch lcd tv with an 80-hour tivo dvr and dvd recorder. one of the us s biggest satellite tv companies  directtv  has even launched its own branded dvr at the show with 100-hours of recording capability  instant replay  and a search function. the set can pause and rewind tv for up to 90 hours. and microsoft chief bill gates announced in his pre-show keynote speech a partnership with tivo  called tivotogo  which means people can play recorded programmes on windows pcs and mobile devices. all these reflect the increasing trend of freeing up multimedia so that people can watch what they want  when they want.'

토큰화된 결과 확인

  • input_ids, token_type_ids, attention_mask로 구성되어 있음

  • input_ids 는 변환된 id 토큰을 의미

  • token_type_ids 는 2개 이상의 문장이 이어지는 문장인지 아닌지를 판단하는 task를 수행

tokenized = tokenizer(df['text'].iloc[0], padding=True, truncation=True)
tokenized
{'input_ids': [101, 2694, 2925, 1999, 1996, 2398, 1997, 7193, 2007, 2188, 3004, 3001, 12123, 2152, 1011, 6210, 2694, 2015, 1998, 3617, 2678, 14520, 2015, 3048, 2046, 1996, 2542, 2282, 1996, 2126, 2111, 3422, 2694, 2097, 2022, 25796, 2367, 1999, 2274, 2086, 2051, 1012, 2008, 2003, 2429, 2000, 2019, 6739, 5997, 2029, 5935, 2012, 1996, 3296, 7325, 8139, 2265, 1999, 5869, 7136, 2000, 6848, 2129, 2122, 2047, 6786, 2097, 4254, 2028, 1997, 2256, 8837, 2627, 14428, 2015, 1012, 2007, 1996, 2149, 2877, 1996, 9874, 8497, 1998, 2060, 4180, 2097, 2022, 5359, 2000, 7193, 3081, 2188, 6125, 2083, 5830, 5871, 18126, 2015, 3316, 1998, 19595, 2326, 11670, 2000, 2392, 4734, 1998, 12109, 5733, 1012, 2028, 1997, 1996, 2087, 5720, 1011, 2055, 6786, 1997, 8292, 2015, 2038, 2042, 3617, 1998, 3167, 2678, 14520, 2015, 1006, 1040, 19716, 1998, 26189, 2099, 1007, 1012, 2122, 2275, 1011, 2327, 8378, 2066, 1996, 2149, 1055, 14841, 6767, 1998, 1996, 2866, 1055, 3712, 1009, 2291, 3499, 2111, 2000, 2501, 3573, 2377, 8724, 1998, 2830, 3612, 2694, 8497, 2043, 2027, 2215, 1012, 7687, 1996, 2974, 4473, 2005, 2172, 2062, 3167, 5084, 2694, 1012, 2027, 2024, 2036, 2108, 2328, 1011, 1999, 2000, 2152, 1011, 6210, 2694, 4520, 2029, 2024, 2502, 2449, 1999, 2900, 1998, 1996, 2149, 2021, 12430, 2000, 2202, 2125, 1999, 2885, 2138, 1997, 1996, 3768, 1997, 2152, 1011, 6210, 4730, 1012, 2025, 2069, 2064, 2111, 2830, 3612, 2083, 4748, 16874, 2015, 2027, 2064, 2036, 5293, 2055, 11113, 28173, 3070, 2011, 2897, 1998, 3149, 20283, 5128, 2362, 2037, 2219, 1037, 1011, 2474, 1011, 11122, 2063, 4024, 1012, 2021, 2070, 2149, 6125, 1998, 5830, 1998, 5871, 3316, 2024, 5191, 2055, 2054, 2009, 2965, 2005, 2068, 1999, 3408, 1997, 6475, 12594, 2004, 2092, 2004, 4435, 4767, 1998, 13972, 9721, 2000, 6833, 1012, 2348, 1996, 2149, 5260, 1999, 2023, 2974, 2012, 1996, 2617, 2009, 2003, 2036, 1037, 5142, 2008, 2003, 2108, 2992, 1999, 2885, 3391, 2007, 1996, 3652, 2039, 15166, 1997, 2578, 2066, 3712, 1009, 1012, 2054, 6433, 2182, 2651, 2057, 2097, 2156, 1999, 3157, 2706, 2000, 1037, 2086, 2051, 1999, 1996, 2866, 4205, 20368, 1996, 4035, 3743, 1055, 11865, 20689, 8662, 2409, 1996, 4035, 2739, 4037, 1012, 2005, 1996, 7777, 1997, 1996, 4035, 2045, 2024, 2053, 3314, 1997, 2439, 6475, 6599, 2664, 1012, 2009, 2003, 1037, 2062, 7827, 3277, 2012, 1996, 2617, 2005, 3293, 2866, 18706, 2021, 4435, 9721, 2003, 2590, 2005, 3071, 1012, 2057, 2097, 2022, 3331, 2062, 2055, 4180, 9639, 2738, 2084, 2897, 9639, 2056, 5199, 7658, 7811, 2013, 4435, 4806, 3813, 2732, 9006, 2865, 6961, 2102, 1012, 1996, 4507, 2003, 2008, 2007, 19595, 7264, 10334, 2064, 2022, 1996, 3135, 1997, 4180, 1012, 2002, 2794, 1024, 1996, 4119, 2085, 2003, 2008, 2009, 2003, 2524, 2000, 5326, 1037, 4746, 2007, 2061, 2172, 3601, 1012, 2054, 2023, 2965, 2056, 19997, 8183, 19666, 2050, 3026, 3580, 2343, 1997, 2694, 5009, 2694, 2177, 2003, 2008, 1996, 2126, 2111, 2424, 1996, 4180, 2027, 2215, 2000, 3422, 2038, 2000, 2022, 11038, 2005, 2694, 7193, 1012, 2009, 2965, 2008, 6125, 1999, 2149, 3408, 2030, 6833, 2071, 2202, 1037, 7053, 2041, 1997, 8224, 1055, 2338, 1998, 2022, 1996, 3945, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

input_ids 를 토큰으로 변환

# input_ids 10개만 출력
input_ids = tokenized['input_ids']
input_ids[:10]
[101, 2694, 2925, 1999, 1996, 2398, 1997, 7193, 2007, 2188]
  • 시작 토큰은 [CLS], 끝나는 토큰은 [SEP] 토큰이 위치합니다.
# 변환된 결과 확인
# 처음 20개의 토큰
print(tokenizer.convert_ids_to_tokens(input_ids)[:20])
['[CLS]', 'tv', 'future', 'in', 'the', 'hands', 'of', 'viewers', 'with', 'home', 'theatre', 'systems', 'plasma', 'high', '-', 'definition', 'tv', '##s', 'and', 'digital']
# 끝부분의 20개의 토큰
print(tokenizer.convert_ids_to_tokens(input_ids)[-20:])
['networks', 'in', 'us', 'terms', 'or', 'channels', 'could', 'take', 'a', 'leaf', 'out', 'of', 'google', 's', 'book', 'and', 'be', 'the', 'search', '[SEP]']

Label Map 생성

label_map = {
    'sport': 0, 
    'business': 1, 
    'politics': 2, 
    'tech': 3, 
    'entertainment': 4
}

# 영문 Label을 숫자로 인코딩 변환
df['category_num'] = df['category'].map(label_map)

Dataset 분할

  • 분할 비율: 0.8: 0.2
from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(df['text'], df['category_num'], 
                                                    stratify=df['category_num'], 
                                                    test_size=0.2, 
                                                    random_state=SEED
                                                   )

Batch Tokenization

  • 배치(batch) 단위로 묶어서 토큰 처리

# truncation 시 최대길이 확인
tokenizer.model_max_length
512
  • truncation: model_max_length 길이에 맞춰 잘라냄

  • padding: model_max_length 길이보다 짧으면 패딩(padding_index=0)으로 채움

# 배치 단위의 사이즈를 맞춰 일괄 처리
batch_tokenized = tokenizer(df['text'].iloc[:10].tolist(), padding=True, truncation=True)
# (batch_size, model_max_length)
np.array(batch_tokenized['input_ids']).shape
(10, 512)

Dataset 생성

from torch.utils.data import DataLoader, Dataset
from torchtext.vocab import build_vocab_from_iterator


class CustomDataset(Dataset):
    def __init__(self, texts, labels):
        super().__init__()
        self.texts = texts
        self.labels = labels        
        
    def __len__(self):
        return len(self.labels)
        
    def __getitem__(self, idx):
        text = self.texts.iloc[idx]
        label = self.labels.iloc[idx]
        return text, label
# Custom Dataset 생성
train_ds = CustomDataset(x_train, y_train)
valid_ds = CustomDataset(x_test, y_test)
# 1개의 데이터 추출
text, label = next(iter(train_ds))
len(text), label
(1665, 2)

DataLoader 생성

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

# torch 디바이스 지정 ('cpu', 'cuda:0' 혹은 cuda:1)
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print(device)
cuda:1
def collate_batch(batch, tokenizer):
    text_list, label_list = [], []
    
    for text, label in batch:
        text_list.append(text)
        label_list.append(label)
    
    label_list = torch.tensor(label_list, dtype=torch.int64)
    
    # padding을 주어 짧은 문장에 대한 길이를 맞춥니다.
    text_tokenized = tokenizer(text_list, padding=True, truncation=True, return_tensors='pt')
    
    return text_tokenized, label_list
train_loader = DataLoader(train_ds, 
                          batch_size=8, 
                          shuffle=True, 
                          collate_fn=lambda x: collate_batch(x, tokenizer))

valid_loader = DataLoader(valid_ds, 
                          batch_size=8, 
                          shuffle=False, 
                          collate_fn=lambda x: collate_batch(x, tokenizer))
x, y = next(iter(train_loader))
x = x.to(device)
y = y.to(device)
x['input_ids'].shape
torch.Size([8, 512])
x['input_ids']
tensor([[  101,  7513,  3084,  ...,  4358,  2012,   102],
        [  101,  4977,  1011,  ...,     0,     0,     0],
        [  101, 19267,  7016,  ...,     0,     0,     0],
        ...,
        [  101,  3945,  5233,  ...,  4471, 18288,   102],
        [  101,  2203,  5747,  ...,  7806,  2075,   102],
        [  101, 24829,  2278,  ...,     0,     0,     0]], device='cuda:1')

모델

from tqdm import tqdm  # Progress Bar 출력
import numpy as np
import torch.nn as nn
import torch.optim as optim
# 사전학습 모델의 구조 확인
print(model)
BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=1024, out_features=1024, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=1024, out_features=5, bias=True)
)
# 가중치 Freeze
for param in model.parameters():
    param.requires_grad = False
# 모델의 classifier 부분 변경
model.classifier = nn.Sequential(
    nn.Linear(1024, 256),
    nn.BatchNorm1d(256),
    nn.ReLU(),
    nn.Linear(256, 32),
    nn.BatchNorm1d(32),
    nn.ReLU(),
    nn.Linear(32, 5)
)
# 변경된 classifier 가중치 업데이트 가능 여부 확인
for param in model.classifier.parameters():
    print(param.requires_grad)
True
True
True
True
True
True
True
True
True
True

모델의 추론 결과

  • logits 키 값으로 확률 값만 추출
# 입력의 각 키별(input_ids, token_type_ids, attention_mask) device 에 로드
inputs = {k: v.to(device) for k, v in x.items()}
inputs
{'input_ids': tensor([[  101,  7513,  3084,  ...,  4358,  2012,   102],
         [  101,  4977,  1011,  ...,     0,     0,     0],
         [  101, 19267,  7016,  ...,     0,     0,     0],
         ...,
         [  101,  3945,  5233,  ...,  4471, 18288,   102],
         [  101,  2203,  5747,  ...,  7806,  2075,   102],
         [  101, 24829,  2278,  ...,     0,     0,     0]], device='cuda:1'),
 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]], device='cuda:1'),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 0, 0, 0]], device='cuda:1')}
# 모델을 device 에 로드
model.to(device)

# inputs를 입력으로 추론
output = model(**inputs)
# SequenceClassifierOutput 결과 값 확인
output
SequenceClassifierOutput(loss=None, logits=tensor([[-0.2875,  0.3375, -0.1978, -0.3443, -0.0309],
        [ 0.9405,  0.1261,  0.2335,  0.4588,  0.7446],
        [-0.0254, -0.6984, -0.3783,  0.4920,  0.4153],
        [ 0.5402,  1.0535, -0.5411, -0.0864, -0.1885],
        [-0.1618,  0.4428, -0.1006, -0.4557,  0.1331],
        [-0.2652,  0.4402, -0.1474, -0.0528, -0.0064],
        [-0.1604, -0.3877, -0.6900,  0.4542,  0.3729],
        [ 1.1010,  0.3947,  0.4235,  0.3663,  1.0580]], device='cuda:1',
      ), hidden_states=None, attentions=None)
# 확률 값인 logits 추출
output.logits
tensor([[-0.2875,  0.3375, -0.1978, -0.3443, -0.0309],
        [ 0.9405,  0.1261,  0.2335,  0.4588,  0.7446],
        [-0.0254, -0.6984, -0.3783,  0.4920,  0.4153],
        [ 0.5402,  1.0535, -0.5411, -0.0864, -0.1885],
        [-0.1618,  0.4428, -0.1006, -0.4557,  0.1331],
        [-0.2652,  0.4402, -0.1474, -0.0528, -0.0064],
        [-0.1604, -0.3877, -0.6900,  0.4542,  0.3729],
        [ 1.1010,  0.3947,  0.4235,  0.3663,  1.0580]], device='cuda:1')

손실함수 및 옵티마이저 정의

# 모델을 device에 로드
model.to(device)

# loss 정의: CrossEntropyLoss
loss_fn = nn.CrossEntropyLoss()

# 옵티마이저 정의: bert.paramters()와 learning_rate 설정
optimizer = optim.Adam(model.parameters(), lr=0.00005)
def model_train(model, data_loader, loss_fn, optimizer, device):
    # 모델을 훈련모드로 설정합니다. training mode 일 때 Gradient 가 업데이트 됩니다. 반드시 train()으로 모드 변경을 해야 합니다.
    model.train()
    
    # loss와 accuracy 계산을 위한 임시 변수 입니다. 0으로 초기화합니다.
    running_loss = 0
    corr = 0
    counts = 0
    
    # 예쁘게 Progress Bar를 출력하면서 훈련 상태를 모니터링 하기 위하여 tqdm으로 래핑합니다.
    prograss_bar = tqdm(data_loader, unit='batch', total=len(data_loader), mininterval=1)
    
    # mini-batch 학습을 시작합니다.
    for idx, (txt, lbl) in enumerate(prograss_bar):
        # txt, lbl 데이터를 device 에 올립니다. (cuda:0 혹은 cpu)
        inputs = {k:v.to(device) for k, v in txt.items()}
        lbl = lbl.to(device)
        
        # 누적 Gradient를 초기화 합니다.
        optimizer.zero_grad()
        
        # Forward Propagation을 진행하여 결과를 얻습니다.
        output = model(**inputs)
        
        # 예측 값인 logits 만 추출합니다.
        output = output.logits
        
        # 손실함수에 output, lbl 값을 대입하여 손실을 계산합니다.
        loss = loss_fn(output, lbl)
        
        # 오차역전파(Back Propagation)을 진행하여 미분 값을 계산합니다.
        loss.backward()
        
        # 계산된 Gradient를 업데이트 합니다.
        optimizer.step()
        
        # Probability Max index 를 구합니다.
        output = output.argmax(dim=1)
        
        # 정답 개수를 구합니다.
        corr += (output == lbl).sum().item()
        counts += len(lbl)
        
        # batch 별 loss 계산하여 누적합을 구합니다.
        running_loss += loss.item()
        
        # 프로그레스바에 학습 상황 업데이트
        prograss_bar.set_description(f"training loss: {running_loss/(idx+1):.5f}, training accuracy: {corr / counts:.5f}")
        
    # 누적된 정답수를 전체 개수로 나누어 주면 정확도가 산출됩니다.
    acc = corr / len(data_loader.dataset)
    
    # 평균 손실(loss)과 정확도를 반환합니다.
    # train_loss, train_acc
    return running_loss / len(data_loader), acc
def model_evaluate(model, data_loader, loss_fn, device):
    # model.eval()은 모델을 평가모드로 설정을 바꾸어 줍니다. 
    # dropout과 같은 layer의 역할 변경을 위하여 evaluation 진행시 꼭 필요한 절차 입니다.
    model.eval()
    
    # Gradient가 업데이트 되는 것을 방지 하기 위하여 반드시 필요합니다.
    with torch.no_grad():
        # loss와 accuracy 계산을 위한 임시 변수 입니다. 0으로 초기화합니다.
        corr = 0
        running_loss = 0
        
        # 배치별 evaluation을 진행합니다.
        for txt, lbl in data_loader:
            # txt, lbl 데이터를 device 에 올립니다. (cuda:0 혹은 cpu)
            inputs = {k:v.to(device) for k, v in txt.items()}
            lbl = lbl.to(device)
    
            # 모델에 Forward Propagation을 하여 결과를 도출합니다.
            output = model(**inputs)
            
            # 예측 값인 logits 만 추출합니다.
            output = output.logits
            
            # 검증 손실을 구합니다.
            loss = loss_fn(output, lbl)
            
            # Probability Max index 를 구합니다.
            output = output.argmax(dim=1)
            
            # 정답 개수를 구합니다.
            corr += (output == lbl).sum().item()
            
            # batch 별 loss 계산하여 누적합을 구합니다.
            running_loss += loss.item()
        
        # validation 정확도를 계산합니다.
        # 누적한 정답숫자를 전체 데이터셋의 숫자로 나누어 최종 accuracy를 산출합니다.
        acc = corr / len(data_loader.dataset)
        
        # 결과를 반환합니다.
        # val_loss, val_acc
        return running_loss / len(data_loader), acc
# 최대 Epoch을 지정합니다.
num_epochs = 5

# checkpoint로 저장할 모델의 이름을 정의 합니다.
model_name = 'BBC-Text-CLF-BERT'

min_loss = np.inf

# Epoch 별 훈련 및 검증을 수행합니다.
for epoch in range(num_epochs):
    # Model Training
    # 훈련 손실과 정확도를 반환 받습니다.
    train_loss, train_acc = model_train(model, train_loader, loss_fn, optimizer, device)

    # 검증 손실과 검증 정확도를 반환 받습니다.
    val_loss, val_acc = model_evaluate(model, valid_loader, loss_fn, device)   
    
    # val_loss 가 개선되었다면 min_loss를 갱신하고 model의 가중치(weights)를 저장합니다.
    if val_loss < min_loss:
        print(f'[INFO] val_loss has been improved from {min_loss:.5f} to {val_loss:.5f}. Saving Model!')
        min_loss = val_loss
        torch.save(model.state_dict(), f'{model_name}.pth')
    
    # Epoch 별 결과를 출력합니다.
    print(f'epoch {epoch+1:02d}, loss: {train_loss:.5f}, acc: {train_acc:.5f}, val_loss: {val_loss:.5f}, val_accuracy: {val_acc:.5f}')
training loss: 0.67771, training accuracy: 0.94213: 100%|█| 223/223 [00:46<00:00
[INFO] val_loss has been improved from inf to 0.36258. Saving Model!
epoch 01, loss: 0.67771, acc: 0.94213, val_loss: 0.36258, val_accuracy: 0.99326
training loss: 0.48028, training accuracy: 0.98764: 100%|█| 223/223 [00:47<00:00
[INFO] val_loss has been improved from 0.36258 to 0.31261. Saving Model!
epoch 02, loss: 0.48028, acc: 0.98764, val_loss: 0.31261, val_accuracy: 0.99326
training loss: 0.42029, training accuracy: 0.99101: 100%|█| 223/223 [00:47<00:00
[INFO] val_loss has been improved from 0.31261 to 0.27140. Saving Model!
epoch 03, loss: 0.42029, acc: 0.99101, val_loss: 0.27140, val_accuracy: 0.99551
training loss: 0.37186, training accuracy: 0.98820: 100%|█| 223/223 [00:47<00:00
[INFO] val_loss has been improved from 0.27140 to 0.24776. Saving Model!
epoch 04, loss: 0.37186, acc: 0.98820, val_loss: 0.24776, val_accuracy: 0.99326
training loss: 0.34230, training accuracy: 0.98820: 100%|█| 223/223 [00:47<00:00
[INFO] val_loss has been improved from 0.24776 to 0.19589. Saving Model!
epoch 05, loss: 0.34230, acc: 0.98820, val_loss: 0.19589, val_accuracy: 0.99326

저장한 가중치 로드

# 가중치 로드
model.load_state_dict(torch.load(f'{model_name}.pth'))

최종 검증손실 및 정확도 확인

model.eval()

with torch.no_grad():
    val_loss, val_acc = model_evaluate(model, valid_loader, loss_fn, device)
    
    print(f'loss: {val_loss:.5f}, accuracy: {val_acc:.5f}')
loss: 0.19589, accuracy: 0.99326

댓글남기기