1414# limitations under the License.
1515
1616"""Test for boltzmann_reward_prediction_policy."""
17- from __future__ import absolute_import
18- from __future__ import division
19- from __future__ import print_function
17+
18+ import numpy as np
2019
2120import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
2221from tf_agents .bandits .policies import boltzmann_reward_prediction_policy as boltzmann_reward_policy
2322from tf_agents .networks import network
23+ from tf_agents .policies import utils
2424from tf_agents .specs import tensor_spec
2525from tf_agents .trajectories import time_step as ts
2626from tf_agents .utils import test_utils
@@ -73,7 +73,7 @@ def testBoltzmannGumbelPredictedRewards(self):
7373 self ._action_spec ,
7474 reward_network = DummyNet (self ._obs_spec ),
7575 boltzmann_gumbel_exploration_constant = 10.0 ,
76- emit_policy_info = ('predicted_rewards_mean' ,),
76+ emit_policy_info = (utils . InfoFields . PREDICTED_REWARDS_MEAN ,),
7777 num_samples_list = num_samples_list )
7878 observations = tf .constant ([[1 , 2 ], [3 , 4 ]], dtype = tf .float32 )
7979 time_step = ts .restart (observations , batch_size = 2 )
@@ -85,6 +85,135 @@ def testBoltzmannGumbelPredictedRewards(self):
8585 p_info = self .evaluate (action_step .info )
8686 self .assertAllEqual (p_info .predicted_rewards_mean .shape , [2 , 3 ])
8787
88+ def testLargeTemperature (self ):
89+ # With a very large temperature, the sampling probability will be uniform.
90+ policy = boltzmann_reward_policy .BoltzmannRewardPredictionPolicy (
91+ self ._time_step_spec ,
92+ self ._action_spec ,
93+ reward_network = DummyNet (self ._obs_spec ),
94+ temperature = 10e8 ,
95+ emit_policy_info = (utils .InfoFields .LOG_PROBABILITY ,))
96+ batch_size = 3000
97+ observations = tf .constant ([[1 , 2 ]] * batch_size , dtype = tf .float32 )
98+ time_step = ts .restart (observations , batch_size = batch_size )
99+ action_step = policy .action (time_step , seed = 1 )
100+ # Initialize all variables
101+ self .evaluate (tf .compat .v1 .global_variables_initializer ())
102+ p_info = self .evaluate (action_step .info )
103+ # Check the log probabilities in the policy info are uniform.
104+ self .assertAllEqual (p_info .log_probability ,
105+ tf .math .log ([1.0 / 3 ] * batch_size ))
106+ # Check the empirical distribution of the chosen arms is uniform.
107+ actions = self .evaluate (action_step .action )
108+ self .assertAllInSet (actions , [0 , 1 , 2 ])
109+ # Set tolerance in the chosen count to be 4 std.
110+ tol = 4.0 * np .sqrt (batch_size * 1.0 / 3 * 2.0 / 3 )
111+ for action in range (3 ):
112+ action_chosen_count = np .sum (actions == action )
113+ self .assertNear (
114+ action_chosen_count ,
115+ 1000 ,
116+ tol ,
117+ msg = f'action: { action } is expected to be chosen between { 1000 - tol } '
118+ f'and { 1000 + tol } times, but was actually chosen '
119+ f'{ action_chosen_count } times.' )
120+
121+ def testZeroTemperature (self ):
122+ # With zero temperature, the chosen actions should be greedy.
123+ policy = boltzmann_reward_policy .BoltzmannRewardPredictionPolicy (
124+ self ._time_step_spec ,
125+ self ._action_spec ,
126+ reward_network = DummyNet (self ._obs_spec ),
127+ temperature = 0.0 ,
128+ emit_policy_info = (utils .InfoFields .LOG_PROBABILITY ,))
129+ observations = tf .constant ([[1 , 2 ], [3 , 4 ]], dtype = tf .float32 )
130+ time_step = ts .restart (observations , batch_size = 2 )
131+ action_step = policy .action (time_step , seed = 1 )
132+ # Initialize all variables
133+ self .evaluate (tf .compat .v1 .global_variables_initializer ())
134+ actions = self .evaluate (action_step .action )
135+ self .assertAllEqual (actions , [1 , 2 ])
136+
137+ def testZeroGumbelExploration (self ):
138+ # When the Boltzmann-Gumbel exploration constant is almost 0, the chosen
139+ # actions should be greedy actions.
140+ num_samples_list = []
141+ for k in range (3 ):
142+ num_samples_list .append (
143+ tf .compat .v2 .Variable (
144+ tf .zeros ([], dtype = tf .int32 ), name = 'num_samples_{}' .format (k )))
145+ num_samples_list [0 ].assign_add (2 )
146+ num_samples_list [1 ].assign_add (4 )
147+ num_samples_list [2 ].assign_add (1 )
148+ policy = boltzmann_reward_policy .BoltzmannRewardPredictionPolicy (
149+ self ._time_step_spec ,
150+ self ._action_spec ,
151+ reward_network = DummyNet (self ._obs_spec ),
152+ boltzmann_gumbel_exploration_constant = 1e-12 ,
153+ num_samples_list = num_samples_list ,
154+ emit_policy_info = (utils .InfoFields .PREDICTED_REWARDS_MEAN ,))
155+ observations = tf .constant ([[1 , 2 ], [3 , 4 ]], dtype = tf .float32 )
156+ time_step = ts .restart (observations , batch_size = 2 )
157+ action_step = policy .action (time_step , seed = 1 )
158+ # Initialize all variables
159+ self .evaluate (tf .compat .v1 .global_variables_initializer ())
160+ actions = self .evaluate (action_step .action )
161+ self .assertAllEqual (actions , [1 , 2 ])
162+
163+ def testAllLargeNumSamples (self ):
164+ # When every action has a very large number of samples, the chosen actions
165+ # should be greedy actions.
166+ num_samples_list = []
167+ for k in range (3 ):
168+ num_samples_list .append (
169+ tf .compat .v2 .Variable (
170+ tf .zeros ([], dtype = tf .int32 ), name = 'num_samples_{}' .format (k )))
171+ num_samples_list [0 ].assign_add (tf .int32 .max - 10 )
172+ num_samples_list [1 ].assign_add (tf .int32 .max - 10 )
173+ num_samples_list [2 ].assign_add (tf .int32 .max - 10 )
174+ policy = boltzmann_reward_policy .BoltzmannRewardPredictionPolicy (
175+ self ._time_step_spec ,
176+ self ._action_spec ,
177+ reward_network = DummyNet (self ._obs_spec ),
178+ boltzmann_gumbel_exploration_constant = 100.0 ,
179+ num_samples_list = num_samples_list ,
180+ emit_policy_info = (utils .InfoFields .PREDICTED_REWARDS_MEAN ,))
181+ observations = tf .constant ([[1 , 2 ], [3 , 4 ]], dtype = tf .float32 )
182+ time_step = ts .restart (observations , batch_size = 2 )
183+ action_step = policy .action (time_step , seed = 1 )
184+ # Initialize all variables
185+ self .evaluate (tf .compat .v1 .global_variables_initializer ())
186+ actions = self .evaluate (action_step .action )
187+ self .assertAllEqual (actions , [1 , 2 ])
188+
189+ def testSomeSmallNumSamples (self ):
190+ # When some action has a much smaller number of samples, it should be chosen
191+ # more frequently than other actions.
192+ num_samples_list = []
193+ for k in range (3 ):
194+ num_samples_list .append (
195+ tf .compat .v2 .Variable (
196+ tf .zeros ([], dtype = tf .int32 ), name = 'num_samples_{}' .format (k )))
197+ num_samples_list [0 ].assign_add (tf .int32 .max - 10 )
198+ num_samples_list [1 ].assign_add (1 )
199+ num_samples_list [2 ].assign_add (tf .int32 .max - 10 )
200+ policy = boltzmann_reward_policy .BoltzmannRewardPredictionPolicy (
201+ self ._time_step_spec ,
202+ self ._action_spec ,
203+ reward_network = DummyNet (self ._obs_spec ),
204+ boltzmann_gumbel_exploration_constant = 10.0 ,
205+ num_samples_list = num_samples_list ,
206+ emit_policy_info = (utils .InfoFields .PREDICTED_REWARDS_MEAN ,))
207+ batch_size = 3000
208+ observations = tf .constant ([[1 , 2 ]] * batch_size , dtype = tf .float32 )
209+ time_step = ts .restart (observations , batch_size = batch_size )
210+ action_step = policy .action (time_step , seed = 1 )
211+ # Initialize all variables
212+ self .evaluate (tf .compat .v1 .global_variables_initializer ())
213+ actions = self .evaluate (action_step .action )
214+ self .assertAllInSet (actions , [0 , 1 , 2 ])
215+ action_counts = {action : np .sum (actions == action ) for action in range (3 )}
216+ self .assertAllLess ([action_counts [0 ], action_counts [2 ]], action_counts [1 ])
88217
89218if __name__ == '__main__' :
90219 tf .test .main ()
0 commit comments