Skip to content

Commit a5473db

Browse files
prajjwal1facebook-github-bot
authored andcommitted
Add Metric compatibility test for RecMetricsModule (#3586)
Summary: In response to https://www.internalfb.com/sevmanager/view/592632, we add a Metric module compatibility test wherein we try to load an older metric module with the latest RecMetricModule. We have a predefined state dict, which simulates a metric module obtained from an older code. We compare its keys with the latest state dict of RecMetricModule Differential Revision: D88207525
1 parent 0a2cebd commit a5473db

File tree

1 file changed

+182
-0
lines changed

1 file changed

+182
-0
lines changed

torchrec/metrics/tests/test_metric_module.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,188 @@ def test_rectask_info(self) -> None:
223223
metric_module_unified_task_info.rec_metrics[0]._tasks,
224224
)
225225

226+
def test_compatibility_with_older_metric_module(self) -> None:
227+
"""
228+
This test checks if latest RecMetricModule can load up
229+
metric module from an older checkpoint
230+
"""
231+
232+
def _create_comprehensive_metrics_config() -> MetricsConfig:
233+
"""
234+
Similar to DefaultMetricsConfig, but with comprehensive metrics and tasks.
235+
"""
236+
from torchrec.metrics.metric_module import REC_METRICS_MAPPING
237+
from torchrec.metrics.metrics_config import SessionMetricDef
238+
239+
metric_arguments = {
240+
RecMetricEnum.MULTICLASS_RECALL: {"number_of_classes": 2},
241+
RecMetricEnum.TOWER_QPS: {"warmup_steps": 100},
242+
}
243+
244+
# Session-level metrics require special task configuration with session_metric_def
245+
session_task_info = RecTaskInfo(
246+
name="SessionTask",
247+
label_name="label",
248+
prediction_name="prediction",
249+
weight_name="weight",
250+
session_metric_def=SessionMetricDef(
251+
session_var_name="session",
252+
top_threshold=10,
253+
run_ranking_of_labels=False,
254+
),
255+
)
256+
257+
# Tensor weighted average metric requires tensor_name
258+
tensor_task_info = RecTaskInfo(
259+
name="TensorTask",
260+
label_name="label",
261+
prediction_name="prediction",
262+
weight_name="weight",
263+
tensor_name="target_tensor",
264+
)
265+
266+
session_metrics = {
267+
RecMetricEnum.RECALL_SESSION_LEVEL,
268+
RecMetricEnum.PRECISION_SESSION_LEVEL,
269+
}
270+
271+
tensor_metrics = {
272+
RecMetricEnum.TENSOR_WEIGHTED_AVG,
273+
}
274+
275+
all_metric_defs: Dict[RecMetricEnum, RecMetricDef] = {}
276+
for metric_enum in REC_METRICS_MAPPING.keys():
277+
if isinstance(metric_enum, RecMetricEnum):
278+
# Session-level metrics require special task configuration
279+
if metric_enum in session_metrics:
280+
all_metric_defs[metric_enum] = RecMetricDef(
281+
rec_tasks=[session_task_info],
282+
window_size=_DEFAULT_WINDOW_SIZE,
283+
arguments=metric_arguments.get(metric_enum),
284+
)
285+
# Tensor metrics require tensor_name
286+
elif metric_enum in tensor_metrics:
287+
all_metric_defs[metric_enum] = RecMetricDef(
288+
rec_tasks=[tensor_task_info],
289+
window_size=_DEFAULT_WINDOW_SIZE,
290+
arguments=metric_arguments.get(metric_enum),
291+
)
292+
else:
293+
arguments = metric_arguments.get(metric_enum)
294+
all_metric_defs[metric_enum] = RecMetricDef(
295+
rec_tasks=[DefaultTaskInfo],
296+
window_size=_DEFAULT_WINDOW_SIZE,
297+
arguments=arguments,
298+
)
299+
comprehensive_config = MetricsConfig(
300+
rec_tasks=[DefaultTaskInfo],
301+
rec_metrics=all_metric_defs,
302+
throughput_metric=ThroughputDef(),
303+
state_metrics=[],
304+
)
305+
return comprehensive_config
306+
307+
ComprehensiveMetricsConfig: MetricsConfig = (
308+
_create_comprehensive_metrics_config()
309+
)
310+
# This simulates what an older checkpoint may have
311+
predefined_state_dict_keys = [
312+
"rec_metrics.rec_metrics.0._metrics_computations.0.cross_entropy_sum",
313+
"rec_metrics.rec_metrics.0._metrics_computations.0.weighted_num_samples",
314+
"rec_metrics.rec_metrics.0._metrics_computations.0.pos_labels",
315+
"rec_metrics.rec_metrics.0._metrics_computations.0.neg_labels",
316+
"rec_metrics.rec_metrics.1._metrics_computations.0.cross_entropy_positive_sum",
317+
"rec_metrics.rec_metrics.1._metrics_computations.0.weighted_num_samples",
318+
"rec_metrics.rec_metrics.1._metrics_computations.0.pos_labels",
319+
"rec_metrics.rec_metrics.1._metrics_computations.0.neg_labels",
320+
"rec_metrics.rec_metrics.2._metrics_computations.0.cross_entropy_sum",
321+
"rec_metrics.rec_metrics.2._metrics_computations.0.weighted_num_samples",
322+
"rec_metrics.rec_metrics.2._metrics_computations.0.pos_labels",
323+
"rec_metrics.rec_metrics.2._metrics_computations.0.neg_labels",
324+
"rec_metrics.rec_metrics.3._metrics_computations.0.cross_entropy_sum",
325+
"rec_metrics.rec_metrics.3._metrics_computations.0.weighted_num_samples",
326+
"rec_metrics.rec_metrics.3._metrics_computations.0.pos_labels",
327+
"rec_metrics.rec_metrics.3._metrics_computations.0.neg_labels",
328+
"rec_metrics.rec_metrics.4._metrics_computations.0.calibration_num",
329+
"rec_metrics.rec_metrics.4._metrics_computations.0.calibration_denom",
330+
"rec_metrics.rec_metrics.5._metrics_computations.0.ctr_num",
331+
"rec_metrics.rec_metrics.5._metrics_computations.0.ctr_denom",
332+
"rec_metrics.rec_metrics.6._metrics_computations.0.calibration_num",
333+
"rec_metrics.rec_metrics.6._metrics_computations.0.calibration_denom",
334+
"rec_metrics.rec_metrics.10._metrics_computations.0.error_sum",
335+
"rec_metrics.rec_metrics.10._metrics_computations.0.weighted_num_samples",
336+
"rec_metrics.rec_metrics.11._metrics_computations.0.error_sum",
337+
"rec_metrics.rec_metrics.11._metrics_computations.0.weighted_num_samples",
338+
"rec_metrics.rec_metrics.12._metrics_computations.0.tp_at_k",
339+
"rec_metrics.rec_metrics.12._metrics_computations.0.total_weights",
340+
"rec_metrics.rec_metrics.13._metrics_computations.0.weighted_sum",
341+
"rec_metrics.rec_metrics.13._metrics_computations.0.weighted_num_samples",
342+
"rec_metrics.rec_metrics.14._metrics_computations.0.num_examples",
343+
"rec_metrics.rec_metrics.14._metrics_computations.0.warmup_examples",
344+
"rec_metrics.rec_metrics.14._metrics_computations.0.time_lapse",
345+
"rec_metrics.rec_metrics.15._metrics_computations.0.num_true_pos",
346+
"rec_metrics.rec_metrics.15._metrics_computations.0.num_false_neg",
347+
"rec_metrics.rec_metrics.16._metrics_computations.0.num_true_pos",
348+
"rec_metrics.rec_metrics.16._metrics_computations.0.num_false_pos",
349+
"rec_metrics.rec_metrics.17._metrics_computations.0.accuracy_sum",
350+
"rec_metrics.rec_metrics.17._metrics_computations.0.weighted_num_samples",
351+
"rec_metrics.rec_metrics.18._metrics_computations.0.sum_ndcg",
352+
"rec_metrics.rec_metrics.18._metrics_computations.0.num_sessions",
353+
"rec_metrics.rec_metrics.19._metrics_computations.0.error_sum",
354+
"rec_metrics.rec_metrics.19._metrics_computations.0.weighted_num_pairs",
355+
"rec_metrics.rec_metrics.21._metrics_computations.0.true_pos_sum",
356+
"rec_metrics.rec_metrics.21._metrics_computations.0.false_pos_sum",
357+
"rec_metrics.rec_metrics.22._metrics_computations.0.true_pos_sum",
358+
"rec_metrics.rec_metrics.22._metrics_computations.0.false_neg_sum",
359+
"rec_metrics.rec_metrics.23._metrics_computations.0.cross_entropy_sum",
360+
"rec_metrics.rec_metrics.23._metrics_computations.0.weighted_num_samples",
361+
"rec_metrics.rec_metrics.23._metrics_computations.0.pos_labels",
362+
"rec_metrics.rec_metrics.23._metrics_computations.0.neg_labels",
363+
"rec_metrics.rec_metrics.23._metrics_computations.0.num_examples",
364+
"rec_metrics.rec_metrics.24._metrics_computations.0.calibration_num",
365+
"rec_metrics.rec_metrics.24._metrics_computations.0.calibration_denom",
366+
"rec_metrics.rec_metrics.24._metrics_computations.0.num_examples",
367+
"rec_metrics.rec_metrics.26._metrics_computations.0.weighted_sum",
368+
"rec_metrics.rec_metrics.26._metrics_computations.0.weighted_num_samples",
369+
"rec_metrics.rec_metrics.27._metrics_computations.0.cross_entropy_sum",
370+
"rec_metrics.rec_metrics.27._metrics_computations.0.weighted_num_samples",
371+
"rec_metrics.rec_metrics.27._metrics_computations.0.pos_labels",
372+
"rec_metrics.rec_metrics.27._metrics_computations.0.neg_labels",
373+
"rec_metrics.rec_metrics.27._metrics_computations.0.weighted_sum_predictions",
374+
"rec_metrics.rec_metrics.28._metrics_computations.0.cross_entropy_sum",
375+
"rec_metrics.rec_metrics.28._metrics_computations.0.weighted_num_samples",
376+
"rec_metrics.rec_metrics.28._metrics_computations.0.pos_labels",
377+
"rec_metrics.rec_metrics.28._metrics_computations.0.neg_labels",
378+
"rec_metrics.rec_metrics.29._metrics_computations.0.true_pos_sum",
379+
"rec_metrics.rec_metrics.29._metrics_computations.0.false_pos_sum",
380+
"rec_metrics.rec_metrics.29._metrics_computations.0.false_neg_sum",
381+
"rec_metrics.rec_metrics.30._metrics_computations.0.error_sum",
382+
"rec_metrics.rec_metrics.30._metrics_computations.0.weighted_num_samples",
383+
"rec_metrics.rec_metrics.30._metrics_computations.0.const_pred_error_sum",
384+
"throughput_metric.total_examples",
385+
"throughput_metric.warmup_examples",
386+
"throughput_metric.time_lapse_after_warmup",
387+
]
388+
389+
# This is the latest RecMetricModule
390+
mock_optimizer = MockOptimizer()
391+
392+
latest_metric_module = generate_metric_module(
393+
TestMetricModule,
394+
metrics_config=ComprehensiveMetricsConfig,
395+
batch_size=128,
396+
world_size=64,
397+
my_rank=0,
398+
state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer},
399+
device=torch.device("cpu"),
400+
)
401+
tc = unittest.TestCase()
402+
tc.assertSetEqual(
403+
set(predefined_state_dict_keys),
404+
set(latest_metric_module.state_dict().keys()),
405+
"RecMetricModule state_dict keys have changed - ensure backward compatibility with older checkpoints",
406+
)
407+
226408
@staticmethod
227409
def _run_trainer_checkpointing(rank: int, world_size: int, backend: str) -> None:
228410
dist.init_process_group(

0 commit comments

Comments
 (0)