-
Notifications
You must be signed in to change notification settings - Fork 743
Description
Hello, I am a student research assistant at the Creative Machines Lab at Columbia, contributing to the Smart Building project. During my work training a PPO agent, I encountered a series of TypeError messages when using tf_agents.agents.ppo.ppo_policy.PPOPolicy with an actor network that outputs a tfp.distributions.SquashToSpecNormal distribution.
The actor is based on tf_agents.agents.ppo.ppo_actor_network but includes one modification to transform MultivariateNormalDiag to SquashToSpecNormal to suit our bounded action spec. For context, the value network is defined with tf_agents.networks.value_network.ValueNetwork as recommended by the PPO Agent documentation.
def create_dist(loc_and_scale):
loc = loc_and_scale['loc']
loc = tanh_and_scale_to_spec(loc, action_tensor_spec)
scale = loc_and_scale['scale']
scale = tf.nn.softplus(scale)
dist = output_spec.build_distribution(loc=loc, scale=scale)
# change here
return distribution_utils.scale_distribution_to_spec(
dist, action_tensor_spec
)Then there will be errors such as
TypeError: Expected binary or unicode string, got BoundedTensorSpec(shape=(2,), dtype=tf.float32, name='action', minimum=array(-1., dtype=float32), maximum=array(1., dtype=float32))TypeError: Failed to convert elements of BoundedTensorSpec(shape=(2,), dtype=tf.float32, name='action', minimum=array(-1., dtype=float32), maximum=array(1., dtype=float32)) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.TypeError: To be compatible with tf.function, Python functions must return zero or more Tensors or ExtensionTypes or None values; in compilation of <function TFPolicy.action at 0x701338615120>, found return value of type BoundedTensorSpec, which is not a Tensor or ExtensionType.
These errors seem related to a BoundedTensorSpec type unexpectedly being passed into certain functions. However, I haven’t been able to identify the exact cause. One can easily reproduce the issues by calling the collect actor / driver of a PPO agent that uses such an actor net to run.
Note:
- The errors only arise when an actor / driver using
PPOPolicyis running. An eval actor / driver (in my case usingtf_agents.policies.greedy_policy.GreedyPolicy) runs without any issue. - The
SquashToSpecNormaldistribution works with SAC agent, as shown in the Google open sourced smart building notebook https://github.com/google/sbsim/blob/copybara_push/smart_control/notebooks/SAC_Demo.ipynb.