Skip to content

Commit 141cefe

Browse files
bartokgcopybara-github
authored andcommitted
Make average return metric work even if the reward is a vector.
In that case the return is the sum of the reward vector. PiperOrigin-RevId: 471044125 Change-Id: Ifdb1d4e477fe9bbdc5c2946aaeaa78d7d302b154
1 parent 6c12f7d commit 141cefe

File tree

3 files changed

+37
-3
lines changed

3 files changed

+37
-3
lines changed

tf_agents/bandits/agents/examples/v2/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def baseline_reward_fn(observation, per_action_reward_fns):
161161
metrics += [tf_metrics.AverageReturnMultiMetric(
162162
reward_spec=environment.reward_spec(),
163163
batch_size=environment.batch_size)]
164-
else:
164+
if not isinstance(environment.reward_spec(), dict):
165165
metrics += [
166166
tf_metrics.AverageReturnMetric(batch_size=environment.batch_size)]
167167

tf_agents/metrics/tf_metrics.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,11 @@ def call(self, trajectory):
177177
tf.where(trajectory.is_first(), tf.zeros_like(self._return_accumulator),
178178
self._return_accumulator))
179179

180-
# Update accumulator with received rewards.
181-
self._return_accumulator.assign_add(trajectory.reward)
180+
# Update accumulator with received rewards. We are summing over all
181+
# non-batch dimensions in case the reward is a vector.
182+
self._return_accumulator.assign_add(
183+
tf.reduce_sum(
184+
trajectory.reward, axis=range(1, len(trajectory.reward.shape))))
182185

183186
# Add final returns to buffer.
184187
last_episode_indices = tf.squeeze(tf.where(trajectory.is_last()), axis=-1)

tf_agents/metrics/tf_metrics_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,37 @@ def testChosenActionHistogram(self, run_mode):
328328
self.evaluate(metric.reset())
329329
self.assertEmpty(self.evaluate(metric.result()))
330330

331+
@parameterized.named_parameters([
332+
('testAverageReturnMetricVectorGraph', context.graph_mode, 6,
333+
tensor_spec.TensorSpec((2,), tf.float32, 'r'), 18.0),
334+
('testAverageReturnMetricVectorEager', context.eager_mode, 6,
335+
tensor_spec.TensorSpec((5,), tf.float32, 'r'), 45.0),])
336+
def testAverageReturnMetricVector(self, run_mode, num_trajectories,
337+
reward_spec, expected_result):
338+
with run_mode():
339+
trajectories = self._create_trajectories()
340+
multi_trajectories = []
341+
for traj in trajectories:
342+
new_reward = tf.stack([traj.reward] * reward_spec.shape.as_list()[0],
343+
axis=1)
344+
new_traj = trajectory.Trajectory(
345+
step_type=traj.step_type,
346+
observation=traj.observation,
347+
action=traj.action,
348+
policy_info=traj.policy_info,
349+
next_step_type=traj.next_step_type,
350+
reward=new_reward,
351+
discount=traj.discount)
352+
multi_trajectories.append(new_traj)
353+
354+
metric = tf_metrics.AverageReturnMetric(batch_size=2)
355+
self.evaluate(tf.compat.v1.global_variables_initializer())
356+
self.evaluate(metric.init_variables())
357+
for i in range(num_trajectories):
358+
self.evaluate(metric(multi_trajectories[i]))
359+
360+
self.assertAllClose(expected_result, self.evaluate(metric.result()))
361+
331362
@parameterized.named_parameters([
332363
('testAverageReturnMultiMetricGraph', context.graph_mode, 6,
333364
tensor_spec.TensorSpec((2,), tf.float32, 'r'), [9.0, 9.0]),

0 commit comments

Comments
 (0)