こんにちは!キカガクでインターンをしている倉田です!
早速ですが、DataAugmentation を手軽に行える torchvision を知っていますか?
データ拡張をたった1行で実装できるすぐれものです!
今回はそんな torchvision のコードを実行例とともにわかりやすく紹介していきます!
- DataAugmentationのやり方がわからない。。
- データ拡張を使って高性能なモデルを作りたい!
- torchvisionを自由に扱えるようになりたい!

学習について無料で相談してみませんか?
独学で悩まれている方へ。理想のキャリアをヒアリングし、カウンセラーが最適な学習方法を一緒に考えます。希望者には講義を無料で体感できるセミナーもご案内!
なぜデータ拡張が必要?
データ拡張(DataAugmentation) では、訓練データに幾何変換(変形や回転など)をかけてデータを増幅させます。

データ拡張のメリットは、少数データの増幅と画像認識性能の向上です。
具体的には、以下のようなメリットがあります。
- 少数データを増幅して、大規模なデータに拡張できる。
- 変形や回転した画像も正確に識別できるようになる。
- 過学習が抑制できる。
データ拡張を行うことで、モデルの性能を大きく高めることができます!
それでは実際に、torchvision を使ってデータ拡張を体験してみましょう!
事前準備
CIFAR10 の画像の用意
今回はデモ用の画像として CIFAR10 を使ってみようと思います。
まずは事前準備として、モジュールと描画用の関数を実行しておきましょう!
# モジュールのインポート
import matplotlib.pyplot as plt
from torchvision import transforms as transforms
import torch
import torchvision
import numpy as np
# 描画用の関数(チャンネル数の関係で、グレースケール画像とカラー画像で表示を分けています!)
def show_image(x):
fig = plt.figure(figsize=(10,10))
for s in range(len(x)):
img = x[s]
npimg = img.numpy()
if npimg.shape[0]==1 : # グレースケール用
npimg = npimg.reshape(32, 32)
ax1 = fig.add_subplot(1, len(x), s+1)
plt.axis('off')
plt.imshow(npimg, cmap='gray')
else : # カラー用
npimg = np.transpose(npimg, (1, 2, 0))
ax1 = fig.add_subplot(1, len(x), s+1)
plt.axis('off')
plt.imshow(npimg)
CIFAR10 の画像をダウンロードして,、 DataLoader を用意しましょう!
# CIFAR10 をダウンロード
cifar10set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
cifar10loader = torch.utils.data.DataLoader(cifar10set, batch_size=3, shuffle=False)
では、実際に CIFAR10 の画像を 3 枚表示してみます!
# 変換前の画像の描画
iter_ = iter(cifar10loader)
x, t = next(iter_)
x, t = next(iter_) # <- 2回繰り返してるのはカエルの画像が嫌だからです!
show_image(x)

シカと車の画像が出てきましたね!
この3枚の画像に torchvision.transforms を使って、様々なデータ拡張を施していきましょう!
torchvision.transforms コード一覧(形状変換)
リサイズ : Resize
画像サイズの変更を行います。今回は 32*32 の画像を 100*100 にリサイズしてみます。
transform = transforms.Resize(size=(100, 100))
transed_x = transform(x)
show_image(transed_x)

中央切り抜き : CenterCrop
画像中央を指定サイズで切り抜きます。今回は 20*20 で切り抜いてみます。
transform = transforms.CenterCrop(size=(20, 20))
transed_x = transform(x)
show_image(transed_x)

5 枚切り抜き: FiveCrop
画像の 4 隅と中央を指定サイズで切り抜きます。今回は 20*20 で切り抜いてみます。
transform = transforms.FiveCrop(size=(20, 20))
transed_x = transform(x)
for i in range(5): # 4隅+中央で5セット出てくるので, 5回ループ
show_image(transed_x[i])

10 枚切り抜き : TenCrop
画像の 4 隅と中央を指定サイズで切り抜きます。今回は 20*20 で切り抜いてみます。
transform = transforms.TenCrop(size=(20, 20))
transed_x = transform(x)
for i in range(10):# 4隅+中央で5セット * 画像2枚 = 10セット出てくるので, 10回ループ
show_image(transed_x[i])

ランダム切り抜き : RandomCrop
画像のランダムな場所を指定サイズで切り抜きます。今回は 20*20 で切り抜いてみます。
transform = transforms.RandomCrop(size=(20, 20))
transed_x = transform(x)
show_image(transed_x)

高度なランダム切り抜き : RandomResizedCrop
画像のランダムな場所を scale
とretio
に基づいて切り抜きます。その後, size
の大きさにリサイズします。
transform = transforms.RandomResizedCrop(size=(40, 40), scale=(0.08, 1.0), ratio=(3 / 4, 4 / 3))
transed_x = transform(x)
show_image(transed_x)

外周埋め込み : Pad
画像の周囲を数値で埋めます。今回は画像の周囲 3 ピクセルを様々な方法 (paddingmode
) で埋めてみます。
– "constant"
: 引数の fill 値で埋める (今回は fill=0)
transform = transforms.Pad(padding=3, padding_mode="constant", fill=0)
transed_x = transform(x)
show_image(transed_x)

– "edge"
: 画像の一番端の数値を埋める
transform = transforms.Pad(padding=3, padding_mode="edge")
transed_x = transform(x)
show_image(transed_x)

– "reflect"
: 画像の端から折り返して埋める (端の数値を繰り返さない)
transform = transforms.Pad(padding=3, padding_mode="reflect")
transed_x = transform(x)
show_image(transed_x)

– "symmetric"
: 画像の端から折り返して埋める (端の数値を繰り返す)
transform = transforms.Pad(padding=3, padding_mode="symmetric")
transed_x = transform(x)
show_image(transed_x)

左右反転 : RandomHorizontalFlip
確率 p で左右反転します。
transform = transforms.RandomHorizontalFlip(p=1.0)
transed_x = transform(x)
show_image(transed_x)

上下反転 : RandomVerticalFlip
確率 p で上下反転します。
transform = transforms.RandomVerticalFlip(p=1.0)
transed_x = transform(x)
show_image(transed_x)

アフィン変換 : RandomAffine
ランダムにアフィン変換 ( = 回転, 平行移動, スケール変換) を施します。
transform = transforms.RandomAffine(
degrees=(-10, 10), translate=(0.1, 0.3), scale=(0.5, 0.8)
)
transed_x = transform(x)
show_image(transed_x)

射影変換 : RandomPerspective
確率 p で射影変換 (画像を歪ませるような変換) を行います。
transform = transforms.transforms.RandomPerspective(distortion_scale=0.3, p=1.0)
transed_x = transform(x)
show_image(transed_x)

回転変換 : RandomRotation
ランダムな回転角で回転変換を行います。
※ 指定範囲内からランダムに選択されます。
transform = transforms.RandomRotation(degrees=(-30, 30))
transed_x = transform(x)
show_image(transed_x)

torchvision.transforms コード一覧(濃度変換)
グレースケール化 : Grayscale
画像をグレースケール化します。
グレースケール化とは、カラー画像を白黒画像にする処理のことです。
transform = transforms.Grayscale(num_output_channels=1)
transed_x = transform(x)
show_image(transed_x)

ランダムにグレースケール化 : Random Grayscale
画像のグレースケール化をランダムに行います。
先程のグレースケール化を確率p
で実施します。
transform = transforms.RandomGrayscale(p=1)
transed_x = transform(x)
show_image(transed_x)

ぼかす : GaussianBlur
画像をぼかすために、ガウシアンフィルタをかけます。
transform = transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
transed_x = transform(x)
show_image(transed_x)

色調整 : Color Jitter
ランダムに明るさ、コントラスト、彩度、色相を変化させます。
transform = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
transed_x = transform(x)
show_image(transed_x)

ネガポジ反転 : RandomInvert
明暗や色調を反転させるネガポジ反転を行います。写真のネガのような画像にする処理です。
transform = transforms.RandomInvert(p=1.0)
transed_x = transform(x)
show_image(transed_x)

階調変換 : RandomPosterize
ランダムに階調変換を行います。
階調変換は、指定した階調数 \(2^{bits}\) 内で画像を再表現します。
注意 : 入力画像 x
は, uint8 である必要があります。
transform = transforms.RandomPosterize(bits=3, p=1.0)
x_uint8 = (255*x).to(torch.uint8) # 0-1 -> 0-255 への逆正規化, float32->uint8 への変更
transed_x = transform(x_uint8)
show_image(transed_x)

ソラリゼーション : RandomSolarize
ランダムにソラリゼーションを行います。
ソラリゼーションとは、閾値threshhold
以上の値をネガポジ反転する処理です。
transform = transforms.RandomSolarize(threshold=0.5)
transed_x = transform(x)
show_image(transed_x)

鮮鋭化 : RandomAdjustSharpness
ランダムに鮮鋭化を行います。
鮮鋭化とは、ぼやけた画像のエッジを強調するような処理です。
transform = transforms.RandomAdjustSharpness(sharpness_factor=5, p=0.5)
transed_x = transform(x)
show_image(transed_x)

コントラスト調整: RandomAutocontrast
ランダムにコントラストの調整を行います。
transform = transforms.RandomAutocontrast(p=1.0)
transed_x = transform(x)
show_image(transed_x)

ヒストグラム平坦化: RandomEqualize
画素値のヒストグラムを全体的に平らにするヒストグラム平坦化を行います。
注意 : 入力画像 x
は、uint8である必要があります。
transform = transforms.RandomEqualize(p=1.0)
x_uint8 = (255*x).to(torch.uint8) # 0-1 -> 0-255 への逆正規化, float32->uint8への変更
transed_x = transform(x_uint8)
show_image(transed_x)

まとめ
いかがだったでしょうか? torchvision.transforms を用いれば、多様なデータ拡張を簡単に実装できることが伝わったかと思います!
torchvision.transforms には、上記の変換処理を組み合わせて用いる Compose() など様々な便利ツールがあります! ぜひ皆さんも torchvision.transforms を使って、データ拡張や DataAugmentation にトライしてみてください!
参考:公式リファレンス | TRANSFORMING AND AUGMENTING IMAGES
最速で学びたい方:キカガクの長期コースがおすすめ

続々と転職・キャリアアップに成功中!受講生ファーストのサポートが人気のポイントです!
AI・機械学習・データサイエンスといえばキカガク!
非常に需要が高まっている最先端スキルを「今のうちに」習得しませんか?
無料説明会を週 2 開催しています。毎月受講生の定員がございますので確認はお早めに!
- 国も企業も育成に力を入れている先端 IT 人材とは
- キカガクの研修実績
- 長期コースでの学び方、できるようになること
- 料金・給付金について
- 質疑応答
まずは無料で学びたい方: Python&機械学習入門コースがおすすめ

AI・機械学習を学び始めるならまずはここから!経産省の Web サイトでも紹介されているわかりやすいと評判の Python&機械学習入門コースが無料で受けられます!
さらにステップアップした脱ブラックボックスコースや、IT パスポートをはじめとした資格取得を目指すコースもなんと無料です!

学習について無料で相談してみませんか?
独学で悩まれている方へ。理想のキャリアをヒアリングし、カウンセラーが最適な学習方法を一緒に考えます。希望者には講義を無料で体感できるセミナーもご案内!