Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions rsl_rl/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .memory import HiddenState, Memory
from .mlp import MLP
from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization
from .spatial_softmax import SpatialSoftmax

__all__ = [
"CNN",
Expand All @@ -17,4 +18,5 @@
"EmpiricalNormalization",
"HiddenState",
"Memory",
"SpatialSoftmax",
]
52 changes: 37 additions & 15 deletions rsl_rl/networks/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import torch
from torch import nn as nn

from rsl_rl.networks.spatial_softmax import SpatialSoftmax
from rsl_rl.utils import get_param, resolve_nn_activation


class CNN(nn.Sequential):
"""Convolutional Neural Network (CNN).

The CNN network is a sequence of convolutional layers, optional normalization layers, optional activation functions,
and optional pooling. The final output can be flattened.
and optional pooling. The final output can be flattened or passed through spatial softmax.
"""

def __init__(
Expand All @@ -32,6 +33,8 @@ def __init__(
activation: str = "elu",
max_pool: bool | tuple[bool] | list[bool] = False,
global_pool: str = "none",
spatial_softmax: bool = False,
spatial_softmax_temperature: float = 1.0,
flatten: bool = True,
) -> None:
"""Initialize the CNN.
Expand All @@ -50,10 +53,15 @@ def __init__(
max_pool: List of booleans indicating whether to apply max pooling after each convolutional layer or a
single boolean for all layers.
global_pool: Global pooling type to apply at the end. Either 'none', 'max', or 'avg'.
flatten: Whether to flatten the output tensor.
spatial_softmax: Whether to apply spatial softmax instead of global pooling.
spatial_softmax_temperature: Temperature parameter for spatial softmax.
flatten: Whether to flatten the output tensor (ignored if spatial_softmax=True).
"""
super().__init__()

if spatial_softmax and global_pool != "none":
raise ValueError("Cannot use both spatial_softmax and global_pool. Set global_pool='none'.")

# Resolve activation function
activation_function = resolve_nn_activation(activation)

Expand Down Expand Up @@ -110,28 +118,42 @@ def __init__(
last_channels = output_channels[idx]
last_dim = _compute_output_dim(last_dim, k, s, d, p, is_max_pool=get_param(max_pool, idx))

# Apply global pooling if specified
if global_pool == "none":
pass
# Apply spatial softmax or global pooling
if spatial_softmax:
layers.append(SpatialSoftmax(last_dim[0], last_dim[1], spatial_softmax_temperature))
self._output_channels = None
self._output_dim = last_channels * 2
elif global_pool == "none":
if flatten:
layers.append(nn.Flatten(start_dim=1))
self._output_channels = None
self._output_dim = last_channels * last_dim[0] * last_dim[1]
else:
self._output_channels = last_channels
self._output_dim = last_dim
elif global_pool == "max":
layers.append(nn.AdaptiveMaxPool2d((1, 1)))
last_dim = (1, 1)
if flatten:
layers.append(nn.Flatten(start_dim=1))
self._output_channels = None
self._output_dim = last_channels
else:
self._output_channels = last_channels
self._output_dim = (1, 1)
elif global_pool == "avg":
layers.append(nn.AdaptiveAvgPool2d((1, 1)))
last_dim = (1, 1)
if flatten:
layers.append(nn.Flatten(start_dim=1))
self._output_channels = None
self._output_dim = last_channels
else:
self._output_channels = last_channels
self._output_dim = (1, 1)
else:
raise ValueError(
f"Unsupported global pooling type: {global_pool}. Supported types are 'none', 'max', and 'avg'."
)

# Apply flattening if specified
if flatten:
layers.append(nn.Flatten(start_dim=1))

# Store final output dimension
self._output_channels = last_channels if not flatten else None
self._output_dim = last_dim if not flatten else last_channels * last_dim[0] * last_dim[1]

# Register the layers
for idx, layer in enumerate(layers):
self.add_module(f"{idx}", layer)
Expand Down
39 changes: 39 additions & 0 deletions rsl_rl/networks/spatial_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F


class SpatialSoftmax(nn.Module):
"""Spatial Softmax layer for extracting spatial features from feature maps.

Given feature maps of shape (B, C, H, W), computes a spatial soft-argmax to
produce (x, y) coordinates for each channel, resulting in output (B, C*2).
"""

def __init__(self, height: int, width: int, temperature: float = 1.0) -> None:
super().__init__()
self.height = height
self.width = width
self.temperature = temperature

pos_y, pos_x = torch.meshgrid(
torch.linspace(-1.0, 1.0, height), torch.linspace(-1.0, 1.0, width), indexing="ij"
)
self.register_buffer("pos_x", pos_x.reshape(-1))
self.register_buffer("pos_y", pos_y.reshape(-1))

def forward(self, features: torch.Tensor) -> torch.Tensor:
b, c, _, _ = features.shape
features_flat = features.view(b, c, -1) # (B, C, H*W)
attention = F.softmax(features_flat / self.temperature, dim=-1) # (B, C, H*W)
x_exp = torch.sum(self.pos_x * attention, dim=-1, keepdim=True) # (B, C, 1)
y_exp = torch.sum(self.pos_y * attention, dim=-1, keepdim=True) # (B, C, 1)
coords = torch.cat([x_exp, y_exp], dim=-1) # (B, C, 2)
return coords.view(b, -1) # (B, C*2)