ORB_SLAM3 軌跡データの取得とグラフ化
ORB_SLAM3で軌跡データの取得とグラフ化を行います。ORB_SLAM3をCtrl+Cで終了した際にterminalのカレントディレクトリに、KeyFrameTrajectory.txtというファイルが作成されます。このファイルがキーフレームの軌跡データになります。タイムスタンプ(赤)、X座標(青)、Y座標(黄)、Z座標(緑)、クゥオータニオン(灰)の順に記録されています。このデータをグラフ化していきます。まず、出力されたファイルをエクセルでも読み込めるCSV形式にしておきます。各データは半角空白で区切られているのでカンマに文字列置換します。ORB_SLAM終了後、同じウィンドウのTerminalで
gedit KeyFrameTrajectory.txt
を実行します。geditで「検索」>「置換」を選択します。置換する文字に半角空白「 」を入力し、置換後の文字にカンマ「,」を入力して全て置換します。さらにファイル名をKeyFrameTrajectory.csvに変更します。これでエクセルで読み込めるようになりました。次にこのデータを、pythonで三次元グラフ化していきます。エクセルでは三次元の散布図グラフは作れません。pythonで三次元のグラフを作成するためmatplotlibというpythonライブラリを利用します。matplotlibの導入方法については以下のページを参考にしました。
ai-gaminglife.hatenablog.com
pythonで書いたプログラム(orb_slam_3d_graph.py)は以下のようなものです。
csvを横1行ずつ読み込んで各データごとの配列を用意し、グラフにプロットします。自分の卒業研究についての話ですが、csvファイル内では複数のスマホの軌跡のデータも、一つの時間軸として統合されています。どこで2つ目の軌跡に切り替わったか分かりづらいですが、別々の軌跡としてグラフ化する必要があるため工夫が必要です。三次元の点を線でつないで軌跡のグラフ化が行われますが、プログラム中では時間の変化量が一秒以上あった場合には線で繋がないようにしています。pythonで実行する際には日本語のコメントを削除してから実行してください。
import csv import pprint import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import math plotData_x=[]; plotData_y=[]; plotData_z=[]; plotData_time=[]; fig = plt.figure() ax = Axes3D(fig) #Xの回転行列 def rotate_x(deg): r = np.radians(deg) R_x = np.array([[1, 0, 0], [0, np.cos(r), np.sin(r)], [0, -np.sin(r), np.cos(r)]]) return R_x #Yの回転行列 def rotate_y(deg): r = np.radians(deg) R_y = np.array([[np.cos(r), 0, -np.sin(r)], [0, 1, 0], [np.sin(r), 0, np.cos(r)]]) return R_y #Zの回転行列 def rotate_z(deg): r = np.radians(deg) R_z = np.array([[np.cos(r), np.sin(r), 0], [-np.sin(r), np.cos(r), 0], [0, 0, 1]]) return R_z def plot_Data(): global ax #表示するx,y,zの範囲 ax.set_xlim3d(0, 1400.0) ax.set_ylim3d(0, 1400.0) ax.set_zlim3d(0, 1400.0) #x,y,zのラベル ax.set_xlabel('x(m)') ax.set_ylabel('y(m)') ax.set_zlabel('z(m)') #時間の変化量を計算する準備 plotData_time.append(0.00) with open('KeyFrameTrajectory.csv') as f: reader = csv.reader(f) #csvファイルを横1行ずつ読み込むforループ # 読み込んだ横1行の各データ(timestamp,x,y,z,...)を配列(row)にする for row in reader: #row[0]にtimestamp、row[4~7]にクウォータニオンが入っている #配列(rawArray)にx,y,zの値を入れる際には文字列から数値化(float)する rawArray = np.array((float(row[1]),float(row[2]),float(row[3]))) #前の1行との時間の変化量を計算 timeDiff=np.abs(float(row[0])-plotData_time[-1]) #時間の変化量が一秒以上ある場合、もしくは二つ目の軌跡に切り替わった指定秒数の場合に実行 if timeDiff>1 or float(row[0])==78.786749: #配列に入っているデータをグラフ上に一度プロット ax.plot(plotData_x, plotData_y,plotData_z,color = "b",linestyle = "solid") #配列を一度空にする plotData_x.clear() plotData_y.clear() plotData_z.clear() #回転行列を計算 rotX = rotate_x(10) rotY = rotate_y(10) rotZ = rotate_z(10) rotatedArray = np.dot(rotX,rawArray) rotatedArray = np.dot(rotY,rotatedArray) rotatedArray = np.dot(rotZ,rotatedArray) #x,y,zの各値はスケールと平行移動を調整して配列に代入。 plotData_time.append(float(row[0])) plotData_x.append(rotatedArray[0]*630+30) plotData_y.append(rotatedArray[1]*630) plotData_z.append(rotatedArray[2]*630+200) #ループ終了後配列に残っている値をグラフにプロット ax.plot(plotData_x, plotData_y, plotData_z,color='r', linestyle = "solid") #m単位の目盛り文字列を指定。以下のコードを書かかなくてもcm単位の目盛りが表示される。 ax.set_xticklabels(["0", "2", "4", "6", "8", "10", "12", "14"]) ax.set_yticklabels(["0", "2", "4", "6", "8", "10", "12", "14"]) ax.set_zticklabels(["0", "2", "4", "6", "8", "10", "12", "14"]) #最初に呼び出される関数 def main(): plot_Data() #グラフの凡例を表示 plt.legend(loc="lower left", fontsize=15) plt.show() if __name__ == '__main__': main()
自分はORB_SLAMのデータ出力に関するプログラムを若干変更しています。
sh090.hatenablog.com
動画入力時のUnixTimeでなく、動画入力開始からの経過時間が記録されるようにしているため、上記のプログラムはそのままでは使えないかもです。if文の箇所だけ抜けばプログラムの変更をしてなくても使えると思います。
プログラムのファイル(orb_slam_3d_graph.py)とKeyFrameTrajectory.csvを同じディレクトリに置いてTerminalで以下を実行することでグラフ化できます。
python orb_slam_3d_graph.py
上記のコードだと一部の関数(list.clear)がpython3.3以降でしか実行できなくて別の環境だとエラーが発生したので、一部変更したものを載せておきます。(3d_graph_for_elapsedTime.py)
# coding: UTF-8 import csv import pprint import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import math plotData_x=[]; plotData_y=[]; plotData_z=[]; plotData_time=[]; fig = plt.figure() ax = Axes3D(fig) def rotate_x(deg): r = np.radians(deg) R_x = np.array([[1, 0, 0], [0, np.cos(r), np.sin(r)], [0, -np.sin(r), np.cos(r)]]) return R_x def rotate_y(deg): r = np.radians(deg) R_y = np.array([[np.cos(r), 0, -np.sin(r)], [0, 1, 0], [np.sin(r), 0, np.cos(r)]]) return R_y def rotate_z(deg): r = np.radians(deg) R_z = np.array([[np.cos(r), np.sin(r), 0], [-np.sin(r), np.cos(r), 0], [0, 0, 1]]) return R_z def plot_plarail_type1(): global ax x1 = np.linspace( -21.4, 21.4, 70) y1 = [21.4] * 70 rad = np.linspace( math.pi/2, math.pi*3/2, 40) x2 = -21.4*np.cos(rad)+21.4 y2 = 21.4*np.sin(rad) x3 = np.linspace( -21.4, 21.4, 70)*-1 y3 = [-21.4] * 70 rad = np.linspace( math.pi*3/2, math.pi*5/2, 40) x4 = -21.4*np.cos(rad)-21.4 y4 = 21.4*np.sin(rad) x=np.hstack([x1,x2,x3,x4]) y=np.hstack([y1,y2,y3,y4]) ax.plot(x, y, np.zeros(len(x)),color = "r", label = "Plarail") def plot_Data(): global ax global plotData_x global plotData_y global plotData_z ax.set_xlim3d(0, 1400.0) ax.set_ylim3d(0, 1400.0) ax.set_zlim3d(0, 1400.0) ax.set_xlabel('x(m)') ax.set_ylabel('y(m)') ax.set_zlabel('z(m)') plotData_time.append(0.00) with open('KeyFrameTrajectory.csv') as f: reader = csv.reader(f) for row in reader: if row[0] == "-----------------------": print("\nKeyFrameTrajectory.csvを編集してください。マップの区切り文字(-----------------------)が含まれています。\n") break rawArray = np.array((float(row[1]),float(row[2]),float(row[3]))) timeDiff=np.abs(float(row[0])-plotData_time[-1]) rotX = rotate_x(10) rotY = rotate_y(10) rotZ = rotate_z(10) rotatedArray = np.dot(rotX,rawArray) rotatedArray = np.dot(rotY,rotatedArray) rotatedArray = np.dot(rotZ,rotatedArray) if timeDiff>2 or float(row[0])==78.786749: ax.plot(plotData_x, plotData_y,plotData_z,color = "b",linestyle = "solid") plotData_x=[] plotData_y=[] plotData_z=[] plotData_time.append(float(row[0])) plotData_x.append(rotatedArray[0]*630+30) plotData_y.append(rotatedArray[1]*630) plotData_z.append(rotatedArray[2]*630+200) ax.plot(plotData_x, plotData_y, plotData_z,color='r', linestyle = "solid") ax.set_xticklabels(["0", "2", "4", "6", "8", "10", "12", "14"]) ax.set_yticklabels(["0", "2", "4", "6", "8", "10", "12", "14"]) ax.set_zticklabels(["0", "2", "4", "6", "8", "10", "12", "14"]) def main(): plot_Data() plt.legend(loc="lower left", fontsize=15) plt.show() if __name__ == '__main__': main()
あとORB_SLAMのプログラム変更前の出力データ(時間軸がUnixTime)に対応したソースも載せておきます。(3d_graph_for_unixTime.py)
import csv import pprint import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import math plotData_x=[]; plotData_y=[]; plotData_z=[]; plotData_time=[]; fig = plt.figure() ax = Axes3D(fig) def rotate_x(deg): r = np.radians(deg) R_x = np.array([[1, 0, 0], [0, np.cos(r), np.sin(r)], [0, -np.sin(r), np.cos(r)]]) return R_x def rotate_y(deg): r = np.radians(deg) R_y = np.array([[np.cos(r), 0, -np.sin(r)], [0, 1, 0], [np.sin(r), 0, np.cos(r)]]) return R_y def rotate_z(deg): r = np.radians(deg) R_z = np.array([[np.cos(r), np.sin(r), 0], [-np.sin(r), np.cos(r), 0], [0, 0, 1]]) return R_z def plot_Data(): global ax global plotData_x global plotData_y global plotData_z ax.set_xlim3d(0, 1200.0) ax.set_ylim3d(0, 1200.0) ax.set_zlim3d(0, 1200.0) ax.set_xlabel('x(m)') ax.set_ylabel('y(m)') ax.set_zlabel('z(m)') plotData_time.append(0.00) with open('KeyFrameTrajectory_unixTime.csv') as f: reader = csv.reader(f) for row in reader: rawArray = np.array((float(row[1]),float(row[2]),float(row[3]))) timeDiff=np.abs(float(row[0])-plotData_time[-1]) if timeDiff>5 and len(plotData_time)>1: print(len(plotData_x)) ax.plot(plotData_x, plotData_y,plotData_z,color = "b",linestyle = "solid") plotData_x=[] plotData_y=[] plotData_z=[] rotX = rotate_x(90) rotY = rotate_y(-3) rotZ = rotate_z(10) rotatedArray = np.dot(rotX,rawArray) rotatedArray = np.dot(rotY,rotatedArray) rotatedArray = np.dot(rotZ,rotatedArray) plotData_time.append(float(row[0])) plotData_x.append(rotatedArray[0]*500+30) plotData_y.append(rotatedArray[1]*700+600) plotData_z.append(rotatedArray[2]*700+200) ax.plot(plotData_x, plotData_y, plotData_z,color='r', linestyle = "solid") ax.set_xticklabels(["0", "2", "4", "6", "8", "10", "12"]) ax.set_yticklabels(["0", "2", "4", "6", "8", "10", "12"]) ax.set_zticklabels(["0", "2", "4", "6", "8", "10", "12"]) def main(): plot_Data() plt.legend(loc="lower left", fontsize=15) plt.show() if __name__ == '__main__': main()