Skip to content

Commit 5e828af

Browse files
kernel-tyety
andauthored
The improved model cannot load properly (#475)
* Modify AgentBase, config, and run scripts * Modify AgentBase, config, and run scripts --------- Co-authored-by: ty <[email protected]>
1 parent 3bdc958 commit 5e828af

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

elegantrl/agents/AgentBase.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def save_or_load_agent(self, cwd: str, if_save: bool):
292292
continue
293293

294294
if if_save:
295-
th.save(getattr(self, attr_name).state_dict(), file_path)
295+
th.save(getattr(self, attr_name), file_path)
296296
elif os.path.isfile(file_path):
297297
setattr(self, attr_name, th.load(file_path, map_location=self.device))
298298

elegantrl/train/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self, agent_class=None, env_class=None, env_args=None):
4242
self.clip_grad_norm = 3.0 # 0.1 ~ 4.0, clip the gradient after normalization
4343
self.state_value_tau = 0 # the tau of normalize for value and state `std = (1-std)*std + tau*std`
4444
self.soft_update_tau = 5e-3 # 2 ** -8 ~= 5e-3. the tau of soft target update `net = (1-tau)*net + tau*net1`
45+
self.continue_train = False # continue train use last train saved models
4546
if self.if_off_policy: # off-policy
4647
self.batch_size = int(64) # num of transitions sampled from replay buffer.
4748
self.horizon_len = int(512) # collect horizon_len step while exploring, then update networks

elegantrl/train/run.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def train_agent_single_process(args: Config):
4545

4646
'''init agent'''
4747
agent = args.agent_class(args.net_dims, args.state_dim, args.action_dim, gpu_id=args.gpu_id, args=args)
48-
agent.save_or_load_agent(args.cwd, if_save=False)
48+
if args.continue_train:
49+
agent.save_or_load_agent(args.cwd, if_save=False)
4950

5051
'''init agent.last_state'''
5152
state, info_dict = env.reset()
@@ -234,7 +235,8 @@ def run(self):
234235

235236
'''Learner init agent'''
236237
agent = args.agent_class(args.net_dims, args.state_dim, args.action_dim, gpu_id=args.gpu_id, args=args)
237-
agent.save_or_load_agent(args.cwd, if_save=False)
238+
if args.continue_train:
239+
agent.save_or_load_agent(args.cwd, if_save=False)
238240

239241
'''Learner init buffer'''
240242
if args.if_off_policy:
@@ -373,7 +375,8 @@ def run(self):
373375

374376
'''init agent'''
375377
agent = args.agent_class(args.net_dims, args.state_dim, args.action_dim, gpu_id=args.gpu_id, args=args)
376-
agent.save_or_load_agent(args.cwd, if_save=False)
378+
if args.continue_train:
379+
agent.save_or_load_agent(args.cwd, if_save=False)
377380

378381
'''init agent.last_state'''
379382
state, info_dict = env.reset()

0 commit comments

Comments
 (0)