|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2020 The TF-Agents Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Policy that samples actions based on the FALCON algorithm. |
| 17 | +
|
| 18 | +This policy implements an action sampling distribution based on the following |
| 19 | +paper: David Simchi-Levi and Yunzong Xu, "Bypassing the Monster: A Faster and |
| 20 | +Simpler Optimal Algorithm for Contextual Bandits under Realizability", |
| 21 | +Mathematics of Operations Research, 2021. https://arxiv.org/pdf/2003.12699.pdf |
| 22 | +""" |
| 23 | + |
| 24 | +from typing import Iterable, Optional, Text, Tuple, Sequence |
| 25 | + |
| 26 | +import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import |
| 27 | +import tensorflow_probability as tfp |
| 28 | + |
| 29 | +from tf_agents.bandits.policies import constraints as constr |
| 30 | +from tf_agents.bandits.policies import reward_prediction_base_policy |
| 31 | +from tf_agents.distributions import shifted_categorical |
| 32 | +from tf_agents.policies import utils as policy_utilities |
| 33 | +from tf_agents.typing import types |
| 34 | + |
| 35 | + |
| 36 | +def get_number_of_trainable_elements(network: types.Network) -> types.Float: |
| 37 | + """Gets the total # of elements in the network's trainable variables. |
| 38 | +
|
| 39 | + Args: |
| 40 | + network: A `types.Network`. |
| 41 | +
|
| 42 | + Returns: |
| 43 | + The total number of elements in the network's trainable variables. |
| 44 | + """ |
| 45 | + num_elements_list = [] |
| 46 | + for var in network.trainable_variables: |
| 47 | + num_elements = var.get_shape().num_elements() |
| 48 | + if num_elements is None: |
| 49 | + raise ValueError( |
| 50 | + f'Variable:{var} is expected to have a known shape, but found ' |
| 51 | + 'otherwise.') |
| 52 | + num_elements_list.append(num_elements) |
| 53 | + return sum(num_elements_list) |
| 54 | + |
| 55 | + |
| 56 | +class FalconRewardPredictionPolicy( |
| 57 | + reward_prediction_base_policy.RewardPredictionBasePolicy): |
| 58 | + """Policy that samples actions based on the FALCON algorithm.""" |
| 59 | + |
| 60 | + def __init__(self, |
| 61 | + time_step_spec: types.TimeStep, |
| 62 | + action_spec: types.NestedTensorSpec, |
| 63 | + reward_network: types.Network, |
| 64 | + exploitation_coefficient: types.FloatOrReturningFloat = 1.0, |
| 65 | + observation_and_action_constraint_splitter: Optional[ |
| 66 | + types.Splitter] = None, |
| 67 | + accepts_per_arm_features: bool = False, |
| 68 | + constraints: Iterable[constr.NeuralConstraint] = (), |
| 69 | + emit_policy_info: Tuple[Text, ...] = (), |
| 70 | + num_samples_list: Sequence[tf.Variable] = (), |
| 71 | + name: Optional[Text] = None): |
| 72 | + """Builds a FalconRewardPredictionPolicy given a reward network. |
| 73 | +
|
| 74 | + This policy takes a tf_agents.Network predicting rewards and samples an |
| 75 | + action based on predicted rewards with the action distribution described |
| 76 | + in Step 6 of Algorithm 1 in the paper: |
| 77 | +
|
| 78 | + David Simchi-Levi and Yunzong Xu, "Bypassing the Monster: A Faster and |
| 79 | + Simpler Optimal Algorithm for Contextual Bandits under Realizability", |
| 80 | + Mathematics of Operations Research, 2021. |
| 81 | + https://arxiv.org/pdf/2003.12699.pdf |
| 82 | +
|
| 83 | + Args: |
| 84 | + time_step_spec: A `TimeStep` spec of the expected time_steps. |
| 85 | + action_spec: A nest of BoundedTensorSpec representing the actions. |
| 86 | + reward_network: An instance of a `tf_agents.network.Network`, callable via |
| 87 | + `network(observation, step_type) -> (output, final_state)`. |
| 88 | + exploitation_coefficient: float or callable that returns a float. Its |
| 89 | + value will be internally lower-bounded at 0. It controls how |
| 90 | + exploitative the policy behaves w.r.t the predicted rewards: A larger |
| 91 | + value makes the policy sample the greedy action (one with the best |
| 92 | + predicted reward) with a higher probability. |
| 93 | + observation_and_action_constraint_splitter: A function used for masking |
| 94 | + valid/invalid actions with each state of the environment. The function |
| 95 | + takes in a full observation and returns a tuple consisting of 1) the |
| 96 | + part of the observation intended as input to the network and 2) the |
| 97 | + mask. The mask should be a 0-1 `Tensor` of shape `[batch_size, |
| 98 | + num_actions]`. This function should also work with a `TensorSpec` as |
| 99 | + input, and should output `TensorSpec` objects for the observation and |
| 100 | + mask. |
| 101 | + accepts_per_arm_features: (bool) Whether the policy accepts per-arm |
| 102 | + features. |
| 103 | + constraints: iterable of constraints objects that are instances of |
| 104 | + `tf_agents.bandits.agents.BaseConstraint`. |
| 105 | + emit_policy_info: (tuple of strings) what side information we want to get |
| 106 | + as part of the policy info. Allowed values can be found in |
| 107 | + `policy_utilities.PolicyInfo`. |
| 108 | + num_samples_list: `Sequence` of tf.Variable's representing the number of |
| 109 | + examples for every action that the policy was trained with. For per-arm |
| 110 | + features, the size of the list is expected to be 1, representing the |
| 111 | + total number of examples the policy was trained with. |
| 112 | + name: The name of this policy. All variables in this module will fall |
| 113 | + under that name. Defaults to the class name. |
| 114 | +
|
| 115 | + Raises: |
| 116 | + NotImplementedError: If `action_spec` contains more than one |
| 117 | + `BoundedTensorSpec` or the `BoundedTensorSpec` is not valid. |
| 118 | + """ |
| 119 | + super(FalconRewardPredictionPolicy, |
| 120 | + self).__init__(time_step_spec, action_spec, reward_network, |
| 121 | + observation_and_action_constraint_splitter, |
| 122 | + accepts_per_arm_features, constraints, |
| 123 | + emit_policy_info, name) |
| 124 | + |
| 125 | + self._exploitation_coefficient = exploitation_coefficient |
| 126 | + self._num_samples_list = num_samples_list if num_samples_list else ( |
| 127 | + [tf.Variable(0, dtype=tf.int64)] * self._expected_num_actions) |
| 128 | + if len(self._num_samples_list) != self._expected_num_actions: |
| 129 | + raise ValueError('Size of num_samples_list: ', |
| 130 | + len(self._num_samples_list), |
| 131 | + ' does not match the expected number of actions:', |
| 132 | + self._expected_num_actions) |
| 133 | + self._num_trainable_elements = get_number_of_trainable_elements( |
| 134 | + self._reward_network) |
| 135 | + |
| 136 | + def _get_exploitation_coefficient(self) -> types.FloatOrReturningFloat: |
| 137 | + coef = self._exploitation_coefficient() if callable( |
| 138 | + self._exploitation_coefficient) else self._exploitation_coefficient |
| 139 | + return tf.maximum(coef, 0.0) |
| 140 | + |
| 141 | + @property |
| 142 | + def num_trainable_elements(self): |
| 143 | + return self._num_trainable_elements |
| 144 | + |
| 145 | + @property |
| 146 | + def num_samples_list(self): |
| 147 | + return self._num_samples_list |
| 148 | + |
| 149 | + def _compute_gamma(self, dtype: tf.DType) -> types.Float: |
| 150 | + """Computes the gamma parameter in the sampling probability. |
| 151 | +
|
| 152 | + This helper method implements a simple heuristic for computing the |
| 153 | + the gamma parameter in Step 2 of Algorithm 1 in the paper |
| 154 | + https://arxiv.org/pdf/2003.12699.pdf. A higher gamma makes the action |
| 155 | + sampling distribution concentrate more on the greedy action. |
| 156 | +
|
| 157 | + Args: |
| 158 | + dtype: Type of the returned value, expected to be a float type. |
| 159 | +
|
| 160 | + Returns: |
| 161 | + The gamma parameter. |
| 162 | + """ |
| 163 | + num_samples_list_float = tf.maximum( |
| 164 | + [tf.cast(x.read_value(), tf.float32) for x in self.num_samples_list], |
| 165 | + 0.0) |
| 166 | + num_trainable_elements_float = tf.cast( |
| 167 | + tf.math.maximum(self.num_trainable_elements, 1), tf.float32) |
| 168 | + return self._get_exploitation_coefficient() * tf.sqrt( |
| 169 | + self._expected_num_actions * tf.reduce_sum(num_samples_list_float) / |
| 170 | + num_trainable_elements_float) |
| 171 | + |
| 172 | + def _action_distribution(self, mask, predicted_rewards): |
| 173 | + gamma = self._compute_gamma(predicted_rewards.dtype) |
| 174 | + batch_size = tf.shape(predicted_rewards)[0] |
| 175 | + # Replace predicted rewards of masked actions with -inf. |
| 176 | + predictions = predicted_rewards if mask is None else tf.where( |
| 177 | + tf.cast(mask, tf.bool), predicted_rewards, -float('Inf') * |
| 178 | + tf.ones_like(predicted_rewards)) |
| 179 | + |
| 180 | + # Get the predicted rewards of the greedy actions. |
| 181 | + greedy_action_predictions = tf.reshape( |
| 182 | + tf.reduce_max(predictions, axis=-1), shape=[-1, 1]) |
| 183 | + |
| 184 | + # `other_actions_probs` is a tensor shaped as [batch_size, num_actions] that |
| 185 | + # contains valid sampling probabilities for all non-greedy actions. |
| 186 | + other_actions_probs = tf.math.divide_no_nan( |
| 187 | + 1.0, self._expected_num_actions + gamma * |
| 188 | + (greedy_action_predictions - predictions)) |
| 189 | + # Although `predictions` has accounted for the action mask, we still need |
| 190 | + # to mask the action probabilities in the case of zero gamma. |
| 191 | + other_actions_probs = ( |
| 192 | + other_actions_probs if mask is None else tf.where( |
| 193 | + tf.cast(mask, tf.bool), other_actions_probs, |
| 194 | + tf.zeros_like(other_actions_probs))) |
| 195 | + |
| 196 | + # Get the greedy action. |
| 197 | + greedy_actions = tf.reshape( |
| 198 | + tf.argmax(predictions, axis=-1, output_type=self.action_spec.dtype), |
| 199 | + [-1, 1]) |
| 200 | + |
| 201 | + # Compute the probabilities of sampling the greedy actions, which is |
| 202 | + # 1 - (the total probability of sampling other actions). |
| 203 | + greedy_action_prob = 1.0 - tf.reshape( |
| 204 | + tf.reduce_sum(other_actions_probs, axis=1), [-1, 1]) + tf.gather( |
| 205 | + other_actions_probs, greedy_actions, axis=1, batch_dims=1) |
| 206 | + |
| 207 | + # Compute the sampling probabilities for all actions by combining |
| 208 | + # `greedy_action_prob` and `other_actions_probs`. |
| 209 | + greedy_action_mask = tf.equal( |
| 210 | + tf.tile([ |
| 211 | + tf.range(self._expected_num_actions, dtype=self.action_spec.dtype) |
| 212 | + ], [batch_size, 1]), greedy_actions) |
| 213 | + action_probs = tf.where( |
| 214 | + greedy_action_mask, |
| 215 | + tf.tile(greedy_action_prob, [1, self._expected_num_actions]), |
| 216 | + other_actions_probs) |
| 217 | + |
| 218 | + if self._action_offset != 0: |
| 219 | + distribution = shifted_categorical.ShiftedCategorical( |
| 220 | + probs=action_probs, |
| 221 | + dtype=self._action_spec.dtype, |
| 222 | + shift=self._action_offset) |
| 223 | + else: |
| 224 | + distribution = tfp.distributions.Categorical( |
| 225 | + probs=action_probs, dtype=self._action_spec.dtype) |
| 226 | + |
| 227 | + bandit_policy_values = tf.fill([batch_size, 1], |
| 228 | + policy_utilities.BanditPolicyType.FALCON) |
| 229 | + return distribution, bandit_policy_values |
0 commit comments