Skip to content

Commit b17be68

Browse files
committed
Add networks
1 parent 398c0a6 commit b17be68

File tree

14 files changed

+1186
-29
lines changed

14 files changed

+1186
-29
lines changed

poetry.lock

Lines changed: 691 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ torch = "^2.3.0"
6565
empy = "=3.3.4"
6666
squaternion = "^2023.9.2"
6767
tqdm = "^4.66.5"
68+
huggingface-hub = "^0.24.6"
69+
transformers = "^4.44.2"
70+
accelerate = "^0.33.0"
71+
bitsandbytes = "^0.43.3"
6872

6973

7074
[tool.poetry.group.dev.dependencies]

src/drl_navigation_ros2/SAC/SAC.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn.functional as F
4+
import SAC.SAC_utils as utils
5+
from SAC.SAC_critic import DoubleQCritic as critic_model
6+
from SAC.SAC_actor import DiagGaussianActor as actor_model
7+
from torch.utils.tensorboard import SummaryWriter
8+
9+
10+
class SAC(object):
11+
"""SAC algorithm."""
12+
13+
def __init__(
14+
self,
15+
obs_dim,
16+
action_dim,
17+
action_range,
18+
device,
19+
discount,
20+
init_temperature,
21+
alpha_lr,
22+
alpha_betas,
23+
actor_lr,
24+
actor_betas,
25+
actor_update_frequency,
26+
critic_lr,
27+
critic_betas,
28+
critic_tau,
29+
critic_target_update_frequency,
30+
batch_size,
31+
learnable_temperature,
32+
):
33+
super().__init__()
34+
35+
self.state_dim = obs_dim
36+
self.action_dim = action_dim
37+
self.action_range = action_range
38+
self.device = torch.device(device)
39+
self.discount = discount
40+
self.critic_tau = critic_tau
41+
self.actor_update_frequency = actor_update_frequency
42+
self.critic_target_update_frequency = critic_target_update_frequency
43+
self.batch_size = batch_size
44+
self.learnable_temperature = learnable_temperature
45+
46+
self.critic = critic_model(
47+
obs_dim=obs_dim, action_dim=action_dim, hidden_dim=1024, hidden_depth=2
48+
).to(self.device)
49+
self.critic_target = critic_model(
50+
obs_dim=obs_dim, action_dim=action_dim, hidden_dim=1024, hidden_depth=2
51+
).to(self.device)
52+
self.critic_target.load_state_dict(self.critic.state_dict())
53+
54+
self.actor = actor_model(
55+
obs_dim=obs_dim,
56+
action_dim=action_dim,
57+
hidden_dim=1024,
58+
hidden_depth=2,
59+
log_std_bounds=[-5, 2],
60+
).to(self.device)
61+
62+
self.log_alpha = torch.tensor(np.log(init_temperature)).to(self.device)
63+
self.log_alpha.requires_grad = True
64+
# set target entropy to -|A|
65+
self.target_entropy = -action_dim
66+
67+
# optimizers
68+
self.actor_optimizer = torch.optim.Adam(
69+
self.actor.parameters(), lr=actor_lr, betas=actor_betas
70+
)
71+
72+
self.critic_optimizer = torch.optim.Adam(
73+
self.critic.parameters(), lr=critic_lr, betas=critic_betas
74+
)
75+
76+
self.log_alpha_optimizer = torch.optim.Adam(
77+
[self.log_alpha], lr=alpha_lr, betas=alpha_betas
78+
)
79+
80+
self.critic_target.train()
81+
82+
self.actor.train(True)
83+
self.critic.train(True)
84+
self.step = 0
85+
self.writer = SummaryWriter()
86+
87+
def train(self, replay_buffer, iterations, batch_size):
88+
for _ in range(iterations):
89+
self.update(
90+
replay_buffer=replay_buffer, step=self.step, batch_size=batch_size
91+
)
92+
self.step += 1
93+
94+
@property
95+
def alpha(self):
96+
return self.log_alpha.exp()
97+
98+
def get_action(self, obs, add_noise):
99+
if add_noise:
100+
return (
101+
self.act(obs) + np.random.normal(0, 0.2, size=self.action_dim)
102+
).clip(self.action_range[0], self.action_range[1])
103+
else:
104+
return self.act(obs)
105+
106+
def act(self, obs, sample=False):
107+
obs = torch.FloatTensor(obs).to(self.device)
108+
obs = obs.unsqueeze(0)
109+
dist = self.actor(obs)
110+
action = dist.sample() if sample else dist.mean
111+
action = action.clamp(*self.action_range)
112+
assert action.ndim == 2 and action.shape[0] == 1
113+
return utils.to_np(action[0])
114+
115+
def update_critic(self, obs, action, reward, next_obs, done, step):
116+
dist = self.actor(next_obs)
117+
next_action = dist.rsample()
118+
log_prob = dist.log_prob(next_action).sum(-1, keepdim=True)
119+
target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
120+
target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_prob
121+
target_Q = reward + ((1 - done) * self.discount * target_V)
122+
target_Q = target_Q.detach()
123+
124+
# get current Q estimates
125+
current_Q1, current_Q2 = self.critic(obs, action)
126+
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
127+
current_Q2, target_Q
128+
)
129+
self.writer.add_scalar("train_critic/loss", critic_loss, step)
130+
131+
# Optimize the critic
132+
self.critic_optimizer.zero_grad()
133+
critic_loss.backward()
134+
self.critic_optimizer.step()
135+
136+
self.critic.log(self.writer, step)
137+
138+
def update_actor_and_alpha(self, obs, step):
139+
dist = self.actor(obs)
140+
action = dist.rsample()
141+
log_prob = dist.log_prob(action).sum(-1, keepdim=True)
142+
actor_Q1, actor_Q2 = self.critic(obs, action)
143+
144+
actor_Q = torch.min(actor_Q1, actor_Q2)
145+
actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean()
146+
147+
self.writer.add_scalar("train_actor/loss", actor_loss, step)
148+
self.writer.add_scalar("train_actor/target_entropy", self.target_entropy, step)
149+
self.writer.add_scalar("train_actor/entropy", -log_prob.mean(), step)
150+
151+
# optimize the actor
152+
self.actor_optimizer.zero_grad()
153+
actor_loss.backward()
154+
self.actor_optimizer.step()
155+
156+
self.actor.log(self.writer, step)
157+
158+
if self.learnable_temperature:
159+
self.log_alpha_optimizer.zero_grad()
160+
alpha_loss = (
161+
self.alpha * (-log_prob - self.target_entropy).detach()
162+
).mean()
163+
self.writer.add_scalar("train_alpha/loss", alpha_loss, step)
164+
self.writer.add_scalar("train_alpha/value", self.alpha, step)
165+
alpha_loss.backward()
166+
self.log_alpha_optimizer.step()
167+
168+
def update(self, replay_buffer, step, batch_size):
169+
(
170+
batch_states,
171+
batch_actions,
172+
batch_rewards,
173+
batch_dones,
174+
batch_next_states,
175+
) = replay_buffer.sample_batch(batch_size)
176+
177+
state = torch.Tensor(batch_states).to(self.device)
178+
next_state = torch.Tensor(batch_next_states).to(self.device)
179+
action = torch.Tensor(batch_actions).to(self.device)
180+
reward = torch.Tensor(batch_rewards).to(self.device)
181+
done = torch.Tensor(batch_dones).to(self.device)
182+
183+
self.writer.add_scalar("train/batch_reward", batch_rewards.mean(), step)
184+
185+
self.update_critic(state, action, reward, next_state, done, step)
186+
187+
if step % self.actor_update_frequency == 0:
188+
self.update_actor_and_alpha(state, step)
189+
190+
if step % self.critic_target_update_frequency == 0:
191+
utils.soft_update_params(self.critic, self.critic_target, self.critic_tau)
192+
193+
def prepare_state(self, latest_scan, distance, cos, sin, collision, goal, action):
194+
# update the returned data from ROS into a form used for learning in the current model
195+
latest_scan = np.array(latest_scan)
196+
197+
inf_mask = np.isinf(latest_scan)
198+
latest_scan[inf_mask] = 7.0
199+
200+
max_bins = self.state_dim - 5
201+
bin_size = int(np.ceil(len(latest_scan) / max_bins))
202+
203+
# Initialize the list to store the minimum values of each bin
204+
min_values = []
205+
206+
# Loop through the data and create bins
207+
for i in range(0, len(latest_scan), bin_size):
208+
# Get the current bin
209+
bin = latest_scan[i : i + min(bin_size, len(latest_scan) - i)]
210+
# Find the minimum value in the current bin and append it to the min_values list
211+
min_values.append(min(bin))
212+
state = min_values + [distance, cos, sin] + [action[0], action[1]]
213+
214+
assert len(state) == self.state_dim
215+
terminal = 1 if collision or goal else 0
216+
217+
return state, terminal
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import torch
2+
import math
3+
from torch import nn
4+
import torch.nn.functional as F
5+
from torch import distributions as pyd
6+
7+
import SAC.SAC_utils as utils
8+
9+
10+
class TanhTransform(pyd.transforms.Transform):
11+
domain = pyd.constraints.real
12+
codomain = pyd.constraints.interval(-1.0, 1.0)
13+
bijective = True
14+
sign = +1
15+
16+
def __init__(self, cache_size=1):
17+
super().__init__(cache_size=cache_size)
18+
19+
@staticmethod
20+
def atanh(x):
21+
return 0.5 * (x.log1p() - (-x).log1p())
22+
23+
def __eq__(self, other):
24+
return isinstance(other, TanhTransform)
25+
26+
def _call(self, x):
27+
return x.tanh()
28+
29+
def _inverse(self, y):
30+
# We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
31+
# one should use `cache_size=1` instead
32+
return self.atanh(y)
33+
34+
def log_abs_det_jacobian(self, x, y):
35+
# We use a formula that is more numerically stable, see details in the following link
36+
# https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
37+
return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x))
38+
39+
40+
class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
41+
def __init__(self, loc, scale):
42+
self.loc = loc
43+
self.scale = scale
44+
45+
self.base_dist = pyd.Normal(loc, scale)
46+
transforms = [TanhTransform()]
47+
super().__init__(self.base_dist, transforms)
48+
49+
@property
50+
def mean(self):
51+
mu = self.loc
52+
for tr in self.transforms:
53+
mu = tr(mu)
54+
return mu
55+
56+
57+
class DiagGaussianActor(nn.Module):
58+
"""torch.distributions implementation of an diagonal Gaussian policy."""
59+
60+
def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth, log_std_bounds):
61+
super().__init__()
62+
63+
self.log_std_bounds = log_std_bounds
64+
self.trunk = utils.mlp(obs_dim, hidden_dim, 2 * action_dim, hidden_depth)
65+
66+
self.outputs = dict()
67+
self.apply(utils.weight_init)
68+
69+
def forward(self, obs):
70+
mu, log_std = self.trunk(obs).chunk(2, dim=-1)
71+
72+
# constrain log_std inside [log_std_min, log_std_max]
73+
log_std = torch.tanh(log_std)
74+
log_std_min, log_std_max = self.log_std_bounds
75+
log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1)
76+
77+
std = log_std.exp()
78+
79+
self.outputs["mu"] = mu
80+
self.outputs["std"] = std
81+
82+
dist = SquashedNormal(mu, std)
83+
return dist
84+
85+
def log(self, writer, step):
86+
for k, v in self.outputs.items():
87+
writer.add_histogram(f"train_actor/{k}_hist", v, step)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
from torch import nn
3+
4+
import SAC.SAC_utils as utils
5+
6+
7+
class DoubleQCritic(nn.Module):
8+
"""Critic network, employes double Q-learning."""
9+
10+
def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth):
11+
super().__init__()
12+
13+
self.Q1 = utils.mlp(obs_dim + action_dim, hidden_dim, 1, hidden_depth)
14+
self.Q2 = utils.mlp(obs_dim + action_dim, hidden_dim, 1, hidden_depth)
15+
16+
self.outputs = dict()
17+
self.apply(utils.weight_init)
18+
19+
def forward(self, obs, action):
20+
assert obs.size(0) == action.size(0)
21+
22+
obs_action = torch.cat([obs, action], dim=-1)
23+
q1 = self.Q1(obs_action)
24+
q2 = self.Q2(obs_action)
25+
26+
self.outputs["q1"] = q1
27+
self.outputs["q2"] = q2
28+
29+
return q1, q2
30+
31+
def log(self, writer, step):
32+
for k, v in self.outputs.items():
33+
writer.add_histogram(f"train_critic/{k}_hist", v, step)

0 commit comments

Comments
 (0)