11import sys
2- import tensorflow as tf
2+
33import tensorflow .compat .v1 as tfv1
44
55from .flags import FLAGS
6- from .logging import log_info , log_error , log_warn
6+ from .logging import log_error , log_info , log_warn
77
88
99def _load_checkpoint (session , checkpoint_path , allow_drop_layers , allow_lr_init = True ):
@@ -19,47 +19,33 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
1919 # compatibility with older checkpoints.
2020 lr_var = set (v for v in load_vars if v .op .name == 'learning_rate' )
2121 if lr_var and ('learning_rate' not in vars_in_ckpt or
22- (FLAGS .force_initialize_learning_rate and allow_lr_init )):
22+ (FLAGS .force_initialize_learning_rate and allow_lr_init )):
2323 assert len (lr_var ) <= 1
2424 load_vars -= lr_var
2525 init_vars |= lr_var
2626
27- if FLAGS .load_cudnn :
28- # Initialize training from a CuDNN RNN checkpoint
29- # Identify the variables which we cannot load, and set them
30- # for initialization
31- missing_vars = set ()
32- for v in load_vars :
33- if v .op .name not in vars_in_ckpt :
34- log_warn ('CUDNN variable not found: %s' % (v .op .name ))
35- missing_vars .add (v )
27+ # After training with "freeze_source_layers" the Adam moment tensors for the frozen layers
28+ # are missing because they were not used. This might also occur when loading a cudnn checkpoint
29+ # Therefore we have to initialize them again to continue training on such checkpoints
30+ print_msg = False
31+ for v in load_vars :
32+ if v .op .name not in vars_in_ckpt :
33+ if 'Adam' in v .name :
3634 init_vars .add (v )
35+ print_msg = True
36+ if print_msg :
37+ msg = "Some Adam tensors are missing, they will be initialized automatically."
38+ log_info (msg )
39+ load_vars -= init_vars
3740
38- load_vars -= init_vars
39-
40- # Check that the only missing variables (i.e. those to be initialised)
41- # are the Adam moment tensors, if they aren't then we have an issue
42- missing_var_names = [v .op .name for v in missing_vars ]
43- if any ('Adam' not in v for v in missing_var_names ):
44- log_error ('Tried to load a CuDNN RNN checkpoint but there were '
45- 'more missing variables than just the Adam moment '
46- 'tensors. Missing variables: {}' .format (missing_var_names ))
47- sys .exit (1 )
48-
49- if FLAGS .load_frozen_graph :
50- # After training with "freeze_source_layers" the Adam tensors for the frozen layers aren't
51- # existing anymore because they were not used
52- # Therefore we have to initialize them again to continue training on such checkpoints
41+ if FLAGS .load_cudnn :
42+ # Check all required tensors are included in the cudnn checkpoint we want to load
5343 for v in load_vars :
54- if v .op .name not in vars_in_ckpt :
55- if 'Adam' in v .name :
56- init_vars .add (v )
57- else :
58- msg = "Tried to load a frozen checkpoint but there was a missing " \
59- "variable other than the Adam tensors: {}"
60- log_error (msg .format (v ))
61- sys .exit (1 )
62- load_vars -= init_vars
44+ if v .op .name not in vars_in_ckpt and 'Adam' not in v .op .name :
45+ msg = 'Tried to load a CuDNN RNN checkpoint but there was a missing' \
46+ ' variable other than an Adam moment tensor: {}'
47+ log_error (msg .format (v .op .name ))
48+ sys .exit (1 )
6349
6450 if allow_drop_layers and FLAGS .drop_source_layers > 0 :
6551 # This transfer learning approach requires supplying
@@ -74,7 +60,7 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
7460 'dropping only 5 layers.' )
7561 FLAGS .drop_source_layers = 5
7662
77- dropped_layers = [ '2' , '3' , 'lstm' , '5' , '6' ][ - 1 * int (FLAGS .drop_source_layers ):]
63+ dropped_layers = drop_freeze_number_to_layers (FLAGS .drop_source_layers , "drop" )
7864 # Initialize all variables needed for DS, but not loaded from ckpt
7965 for v in load_vars :
8066 if any (layer in v .op .name for layer in dropped_layers ):
@@ -90,6 +76,24 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
9076 session .run (v .initializer )
9177
9278
79+ def drop_freeze_number_to_layers (drop_freeze_number , mode ):
80+ """ Convert number of layers to drop or freeze into layer names """
81+
82+ if drop_freeze_number >= 6 :
83+ log_warn ('The checkpoint only has 6 layers, but you are trying '
84+ 'to drop or freeze all of them or more. Continuing with 5 layers.' )
85+ drop_freeze_number = 5
86+
87+ layer_keys = ["layer_1" , "layer_2" , "layer_3" , "lstm" , "layer_5" , "layer_6" ]
88+ if mode == "drop" :
89+ layer_keys = layer_keys [- 1 * int (drop_freeze_number ):]
90+ elif mode == "freeze" :
91+ layer_keys = layer_keys [:- 1 * int (drop_freeze_number )]
92+ else :
93+ raise ValueError
94+ return layer_keys
95+
96+
9397def _checkpoint_path_or_none (checkpoint_filename ):
9498 checkpoint = tfv1 .train .get_checkpoint_state (FLAGS .load_checkpoint_dir , checkpoint_filename )
9599 if not checkpoint :
0 commit comments