Skip to content

Commit e1b0d92

Browse files
Merge pull request #33 from alexanderjyuen/upstream_fork
Added Position and Velocity Covariances, Added option to not transform obstacles into global frame
2 parents 6eae17e + 3299d8e commit e1b0d92

File tree

7 files changed

+194
-97
lines changed

7 files changed

+194
-97
lines changed

kf_hungarian_tracker/kf_hungarian_tracker/kf_hungarian_node.py

Lines changed: 94 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
import numpy as np
1+
import numpy as np
22
import uuid
3+
import math
34
from scipy.optimize import linear_sum_assignment
45

56
from nav2_dynamic_msgs.msg import Obstacle, ObstacleArray
67
from visualization_msgs.msg import Marker, MarkerArray
78

89
import rclpy
10+
import copy
911
from rclpy.node import Node
1012
import colorsys
1113
from kf_hungarian_tracker.obstacle_class import ObstacleClass
@@ -16,8 +18,9 @@
1618
from tf2_geometry_msgs import do_transform_point, do_transform_vector3
1719
from geometry_msgs.msg import PointStamped, Vector3Stamped
1820

21+
1922
class KFHungarianTracker(Node):
20-
'''Use Kalman Fiter and Hungarian algorithm to track multiple dynamic obstacles
23+
"""Use Kalman Fiter and Hungarian algorithm to track multiple dynamic obstacles
2124
2225
Use Hungarian algorithm to match presenting obstacles with new detection and maintain a kalman filter for each obstacle.
2326
spawn ObstacleClass when new obstacles come and delete when they disappear for certain number of frames
@@ -28,25 +31,28 @@ class KFHungarianTracker(Node):
2831
detection_sub: subscrib detection result from detection node
2932
tracker_obstacle_pub: publish tracking obstacles with ObstacleArray
3033
tracker_pose_pub: publish tracking obstacles with PoseArray, for rviz visualization
31-
'''
34+
"""
3235

3336
def __init__(self):
34-
'''initialize attributes and setup subscriber and publisher'''
37+
"""initialize attributes and setup subscriber and publisher"""
3538

36-
super().__init__('kf_hungarian_node')
39+
super().__init__("kf_hungarian_node")
3740
self.declare_parameters(
38-
namespace='',
41+
namespace="",
3942
parameters=[
40-
('global_frame', "camera_link"),
41-
('process_noise_cov', [2., 2., 0.5]),
42-
('top_down', False),
43-
('death_threshold', 3),
44-
('measurement_noise_cov', [1., 1., 1.]),
45-
('error_cov_post', [1., 1., 1., 10., 10., 10.]),
46-
('vel_filter', [0.1, 2.0]),
47-
('height_filter', [-2.0, 2.0]),
48-
('cost_filter', 1.0)
49-
])
43+
("global_frame", "camera_link"),
44+
("process_noise_cov", [2.0, 2.0, 0.5]),
45+
("top_down", False),
46+
("death_threshold", 3),
47+
("measurement_noise_cov", [1.0, 1.0, 1.0]),
48+
("error_cov_post", [1.0, 1.0, 1.0, 10.0, 10.0, 10.0]),
49+
("vel_filter", [0.1, 2.0]),
50+
("height_filter", [-2.0, 2.0]),
51+
("cost_filter", 1.0),
52+
("transform_to_global_frame", True),
53+
("infer_orientation_from_velocity", True),
54+
],
55+
)
5056
self.global_frame = self.get_parameter("global_frame")._value
5157
self.death_threshold = self.get_parameter("death_threshold")._value
5258
self.measurement_noise_cov = self.get_parameter("measurement_noise_cov")._value
@@ -56,31 +62,39 @@ def __init__(self):
5662
self.height_filter = self.get_parameter("height_filter")._value
5763
self.top_down = self.get_parameter("top_down")._value
5864
self.cost_filter = self.get_parameter("cost_filter")._value
65+
self.transform_to_global_frame = self.get_parameter(
66+
"transform_to_global_frame"
67+
)._value
68+
self.infer_orientation_from_velocity = self.get_parameter(
69+
"infer_orientation_from_velocity"
70+
)._value
5971

6072
self.obstacle_list = []
6173
self.sec = 0
6274
self.nanosec = 0
6375

64-
# subscribe to detector
76+
# subscribe to detector
6577
self.detection_sub = self.create_subscription(
66-
ObstacleArray,
67-
'detection',
68-
self.callback,
69-
10)
78+
ObstacleArray, "detection", self.callback, 10
79+
)
7080

7181
# publisher for tracking result
72-
self.tracker_obstacle_pub = self.create_publisher(ObstacleArray, 'tracking', 10)
73-
self.tracker_marker_pub = self.create_publisher(MarkerArray, 'tracking_marker', 10)
82+
self.tracker_obstacle_pub = self.create_publisher(ObstacleArray, "tracking", 10)
83+
self.tracker_marker_pub = self.create_publisher(
84+
MarkerArray, "tracking_marker", 10
85+
)
7486

7587
# setup tf related
7688
self.tf_buffer = Buffer()
7789
self.tf_listener = TransformListener(self.tf_buffer, self)
7890

7991
def callback(self, msg):
80-
'''callback function for detection result'''
92+
"""callback function for detection result"""
8193

8294
# update delta time
83-
dt = (msg.header.stamp.sec - self.sec) + (msg.header.stamp.nanosec - self.nanosec) / 1e9
95+
dt = (msg.header.stamp.sec - self.sec) + (
96+
msg.header.stamp.nanosec - self.nanosec
97+
) / 1e9
8498
self.sec = msg.header.stamp.sec
8599
self.nanosec = msg.header.stamp.nanosec
86100

@@ -93,10 +107,11 @@ def callback(self, msg):
93107
for obj in self.obstacle_list:
94108
obj.predict(dt)
95109

96-
# transform to global frame
97-
if self.global_frame is not None:
110+
if (self.transform_to_global_frame) and (self.global_frame is not None):
98111
try:
99-
trans = self.tf_buffer.lookup_transform(self.global_frame, msg.header.frame_id, rclpy.time.Time())
112+
trans = self.tf_buffer.lookup_transform(
113+
self.global_frame, msg.header.frame_id, rclpy.time.Time()
114+
)
100115
msg.header.frame_id = self.global_frame
101116
# do_transform_vector3(vector, trans) resets trans.transform.translation
102117
# values to 0.0, so we need to preserve them for future usage in the loop below
@@ -122,8 +137,10 @@ def callback(self, msg):
122137

123138
except TransformException as ex:
124139
self.get_logger().error(
125-
'fail to get tf from {} to {}: {}'.format(
126-
msg.header.frame_id, self.global_frame, ex))
140+
"fail to get tf from {} to {}: {}".format(
141+
msg.header.frame_id, self.global_frame, ex
142+
)
143+
)
127144
return
128145

129146
# hungarian matching
@@ -154,19 +171,27 @@ def callback(self, msg):
154171
# apply velocity and height filter
155172
filtered_obstacle_list = []
156173
for obs in self.obstacle_list:
157-
obs_vel = np.linalg.norm(np.array([obs.msg.velocity.x, obs.msg.velocity.y, obs.msg.velocity.z]))
174+
obs_vel = np.linalg.norm(
175+
np.array([obs.msg.velocity.x, obs.msg.velocity.y, obs.msg.velocity.z])
176+
)
158177
obs_height = obs.msg.position.z
159-
if obs_vel > self.vel_filter[0] and obs_vel < self.vel_filter[1] and obs_height > self.height_filter[0] and obs_height < self.height_filter[1]:
178+
if (
179+
obs_vel > self.vel_filter[0]
180+
and obs_vel < self.vel_filter[1]
181+
and obs_height > self.height_filter[0]
182+
and obs_height < self.height_filter[1]
183+
):
160184
filtered_obstacle_list.append(obs)
161185

162186
# construct ObstacleArray
163187
if self.tracker_obstacle_pub.get_subscription_count() > 0:
164188
obstacle_array = ObstacleArray()
165189
obstacle_array.header = msg.header
166190
track_list = []
191+
167192
for obs in filtered_obstacle_list:
168-
# do not publish obstacles with low speed
169193
track_list.append(obs.msg)
194+
170195
obstacle_array.obstacles = track_list
171196
self.tracker_obstacle_pub.publish(obstacle_array)
172197

@@ -177,49 +202,60 @@ def callback(self, msg):
177202
# add current active obstacles
178203
for obs in filtered_obstacle_list:
179204
obstacle_uuid = uuid.UUID(bytes=bytes(obs.msg.uuid.uuid))
180-
(r, g, b) = colorsys.hsv_to_rgb(obstacle_uuid.int % 360 / 360., 1., 1.) # encode id with rgb color
181-
# make a cube
205+
(r, g, b) = colorsys.hsv_to_rgb(
206+
obstacle_uuid.int % 360 / 360.0, 1.0, 1.0
207+
) # encode id with rgb color
208+
209+
# make a cube
182210
marker = Marker()
183211
marker.header = msg.header
184212
marker.ns = str(obstacle_uuid)
185213
marker.id = 0
186-
marker.type = 1 # CUBE
214+
marker.type = 1 # CUBE
187215
marker.action = 0
188216
marker.color.a = 0.5
189217
marker.color.r = r
190218
marker.color.g = g
191219
marker.color.b = b
192220
marker.pose.position = obs.msg.position
193221
angle = np.arctan2(obs.msg.velocity.y, obs.msg.velocity.x)
194-
marker.pose.orientation.z = np.float(np.sin(angle / 2))
195-
marker.pose.orientation.w = np.float(np.cos(angle / 2))
222+
if self.infer_orientation_from_velocity:
223+
marker.pose.orientation.z = float(np.sin(angle / 2))
224+
marker.pose.orientation.w = float(np.cos(angle / 2))
225+
else:
226+
marker.pose.orientation.z = 0.0
227+
marker.pose.orientation.w = 1.0
228+
196229
marker.scale = obs.msg.size
197230
marker_list.append(marker)
198231
# make an arrow
199232
arrow = Marker()
200233
arrow.header = msg.header
201234
arrow.ns = str(obstacle_uuid)
202-
arrow.id = 1
235+
arrow.id = 1
203236
arrow.type = 0
204237
arrow.action = 0
205238
arrow.color.a = 1.0
206239
arrow.color.r = r
207240
arrow.color.g = g
208241
arrow.color.b = b
209242
arrow.pose.position = obs.msg.position
210-
arrow.pose.orientation.z = np.float(np.sin(angle / 2))
211-
arrow.pose.orientation.w = np.float(np.cos(angle / 2))
212-
arrow.scale.x = np.linalg.norm([obs.msg.velocity.x, obs.msg.velocity.y, obs.msg.velocity.z])
243+
arrow.pose.orientation.z = float(np.sin(angle / 2))
244+
arrow.pose.orientation.w = float(np.cos(angle / 2))
245+
arrow.scale.x = np.linalg.norm(
246+
[obs.msg.velocity.x, obs.msg.velocity.y, obs.msg.velocity.z]
247+
)
213248
arrow.scale.y = 0.05
214249
arrow.scale.z = 0.05
215250
marker_list.append(arrow)
251+
216252
# add dead obstacles to delete in rviz
217253
for dead_uuid in dead_object_list:
218254
marker = Marker()
219255
marker.header = msg.header
220256
marker.ns = str(dead_uuid)
221257
marker.id = 0
222-
marker.action = 2 # delete
258+
marker.action = 2 # delete
223259
arrow = Marker()
224260
arrow.header = msg.header
225261
arrow.ns = str(dead_uuid)
@@ -231,14 +267,20 @@ def callback(self, msg):
231267
self.tracker_marker_pub.publish(marker_array)
232268

233269
def birth(self, det_ind, num_of_detect, detections):
234-
'''generate new ObstacleClass for detections that do not match any in current obstacle list'''
270+
"""generate new ObstacleClass for detections that do not match any in current obstacle list"""
235271
for det in range(num_of_detect):
236272
if det not in det_ind:
237-
obstacle = ObstacleClass(detections[det], self.top_down, self.measurement_noise_cov, self.error_cov_post, self.process_noise_cov)
273+
obstacle = ObstacleClass(
274+
detections[det],
275+
self.top_down,
276+
self.measurement_noise_cov,
277+
self.error_cov_post,
278+
self.process_noise_cov,
279+
)
238280
self.obstacle_list.append(obstacle)
239281

240282
def death(self, obj_ind, num_of_obstacle):
241-
'''count obstacles' missing frames and delete when reach threshold'''
283+
"""count obstacles' missing frames and delete when reach threshold"""
242284
new_object_list = []
243285
dead_object_list = []
244286
# for previous obstacles
@@ -251,16 +293,19 @@ def death(self, obj_ind, num_of_obstacle):
251293
if self.obstacle_list[obs].dying < self.death_threshold:
252294
new_object_list.append(self.obstacle_list[obs])
253295
else:
254-
obstacle_uuid = uuid.UUID(bytes=bytes(self.obstacle_list[obs].msg.uuid.uuid))
296+
obstacle_uuid = uuid.UUID(
297+
bytes=bytes(self.obstacle_list[obs].msg.uuid.uuid)
298+
)
255299
dead_object_list.append(obstacle_uuid)
256-
300+
257301
# add newly born obstacles
258302
for obs in range(num_of_obstacle, len(self.obstacle_list)):
259303
new_object_list.append(self.obstacle_list[obs])
260304

261305
self.obstacle_list = new_object_list
262306
return dead_object_list
263307

308+
264309
def main(args=None):
265310
rclpy.init(args=args)
266311

@@ -271,5 +316,6 @@ def main(args=None):
271316

272317
rclpy.shutdown()
273318

319+
274320
if __name__ == "__main__":
275321
main()

0 commit comments

Comments
 (0)