たった1行でデータ拡張!torchvision のコードまとめ!

たった一行でデータ拡張!torchvision のコードまとめ

こんにちは!キカガクでインターンをしている倉田です!

早速ですが、DataAugmentation を手軽に行える torchvision を知っていますか?
データ拡張をたった1行で実装できるすぐれものです!
今回はそんな torchvision のコードを実行例とともにわかりやすく紹介していきます!

この記事はこんな方にオススメです
  • DataAugmentationのやり方がわからない。。
  • データ拡張を使って高性能なモデルを作りたい!
  • torchvisionを自由に扱えるようになりたい!

なぜデータ拡張が必要?

データ拡張(DataAugmentation) では、訓練データに幾何変換(変形や回転など)をかけてデータを増幅させます。

データ拡張のメリットは、少数データの増幅画像認識性能の向上です。
具体的には、以下のようなメリットがあります。

  • 少数データを増幅して、大規模なデータに拡張できる。
  • 変形や回転した画像も正確に識別できるようになる。
  • 過学習が抑制できる。

データ拡張を行うことで、モデルの性能を大きく高めることができます!

それでは実際に、torchvision を使ってデータ拡張を体験してみましょう!

事前準備

CIFAR10 の画像の用意

今回はデモ用の画像として CIFAR10 を使ってみようと思います。
まずは事前準備として、モジュールと描画用の関数を実行しておきましょう!

事前準備
# モジュールのインポート
import matplotlib.pyplot as plt
from torchvision import transforms as transforms
import torch
import torchvision
import numpy as np

# 描画用の関数(チャンネル数の関係で、グレースケール画像とカラー画像で表示を分けています!)
def show_image(x):
  fig = plt.figure(figsize=(10,10))

  for s in range(len(x)):
      img = x[s]
      npimg = img.numpy()
      if npimg.shape[0]==1 :  # グレースケール用
        npimg = npimg.reshape(32, 32)
        ax1 = fig.add_subplot(1, len(x), s+1)
        plt.axis('off')
        plt.imshow(npimg, cmap='gray')
      else :  # カラー用
        npimg = np.transpose(npimg, (1, 2, 0))
        ax1 = fig.add_subplot(1, len(x), s+1)
        plt.axis('off')
        plt.imshow(npimg)

CIFAR10 の画像をダウンロードして,、 DataLoader を用意しましょう!

CIFAR10 の画像をダウンロード
# CIFAR10 をダウンロード
cifar10set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
cifar10loader = torch.utils.data.DataLoader(cifar10set, batch_size=3, shuffle=False)

では、実際に CIFAR10 の画像を 3 枚表示してみます!

変換前の画像の描画
# 変換前の画像の描画
iter_ = iter(cifar10loader)
x, t = next(iter_)
x, t = next(iter_)  # <- 2回繰り返してるのはカエルの画像が嫌だからです!

show_image(x)

シカと車の画像が出てきましたね!
この3枚の画像に torchvision.transforms を使って、様々なデータ拡張を施していきましょう!

torchvision.transforms コード一覧(形状変換)

リサイズ : Resize

画像サイズの変更を行います。今回は 32*32 の画像を 100*100 にリサイズしてみます。

引数

size : リサイズ後のサイズ ; ex.) 100*100

リサイズ : Resize
transform = transforms.Resize(size=(100, 100))
transed_x = transform(x)

show_image(transed_x)

中央切り抜き : CenterCrop

画像中央を指定サイズで切り抜きます。今回は 20*20 で切り抜いてみます。

引数

size : 切り抜きサイズ ; ex.) 20*20

中央切り抜き : CenterCrop
transform = transforms.CenterCrop(size=(20, 20))
transed_x = transform(x)

show_image(transed_x)

5 枚切り抜き: FiveCrop

画像の 4 隅と中央を指定サイズで切り抜きます。今回は 20*20 で切り抜いてみます。

引数

size : 切り抜きサイズ ; ex.) 20*20

5 枚切り抜き: FiveCrop
transform = transforms.FiveCrop(size=(20, 20))
transed_x = transform(x)

for i in range(5): # 4隅+中央で5セット出てくるので, 5回ループ
  show_image(transed_x[i])

10 枚切り抜き : TenCrop

画像の 4 隅と中央を指定サイズで切り抜きます。今回は 20*20 で切り抜いてみます。

引数

size : 切り抜きサイズ ; ex.) 20*20

10 枚切り抜き : TenCrop
transform = transforms.TenCrop(size=(20, 20))
transed_x = transform(x)

for i in range(10):# 4隅+中央で5セット * 画像2枚 = 10セット出てくるので, 10回ループ
  show_image(transed_x[i])

ランダム切り抜き : RandomCrop

画像のランダムな場所を指定サイズで切り抜きます。今回は 20*20 で切り抜いてみます。

引数

size : 切り抜きサイズ ; ex.) 20*20

ランダム切り抜き : RandomCrop
transform = transforms.RandomCrop(size=(20, 20))
transed_x = transform(x)

show_image(transed_x)

高度なランダム切り抜き : RandomResizedCrop

画像のランダムな場所を scaleretioに基づいて切り抜きます。その後, size の大きさにリサイズします。

引数
  • size : 切り抜き後にリサイズするサイズ
  • scale : スケール (縦, 横)
  • retio : アスペクト比 (縦, 横)
高度なランダム切り抜き : RandomResizedCrop
transform = transforms.RandomResizedCrop(size=(40, 40), scale=(0.08, 1.0), ratio=(3 / 4, 4 / 3))
transed_x = transform(x)

show_image(transed_x)

外周埋め込み : Pad

画像の周囲を数値で埋めます。今回は画像の周囲 3 ピクセルを様々な方法 (paddingmode) で埋めてみます。

引数
  • padding : 埋めるピクセル幅
  • padding_mode : 埋める方法 (パディング方法)

"constant" : 引数の fill 値で埋める (今回は fill=0)

外周埋め込み : Pad
transform = transforms.Pad(padding=3, padding_mode="constant", fill=0)
transed_x = transform(x)

show_image(transed_x)

"edge" : 画像の一番端の数値を埋める

外周埋め込み : Pad
transform = transforms.Pad(padding=3, padding_mode="edge")
transed_x = transform(x)

show_image(transed_x)

"reflect" : 画像の端から折り返して埋める (端の数値を繰り返さない)

外周埋め込み : Pad
transform = transforms.Pad(padding=3, padding_mode="reflect")
transed_x = transform(x)

show_image(transed_x)

"symmetric" : 画像の端から折り返して埋める (端の数値を繰り返す)

外周埋め込み : Pad
transform = transforms.Pad(padding=3, padding_mode="symmetric")
transed_x = transform(x)

show_image(transed_x)

左右反転 : RandomHorizontalFlip

確率 p で左右反転します。

引数

p : 適用される確率

左右反転 : RandomHorizontalFlip
transform = transforms.RandomHorizontalFlip(p=1.0)
transed_x = transform(x)

show_image(transed_x)

上下反転 : RandomVerticalFlip

確率 p で上下反転します。

引数

p : 適用される確率

上下反転 : RandomVerticalFlip
transform = transforms.RandomVerticalFlip(p=1.0)
transed_x = transform(x)

show_image(transed_x)

アフィン変換 : RandomAffine

ランダムにアフィン変換 ( = 回転, 平行移動, スケール変換) を施します。

引数
  • p : 適用される確率
  • degrees※ : 回転角の変動幅 ; ex.) -10° ~ 10°
  • translare※ : 平行移動量の変動幅 ; ex.) 0.1 ~ 0.3
  • scale※ : 画像の拡大(縮小)率の変動幅 ; ex.) ×0.5 ~ ×0.8※ 指定範囲内からランダムに選択されます。
アフィン変換 : RandomAffine
transform = transforms.RandomAffine(
    degrees=(-10, 10), translate=(0.1, 0.3), scale=(0.5, 0.8)
    ) 
transed_x = transform(x)

show_image(transed_x)

射影変換 : RandomPerspective

確率 p で射影変換 (画像を歪ませるような変換) を行います。

引数
  • p : 適用される確率
  • distortion_scale : 傾き度合い
射影変換 : RandomPerspective
transform = transforms.transforms.RandomPerspective(distortion_scale=0.3, p=1.0)
transed_x = transform(x)

show_image(transed_x)

回転変換 : RandomRotation

ランダムな回転角で回転変換を行います。

※ 指定範囲内からランダムに選択されます。

引数

degree※ : 回転角の変動幅 ; ex.) -30° ~ 30°※ 指定範囲内からランダムに選択されます。

回転変換 : RandomRotation
transform = transforms.RandomRotation(degrees=(-30, 30))
transed_x = transform(x)

show_image(transed_x)

torchvision.transforms コード一覧(濃度変換)

グレースケール化 : Grayscale

画像をグレースケール化します。
グレースケール化とは、カラー画像を白黒画像にする処理のことです。

引数

num_output_channels : 1でグレースケール/ 3でカラー

グレースケール化 : Grayscale
transform = transforms.Grayscale(num_output_channels=1)
transed_x = transform(x)

show_image(transed_x)

ランダムにグレースケール化 : Random Grayscale

画像のグレースケール化をランダムに行います。
先程のグレースケール化を確率pで実施します。

引数

p : 変換を行う確率

ランダムにグレースケール化 : Random Grayscale
transform = transforms.RandomGrayscale(p=1)
transed_x = transform(x)

show_image(transed_x)

ぼかす : GaussianBlur

画像をぼかすために、ガウシアンフィルタをかけます。

引数
  • kernel_size : ガウシアンフィルタの大きさ (大きいほど境界線が強くぼかされます。)
  • sigma : ぼかしの強さ ; (横方向, 縦方向)
ぼかす : GaussianBlur
transform = transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
transed_x = transform(x)

show_image(transed_x)

色調整 : Color Jitter

ランダムに明るさ、コントラスト、彩度、色相を変化させます。

引数
  • brightness : 明るさの変動幅
  • contrast : コントラスト
  • saturation : 彩度
  • hue : 色相
色調整 : Color Jitter
transform = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
transed_x = transform(x)

show_image(transed_x)

ネガポジ反転 : RandomInvert

明暗や色調を反転させるネガポジ反転を行います。写真のネガのような画像にする処理です。

引数

p : 変換を行う確率

ネガポジ反転 : RandomInvert
transform = transforms.RandomInvert(p=1.0)
transed_x = transform(x)

show_image(transed_x)

階調変換 : RandomPosterize

ランダムに階調変換を行います。
階調変換は、指定した階調数 \(2^{bits}\) 内で画像を再表現します。

注意 : 入力画像 x は, uint8 である必要があります。

引数
  • bits: 階調後の色の種類数(\(2^{bits}\));max=8
  • p : 変換を行う確率
階調変換 : RandomPosterize
transform = transforms.RandomPosterize(bits=3, p=1.0)
x_uint8 = (255*x).to(torch.uint8) # 0-1 -> 0-255 への逆正規化, float32->uint8 への変更

transed_x = transform(x_uint8)

show_image(transed_x)

ソラリゼーション : RandomSolarize

ランダムにソラリゼーションを行います。
ソラリゼーションとは、閾値threshhold以上の値をネガポジ反転する処理です。

引数

threshhold : この値以上の画像をネガポジ反転

ソラリゼーション : RandomSolarize
transform = transforms.RandomSolarize(threshold=0.5)
transed_x = transform(x)

show_image(transed_x)

鮮鋭化 : RandomAdjustSharpness

ランダムに鮮鋭化を行います。
鮮鋭化とは、ぼやけた画像のエッジを強調するような処理です。

引数
  • sharpness_factor : 0-ぼかす, 1-元画像, 2以上-値が多いほど鮮鋭化
  • p : 変換を行う確率
鮮鋭化 : RandomAdjustSharpness
transform = transforms.RandomAdjustSharpness(sharpness_factor=5, p=0.5)
transed_x = transform(x)

show_image(transed_x)

コントラスト調整: RandomAutocontrast

ランダムにコントラストの調整を行います。

引数

p : 変換を行う確率

コントラスト調整: RandomAutocontrast
transform = transforms.RandomAutocontrast(p=1.0)
transed_x = transform(x)

show_image(transed_x)

ヒストグラム平坦化: RandomEqualize

画素値のヒストグラムを全体的に平らにするヒストグラム平坦化を行います。
注意 : 入力画像 x は、uint8である必要があります。

引数

p : 変換を行う確率

ヒストグラム平坦化: RandomEqualize
transform = transforms.RandomEqualize(p=1.0)
x_uint8 = (255*x).to(torch.uint8) # 0-1 -> 0-255 への逆正規化, float32->uint8への変更

transed_x = transform(x_uint8)

show_image(transed_x)

まとめ

いかがだったでしょうか? torchvision.transforms を用いれば、多様なデータ拡張を簡単に実装できることが伝わったかと思います!

torchvision.transforms には、上記の変換処理を組み合わせて用いる Compose() など様々な便利ツールがあります! ぜひ皆さんも torchvision.transforms を使って、データ拡張や DataAugmentation にトライしてみてください!

参考:公式リファレンス | TRANSFORMING AND AUGMENTING IMAGES

こちらの記事もオススメ

まずは無料で学びたい方・最速で学びたい方へ

まずは無料で学びたい方: Python&機械学習入門コースがおすすめ

Python&機械学習入門コース

AI・機械学習を学び始めるならまずはここから!経産省の Web サイトでも紹介されているわかりやすいと評判の Python&機械学習入門コースが無料で受けられます!
さらにステップアップした脱ブラックボックスコースや、IT パスポートをはじめとした資格取得を目指すコースもなんと無料です!

無料で学ぶ

最速で学びたい方:キカガクの長期コースがおすすめ

一生学び放題

続々と転職・キャリアアップに成功中!受講生ファーストのサポートが人気のポイントです!

AI・機械学習・データサイエンスといえばキカガク!
非常に需要が高まっている最先端スキルを「今のうちに」習得しませんか?

無料説明会を週 2 開催しています。毎月受講生の定員がございますので確認はお早めに!

説明会ではこんなことをお話します!
  • 国も企業も育成に力を入れている先端 IT 人材とは
  • キカガクの研修実績
  • 長期コースでの学び方、できるようになること
  • 料金・給付金について
  • 質疑応答