
こんにちは、機械学習の講師をしている木下です!
ディープラーニングのコーディングをしていく際には、PyTorch などのフレームワークを使う機会が多いと思います。そんな中でよく出てくる、テンソルや次元確認・次元操作という言葉に面食らった経験もあるのではないでしょうか。
PyTorch で次元操作をする際にも、permute,transpose,reshape,view など様々な関数・メソッドが存在し混乱を招く原因となっています。
そこで、本記事では PyTorch の次元操作を徹底解説していきます!今までなんとなく関数やメソッドを使用していた方もこの記事を読めば、きちんと理解しながら PyTorch のコーディングを行うことができます!
こんな人におすすめ!
今回扱う関数・メソッドは以下の通りです。また、これらに加えて次元確認のメソッドについても確認していきましょう。
目次
予め PyTorch をインポートしておきましょう。PyTorch は Google Colabratory 上には予めインストールされているので、インストールは不要です。
# インポート import torch # バージョンの確認 torch.__version__ """ 1.12.1+cu113 """
それでは、これから扱っていくサンプルデータを作成します。PyTorch では randn を利用することで、任意の次元のデータをランダムに作成することができます。ここでは、各次元の要素数が 2, 3, 5 の 3 次元データを作成してみます。
# データの作成 x = torch.randn(2, 3, 5) # データの確認(環境によって値は代わります) print(x) """ tensor([[[-0.6246, 1.0766, 1.4077, 1.3389, 1.3747], [ 1.4110, -0.6644, -0.2251, 0.3428, 1.1569], [ 0.0392, 0.5391, 0.0154, 1.3016, -0.8394]], [[ 2.3794, -0.0364, 1.2753, -0.6133, -0.8352], [ 1.0868, 1.0627, 1.3589, -1.4522, -1.7674], [-2.6390, -0.6833, 2.4647, 1.1037, 0.2417]]])"""
PyTorch ではデータの次元数や要素数が異なるとコードの実行時にエラーがでます。以下のコードを活用し、常にデータの形状を確認する習慣をつけましょう。
# データの次元数確認のメソッド dim x.dim() # データの次元数確認のメソッド ndimension も全く同等のメソッド x.ndimension() # ndim 属性を参照しても同様 x.ndim
いずれを実行しても、今回の次元数である 3 が表示されます。
それぞれの要素数は以下のコードで確認しましょう。
# PyTorch の size メソッドを用いることが一般的 x.size() # numpy と同じように shape 属性でも確認可能 x.shape
いずれもそれぞれの要素数、torch.Size([2, 3, 5]) が表示されます。
それでは、各関数・メソッドの違いについてみていきましょう。
まずは、permute について解説していきます。
permute は軸(次元)を並び替えます。第一引数に並び替えたいテンソル、第二引数に並び替える順番をタプル型で指定します。
# 第一引数に並び替えたいテンソル、第二引数に並び替える順番をタプル型 x_permute = torch.permute(x, (2, 0, 1)) x_permute.size()
例えば、上記のように 2, 0, 1 と指定すると元々の軸(次元)の 2 番目、0 番目、1 番目の軸の順番に並び替えられるので、torch.Size([5, 2, 3]) のように形が変わります。
少し発展的ですが、2 次元のテンソルの場合、転置(torch.t())した後のテンソルにも適用可能です。
# 転置用データの作成(2 次元である必要がある) y = torch.randn(2, 3) # データの確認(環境によって値は代わります) print(y) """ tensor([[ 0.6869, 0.3730, -0.0598], [ 0.8805, -0.0262, 1.2460]])"""
# 転置後に要素を入れ替える y_permute = torch.t(y).permute(1, 0) y_permute.size()
上記を実行すると転置で次元が入れ替わった後、さらに入れ替えをおこなっているため、作成時と同様の torch.Size([2, 3]) という形になっていることがわかります。
それでは次に transpose を見てみましょう。
transpose は二つの軸(次元)を入れ替えるための関数です。
# 第一引数にデータ、第二引数、第三引数に入れ替えたい軸を指定 x_transpose = torch.transpose(x, 2, 1) x_transpose.size()
上記のコードでは 2 番目と 1 番目の軸を入れ替えているためサイズは torch.Size([2, 5, 3]) のようになります(Python は 0 はじまりなことに注意してください)。
ここで、3 つ以上の軸を並び替えることを試みてみます。
# 第一引数にデータ、第二引数、第三引数、第四引数に入れ替えたい軸を指定 x_transpose = torch.transpose(x, 2, 1, 0) x_transpose.size() # 実行結果 """ TypeError: transpose() received an invalid combination of arguments - got (Tensor, int, int, int), but expected one of: (Tensor input, int dim0, int dim1) (Tensor input, name dim0, name dim1) """
実は上記のように、エラーが起きてしまうため 3 つ以上の軸を同時に並び替えることはできません。
また、transpose も permute と同様に転置後の処理が可能な関数です。
次に見かけることも多い reshape を見ていきます。
reshape は軸の並び替えだけでなく、次元数や要素数を変更することができます。
ただし、合計の要素数が合うように気をつける必要があります。例えば、今回の torch.Size([2, 3, 5]) であれば、全てを掛け合わせた 30 に要素数を合わせる必要があります。例えば、
# 第一引数にデータ、第二引数に各軸の要素数を指定 x_reshape = torch.reshape(x, (1, 5, 6)) x_reshape.size()
上記のコードでは 3 次元を保ったまま各要素数を変更しています。torch.Size([1, 5, 6]) と表示されたのではないでしょうか。以下のように軸の数を増減させることも可能です。
# 第一引数にデータ、第二引数に各軸の要素数を指定(要素数の掛け算が 30 になっていることに注意) x_reshape = torch.reshape(x, (1, 5, 2, 3, 1)) x_reshape.size()
試しに、要素数の合計が異なるため、エラーが起きてしまう例も見ておきましょう。
# 第一引数にデータ、第二引数に各軸の要素数を指定(要素数の掛け算が 30 になっていないことに注意) x_reshape = torch.reshape(x, (1, 2, 3)) x_reshape.size() # 実行結果 """ RuntimeError: shape '[1, 2, 3]' is invalid for input of size 30 """
また、 reshape は転置後の処理が可能です。これは view と異なる点になります。
発展的な内容ですが、reshape はメモリ上で要素順に並んでいない場合は、コピーを作ってから処理をするので、このような処理が可能になっています。
最後に view について解説を行っていきます。view は基本的に reshape と同じような操作が可能です。
# 第一引数にデータ、第二引数に各軸の要素数を指定 x_view = x.view(1, 5, 6) x_view.size()
上記のように reshape と全く同じ使い方で、出力のサイズも torch.Size([1, 5, 6]) と同様のものになります。
それでは何が違うのでしょうか。試しに以下のコードを実行してみてください。
# 転置後のデータの第一引数にデータ、第二引数に各軸の要素数を指定 y_view = torch.t(y).view(1, 6) y_view.size() # 実行結果 """ RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. """
エラーが出てしまいます。実は view は転置後の処理ができないのです。発展的な内容ですが view はメモリ上で要素順に並んでいない場合は、処理ができずエラーとなってしまいます。
view で無理やり実行する場合には以下のように一度メモリ上でデータを並び替える必要がある点に注意しましょう。
# contiguous() でメモリ上の並びかえをしてから実行 y_view = torch.t(y).contiguous().view(1, 6) y_view.size()
これらの使い分けには決まったルールはありませんが、以下の点を意識しておくとコーディングが楽になるかもしれません!
permute の方が上位互換のように見えますが、実際には transpose による2軸の入れ替えで事足りてしまうことの方が多いです。
引数の指定が transpose の方が少なくて済むので、ネットで調べると transpose を使う例が多く見つかります。
メモリについては、発展的な内容なため、慣れてくるまで理解する必要はあまりないと思います。
重要なのは reshape はメモリ上にコピーを取る可能性があるので、演算処理の負荷がその分重くなる場合があるという点を意識しておくことです。
ディープラーニングの層が深くなると、その分演算処理に負荷がかかるので、処理を軽くするために view を使うことが多いというのが現状です。
簡単なモデルであれば reshape, 複雑なモデルであれば view と使い分けてみてください。
本記事では、初学者がつまづきがちな次元削減の関数について整理し、使い分けのアドバイスまで行いました。
PyTorch を用いると 100% 使用する関数なのでここでぜひ理解しておきましょう!
以上、Python 学習している方々のお力添えになれば幸いです!
.jpg&w=3840&q=75)
キカガクの長期コースはプログラミング経験ゼロの初学者が最先端技術を使いこなすAIエンジニアになるためのサポート体制が整っています!
実際に未経験からの転職・キャリアアップに続々と成功中です
まずは無料説明会で、キカガクのサポート体制を確認しにきてください!
説明会ではこんなことをお話します!
.png&w=3840&q=75)
AI・機械学習を学び始めるならまずはここから!経産省の Web サイトでも紹介されているわかりやすいと評判の Python&機械学習入門コースが無料で受けられます!
さらにステップアップした脱ブラックボックスコースや、IT パスポートをはじめとした資格取得を目指すコースもなんと無料です!
参考 | Python 3.9.2 ドキュメント
確認する参考 | Chainer チュートリアル
確認するSHARE
AI/データサイエンス学びはじめの方におすすめの記事
コース一覧
注目記事
新着記事