Skip to content

Commit 62a026b

Browse files
TF-Agents Teamcopybara-github
authored andcommitted
FalconRewardPredictionPolicy
This policy implements the action sampling strategy (Step 6 of Algorithm 1) in the FALCON paper https://arxiv.org/pdf/2003.12699.pdf PiperOrigin-RevId: 472726674 Change-Id: I4b304f02ba1dcbd01737eeb7824618b8a463de17
1 parent 141cefe commit 62a026b

File tree

5 files changed

+456
-2
lines changed

5 files changed

+456
-2
lines changed

tf_agents/bandits/policies/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Module importing all policies."""
1717

1818
from tf_agents.bandits.policies import categorical_policy
19+
from tf_agents.bandits.policies import falcon_reward_prediction_policy
1920
from tf_agents.bandits.policies import greedy_multi_objective_neural_policy
2021
from tf_agents.bandits.policies import greedy_reward_prediction_policy
2122
from tf_agents.bandits.policies import lin_ucb_policy
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# coding=utf-8
2+
# Copyright 2020 The TF-Agents Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Policy that samples actions based on the FALCON algorithm.
17+
18+
This policy implements an action sampling distribution based on the following
19+
paper: David Simchi-Levi and Yunzong Xu, "Bypassing the Monster: A Faster and
20+
Simpler Optimal Algorithm for Contextual Bandits under Realizability",
21+
Mathematics of Operations Research, 2021. https://arxiv.org/pdf/2003.12699.pdf
22+
"""
23+
24+
from typing import Iterable, Optional, Text, Tuple, Sequence
25+
26+
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
27+
import tensorflow_probability as tfp
28+
29+
from tf_agents.bandits.policies import constraints as constr
30+
from tf_agents.bandits.policies import reward_prediction_base_policy
31+
from tf_agents.distributions import shifted_categorical
32+
from tf_agents.policies import utils as policy_utilities
33+
from tf_agents.typing import types
34+
35+
36+
def get_number_of_trainable_elements(network: types.Network) -> types.Float:
37+
"""Gets the total # of elements in the network's trainable variables.
38+
39+
Args:
40+
network: A `types.Network`.
41+
42+
Returns:
43+
The total number of elements in the network's trainable variables.
44+
"""
45+
num_elements_list = []
46+
for var in network.trainable_variables:
47+
num_elements = var.get_shape().num_elements()
48+
if num_elements is None:
49+
raise ValueError(
50+
f'Variable:{var} is expected to have a known shape, but found '
51+
'otherwise.')
52+
num_elements_list.append(num_elements)
53+
return sum(num_elements_list)
54+
55+
56+
class FalconRewardPredictionPolicy(
57+
reward_prediction_base_policy.RewardPredictionBasePolicy):
58+
"""Policy that samples actions based on the FALCON algorithm."""
59+
60+
def __init__(self,
61+
time_step_spec: types.TimeStep,
62+
action_spec: types.NestedTensorSpec,
63+
reward_network: types.Network,
64+
exploitation_coefficient: types.FloatOrReturningFloat = 1.0,
65+
observation_and_action_constraint_splitter: Optional[
66+
types.Splitter] = None,
67+
accepts_per_arm_features: bool = False,
68+
constraints: Iterable[constr.NeuralConstraint] = (),
69+
emit_policy_info: Tuple[Text, ...] = (),
70+
num_samples_list: Sequence[tf.Variable] = (),
71+
name: Optional[Text] = None):
72+
"""Builds a FalconRewardPredictionPolicy given a reward network.
73+
74+
This policy takes a tf_agents.Network predicting rewards and samples an
75+
action based on predicted rewards with the action distribution described
76+
in Step 6 of Algorithm 1 in the paper:
77+
78+
David Simchi-Levi and Yunzong Xu, "Bypassing the Monster: A Faster and
79+
Simpler Optimal Algorithm for Contextual Bandits under Realizability",
80+
Mathematics of Operations Research, 2021.
81+
https://arxiv.org/pdf/2003.12699.pdf
82+
83+
Args:
84+
time_step_spec: A `TimeStep` spec of the expected time_steps.
85+
action_spec: A nest of BoundedTensorSpec representing the actions.
86+
reward_network: An instance of a `tf_agents.network.Network`, callable via
87+
`network(observation, step_type) -> (output, final_state)`.
88+
exploitation_coefficient: float or callable that returns a float. Its
89+
value will be internally lower-bounded at 0. It controls how
90+
exploitative the policy behaves w.r.t the predicted rewards: A larger
91+
value makes the policy sample the greedy action (one with the best
92+
predicted reward) with a higher probability.
93+
observation_and_action_constraint_splitter: A function used for masking
94+
valid/invalid actions with each state of the environment. The function
95+
takes in a full observation and returns a tuple consisting of 1) the
96+
part of the observation intended as input to the network and 2) the
97+
mask. The mask should be a 0-1 `Tensor` of shape `[batch_size,
98+
num_actions]`. This function should also work with a `TensorSpec` as
99+
input, and should output `TensorSpec` objects for the observation and
100+
mask.
101+
accepts_per_arm_features: (bool) Whether the policy accepts per-arm
102+
features.
103+
constraints: iterable of constraints objects that are instances of
104+
`tf_agents.bandits.agents.BaseConstraint`.
105+
emit_policy_info: (tuple of strings) what side information we want to get
106+
as part of the policy info. Allowed values can be found in
107+
`policy_utilities.PolicyInfo`.
108+
num_samples_list: `Sequence` of tf.Variable's representing the number of
109+
examples for every action that the policy was trained with. For per-arm
110+
features, the size of the list is expected to be 1, representing the
111+
total number of examples the policy was trained with.
112+
name: The name of this policy. All variables in this module will fall
113+
under that name. Defaults to the class name.
114+
115+
Raises:
116+
NotImplementedError: If `action_spec` contains more than one
117+
`BoundedTensorSpec` or the `BoundedTensorSpec` is not valid.
118+
"""
119+
super(FalconRewardPredictionPolicy,
120+
self).__init__(time_step_spec, action_spec, reward_network,
121+
observation_and_action_constraint_splitter,
122+
accepts_per_arm_features, constraints,
123+
emit_policy_info, name)
124+
125+
self._exploitation_coefficient = exploitation_coefficient
126+
self._num_samples_list = num_samples_list if num_samples_list else (
127+
[tf.Variable(0, dtype=tf.int64)] * self._expected_num_actions)
128+
if len(self._num_samples_list) != self._expected_num_actions:
129+
raise ValueError('Size of num_samples_list: ',
130+
len(self._num_samples_list),
131+
' does not match the expected number of actions:',
132+
self._expected_num_actions)
133+
self._num_trainable_elements = get_number_of_trainable_elements(
134+
self._reward_network)
135+
136+
def _get_exploitation_coefficient(self) -> types.FloatOrReturningFloat:
137+
coef = self._exploitation_coefficient() if callable(
138+
self._exploitation_coefficient) else self._exploitation_coefficient
139+
return tf.maximum(coef, 0.0)
140+
141+
@property
142+
def num_trainable_elements(self):
143+
return self._num_trainable_elements
144+
145+
@property
146+
def num_samples_list(self):
147+
return self._num_samples_list
148+
149+
def _compute_gamma(self, dtype: tf.DType) -> types.Float:
150+
"""Computes the gamma parameter in the sampling probability.
151+
152+
This helper method implements a simple heuristic for computing the
153+
the gamma parameter in Step 2 of Algorithm 1 in the paper
154+
https://arxiv.org/pdf/2003.12699.pdf. A higher gamma makes the action
155+
sampling distribution concentrate more on the greedy action.
156+
157+
Args:
158+
dtype: Type of the returned value, expected to be a float type.
159+
160+
Returns:
161+
The gamma parameter.
162+
"""
163+
num_samples_list_float = tf.maximum(
164+
[tf.cast(x.read_value(), tf.float32) for x in self.num_samples_list],
165+
0.0)
166+
num_trainable_elements_float = tf.cast(
167+
tf.math.maximum(self.num_trainable_elements, 1), tf.float32)
168+
return self._get_exploitation_coefficient() * tf.sqrt(
169+
self._expected_num_actions * tf.reduce_sum(num_samples_list_float) /
170+
num_trainable_elements_float)
171+
172+
def _action_distribution(self, mask, predicted_rewards):
173+
gamma = self._compute_gamma(predicted_rewards.dtype)
174+
batch_size = tf.shape(predicted_rewards)[0]
175+
# Replace predicted rewards of masked actions with -inf.
176+
predictions = predicted_rewards if mask is None else tf.where(
177+
tf.cast(mask, tf.bool), predicted_rewards, -float('Inf') *
178+
tf.ones_like(predicted_rewards))
179+
180+
# Get the predicted rewards of the greedy actions.
181+
greedy_action_predictions = tf.reshape(
182+
tf.reduce_max(predictions, axis=-1), shape=[-1, 1])
183+
184+
# `other_actions_probs` is a tensor shaped as [batch_size, num_actions] that
185+
# contains valid sampling probabilities for all non-greedy actions.
186+
other_actions_probs = tf.math.divide_no_nan(
187+
1.0, self._expected_num_actions + gamma *
188+
(greedy_action_predictions - predictions))
189+
# Although `predictions` has accounted for the action mask, we still need
190+
# to mask the action probabilities in the case of zero gamma.
191+
other_actions_probs = (
192+
other_actions_probs if mask is None else tf.where(
193+
tf.cast(mask, tf.bool), other_actions_probs,
194+
tf.zeros_like(other_actions_probs)))
195+
196+
# Get the greedy action.
197+
greedy_actions = tf.reshape(
198+
tf.argmax(predictions, axis=-1, output_type=self.action_spec.dtype),
199+
[-1, 1])
200+
201+
# Compute the probabilities of sampling the greedy actions, which is
202+
# 1 - (the total probability of sampling other actions).
203+
greedy_action_prob = 1.0 - tf.reshape(
204+
tf.reduce_sum(other_actions_probs, axis=1), [-1, 1]) + tf.gather(
205+
other_actions_probs, greedy_actions, axis=1, batch_dims=1)
206+
207+
# Compute the sampling probabilities for all actions by combining
208+
# `greedy_action_prob` and `other_actions_probs`.
209+
greedy_action_mask = tf.equal(
210+
tf.tile([
211+
tf.range(self._expected_num_actions, dtype=self.action_spec.dtype)
212+
], [batch_size, 1]), greedy_actions)
213+
action_probs = tf.where(
214+
greedy_action_mask,
215+
tf.tile(greedy_action_prob, [1, self._expected_num_actions]),
216+
other_actions_probs)
217+
218+
if self._action_offset != 0:
219+
distribution = shifted_categorical.ShiftedCategorical(
220+
probs=action_probs,
221+
dtype=self._action_spec.dtype,
222+
shift=self._action_offset)
223+
else:
224+
distribution = tfp.distributions.Categorical(
225+
probs=action_probs, dtype=self._action_spec.dtype)
226+
227+
bandit_policy_values = tf.fill([batch_size, 1],
228+
policy_utilities.BanditPolicyType.FALCON)
229+
return distribution, bandit_policy_values

0 commit comments

Comments
 (0)