Helve’s Python memo

Pythonを使った機械学習やデータ分析の備忘録

<NumPy> 統計の関数

目次

最大・最小

>>> import numpy as np
>>> a = np.array([[1, 2, 3], [4, 5, 6]])
>>> np.amax(a)         # 全要素の最大値
6
>>> np.amax(a, axis=0) # 列方向の最大値
array([4, 5, 6])
>>> np.amax(a, axis=1) # 行方向の最大値
array([3, 6])
>>> np.amax(a, axis=1, keepdims=True) # 元の配列構造を維持
array([[3],
       [6]])
>>> np.amin(a)         # 全要素の最小値
1
>>> np.argmax(a)       # 最大要素のインデックスを返す
>>> # 配列は1次元に変換 (flatten) される
5
>>> np.argmax(a, axis=0) # 列方向の最大要素インデックス
array([1, 1, 1], dtype=int64)
>>> np.argmin(a)         # 最小要素のインデックスを返す
0
>>> np.ptp(a)            # 最大値-最小値 (Peak To Peak)
5
>>> np.ptp(a, axis=0)    # 列方向の最大値-最小値
array([3, 3, 3])

平均と分散

>>> a = np.array([[1, 2, 3], [4, 5, 6]])
>>> np.median(a)          # 全要素の中央値
3.5
>>> np.median(a, axis=0)  # 列方向の中央値
array([ 2.5,  3.5,  4.5])
>>> np.average(a)         # 全要素の平均値
3.5
>>> np.average(a, axis=0) # 列方向の平均値
array([ 2.5,  3.5,  4.5])
>>> np.average([1, 2, 3], weights=[2.5, 1.5, 0.5]) # 重みづけ平均
>>> # (2.5*1 + 1.5*2 + 0.5*3) / (2.5+1.5+0.5)
1.5555555555555556
>>> np.mean(a) # 全要素の平均値(axisオプションはあるが、weightsオプションはない)
3.5
>>> np.std(a)          # 標準偏差 (axisオプションあり)
1.707825127659933
>>> np.std(a, ddof=1)  # 不偏標準偏差 (標本数-ddofで標準偏差を計算。ddof=0がデフォルト)
1.8708286933869707
>>> np.var(a)          # 標本分散 (axisオプションあり)
2.9166666666666665
>>> np.var(a, ddof=1)  # 不偏分散 (標本数-ddofで分散を計算。ddof=0がデフォルト)
3.5

相関係数

>>> a = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> b = np.array([1, 0, 4, 2, 6, 3, 5, 9, 8, 7])
>>> c = np.array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
>>> x = np.vstack([a, b, c])
>>> np.corrcoef(x)         # 相関係数行列
>>> # [0, 1]要素:aとbの相関係数
>>> # [0, 2]要素:aとcの相関係数
>>> # [1, 2]要素:bとcの相関係数
array([[ 1.        ,  0.85454545, -1.        ],
       [ 0.85454545,  1.        , -0.85454545],
       [-1.        , -0.85454545,  1.        ]])
>>> np.corrcoef(a, b) # 2つの1次元配列を引数にとっても良い
array([[ 1.        ,  0.85454545],
       [ 0.85454545,  1.        ]])
>>> np.cov(x)  # 共分散行列 (ddof=1がデフォルト)
array([[ 9.16666667,  7.83333333, -9.16666667],
       [ 7.83333333,  9.16666667, -7.83333333],
       [-9.16666667, -7.83333333,  9.16666667]])

参考
Statistics — NumPy v1.13 Manual