NumPy入門 行列の演算

NumPyで基本的な行列を生成する方法について学習しました。ここでは生成した行列を計算する方法について学習します。

ndarrayの四則演算

まず、ndarrayの四則演算について確認しましょう。ndarrayには演算が定義されていますが、掛け算も加減算と同様、要素ごとに計算を行うだけのもの(いわゆるアダマール積)なので注意してください。つまり、一般的な行列の積(内積)とは異なります。

import numpy as np

a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# array([[1, 2, 3],
#        [4, 5, 6],
#        [7, 8, 9]])

b = np.array([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
# array([[10, 20, 30],
#        [40, 50, 60],
#        [70, 80, 90]])


a + b # 足し算
# array([[11, 22, 33],
#        [44, 55, 66],
#        [77, 88, 99]])

a - b # 引き算
# array([[ -9, -18, -27],
#        [-36, -45, -54],
#        [-63, -72, -81]])

a * b # 掛け算
# array([[ 10,  40,  90],
#        [160, 250, 360],
#        [490, 640, 810]])

a / b # 割り算
# array([[ 0.1,  0.1,  0.1],
#        [ 0.1,  0.1,  0.1],
#        [ 0.1,  0.1,  0.1]])

上記サンプルでは2つの行列a, bに対して演算を行っています。和差は通常の行列計算と同様ですが、積は要素ごとに掛け算されていることが確認できます。また、一般的な行列演算とは異なり要素ごとの割り算も定義されています。

行列の計算

では行列の和、差、積についてです。和、差はndarrayの演算になりますので細かい説明は省略します。

和・差

ndarrayの+、-で和・差の計算が可能です。先ほどのサンプルの和差部分のみ再掲します。

import numpy as np

a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# array([[1, 2, 3],
#        [4, 5, 6],
#        [7, 8, 9]])

b = np.array([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
# array([[10, 20, 30],
#        [40, 50, 60],
#        [70, 80, 90]])


a + b # 足し算
# array([[11, 22, 33],
#        [44, 55, 66],
#        [77, 88, 99]])

a - b # 引き算
# array([[ -9, -18, -27],
#        [-36, -45, -54],
#        [-63, -72, -81]])

積(内積)

それでは掛け算(内積)です。np.dotを使用します。行列a、bに対し、a×bを行う場合はnp.dot(a, b)と記述します。

import numpy as np

a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])

np.dot(a, b)
# array([[ 6, 12, 18],
#        [15, 30, 45],
#        [24, 48, 72]])
# 

もうひとつサンプルです。単位行列を左右から掛け算し、値が変わらないことを確認してみましょう。

import numpy as np

a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
e = np.eye(3)

np.dot(a, e)

# array([[ 1.,  2.,  3.],
#        [ 4.,  5.,  6.],
#        [ 7.,  8.,  9.]])


np.dot(e, a)
# array([[ 1.,  2.,  3.],
#        [ 4.,  5.,  6.],
#        [ 7.,  8.,  9.]])
# 

左右どちらからでも値が変わらないことが確認できました。