11import os
22import pytest
3+ import torch
34
45import jiant .utils .python .io as py_io
56from jiant .proj .simple import runscript as run
67import jiant .scripts .download_data .runscript as downloader
8+ import jiant .utils .torch_utils as torch_utils
79
810
911@pytest .mark .parametrize ("task_name" , ["copa" ])
@@ -29,3 +31,83 @@ def test_simple_runscript(tmpdir, task_name, model_type):
2931
3032 val_metrics = py_io .read_json (os .path .join (exp_dir , "runs" , RUN_NAME , "val_metrics.json" ))
3133 assert val_metrics ["aggregated" ] > 0
34+
35+
36+ @pytest .mark .gpu
37+ @pytest .mark .parametrize ("task_name" , ["copa" ])
38+ @pytest .mark .parametrize ("model_type" , ["roberta-large" ])
39+ def test_simple_runscript_save (tmpdir , task_name , model_type ):
40+ run_name = f"{ test_simple_runscript .__name__ } _{ task_name } _{ model_type } _save"
41+ data_dir = str (tmpdir .mkdir ("data" ))
42+ exp_dir = str (tmpdir .mkdir ("exp" ))
43+
44+ downloader .download_data ([task_name ], data_dir )
45+
46+ args = run .RunConfiguration (
47+ run_name = run_name ,
48+ exp_dir = exp_dir ,
49+ data_dir = data_dir ,
50+ model_type = model_type ,
51+ tasks = task_name ,
52+ max_steps = 1 ,
53+ train_batch_size = 32 ,
54+ do_save = True ,
55+ eval_every_steps = 10 ,
56+ learning_rate = 0.01 ,
57+ num_train_epochs = 5 ,
58+ )
59+ run .run_simple (args )
60+
61+ # check best_model and last_model exist
62+ assert os .path .exists (os .path .join (exp_dir , "runs" , run_name , "best_model.p" ))
63+ assert os .path .exists (os .path .join (exp_dir , "runs" , run_name , "best_model.metadata.json" ))
64+ assert os .path .exists (os .path .join (exp_dir , "runs" , run_name , "last_model.p" ))
65+ assert os .path .exists (os .path .join (exp_dir , "runs" , run_name , "last_model.metadata.json" ))
66+
67+ # assert best_model not equal to last_model
68+ best_model_weights = torch .load (
69+ os .path .join (exp_dir , "runs" , run_name , "best_model.p" ), map_location = torch .device ("cpu" )
70+ )
71+ last_model_weights = torch .load (
72+ os .path .join (exp_dir , "runs" , run_name , "last_model.p" ), map_location = torch .device ("cpu" )
73+ )
74+ assert not torch_utils .eq_state_dicts (best_model_weights , last_model_weights )
75+
76+ run_name = f"{ test_simple_runscript .__name__ } _{ task_name } _{ model_type } _save_best"
77+ args = run .RunConfiguration (
78+ run_name = run_name ,
79+ exp_dir = exp_dir ,
80+ data_dir = data_dir ,
81+ model_type = model_type ,
82+ tasks = task_name ,
83+ max_steps = 1 ,
84+ train_batch_size = 16 ,
85+ do_save_best = True ,
86+ )
87+ run .run_simple (args )
88+
89+ # check only best_model saved
90+ assert os .path .exists (os .path .join (exp_dir , "runs" , run_name , "best_model.p" ))
91+ assert os .path .exists (os .path .join (exp_dir , "runs" , run_name , "best_model.metadata.json" ))
92+ assert not os .path .exists (os .path .join (exp_dir , "runs" , run_name , "last_model.p" ))
93+ assert not os .path .exists (os .path .join (exp_dir , "runs" , run_name , "last_model.metadata.json" ))
94+
95+ # check output last model
96+ run_name = f"{ test_simple_runscript .__name__ } _{ task_name } _{ model_type } _save_last"
97+ args = run .RunConfiguration (
98+ run_name = run_name ,
99+ exp_dir = exp_dir ,
100+ data_dir = data_dir ,
101+ model_type = model_type ,
102+ tasks = task_name ,
103+ max_steps = 1 ,
104+ train_batch_size = 16 ,
105+ do_save_last = True ,
106+ )
107+ run .run_simple (args )
108+
109+ # check only last_model saved
110+ assert not os .path .exists (os .path .join (exp_dir , "runs" , run_name , "best_model.p" ))
111+ assert not os .path .exists (os .path .join (exp_dir , "runs" , run_name , "best_model.metadata.json" ))
112+ assert os .path .exists (os .path .join (exp_dir , "runs" , run_name , "last_model.p" ))
113+ assert os .path .exists (os .path .join (exp_dir , "runs" , run_name , "last_model.metadata.json" ))
0 commit comments