3Dグラフをmatplotlibのmplot3dで作成する方法

今回はmatplotlibで3Dグラフを作成する方法を紹介していきます。matplotlibでは3Dグラフを生成することもできます。mplot3dが用意されており、3Dの散布図や等高線プロットなどを描画できます。

公式のチュートリアルがよくまとまっているので、詳細はこちらでご覧ください。本記事も公式を参考に進めていきます。

参考 mplot3d tutorialMatplotlib

mplot3dの下準備

まずはインポート部分から見ていきましょう。一般的なmatplotlibのインポート(import matplotlib.pyplot as plt)に加えて、Axes3Dをインポートします。

空間だけ描画すると以下のようになります。ここにこれから描画していきましょう!

Python
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
注意
Axes3Dを使用していないように見えますが、ここでインポートされていないとKeyError: ‘3D’が出力されるので、要注意です。

Line plotsを描画してみる

先ほどの空間にline plotsを描画してみると、以下のようになります。parametric curveというみたいですね。

Python
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib as mpl

mpl.rcParams['legend.fontsize'] = 10

fig = plt.figure()
ax = fig.gca(projection='3d')
theta = np.linspace(-4 * np.pi, 4 * np.pi, 100)
z = np.linspace(-2, 2, 100)
r = z**2 + 1
x = r * np.sin(theta)
y = r * np.cos(theta)
ax.plot(x, y, z, label='parametric curve')
ax.legend()

plt.show()

Scatter plotsを描画してみる

散布図を3D上に表示することもできます。

Python
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def randrange(n, vmin, vmax):
return (vmax - vmin)*np.random.rand(n) + vmin

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

n = 100
for c, m, zlow, zhigh in [('r', 'o', -50, -25), ('b', '^', -30, -5)]:
xs = randrange(n, 23, 32)
ys = randrange(n, 0, 100)
zs = randrange(n, zlow, zhigh)
ax.scatter(xs, ys, zs, c=c, marker=m)

ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')

plt.show()

Surface plotsを描画してみる

surface plotsは以下のようなグラフになります。

Python
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter

fig = plt.figure()
ax = fig.gca(projection='3d')

X = np.arange(-5, 5, 0.25)
Y = np.arange(-5, 5, 0.25)
X, Y = np.meshgrid(X, Y)
R = np.sqrt(X**2 + Y**2)
Z = np.sin(R)

surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,
                       linewidth=0, antialiased=False)

ax.set_zlim(-1.01, 1.01)
ax.zaxis.set_major_locator(LinearLocator(10))
ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))

fig.colorbar(surf, shrink=0.5, aspect=5)

plt.show()

まとめ

様々な3Dプロットを見てきましたが、他にもたくさんの3Dグラフを描画できます。公式のページにはサンプルコード付きですべて載っているので、ぜひ見てみてください!

参考 mplot3d tutorialMatplotlib

また、Numpyの知識も必要になると思うので、関連記事をこちらに載せておきます。ぜひチェックしてみてください!

numpyのrandomで生成できる乱数を総まとめ