Skip to content

Commit 6c12f7d

Browse files
TF-Agents Teamcopybara-github
authored andcommitted
Fixes LOG_PROBABILITY action info in RewardPredictionBasePolicy
The `RewardPredictionBasePolicy` overwrites the `_distribution` method by returning a deterministic distribution of sampled actions, and populating the action probabilities in the returned policy info. However, it sets the `emit_log_probability` parameter in the base TFPolicy class initializer, which causes the base `TFPolicy.action` method to ignore the log probabilities returned by sub-classes, and instead invoke the distribution object returned by sub-classes to obtain the log probabilities. But because `RewardPredictionBasePolicy` returns a deterministic distribution, the log probabilities are always 0. To fix this, we change `RewardPredictionBasePolicy._distribution` to return the actual distribution instead of a deterministic one based on sampled actions. Because it returns the actual distribution, it can no longer populate `chosen_arm_features` because the actual action sampling has not happened yet. We therefore also overwrite the `_action` method in `RewardPredictionBasePolicy` to populate `chosen_arm_features` after action sampling. In addition, we added a few tests for `BoltzmanRewardPredictionPolicy` to verify the expected behavior around its parameters. PiperOrigin-RevId: 470985015 Change-Id: I0bce336d9ab4b60811401129da87e0d84866d9ac
1 parent aed7287 commit 6c12f7d

File tree

4 files changed

+244
-56
lines changed

4 files changed

+244
-56
lines changed

tf_agents/bandits/policies/boltzmann_reward_prediction_policy.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
from tf_agents.policies import utils as policy_utilities
3131
from tf_agents.typing import types
3232

33+
# The temperature parameter is internally lower-bounded at this value to avoid
34+
# numerical issues.
35+
_MIN_TEMPERATURE = 1e-12
36+
3337

3438
class BoltzmannRewardPredictionPolicy(
3539
reward_prediction_base_policy.RewardPredictionBasePolicy):
@@ -45,7 +49,7 @@ def __init__(
4549
observation_and_action_constraint_splitter: Optional[
4650
types.Splitter] = None,
4751
accepts_per_arm_features: bool = False,
48-
constraints: Iterable[constr.NeuralConstraint] = (),
52+
constraints: Iterable[constr.BaseConstraint] = (),
4953
emit_policy_info: Tuple[Text, ...] = (),
5054
num_samples_list: Sequence[tf.Variable] = (),
5155
name: Optional[Text] = None):
@@ -77,7 +81,7 @@ def __init__(
7781
accepts_per_arm_features: (bool) Whether the policy accepts per-arm
7882
features.
7983
constraints: iterable of constraints objects that are instances of
80-
`tf_agents.bandits.agents.NeuralConstraint`.
84+
`tf_agents.bandits.agents.BaseConstraint`.
8185
emit_policy_info: (tuple of strings) what side information we want to get
8286
as part of the policy info. Allowed values can be found in
8387
`policy_utilities.PolicyInfo`.
@@ -119,11 +123,12 @@ def __init__(
119123
self._expected_num_actions)
120124

121125
def _get_temperature_value(self):
122-
if callable(self._temperature):
123-
return self._temperature()
124-
return self._temperature
126+
return tf.math.maximum(
127+
_MIN_TEMPERATURE,
128+
self._temperature()
129+
if callable(self._temperature) else self._temperature)
125130

126-
def _sample_action(self, mask, predicted_rewards):
131+
def _action_distribution(self, mask, predicted_rewards):
127132
batch_size = tf.shape(predicted_rewards)[0]
128133
if self._boltzmann_gumbel_exploration_constant is not None:
129134
logits = predicted_rewards
@@ -146,9 +151,11 @@ def _sample_action(self, mask, predicted_rewards):
146151
final_logits = logits + exploration_weights * gumbel_samples
147152
actions = tf.cast(
148153
tf.math.argmax(final_logits, axis=1), self._action_spec.dtype)
149-
# Log probability is not available in closed form. We treat this as a
150-
# deterministic policy at the moment.
151-
log_probability = tf.zeros([batch_size], tf.float32)
154+
# To conform with the return type, we construct a deterministic
155+
# distribution here. Note that this results in the log_probability of
156+
# the chosen arm being 0. The true sampling probability here has no simple
157+
# closed-form.
158+
distribution = tfp.distributions.Deterministic(loc=actions)
152159
else:
153160
# Apply the temperature scaling, needed for Boltzmann exploration.
154161
logits = predicted_rewards / self._get_temperature_value()
@@ -170,9 +177,6 @@ def _sample_action(self, mask, predicted_rewards):
170177
logits=logits,
171178
dtype=self._action_spec.dtype)
172179

173-
actions = distribution.sample()
174-
log_probability = distribution.log_prob(actions)
175-
176180
bandit_policy_values = tf.fill([batch_size, 1],
177181
policy_utilities.BanditPolicyType.BOLTZMANN)
178-
return actions, log_probability, bandit_policy_values
182+
return distribution, bandit_policy_values

tf_agents/bandits/policies/boltzmann_reward_prediction_policy_test.py

Lines changed: 133 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
# limitations under the License.
1515

1616
"""Test for boltzmann_reward_prediction_policy."""
17-
from __future__ import absolute_import
18-
from __future__ import division
19-
from __future__ import print_function
17+
18+
import numpy as np
2019

2120
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
2221
from tf_agents.bandits.policies import boltzmann_reward_prediction_policy as boltzmann_reward_policy
2322
from tf_agents.networks import network
23+
from tf_agents.policies import utils
2424
from tf_agents.specs import tensor_spec
2525
from tf_agents.trajectories import time_step as ts
2626
from tf_agents.utils import test_utils
@@ -73,7 +73,7 @@ def testBoltzmannGumbelPredictedRewards(self):
7373
self._action_spec,
7474
reward_network=DummyNet(self._obs_spec),
7575
boltzmann_gumbel_exploration_constant=10.0,
76-
emit_policy_info=('predicted_rewards_mean',),
76+
emit_policy_info=(utils.InfoFields.PREDICTED_REWARDS_MEAN,),
7777
num_samples_list=num_samples_list)
7878
observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
7979
time_step = ts.restart(observations, batch_size=2)
@@ -85,6 +85,135 @@ def testBoltzmannGumbelPredictedRewards(self):
8585
p_info = self.evaluate(action_step.info)
8686
self.assertAllEqual(p_info.predicted_rewards_mean.shape, [2, 3])
8787

88+
def testLargeTemperature(self):
89+
# With a very large temperature, the sampling probability will be uniform.
90+
policy = boltzmann_reward_policy.BoltzmannRewardPredictionPolicy(
91+
self._time_step_spec,
92+
self._action_spec,
93+
reward_network=DummyNet(self._obs_spec),
94+
temperature=10e8,
95+
emit_policy_info=(utils.InfoFields.LOG_PROBABILITY,))
96+
batch_size = 3000
97+
observations = tf.constant([[1, 2]] * batch_size, dtype=tf.float32)
98+
time_step = ts.restart(observations, batch_size=batch_size)
99+
action_step = policy.action(time_step, seed=1)
100+
# Initialize all variables
101+
self.evaluate(tf.compat.v1.global_variables_initializer())
102+
p_info = self.evaluate(action_step.info)
103+
# Check the log probabilities in the policy info are uniform.
104+
self.assertAllEqual(p_info.log_probability,
105+
tf.math.log([1.0 / 3] * batch_size))
106+
# Check the empirical distribution of the chosen arms is uniform.
107+
actions = self.evaluate(action_step.action)
108+
self.assertAllInSet(actions, [0, 1, 2])
109+
# Set tolerance in the chosen count to be 4 std.
110+
tol = 4.0 * np.sqrt(batch_size * 1.0 / 3 * 2.0 / 3)
111+
for action in range(3):
112+
action_chosen_count = np.sum(actions == action)
113+
self.assertNear(
114+
action_chosen_count,
115+
1000,
116+
tol,
117+
msg=f'action: {action} is expected to be chosen between {1000 - tol} '
118+
f'and {1000 + tol} times, but was actually chosen '
119+
f'{action_chosen_count} times.')
120+
121+
def testZeroTemperature(self):
122+
# With zero temperature, the chosen actions should be greedy.
123+
policy = boltzmann_reward_policy.BoltzmannRewardPredictionPolicy(
124+
self._time_step_spec,
125+
self._action_spec,
126+
reward_network=DummyNet(self._obs_spec),
127+
temperature=0.0,
128+
emit_policy_info=(utils.InfoFields.LOG_PROBABILITY,))
129+
observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
130+
time_step = ts.restart(observations, batch_size=2)
131+
action_step = policy.action(time_step, seed=1)
132+
# Initialize all variables
133+
self.evaluate(tf.compat.v1.global_variables_initializer())
134+
actions = self.evaluate(action_step.action)
135+
self.assertAllEqual(actions, [1, 2])
136+
137+
def testZeroGumbelExploration(self):
138+
# When the Boltzmann-Gumbel exploration constant is almost 0, the chosen
139+
# actions should be greedy actions.
140+
num_samples_list = []
141+
for k in range(3):
142+
num_samples_list.append(
143+
tf.compat.v2.Variable(
144+
tf.zeros([], dtype=tf.int32), name='num_samples_{}'.format(k)))
145+
num_samples_list[0].assign_add(2)
146+
num_samples_list[1].assign_add(4)
147+
num_samples_list[2].assign_add(1)
148+
policy = boltzmann_reward_policy.BoltzmannRewardPredictionPolicy(
149+
self._time_step_spec,
150+
self._action_spec,
151+
reward_network=DummyNet(self._obs_spec),
152+
boltzmann_gumbel_exploration_constant=1e-12,
153+
num_samples_list=num_samples_list,
154+
emit_policy_info=(utils.InfoFields.PREDICTED_REWARDS_MEAN,))
155+
observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
156+
time_step = ts.restart(observations, batch_size=2)
157+
action_step = policy.action(time_step, seed=1)
158+
# Initialize all variables
159+
self.evaluate(tf.compat.v1.global_variables_initializer())
160+
actions = self.evaluate(action_step.action)
161+
self.assertAllEqual(actions, [1, 2])
162+
163+
def testAllLargeNumSamples(self):
164+
# When every action has a very large number of samples, the chosen actions
165+
# should be greedy actions.
166+
num_samples_list = []
167+
for k in range(3):
168+
num_samples_list.append(
169+
tf.compat.v2.Variable(
170+
tf.zeros([], dtype=tf.int32), name='num_samples_{}'.format(k)))
171+
num_samples_list[0].assign_add(tf.int32.max - 10)
172+
num_samples_list[1].assign_add(tf.int32.max - 10)
173+
num_samples_list[2].assign_add(tf.int32.max - 10)
174+
policy = boltzmann_reward_policy.BoltzmannRewardPredictionPolicy(
175+
self._time_step_spec,
176+
self._action_spec,
177+
reward_network=DummyNet(self._obs_spec),
178+
boltzmann_gumbel_exploration_constant=100.0,
179+
num_samples_list=num_samples_list,
180+
emit_policy_info=(utils.InfoFields.PREDICTED_REWARDS_MEAN,))
181+
observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
182+
time_step = ts.restart(observations, batch_size=2)
183+
action_step = policy.action(time_step, seed=1)
184+
# Initialize all variables
185+
self.evaluate(tf.compat.v1.global_variables_initializer())
186+
actions = self.evaluate(action_step.action)
187+
self.assertAllEqual(actions, [1, 2])
188+
189+
def testSomeSmallNumSamples(self):
190+
# When some action has a much smaller number of samples, it should be chosen
191+
# more frequently than other actions.
192+
num_samples_list = []
193+
for k in range(3):
194+
num_samples_list.append(
195+
tf.compat.v2.Variable(
196+
tf.zeros([], dtype=tf.int32), name='num_samples_{}'.format(k)))
197+
num_samples_list[0].assign_add(tf.int32.max - 10)
198+
num_samples_list[1].assign_add(1)
199+
num_samples_list[2].assign_add(tf.int32.max - 10)
200+
policy = boltzmann_reward_policy.BoltzmannRewardPredictionPolicy(
201+
self._time_step_spec,
202+
self._action_spec,
203+
reward_network=DummyNet(self._obs_spec),
204+
boltzmann_gumbel_exploration_constant=10.0,
205+
num_samples_list=num_samples_list,
206+
emit_policy_info=(utils.InfoFields.PREDICTED_REWARDS_MEAN,))
207+
batch_size = 3000
208+
observations = tf.constant([[1, 2]] * batch_size, dtype=tf.float32)
209+
time_step = ts.restart(observations, batch_size=batch_size)
210+
action_step = policy.action(time_step, seed=1)
211+
# Initialize all variables
212+
self.evaluate(tf.compat.v1.global_variables_initializer())
213+
actions = self.evaluate(action_step.action)
214+
self.assertAllInSet(actions, [0, 1, 2])
215+
action_counts = {action: np.sum(actions == action) for action in range(3)}
216+
self.assertAllLess([action_counts[0], action_counts[2]], action_counts[1])
88217

89218
if __name__ == '__main__':
90219
tf.test.main()

tf_agents/bandits/policies/greedy_reward_prediction_policy.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@
1515

1616
"""Policy for greedy reward prediction."""
1717

18-
from __future__ import absolute_import
19-
from __future__ import division
20-
from __future__ import print_function
21-
2218
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
19+
import tensorflow_probability as tfp
2320

2421
from tf_agents.bandits.policies import reward_prediction_base_policy
2522
from tf_agents.policies import utils as policy_utilities
@@ -29,7 +26,7 @@ class GreedyRewardPredictionPolicy(
2926
reward_prediction_base_policy.RewardPredictionBasePolicy):
3027
"""Class to build GreedyNNPredictionPolicies."""
3128

32-
def _sample_action(self, mask, predicted_rewards):
29+
def _action_distribution(self, mask, predicted_rewards):
3330
"""Returns the action with largest predicted reward."""
3431
# Argmax.
3532
batch_size = tf.shape(predicted_rewards)[0]
@@ -44,6 +41,16 @@ def _sample_action(self, mask, predicted_rewards):
4441

4542
bandit_policy_values = tf.fill([batch_size, 1],
4643
policy_utilities.BanditPolicyType.GREEDY)
47-
# This deterministic policy chooses the greedy action with probability 1.
48-
log_probability = tf.zeros([batch_size], tf.float32)
49-
return actions, log_probability, bandit_policy_values
44+
return tfp.distributions.Deterministic(loc=actions), bandit_policy_values
45+
46+
def _distribution(self, time_step, policy_state):
47+
step = super(GreedyRewardPredictionPolicy,
48+
self)._distribution(time_step, policy_state)
49+
# Greedy is deterministic, so we know the chosen arm features here. We
50+
# save it here so the chosen arm features get correctly returned by
51+
# `tf_agents.policies.epsilon_greey_policy.EpsilonGreedyPolicy` wrapping a
52+
# `GreedyRewardPredictionPolicy` because `EpsilonGreedyPolicy` only accesses
53+
# the `distribution` method of the wrapped policy via
54+
# `tf_agents.policies.greedy_policy.GreedyPolicy`.
55+
action = step.action.sample()
56+
return self._maybe_save_chosen_arm_features(time_step, action, step)

0 commit comments

Comments
 (0)