🔥알림🔥
① 테디노트 유튜브 - 구경하러 가기!
② LangChain 한국어 튜토리얼 바로가기 👀
③ 랭체인 노트 무료 전자책(wikidocs) 바로가기 🙌
④ RAG 비법노트 LangChain 강의오픈 바로가기 🙌
⑤ 서울대 PyTorch 딥러닝 강의 바로가기 🙌

2 분 소요

matplotlib이나 seaborn을 활용하여 시각화를 할 때 color, cmap, palette의 옵션 설정을 통해 그래프(시각화)의 색상을 쉽게 변경할 수 있습니다.

어떤 색상으로 설정하는가에 따라 그래프의 퀄리티가 더 좋아 보이기도 하고 더 유려한 시각적인 효과를 줄 수 있습니다.

색상 코드를 매번 검색이나 도큐먼트에서 찾는 것이 번거로워 이참에 정리해 보았습니다.

색상 코드는 도큐먼트에 있는 예제 코드를 그대로 활용하였으며, 색상 코드를 적용하는 간단한 예시도 같이 담아봤습니다.

Colors

from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

plt.rcParams['figure.figsize'] = (14, 10)


def plot_colortable(colors, title, sort_colors=True, emptycols=0):

    cell_width = 212
    cell_height = 22
    swatch_width = 48
    margin = 12
    topmargin = 40

    # Sort colors by hue, saturation, value and name.
    if sort_colors is True:
        by_hsv = sorted((tuple(mcolors.rgb_to_hsv(mcolors.to_rgb(color))),
                         name)
                        for name, color in colors.items())
        names = [name for hsv, name in by_hsv]
    else:
        names = list(colors)

    n = len(names)
    ncols = 4 - emptycols
    nrows = n // ncols + int(n % ncols > 0)

    width = cell_width * 4 + 2 * margin
    height = cell_height * nrows + margin + topmargin
    dpi = 72

    fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi)
    fig.subplots_adjust(margin/width, margin/height,
                        (width-margin)/width, (height-topmargin)/height)
    ax.set_xlim(0, cell_width * 4)
    ax.set_ylim(cell_height * (nrows-0.5), -cell_height/2.)
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.set_axis_off()
    ax.set_title(title, fontsize=24, loc="left", pad=10)

    for i, name in enumerate(names):
        row = i % nrows
        col = i // nrows
        y = row * cell_height

        swatch_start_x = cell_width * col
        text_pos_x = cell_width * col + swatch_width + 7

        ax.text(text_pos_x, y, name, fontsize=14,
                horizontalalignment='left',
                verticalalignment='center')

        ax.add_patch(
            Rectangle(xy=(swatch_start_x, y-9), width=swatch_width,
                      height=18, facecolor=colors[name], edgecolor='0.7')
        )

    return fig

plot_colortable(mcolors.BASE_COLORS, "Base Colors",
                sort_colors=False, emptycols=1)
plot_colortable(mcolors.TABLEAU_COLORS, "Tableau Palette",
                sort_colors=False, emptycols=2)

plot_colortable(mcolors.CSS4_COLORS, "CSS Colors")

plt.show()

Color를 plot에 적용한 예시

x = np.arange(100)
y1 = x
y2 = x*2
y3 = x*3
y4 = x*4
y5 = x*5
y6 = x*6
y7 = x*7

plt.plot(x, y1, color='lightcoral')
plt.plot(x, y2, color='orangered')
plt.plot(x, y3, color='olive')
plt.plot(x, y4, color='dodgerblue')
plt.plot(x, y5, color='midnightblue')
plt.plot(x, y6, color='darkviolet')
plt.plot(x, y7, color='deeppink')

plt.axis('off')
plt.show()

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm
from colorspacious import cspace_converter

cmaps = {}

gradient = np.linspace(0, 1, 256)
gradient = np.vstack((gradient, gradient))


def plot_color_gradients(category, cmap_list):
    # Create figure and adjust figure height to number of colormaps
    nrows = len(cmap_list)
    figh = 0.35 + 0.15 + (nrows + (nrows - 1) * 0.1) * 0.22
    fig, axs = plt.subplots(nrows=nrows + 1, figsize=(6.4, figh))
    fig.subplots_adjust(top=1 - 0.35 / figh, bottom=0.15 / figh,
                        left=0.2, right=0.99)
    axs[0].set_title(f'{category} colormaps', fontsize=14)

    for ax, name in zip(axs, cmap_list):
        ax.imshow(gradient, aspect='auto', cmap=plt.get_cmap(name))
        ax.text(-0.01, 0.5, name, va='center', ha='right', fontsize=12,
                transform=ax.transAxes)

    # Turn off *all* ticks & spines, not just the ones with colormaps.
    for ax in axs:
        ax.set_axis_off()

    # Save colormap list for later.
    cmaps[category] = cmap_list

Colormaps

plot_color_gradients('Perceptually Uniform Sequential',
                     ['viridis', 'plasma', 'inferno', 'magma', 'cividis'])

plot_color_gradients('Sequential',
                     ['Greys', 'Purples', 'Blues', 'Greens', 'Oranges', 'Reds',
                      'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu',
                      'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn'])

plot_color_gradients('Sequential (2)',
                     ['binary', 'gist_yarg', 'gist_gray', 'gray', 'bone',
                      'pink', 'spring', 'summer', 'autumn', 'winter', 'cool',
                      'Wistia', 'hot', 'afmhot', 'gist_heat', 'copper'])

plot_color_gradients('Diverging',
                     ['PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu',
                      'RdYlGn', 'Spectral', 'coolwarm', 'bwr', 'seismic'])

plot_color_gradients('Cyclic', ['twilight', 'twilight_shifted', 'hsv'])

plot_color_gradients('Qualitative',
                     ['Pastel1', 'Pastel2', 'Paired', 'Accent', 'Dark2',
                      'Set1', 'Set2', 'Set3', 'tab10', 'tab20', 'tab20b',
                      'tab20c'])

plot_color_gradients('Miscellaneous',
                     ['flag', 'prism', 'ocean', 'gist_earth', 'terrain',
                      'gist_stern', 'gnuplot', 'gnuplot2', 'CMRmap',
                      'cubehelix', 'brg', 'gist_rainbow', 'rainbow', 'jet',
                      'turbo', 'nipy_spectral', 'gist_ncar'])

plt.show()

Palette를 적용한 예시

import seaborn as sns
import warnings

warnings.filterwarnings('ignore')

df = sns.load_dataset('titanic')
sns.countplot(df['age'], palette='viridis')
plt.title('viridis')
plt.axis('off')
plt.show()

sns.countplot(df['age'], palette='Oranges')
plt.title('Oranges')
plt.axis('off')
plt.show()

sns.countplot(df['age'], palette='afmhot')
plt.title('afmhot')
plt.axis('off')
plt.show()

sns.countplot(df['age'], palette='coolwarm')
plt.title('coolwarm')
plt.axis('off')
plt.show()

sns.countplot(df['age'], palette='hsv')
plt.title('hsv')
plt.axis('off')
plt.show()

sns.countplot(df['age'], palette='tab20c')
plt.title('tab20c')
plt.axis('off')
plt.show()

sns.countplot(df['age'], palette='rainbow')
plt.title('rainbow')
plt.axis('off')
plt.show()

참고 (References)

  • https://matplotlib.org/stable/tutorials/colors/colormaps.html

  • https://matplotlib.org/stable/gallery/color/named_colors.html

댓글남기기