|
15 | 15 | import os |
16 | 16 | import tempfile |
17 | 17 | import unittest |
| 18 | +from collections import OrderedDict |
18 | 19 | from typing import Any, Callable, Dict, List, Optional |
19 | 20 | from unittest.mock import MagicMock, patch |
20 | 21 |
|
@@ -223,6 +224,53 @@ def test_rectask_info(self) -> None: |
223 | 224 | metric_module_unified_task_info.rec_metrics[0]._tasks, |
224 | 225 | ) |
225 | 226 |
|
| 227 | + def test_compatibility_with_older_metric_module(self) -> None: |
| 228 | + """ |
| 229 | + This test checks if latest RecMetricModule can load up |
| 230 | + metric module from an older checkpoint |
| 231 | + """ |
| 232 | + # This simulates what an older checkpoint may have |
| 233 | + predefined_state_dict = OrderedDict( |
| 234 | + { |
| 235 | + "rec_metrics.rec_metrics.0._metrics_computations.0.cross_entropy_sum": torch.tensor( |
| 236 | + [0.0], dtype=torch.float64 |
| 237 | + ), |
| 238 | + "rec_metrics.rec_metrics.0._metrics_computations.0.weighted_num_samples": torch.tensor( |
| 239 | + [0.0], dtype=torch.float64 |
| 240 | + ), |
| 241 | + "rec_metrics.rec_metrics.0._metrics_computations.0.pos_labels": torch.tensor( |
| 242 | + [0.0], dtype=torch.float64 |
| 243 | + ), |
| 244 | + "rec_metrics.rec_metrics.0._metrics_computations.0.neg_labels": torch.tensor( |
| 245 | + [0.0], dtype=torch.float64 |
| 246 | + ), |
| 247 | + "throughput_metric.total_examples": torch.Tensor(0), |
| 248 | + "throughput_metric.warmup_examples": torch.tensor(0), |
| 249 | + "throughput_metric.time_lapse_after_warmup": torch.tensor( |
| 250 | + 0.0, dtype=torch.float64 |
| 251 | + ), |
| 252 | + } |
| 253 | + ) |
| 254 | + |
| 255 | + # This is the latest RecMetricModule |
| 256 | + mock_optimizer = MockOptimizer() |
| 257 | + config = DefaultMetricsConfig |
| 258 | + latest_metric_module = generate_metric_module( |
| 259 | + TestMetricModule, |
| 260 | + metrics_config=config, |
| 261 | + batch_size=128, |
| 262 | + world_size=64, |
| 263 | + my_rank=0, |
| 264 | + state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer}, |
| 265 | + device=torch.device("cpu"), |
| 266 | + ) |
| 267 | + tc = unittest.TestCase() |
| 268 | + tc.assertEqual( |
| 269 | + predefined_state_dict.keys(), |
| 270 | + latest_metric_module.state_dict().keys(), |
| 271 | + "RecMetricModule state_dict keys have changed - ensure backward compatibility with older checkpoints", |
| 272 | + ) |
| 273 | + |
226 | 274 | @staticmethod |
227 | 275 | def _run_trainer_checkpointing(rank: int, world_size: int, backend: str) -> None: |
228 | 276 | dist.init_process_group( |
|
0 commit comments