修改陣列形狀
在處理多維陣列資料時,會遇到改變資料結構的狀況,例如一維陣列改成三維陣列、二維陣列改成一維陣列...等,而 NumPy 提供了 reshape、flatten...等方法,可以很快速的改變陣列的排列,這篇教學將會介紹相關的用法。
本篇使用的 Python 版本為 3.7.12,所有範例可使用 Google Colab 實作,不用安裝任何軟體 ( 參考:使用 Google Colab )
修改陣列形狀的方法
NumPy 有下列幾種改變陣列形狀的方法:
方法 | 說明 |
---|---|
reshape() | 改變陣列形狀。 |
flatten()、numpy.ravel() | 扁平化陣列。 |
numpy.transpose()、T | 互換維度。 |
numpy.rollaxis()、numpy.swapaxes() | 根據指定「軸」,將陣列項目「滾動」或「交換」位置。 |
reshape()
reshape() 可以將現有的陣列,轉換為特定維度的陣列,使用時必須注意特定維度的項目總數,要和原本的陣列相同,下方的例子會將一個一維陣列,轉換成 4x2 以及 2x4 的陣列。
import numpy as np
a = np.array([1,2,3,4,5,6,7,8])
b = a.reshape((4,2))
print(b)
'''
[[1 2]
[3 4]
[5 6]
[7 8]]
'''
c = a.reshape((2,4))
print(c)
'''
[[1 2 3 4]
[5 6 7 8]]
'''
如果轉換時不確定該維度要具有多少個項目,可使用「-1」代替,下方的例子會將一個 4x2 的二維陣列,變成一維陣列、2x4 陣列以及三維陣列。
import numpy as np
a = np.array([[1,2],[3,4],[5,6],[7,8]])
b = a.reshape(-1) # 轉換成一維陣列
print(b) # [1 2 3 4 5 6 7 8]
c = a.reshape((2,-1)) # 等同 a.reshape((2,4))
print(c)
'''
[[1 2 3 4]
[5 6 7 8]]
'''
d = a.reshape((2,2,-1)) # 等同 a.reshape((2,2,2))
print(d)
'''
[[[1 2]
[3 4]]
[[5 6]
[7 8]]]
'''
flatten()、numpy.ravel()
flatten() 和 numpy.ravel() 能將多維度的陣列,扁平化成一維陣列,可以設定 order 參數調整扁平的順序,預設為 C,表示先水平再垂直,設定為 F 表示先垂直再水平。
import numpy as np
a = np.array([[1,2],[3,4],[5,6],[7,8]])
b = a.flatten('C') # [1 2 3 4 5 6 7 8]
c = a.flatten('F') # [1 2 3 4 5 6 7 8]
b1 = np.ravel(a,'C') # [1 3 5 7 2 4 6 8]
c1 = np.ravel(a,'K') # [1 2 3 4 5 6 7 8]
除了 flatten 的方法,也可以直接使用 flat 轉換成一維陣列並取值。
import numpy as np
a = np.array([[1,2],[3,4],[5,6],[7,8]])
b = a.flat[3]
print(b) # 4
numpy.transpose()、T
詳細可參考:numpy.transpose()、T
numpy.transpose() 和 T 可以將陣列的行與列互換,產生新的陣列。
import numpy as np
a = np.array([[1,2],[3,4]])
b = a.T
c = np.transpose(a)
print(b)
'''
[[1 3]
[2 4]]
'''
print(c)
'''
[[1 3]
[2 4]]
'''
numpy.rollaxis()、numpy.swapaxes()
numpy.rollaxis() 和 numpy.swapaxes() 會根據指定「軸」,將陣列項目滾動改變或交換位置,下圖為一個陣列的「軸」的示意。
詳細可參考:numpy.rollaxis、numpy.swapaxes
numpy.rollaxis 共有三個參數 arr 表示陣列,axis 表示起始的軸,start 表示滾動的特定位置。
import numpy as np
a = np.array([[1,1,1],[2,2,2],[3,3,3]])
b = np.rollaxis(a,0,0)
c = np.rollaxis(a,1,0)
print(b)
'''
[[1 1 1]
[2 2 2]
[3 3 3]]
'''
print(c)
'''
[[1 2 3]
[1 2 3]
[1 2 3]]
'''
numpy.swapaxes 共有三個參數 arr 表示陣列,axis1 是對應第一個軸的整數,axis2 是對應第二個軸的整數。
import numpy as np
a = np.array([[1,1,1],[2,2,2],[3,3,3]])
b = np.swapaxes(a,0,0)
c = np.swapaxes(a,0,1)
print(b)
'''
[[1 1 1]
[2 2 2]
[3 3 3]]
'''
print(c)
'''
[[1 2 3]
[1 2 3]
[1 2 3]]
'''
意見回饋
如果有任何建議或問題,可傳送「意見表單」給我,謝謝~