Skip to content

Commit a093b4a

Browse files
Marvin182copybara-github
authored andcommitted
Remove default for checkpoint argument.
PiperOrigin-RevId: 476332142
1 parent bc60709 commit a093b4a

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

vmoe/evaluate/evaluator_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _create_dataset_and_expected_state(cls):
7575
# 0 or loss[i].
7676
sum_loss=tf.reduce_sum(loss * valid).numpy(),
7777
rngs={})
78-
return TfDatasetIterator(dataset), expected_eval_state
78+
return TfDatasetIterator(dataset, checkpoint=False), expected_eval_state
7979

8080
def test_evaluate_dataset(self):
8181
# Create random test dataset.

vmoe/evaluate/fewshot_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,12 @@ def setUp(self):
112112
'label': labels,
113113
fewshot.VALID_KEY: valid,
114114
})
115-
self.mock_get_dataset = self.enter_context(mock.patch.object(
116-
fewshot.vmoe.data.input_pipeline, 'get_dataset',
117-
side_effect=lambda *a, **kw: clu.data.TfDatasetIterator(dataset)))
115+
self.mock_get_dataset = self.enter_context(
116+
mock.patch.object(
117+
fewshot.vmoe.data.input_pipeline,
118+
'get_dataset',
119+
side_effect=lambda *a, **kw: clu.data.TfDatasetIterator( # pylint: disable=g-long-lambda
120+
dataset, checkpoint=False)))
118121

119122
@classmethod
120123
def _apply_fn(cls, variables, images, rngs=None):

0 commit comments

Comments
 (0)