Skip to content

Commit c80b502

Browse files
prajjwal1facebook-github-bot
authored andcommitted
Add Metric compatibility test for RecMetricsModule (#3586)
Summary: In response to https://www.internalfb.com/sevmanager/view/592632, we add a Metric module compatibility test wherein we try to load an older metric module with the latest RecMetricModule. We have a predefined state dict, which simulates a metric module obtained from an older code. We compare its keys with the latest state dict of RecMetricModule Differential Revision: D88207525
1 parent a533012 commit c80b502

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

torchrec/metrics/tests/test_metric_module.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import tempfile
1717
import unittest
18+
from collections import OrderedDict
1819
from typing import Any, Callable, Dict, List, Optional
1920
from unittest.mock import MagicMock, patch
2021

@@ -223,6 +224,53 @@ def test_rectask_info(self) -> None:
223224
metric_module_unified_task_info.rec_metrics[0]._tasks,
224225
)
225226

227+
def test_compatibility_with_older_metric_module(self) -> None:
228+
"""
229+
This test checks if latest RecMetricModule can load up
230+
metric module from an older checkpoint
231+
"""
232+
# This simulates what an older checkpoint may have
233+
predefined_state_dict = OrderedDict(
234+
{
235+
"rec_metrics.rec_metrics.0._metrics_computations.0.cross_entropy_sum": torch.tensor(
236+
[0.0], dtype=torch.float64
237+
),
238+
"rec_metrics.rec_metrics.0._metrics_computations.0.weighted_num_samples": torch.tensor(
239+
[0.0], dtype=torch.float64
240+
),
241+
"rec_metrics.rec_metrics.0._metrics_computations.0.pos_labels": torch.tensor(
242+
[0.0], dtype=torch.float64
243+
),
244+
"rec_metrics.rec_metrics.0._metrics_computations.0.neg_labels": torch.tensor(
245+
[0.0], dtype=torch.float64
246+
),
247+
"throughput_metric.total_examples": torch.Tensor(0),
248+
"throughput_metric.warmup_examples": torch.tensor(0),
249+
"throughput_metric.time_lapse_after_warmup": torch.tensor(
250+
0.0, dtype=torch.float64
251+
),
252+
}
253+
)
254+
255+
# This is the latest RecMetricModule
256+
mock_optimizer = MockOptimizer()
257+
config = DefaultMetricsConfig
258+
latest_metric_module = generate_metric_module(
259+
TestMetricModule,
260+
metrics_config=config,
261+
batch_size=128,
262+
world_size=64,
263+
my_rank=0,
264+
state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer},
265+
device=torch.device("cpu"),
266+
)
267+
tc = unittest.TestCase()
268+
tc.assertEqual(
269+
predefined_state_dict.keys(),
270+
latest_metric_module.state_dict().keys(),
271+
"RecMetricModule state_dict keys have changed - ensure backward compatibility with older checkpoints",
272+
)
273+
226274
@staticmethod
227275
def _run_trainer_checkpointing(rank: int, world_size: int, backend: str) -> None:
228276
dist.init_process_group(

0 commit comments

Comments
 (0)