TensorFlow Datasets API 활용법
Apr 17, 2020

TensorFlow Datasets 는 다양한 데이터셋을 TensorFlow에서 활용하기 쉽도록 제공합니다. 굉장히 많고, 다양한 데이터셋이 학습하기 편한 형태로 제공 되기 때문에, 간단한 사용법만 알아두어도, 샘플로 모델을 돌려보고 학습하기에 매우 유용합니다.

References

텐서플로우 가이드를 살펴보시면, 세부 예제와 함께 소스코드를 제공합니다.

TensorFlow Datasets 라이브러리 설치

tensorflow만 설치하신 분들은 tensorflow-datasets를 별도로 설치해 주셔야 합니다.

!pip install tensorflow-datasets
import tensorflow_datasets as tfds
tfds.__version__
'3.0.0'

필요한 library import

import tensorflow_datasets as tfds
import tensorflow as tf

import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

list_builders: 데이터셋 종류 확인

tfds.list_builders()[:10]
['abstract_reasoning',
 'aeslc',
 'aflw2k3d',
 'amazon_us_reviews',
 'arc',
 'bair_robot_pushing_small',
 'beans',
 'big_patent',
 'bigearthnet',
 'billsum']

info: 데이터셋의 정보 확인

dataset, info = tfds.load(name='horses_or_humans', split=tfds.Split.TRAIN, with_info=True)
info
tfds.core.DatasetInfo(
    name='horses_or_humans',
    version=3.0.0,
    description='A large set of images of horses and humans.',
    homepage='http://laurencemoroney.com/horses-or-humans-dataset',
    features=FeaturesDict({
        'image': Image(shape=(300, 300, 3), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=2),
    }),
    total_num_examples=1283,
    splits={
        'test': 256,
        'train': 1027,
    },
    supervised_keys=('image', 'label'),
    citation="""@ONLINE {horses_or_humans,
    author = "Laurence Moroney",
    title = "Horses or Humans Dataset",
    month = "feb",
    year = "2019",
    url = "http://laurencemoroney.com/horses-or-humans-dataset"
    }""",
    redistribution_info=,
)

info 에서 데이터셋의 크기를 제공합니다.

이 정보를 활용하여 나중에 steps_per_epoch에 지정해 주어야할 값을 손쉽게 계산할 수 있습니다.

# info.splits에 train/test의 샘플 사이즈를 제공합니다.
info.splits
{'test': <tfds.core.SplitInfo num_examples=256>,
 'train': <tfds.core.SplitInfo num_examples=1027>}
train_size = info.splits['train'].num_examples
test_size = info.splits['test'].num_examples
print('train_size: {}\ntest_size: {} 개'.format(train_size, test_size))
train_size: 1027 개
test_size: 256 개

load: 데이터셋 로드

다음 1줄로 train dataset을 손쉽게 로드할 수 있습니다.

train_dataset = tfds.load(name="horses_or_humans", split=tfds.Split.TRAIN)
for data in train_dataset.take(1):
    image, label = data['image'], data['label']
    plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap("gray"))
    plt.axis('off')
    print("Label: %d" % label.numpy())
Label: 0

test dataset도 마찬가지로 split=tfds.Split.TEST 만 지정해 주면 로드 됩니다.

test_dataset = tfds.load("horses_or_humans", split=tfds.Split.TEST)
for data in test_dataset.take(1):
    image, label = data['image'], data['label']
    plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap("gray"))
    plt.axis('off')
    print("Label: %d" % label.numpy())
Label: 0

slicing: 원하는 만큼 데이터를 가져오기

slicing 기능을 내부적으로 지원합니다.

데이터의 갯수가 많은 경우 부분만 slicing 해서 가져올 수 있습니다.

문법은 다음과 같습니다. split=에 적절한 문자열로 지정해 줄 수 잇습니다.

# 원본 데이터 size
d, info = tfds.load('mnist', split='train', with_info=True)
info.splits['train'].num_examples, info.splits['test'].num_examples
(60000, 10000)
# 데이터 갯수를 세는 함수
def count_datasets(dataset):
    cnt = [x for x, y in enumerate(dataset)][-1] + 1
    return cnt
# 전체 데이터 가져오기
train_ds = tfds.load('mnist', split='train')

count_datasets(train_ds)
60000
# train, test 분리된 형태의 데이터
train_ds, test_ds = tfds.load('mnist', split=['train', 'test'])
print('train_ds: {} / test_ds {}'.format(count_datasets(train_ds), count_datasets(test_ds)))

# train과 test를 합친 데이터
train_test_ds = tfds.load('mnist', split='train+test')
print('train_test_ds: {}'.format(count_datasets(train_test_ds)))

# train: 10~20 index의 데이터
train_10_20_ds = tfds.load('mnist', split='train[10:20]')
print('train_10_20_ds: {}'.format(count_datasets(train_10_20_ds)))

# train: 처음 10% 데이터
train_10pct_ds = tfds.load('mnist', split='train[:10%]')
print('train_10pct_ds: {}'.format(count_datasets(train_10pct_ds)))

# train: 처음 10% + 마지막 80%
train_10_80pct_ds = tfds.load('mnist', split='train[:10%]+train[-80%:]')
print('train_10_80pct_ds: {}'.format(count_datasets(train_10_80pct_ds)))
train_ds: 60000 / test_ds 10000
train_test_ds: 70000
train_10_20_ds: 10
train_10pct_ds: 6000
train_10_80pct_ds: 54000

shuffle_files

dataset, info = tfds.load(name='mnist', 
                          split=tfds.Split.TRAIN, 
                          # shuffle_files 옵션을 주면, 로드 할 때 shuffle 수행
                          shuffle_files=True, 
                          with_info=True)

as_supervised: dict / tuple 형식으로 feed 받기

기본 값은 as_supervised=False입니다.

dataset으로부터 return 되는 data는 dict 형식을 갖습니다.

dataset = tfds.load(name='horses_or_humans', split=tfds.Split.TRAIN)

# dict 형식으로 받는 경우 (as_supervised=False)
for data in dataset.take(1):
    plt.imshow(data['image'])
    print(data['label'].numpy())
0

as_supervised=True 옵션을 주면 dict 형태가 아닌 tuple 형태로 데이터를 return 받습니다.

dataset = tfds.load(name='horses_or_humans', split=tfds.Split.TRAIN, as_supervised=True)
for image, label in dataset.take(1):
    plt.imshow(image)
    print(label.numpy())
0

map (함수를 매핑하기)

dataset_name = 'horses_or_humans'
dataset, info = tfds.load(name=dataset_name, split=tfds.Split.TRAIN, with_info=True)
def normalize(dataset):    
    image, label = tf.cast(dataset['image'], tf.float32) / 255.0, dataset['label']
    return image, label

위에서 만든 normalize 함수를 map할 수 있습니다.

train_dataset = dataset.map(normalize).batch(32)
dataset.map(normalize).batch(32)
<BatchDataset shapes: ((None, 300, 300, 3), (None,)), types: (tf.float32, tf.int64)>
 

shuffle

dataset에서 load 할 때 셔플하는 옵션과는 별개로 다시 셔플을 진행할 수 있습니다.

buffer_size 옵션을 반드시 지정해 주어야 합니다.

full shuffle을 위해서는 buffer_size == 전체 이미지 갯수로 정합니다.

dataset.map(normalize).shuffle(buffer_size=1000).batch(32)
<BatchDataset shapes: ((None, 300, 300, 3), (None,)), types: (tf.float32, tf.int64)>

repeat

솔직히 잘 안사용하는 함수이긴 한데, 데이터가 부족할 때 repeat 메소드를 통해 계속 피드될 수 있도록 합니다.

 

시각화

label_map은 단지 시각화에 대한 label을 지정하기 위함입니다.

train_dataset = dataset.map(normalize).batch(32)
label_map = {
    0: 'horse',
    1: 'human'
}
for image, label in train_dataset.take(1):
    fig, axes = plt.subplots(8, 4)
    fig.set_size_inches(10, 16)
    for i in range(32):
        axes[i//4, i%4].imshow(image[i])
        axes[i//4, i%4].axis('off')
        axes[i//4, i%4].set_title(label_map[label[i].numpy()], fontsize=15)
    plt.show()