@@ -223,6 +223,188 @@ def test_rectask_info(self) -> None:
223223 metric_module_unified_task_info .rec_metrics [0 ]._tasks ,
224224 )
225225
226+ def test_compatibility_with_older_metric_module (self ) -> None :
227+ """
228+ This test checks if latest RecMetricModule can load up
229+ metric module from an older checkpoint
230+ """
231+
232+ def _create_comprehensive_metrics_config () -> MetricsConfig :
233+ """
234+ Similar to DefaultMetricsConfig, but with comprehensive metrics and tasks.
235+ """
236+ from torchrec .metrics .metric_module import REC_METRICS_MAPPING
237+ from torchrec .metrics .metrics_config import SessionMetricDef
238+
239+ metric_arguments = {
240+ RecMetricEnum .MULTICLASS_RECALL : {"number_of_classes" : 2 },
241+ RecMetricEnum .TOWER_QPS : {"warmup_steps" : 100 },
242+ }
243+
244+ # Session-level metrics require special task configuration with session_metric_def
245+ session_task_info = RecTaskInfo (
246+ name = "SessionTask" ,
247+ label_name = "label" ,
248+ prediction_name = "prediction" ,
249+ weight_name = "weight" ,
250+ session_metric_def = SessionMetricDef (
251+ session_var_name = "session" ,
252+ top_threshold = 10 ,
253+ run_ranking_of_labels = False ,
254+ ),
255+ )
256+
257+ # Tensor weighted average metric requires tensor_name
258+ tensor_task_info = RecTaskInfo (
259+ name = "TensorTask" ,
260+ label_name = "label" ,
261+ prediction_name = "prediction" ,
262+ weight_name = "weight" ,
263+ tensor_name = "target_tensor" ,
264+ )
265+
266+ session_metrics = {
267+ RecMetricEnum .RECALL_SESSION_LEVEL ,
268+ RecMetricEnum .PRECISION_SESSION_LEVEL ,
269+ }
270+
271+ tensor_metrics = {
272+ RecMetricEnum .TENSOR_WEIGHTED_AVG ,
273+ }
274+
275+ all_metric_defs : Dict [RecMetricEnum , RecMetricDef ] = {}
276+ for metric_enum in REC_METRICS_MAPPING .keys ():
277+ if isinstance (metric_enum , RecMetricEnum ):
278+ # Session-level metrics require special task configuration
279+ if metric_enum in session_metrics :
280+ all_metric_defs [metric_enum ] = RecMetricDef (
281+ rec_tasks = [session_task_info ],
282+ window_size = _DEFAULT_WINDOW_SIZE ,
283+ arguments = metric_arguments .get (metric_enum ),
284+ )
285+ # Tensor metrics require tensor_name
286+ elif metric_enum in tensor_metrics :
287+ all_metric_defs [metric_enum ] = RecMetricDef (
288+ rec_tasks = [tensor_task_info ],
289+ window_size = _DEFAULT_WINDOW_SIZE ,
290+ arguments = metric_arguments .get (metric_enum ),
291+ )
292+ else :
293+ arguments = metric_arguments .get (metric_enum )
294+ all_metric_defs [metric_enum ] = RecMetricDef (
295+ rec_tasks = [DefaultTaskInfo ],
296+ window_size = _DEFAULT_WINDOW_SIZE ,
297+ arguments = arguments ,
298+ )
299+ comprehensive_config = MetricsConfig (
300+ rec_tasks = [DefaultTaskInfo ],
301+ rec_metrics = all_metric_defs ,
302+ throughput_metric = ThroughputDef (),
303+ state_metrics = [],
304+ )
305+ return comprehensive_config
306+
307+ ComprehensiveMetricsConfig : MetricsConfig = (
308+ _create_comprehensive_metrics_config ()
309+ )
310+ # This simulates what an older checkpoint may have
311+ predefined_state_dict_keys = [
312+ "rec_metrics.rec_metrics.0._metrics_computations.0.cross_entropy_sum" ,
313+ "rec_metrics.rec_metrics.0._metrics_computations.0.weighted_num_samples" ,
314+ "rec_metrics.rec_metrics.0._metrics_computations.0.pos_labels" ,
315+ "rec_metrics.rec_metrics.0._metrics_computations.0.neg_labels" ,
316+ "rec_metrics.rec_metrics.1._metrics_computations.0.cross_entropy_positive_sum" ,
317+ "rec_metrics.rec_metrics.1._metrics_computations.0.weighted_num_samples" ,
318+ "rec_metrics.rec_metrics.1._metrics_computations.0.pos_labels" ,
319+ "rec_metrics.rec_metrics.1._metrics_computations.0.neg_labels" ,
320+ "rec_metrics.rec_metrics.2._metrics_computations.0.cross_entropy_sum" ,
321+ "rec_metrics.rec_metrics.2._metrics_computations.0.weighted_num_samples" ,
322+ "rec_metrics.rec_metrics.2._metrics_computations.0.pos_labels" ,
323+ "rec_metrics.rec_metrics.2._metrics_computations.0.neg_labels" ,
324+ "rec_metrics.rec_metrics.3._metrics_computations.0.cross_entropy_sum" ,
325+ "rec_metrics.rec_metrics.3._metrics_computations.0.weighted_num_samples" ,
326+ "rec_metrics.rec_metrics.3._metrics_computations.0.pos_labels" ,
327+ "rec_metrics.rec_metrics.3._metrics_computations.0.neg_labels" ,
328+ "rec_metrics.rec_metrics.4._metrics_computations.0.calibration_num" ,
329+ "rec_metrics.rec_metrics.4._metrics_computations.0.calibration_denom" ,
330+ "rec_metrics.rec_metrics.5._metrics_computations.0.ctr_num" ,
331+ "rec_metrics.rec_metrics.5._metrics_computations.0.ctr_denom" ,
332+ "rec_metrics.rec_metrics.6._metrics_computations.0.calibration_num" ,
333+ "rec_metrics.rec_metrics.6._metrics_computations.0.calibration_denom" ,
334+ "rec_metrics.rec_metrics.10._metrics_computations.0.error_sum" ,
335+ "rec_metrics.rec_metrics.10._metrics_computations.0.weighted_num_samples" ,
336+ "rec_metrics.rec_metrics.11._metrics_computations.0.error_sum" ,
337+ "rec_metrics.rec_metrics.11._metrics_computations.0.weighted_num_samples" ,
338+ "rec_metrics.rec_metrics.12._metrics_computations.0.tp_at_k" ,
339+ "rec_metrics.rec_metrics.12._metrics_computations.0.total_weights" ,
340+ "rec_metrics.rec_metrics.13._metrics_computations.0.weighted_sum" ,
341+ "rec_metrics.rec_metrics.13._metrics_computations.0.weighted_num_samples" ,
342+ "rec_metrics.rec_metrics.14._metrics_computations.0.num_examples" ,
343+ "rec_metrics.rec_metrics.14._metrics_computations.0.warmup_examples" ,
344+ "rec_metrics.rec_metrics.14._metrics_computations.0.time_lapse" ,
345+ "rec_metrics.rec_metrics.15._metrics_computations.0.num_true_pos" ,
346+ "rec_metrics.rec_metrics.15._metrics_computations.0.num_false_neg" ,
347+ "rec_metrics.rec_metrics.16._metrics_computations.0.num_true_pos" ,
348+ "rec_metrics.rec_metrics.16._metrics_computations.0.num_false_pos" ,
349+ "rec_metrics.rec_metrics.17._metrics_computations.0.accuracy_sum" ,
350+ "rec_metrics.rec_metrics.17._metrics_computations.0.weighted_num_samples" ,
351+ "rec_metrics.rec_metrics.18._metrics_computations.0.sum_ndcg" ,
352+ "rec_metrics.rec_metrics.18._metrics_computations.0.num_sessions" ,
353+ "rec_metrics.rec_metrics.19._metrics_computations.0.error_sum" ,
354+ "rec_metrics.rec_metrics.19._metrics_computations.0.weighted_num_pairs" ,
355+ "rec_metrics.rec_metrics.21._metrics_computations.0.true_pos_sum" ,
356+ "rec_metrics.rec_metrics.21._metrics_computations.0.false_pos_sum" ,
357+ "rec_metrics.rec_metrics.22._metrics_computations.0.true_pos_sum" ,
358+ "rec_metrics.rec_metrics.22._metrics_computations.0.false_neg_sum" ,
359+ "rec_metrics.rec_metrics.23._metrics_computations.0.cross_entropy_sum" ,
360+ "rec_metrics.rec_metrics.23._metrics_computations.0.weighted_num_samples" ,
361+ "rec_metrics.rec_metrics.23._metrics_computations.0.pos_labels" ,
362+ "rec_metrics.rec_metrics.23._metrics_computations.0.neg_labels" ,
363+ "rec_metrics.rec_metrics.23._metrics_computations.0.num_examples" ,
364+ "rec_metrics.rec_metrics.24._metrics_computations.0.calibration_num" ,
365+ "rec_metrics.rec_metrics.24._metrics_computations.0.calibration_denom" ,
366+ "rec_metrics.rec_metrics.24._metrics_computations.0.num_examples" ,
367+ "rec_metrics.rec_metrics.26._metrics_computations.0.weighted_sum" ,
368+ "rec_metrics.rec_metrics.26._metrics_computations.0.weighted_num_samples" ,
369+ "rec_metrics.rec_metrics.27._metrics_computations.0.cross_entropy_sum" ,
370+ "rec_metrics.rec_metrics.27._metrics_computations.0.weighted_num_samples" ,
371+ "rec_metrics.rec_metrics.27._metrics_computations.0.pos_labels" ,
372+ "rec_metrics.rec_metrics.27._metrics_computations.0.neg_labels" ,
373+ "rec_metrics.rec_metrics.27._metrics_computations.0.weighted_sum_predictions" ,
374+ "rec_metrics.rec_metrics.28._metrics_computations.0.cross_entropy_sum" ,
375+ "rec_metrics.rec_metrics.28._metrics_computations.0.weighted_num_samples" ,
376+ "rec_metrics.rec_metrics.28._metrics_computations.0.pos_labels" ,
377+ "rec_metrics.rec_metrics.28._metrics_computations.0.neg_labels" ,
378+ "rec_metrics.rec_metrics.29._metrics_computations.0.true_pos_sum" ,
379+ "rec_metrics.rec_metrics.29._metrics_computations.0.false_pos_sum" ,
380+ "rec_metrics.rec_metrics.29._metrics_computations.0.false_neg_sum" ,
381+ "rec_metrics.rec_metrics.30._metrics_computations.0.error_sum" ,
382+ "rec_metrics.rec_metrics.30._metrics_computations.0.weighted_num_samples" ,
383+ "rec_metrics.rec_metrics.30._metrics_computations.0.const_pred_error_sum" ,
384+ "throughput_metric.total_examples" ,
385+ "throughput_metric.warmup_examples" ,
386+ "throughput_metric.time_lapse_after_warmup" ,
387+ ]
388+
389+ # This is the latest RecMetricModule
390+ mock_optimizer = MockOptimizer ()
391+
392+ latest_metric_module = generate_metric_module (
393+ TestMetricModule ,
394+ metrics_config = ComprehensiveMetricsConfig ,
395+ batch_size = 128 ,
396+ world_size = 64 ,
397+ my_rank = 0 ,
398+ state_metrics_mapping = {StateMetricEnum .OPTIMIZERS : mock_optimizer },
399+ device = torch .device ("cpu" ),
400+ )
401+ tc = unittest .TestCase ()
402+ tc .assertSetEqual (
403+ set (predefined_state_dict_keys ),
404+ set (latest_metric_module .state_dict ().keys ()),
405+ "RecMetricModule state_dict keys have changed - ensure backward compatibility with older checkpoints" ,
406+ )
407+
226408 @staticmethod
227409 def _run_trainer_checkpointing (rank : int , world_size : int , backend : str ) -> None :
228410 dist .init_process_group (
0 commit comments