Skip to content

Commit fc08e1b

Browse files
authored
[Hook] Add 'before_create_session' interface to SessionRunHook. (#991)
Signed-off-by: chenbangduo.cbd <[email protected]>
1 parent 04413cf commit fc08e1b

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

tensorflow/python/training/monitored_session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,8 @@ def __init__(self, session_creator, hooks, stop_grace_period_secs):
957957
def create_session(self):
958958
"""Creates a coordinated session."""
959959
# Keep the tf_sess for unit testing.
960+
for hook in self._hooks:
961+
hook.before_create_session()
960962
self.tf_sess = self._session_creator.create_session()
961963
# We don't want coordinator to suppress any exception.
962964
self.coord = coordinator.Coordinator(clean_stop_exception_types=[])
@@ -1027,6 +1029,7 @@ class MonitoredSession(_MonitoredSession):
10271029
in given order:
10281030
10291031
* calls `hook.begin()` for each given hook
1032+
* calls `hook.before_create_session()`
10301033
* finalizes the graph via `scaffold.finalize()`
10311034
* create session
10321035
* initializes the model via initialization ops provided by `Scaffold`

tensorflow/python/training/session_run_hook.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,20 @@ def begin(self):
109109
"""
110110
pass
111111

112+
def before_create_session(self):
113+
"""Called before new TensorFlow session is created.
114+
115+
This has two essential differences with the situation in which `begin` is
116+
called:
117+
118+
* Do not modify the graph in this method, ops should not be added to graph.
119+
The modification of the graph should take place within the begin
120+
interface.
121+
* This method will also be called prior to the recovery of a wrapped
122+
session, not just at the beginning of the overall session.
123+
"""
124+
pass
125+
112126
def after_create_session(self, session, coord): # pylint: disable=unused-argument
113127
"""Called when new TensorFlow session is created.
114128

0 commit comments

Comments
 (0)