1- import numpy as np
1+ import numpy as np
22import uuid
3+ import math
34from scipy .optimize import linear_sum_assignment
45
56from nav2_dynamic_msgs .msg import Obstacle , ObstacleArray
67from visualization_msgs .msg import Marker , MarkerArray
78
89import rclpy
10+ import copy
911from rclpy .node import Node
1012import colorsys
1113from kf_hungarian_tracker .obstacle_class import ObstacleClass
1618from tf2_geometry_msgs import do_transform_point , do_transform_vector3
1719from geometry_msgs .msg import PointStamped , Vector3Stamped
1820
21+
1922class KFHungarianTracker (Node ):
20- ''' Use Kalman Fiter and Hungarian algorithm to track multiple dynamic obstacles
23+ """ Use Kalman Fiter and Hungarian algorithm to track multiple dynamic obstacles
2124
2225 Use Hungarian algorithm to match presenting obstacles with new detection and maintain a kalman filter for each obstacle.
2326 spawn ObstacleClass when new obstacles come and delete when they disappear for certain number of frames
@@ -28,25 +31,28 @@ class KFHungarianTracker(Node):
2831 detection_sub: subscrib detection result from detection node
2932 tracker_obstacle_pub: publish tracking obstacles with ObstacleArray
3033 tracker_pose_pub: publish tracking obstacles with PoseArray, for rviz visualization
31- '''
34+ """
3235
3336 def __init__ (self ):
34- ''' initialize attributes and setup subscriber and publisher'''
37+ """ initialize attributes and setup subscriber and publisher"""
3538
36- super ().__init__ (' kf_hungarian_node' )
39+ super ().__init__ (" kf_hungarian_node" )
3740 self .declare_parameters (
38- namespace = '' ,
41+ namespace = "" ,
3942 parameters = [
40- ('global_frame' , "camera_link" ),
41- ('process_noise_cov' , [2. , 2. , 0.5 ]),
42- ('top_down' , False ),
43- ('death_threshold' , 3 ),
44- ('measurement_noise_cov' , [1. , 1. , 1. ]),
45- ('error_cov_post' , [1. , 1. , 1. , 10. , 10. , 10. ]),
46- ('vel_filter' , [0.1 , 2.0 ]),
47- ('height_filter' , [- 2.0 , 2.0 ]),
48- ('cost_filter' , 1.0 )
49- ])
43+ ("global_frame" , "camera_link" ),
44+ ("process_noise_cov" , [2.0 , 2.0 , 0.5 ]),
45+ ("top_down" , False ),
46+ ("death_threshold" , 3 ),
47+ ("measurement_noise_cov" , [1.0 , 1.0 , 1.0 ]),
48+ ("error_cov_post" , [1.0 , 1.0 , 1.0 , 10.0 , 10.0 , 10.0 ]),
49+ ("vel_filter" , [0.1 , 2.0 ]),
50+ ("height_filter" , [- 2.0 , 2.0 ]),
51+ ("cost_filter" , 1.0 ),
52+ ("transform_to_global_frame" , True ),
53+ ("infer_orientation_from_velocity" , True ),
54+ ],
55+ )
5056 self .global_frame = self .get_parameter ("global_frame" )._value
5157 self .death_threshold = self .get_parameter ("death_threshold" )._value
5258 self .measurement_noise_cov = self .get_parameter ("measurement_noise_cov" )._value
@@ -56,31 +62,39 @@ def __init__(self):
5662 self .height_filter = self .get_parameter ("height_filter" )._value
5763 self .top_down = self .get_parameter ("top_down" )._value
5864 self .cost_filter = self .get_parameter ("cost_filter" )._value
65+ self .transform_to_global_frame = self .get_parameter (
66+ "transform_to_global_frame"
67+ )._value
68+ self .infer_orientation_from_velocity = self .get_parameter (
69+ "infer_orientation_from_velocity"
70+ )._value
5971
6072 self .obstacle_list = []
6173 self .sec = 0
6274 self .nanosec = 0
6375
64- # subscribe to detector
76+ # subscribe to detector
6577 self .detection_sub = self .create_subscription (
66- ObstacleArray ,
67- 'detection' ,
68- self .callback ,
69- 10 )
78+ ObstacleArray , "detection" , self .callback , 10
79+ )
7080
7181 # publisher for tracking result
72- self .tracker_obstacle_pub = self .create_publisher (ObstacleArray , 'tracking' , 10 )
73- self .tracker_marker_pub = self .create_publisher (MarkerArray , 'tracking_marker' , 10 )
82+ self .tracker_obstacle_pub = self .create_publisher (ObstacleArray , "tracking" , 10 )
83+ self .tracker_marker_pub = self .create_publisher (
84+ MarkerArray , "tracking_marker" , 10
85+ )
7486
7587 # setup tf related
7688 self .tf_buffer = Buffer ()
7789 self .tf_listener = TransformListener (self .tf_buffer , self )
7890
7991 def callback (self , msg ):
80- ''' callback function for detection result'''
92+ """ callback function for detection result"""
8193
8294 # update delta time
83- dt = (msg .header .stamp .sec - self .sec ) + (msg .header .stamp .nanosec - self .nanosec ) / 1e9
95+ dt = (msg .header .stamp .sec - self .sec ) + (
96+ msg .header .stamp .nanosec - self .nanosec
97+ ) / 1e9
8498 self .sec = msg .header .stamp .sec
8599 self .nanosec = msg .header .stamp .nanosec
86100
@@ -93,10 +107,11 @@ def callback(self, msg):
93107 for obj in self .obstacle_list :
94108 obj .predict (dt )
95109
96- # transform to global frame
97- if self .global_frame is not None :
110+ if (self .transform_to_global_frame ) and (self .global_frame is not None ):
98111 try :
99- trans = self .tf_buffer .lookup_transform (self .global_frame , msg .header .frame_id , rclpy .time .Time ())
112+ trans = self .tf_buffer .lookup_transform (
113+ self .global_frame , msg .header .frame_id , rclpy .time .Time ()
114+ )
100115 msg .header .frame_id = self .global_frame
101116 # do_transform_vector3(vector, trans) resets trans.transform.translation
102117 # values to 0.0, so we need to preserve them for future usage in the loop below
@@ -122,8 +137,10 @@ def callback(self, msg):
122137
123138 except TransformException as ex :
124139 self .get_logger ().error (
125- 'fail to get tf from {} to {}: {}' .format (
126- msg .header .frame_id , self .global_frame , ex ))
140+ "fail to get tf from {} to {}: {}" .format (
141+ msg .header .frame_id , self .global_frame , ex
142+ )
143+ )
127144 return
128145
129146 # hungarian matching
@@ -154,19 +171,27 @@ def callback(self, msg):
154171 # apply velocity and height filter
155172 filtered_obstacle_list = []
156173 for obs in self .obstacle_list :
157- obs_vel = np .linalg .norm (np .array ([obs .msg .velocity .x , obs .msg .velocity .y , obs .msg .velocity .z ]))
174+ obs_vel = np .linalg .norm (
175+ np .array ([obs .msg .velocity .x , obs .msg .velocity .y , obs .msg .velocity .z ])
176+ )
158177 obs_height = obs .msg .position .z
159- if obs_vel > self .vel_filter [0 ] and obs_vel < self .vel_filter [1 ] and obs_height > self .height_filter [0 ] and obs_height < self .height_filter [1 ]:
178+ if (
179+ obs_vel > self .vel_filter [0 ]
180+ and obs_vel < self .vel_filter [1 ]
181+ and obs_height > self .height_filter [0 ]
182+ and obs_height < self .height_filter [1 ]
183+ ):
160184 filtered_obstacle_list .append (obs )
161185
162186 # construct ObstacleArray
163187 if self .tracker_obstacle_pub .get_subscription_count () > 0 :
164188 obstacle_array = ObstacleArray ()
165189 obstacle_array .header = msg .header
166190 track_list = []
191+
167192 for obs in filtered_obstacle_list :
168- # do not publish obstacles with low speed
169193 track_list .append (obs .msg )
194+
170195 obstacle_array .obstacles = track_list
171196 self .tracker_obstacle_pub .publish (obstacle_array )
172197
@@ -177,49 +202,60 @@ def callback(self, msg):
177202 # add current active obstacles
178203 for obs in filtered_obstacle_list :
179204 obstacle_uuid = uuid .UUID (bytes = bytes (obs .msg .uuid .uuid ))
180- (r , g , b ) = colorsys .hsv_to_rgb (obstacle_uuid .int % 360 / 360. , 1. , 1. ) # encode id with rgb color
181- # make a cube
205+ (r , g , b ) = colorsys .hsv_to_rgb (
206+ obstacle_uuid .int % 360 / 360.0 , 1.0 , 1.0
207+ ) # encode id with rgb color
208+
209+ # make a cube
182210 marker = Marker ()
183211 marker .header = msg .header
184212 marker .ns = str (obstacle_uuid )
185213 marker .id = 0
186- marker .type = 1 # CUBE
214+ marker .type = 1 # CUBE
187215 marker .action = 0
188216 marker .color .a = 0.5
189217 marker .color .r = r
190218 marker .color .g = g
191219 marker .color .b = b
192220 marker .pose .position = obs .msg .position
193221 angle = np .arctan2 (obs .msg .velocity .y , obs .msg .velocity .x )
194- marker .pose .orientation .z = np .float (np .sin (angle / 2 ))
195- marker .pose .orientation .w = np .float (np .cos (angle / 2 ))
222+ if self .infer_orientation_from_velocity :
223+ marker .pose .orientation .z = float (np .sin (angle / 2 ))
224+ marker .pose .orientation .w = float (np .cos (angle / 2 ))
225+ else :
226+ marker .pose .orientation .z = 0.0
227+ marker .pose .orientation .w = 1.0
228+
196229 marker .scale = obs .msg .size
197230 marker_list .append (marker )
198231 # make an arrow
199232 arrow = Marker ()
200233 arrow .header = msg .header
201234 arrow .ns = str (obstacle_uuid )
202- arrow .id = 1
235+ arrow .id = 1
203236 arrow .type = 0
204237 arrow .action = 0
205238 arrow .color .a = 1.0
206239 arrow .color .r = r
207240 arrow .color .g = g
208241 arrow .color .b = b
209242 arrow .pose .position = obs .msg .position
210- arrow .pose .orientation .z = np .float (np .sin (angle / 2 ))
211- arrow .pose .orientation .w = np .float (np .cos (angle / 2 ))
212- arrow .scale .x = np .linalg .norm ([obs .msg .velocity .x , obs .msg .velocity .y , obs .msg .velocity .z ])
243+ arrow .pose .orientation .z = float (np .sin (angle / 2 ))
244+ arrow .pose .orientation .w = float (np .cos (angle / 2 ))
245+ arrow .scale .x = np .linalg .norm (
246+ [obs .msg .velocity .x , obs .msg .velocity .y , obs .msg .velocity .z ]
247+ )
213248 arrow .scale .y = 0.05
214249 arrow .scale .z = 0.05
215250 marker_list .append (arrow )
251+
216252 # add dead obstacles to delete in rviz
217253 for dead_uuid in dead_object_list :
218254 marker = Marker ()
219255 marker .header = msg .header
220256 marker .ns = str (dead_uuid )
221257 marker .id = 0
222- marker .action = 2 # delete
258+ marker .action = 2 # delete
223259 arrow = Marker ()
224260 arrow .header = msg .header
225261 arrow .ns = str (dead_uuid )
@@ -231,14 +267,20 @@ def callback(self, msg):
231267 self .tracker_marker_pub .publish (marker_array )
232268
233269 def birth (self , det_ind , num_of_detect , detections ):
234- ''' generate new ObstacleClass for detections that do not match any in current obstacle list'''
270+ """ generate new ObstacleClass for detections that do not match any in current obstacle list"""
235271 for det in range (num_of_detect ):
236272 if det not in det_ind :
237- obstacle = ObstacleClass (detections [det ], self .top_down , self .measurement_noise_cov , self .error_cov_post , self .process_noise_cov )
273+ obstacle = ObstacleClass (
274+ detections [det ],
275+ self .top_down ,
276+ self .measurement_noise_cov ,
277+ self .error_cov_post ,
278+ self .process_noise_cov ,
279+ )
238280 self .obstacle_list .append (obstacle )
239281
240282 def death (self , obj_ind , num_of_obstacle ):
241- ''' count obstacles' missing frames and delete when reach threshold'''
283+ """ count obstacles' missing frames and delete when reach threshold"""
242284 new_object_list = []
243285 dead_object_list = []
244286 # for previous obstacles
@@ -251,16 +293,19 @@ def death(self, obj_ind, num_of_obstacle):
251293 if self .obstacle_list [obs ].dying < self .death_threshold :
252294 new_object_list .append (self .obstacle_list [obs ])
253295 else :
254- obstacle_uuid = uuid .UUID (bytes = bytes (self .obstacle_list [obs ].msg .uuid .uuid ))
296+ obstacle_uuid = uuid .UUID (
297+ bytes = bytes (self .obstacle_list [obs ].msg .uuid .uuid )
298+ )
255299 dead_object_list .append (obstacle_uuid )
256-
300+
257301 # add newly born obstacles
258302 for obs in range (num_of_obstacle , len (self .obstacle_list )):
259303 new_object_list .append (self .obstacle_list [obs ])
260304
261305 self .obstacle_list = new_object_list
262306 return dead_object_list
263307
308+
264309def main (args = None ):
265310 rclpy .init (args = args )
266311
@@ -271,5 +316,6 @@ def main(args=None):
271316
272317 rclpy .shutdown ()
273318
319+
274320if __name__ == "__main__" :
275321 main ()
0 commit comments