webエンジニアの日常

RubyやPython, JSなど、IT関連の記事を書いています

(真)多次元配列のドット積の次元について

こんにちは、エンジニアのさもです

以前、「多次元配列のドット積の次元について」というタイトルで記事を書きました。

numpyのdot関数についての考察でしたが、今回はもう少し詳しく調べてみました。

www.uosansatox.biz

スポンサーリンク

dot関数については、こちらのサイトに詳細が書いてあります

numpy.dot — NumPy v1.13 Manual

dot関数は1次元配列のときは、通常のベクトルの内積、2次元配列のときは、行列の積と同じように計算されます。

N次元の場合は以下のようなルールで計算されます。

dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])

具体的には、例えば、

import numpy as np
a = np.ones([2,3,4])
b = np.ones([5,4,2])

であったとします。

このとき、np.dot(a,b)(2, 3, 5, 2)という形の多次元配列になります。

aからは一番最後の軸以外の数、bからは一番最後から2番目の軸以外の数をとってきて結合させると、出力のサイズになりますね。

さて、計算ルールに当てはめると、

np.dot(a,b)[1,0,3,1] = sum(a[1,0,:]*b[3,:,1])

ということになります。

右辺のそれぞれは、どちらも同じサイズの1次配列です。

a[1,0,:]
#=>array([ 1.,  1.,  1.,  1.])
b[3,:,1]
#=>array([ 1.,  1.,  1.,  1.])

なので、N次元の配列をドット積する場合は、aの一番最後の軸のサイズと、bの最後から2番目の軸のサイズが同じでないといけません。

配列がより深い次元になっても同様です。

たとえば、aが(2,3,4,5), bが(3,4,5,6)の場合は、

a = np.arange(2*3*4*5).reshape((2,3,4,5))
b = np.arange(3*4*5*6).reshape((3,4,5,6))
np.dot(a,b).shape          #=> (2, 3, 4, 3, 4, 6)
np.dot(a,b)[0,0,0,1,1,1]   #=> 1690
sum(a[0,0,0,:]*b[1,1,:,1]) #=> 1690

となります。

ちなみに、行列の積は、最初の計算ルールからj,kを除いた式

dot(a, b)[i,m] = sum(a[i,:] * b[:,m])

と同じになっていて、ベクトルの内積は、さらにi,mを除いた式

dot(a, b) = sum(a[:] * b[:])

と同じになっています。