@@ -215,7 +215,7 @@ def _test_tracin_regression(self, features: int, mode: int) -> None:
215215 for i in range (len (idx )):
216216 self .assertTrue (isSorted (idx [i ]))
217217
218- if mode == "check_autograd_hacks " :
218+ if mode == "sample_wise_trick " :
219219
220220 criterion = nn .MSELoss (reduction = "none" )
221221
@@ -228,39 +228,47 @@ def _test_tracin_regression(self, features: int, mode: int) -> None:
228228 False ,
229229 )
230230
231- # With autograd hacks
231+ # With sample-wise trick
232232 criterion = nn .MSELoss (reduction = "sum" )
233- tracin_hack = self .tracin_constructor (
233+ tracin_sample_wise_trick = self .tracin_constructor (
234234 net , dataset , tmpdir , batch_size , criterion , True
235235 )
236236
237237 train_scores = tracin .influence (train_inputs , train_labels )
238- train_scores_hack = tracin_hack .influence (train_inputs , train_labels )
239- assertTensorAlmostEqual (self , train_scores , train_scores_hack )
238+ train_scores_sample_wise_trick = tracin_sample_wise_trick .influence (
239+ train_inputs , train_labels
240+ )
241+ assertTensorAlmostEqual (
242+ self , train_scores , train_scores_sample_wise_trick
243+ )
240244
241245 test_scores = tracin .influence (test_inputs , test_labels )
242- test_scores_hack = tracin_hack .influence (test_inputs , test_labels )
243- assertTensorAlmostEqual (self , test_scores , test_scores_hack )
246+ test_scores_sample_wise_trick = tracin_sample_wise_trick .influence (
247+ test_inputs , test_labels
248+ )
249+ assertTensorAlmostEqual (
250+ self , test_scores , test_scores_sample_wise_trick
251+ )
244252
245253
246254class _TestTracInRegression1DCheckIdx (_TestTracInRegression ):
247255 def test_tracin_regression_1D_check_idx (self ):
248256 self ._test_tracin_regression (1 , "check_idx" )
249257
250258
251- class _TestTracInRegression1DCheckAutogradHacks (_TestTracInRegression ):
252- def test_tracin_regression_1D_check_autograd_hacks (self ):
253- self ._test_tracin_regression (1 , "check_autograd_hacks " )
259+ class _TestTracInRegression1DCheckSampleWiseTrick (_TestTracInRegression ):
260+ def test_tracin_regression_1D_check_sample_wise_trick (self ):
261+ self ._test_tracin_regression (1 , "sample_wise_trick " )
254262
255263
256264class _TestTracInRegression20DCheckIdx (_TestTracInRegression ):
257265 def test_tracin_regression_20D_check_idx (self ):
258266 self ._test_tracin_regression (20 , "check_idx" )
259267
260268
261- class _TestTracInRegression20DCheckAutogradHacks (_TestTracInRegression ):
262- def test_tracin_regression_20D_check_autograd_hacks (self ):
263- self ._test_tracin_regression (20 , "check_autograd_hacks " )
269+ class _TestTracInRegression20DCheckSampleWiseTrick (_TestTracInRegression ):
270+ def test_tracin_regression_20D_check_sample_wise_trick (self ):
271+ self ._test_tracin_regression (20 , "sample_wise_tricksample_wise_trick " )
264272
265273
266274class _TestTracInXOR :
@@ -434,7 +442,7 @@ def _test_tracin_xor(self, mode) -> None:
434442 influence_labels = dataset .labels [idx [i ][0 :5 ], 0 ]
435443 self .assertTrue (torch .all (testlabels [i , 0 ] == influence_labels ))
436444
437- if mode == "check_autograd_hacks " :
445+ if mode == "sample_wise_trick " :
438446
439447 criterion = nn .MSELoss (reduction = "none" )
440448
@@ -447,9 +455,9 @@ def _test_tracin_xor(self, mode) -> None:
447455 False ,
448456 )
449457
450- # With autograd hacks
458+ # With sample-wise trick
451459 criterion = nn .MSELoss (reduction = "sum" )
452- tracin_hack = self .tracin_constructor (
460+ tracin_sample_wise_trick = self .tracin_constructor (
453461 net ,
454462 dataset ,
455463 tmpdir ,
@@ -459,18 +467,22 @@ def _test_tracin_xor(self, mode) -> None:
459467 )
460468
461469 test_scores = tracin .influence (testset , testlabels )
462- test_scores_hack = tracin_hack .influence (testset , testlabels )
463- assertTensorAlmostEqual (self , test_scores , test_scores_hack )
470+ test_scores_sample_wise_trick = tracin_sample_wise_trick .influence (
471+ testset , testlabels
472+ )
473+ assertTensorAlmostEqual (
474+ self , test_scores , test_scores_sample_wise_trick
475+ )
464476
465477
466478class _TestTracInXORCheckIdx (_TestTracInXOR ):
467479 def test_tracin_xor_check_idx (self ):
468480 self ._test_tracin_xor ("check_idx" )
469481
470482
471- class _TestTracInXORCheckAutogradHacks (_TestTracInXOR ):
472- def test_tracin_xor_check_autograd_hacks (self ):
473- self ._test_tracin_xor ("check_autograd_hacks " )
483+ class _TestTracInXORCheckSampleWiseTrick (_TestTracInXOR ):
484+ def test_tracin_xor_check_sample_wise_trick (self ):
485+ self ._test_tracin_xor ("sample_wise_trick " )
474486
475487
476488class _TestTracInIdentityRegression :
@@ -537,7 +549,7 @@ def _test_tracin_identity_regression(self, mode) -> None:
537549 for i in range (len (idx )):
538550 self .assertEqual (idx [i ][0 ], i )
539551
540- if mode == "check_autograd_hacks " :
552+ if mode == "sample_wise_trick " :
541553
542554 criterion = nn .MSELoss (reduction = "none" )
543555
@@ -550,9 +562,9 @@ def _test_tracin_identity_regression(self, mode) -> None:
550562 False ,
551563 )
552564
553- # With autograd hacks
565+ # With sample-wise trick
554566 criterion = nn .MSELoss (reduction = "sum" )
555- tracin_hack = self .tracin_constructor (
567+ tracin_sample_wise_trick = self .tracin_constructor (
556568 net ,
557569 dataset ,
558570 tmpdir ,
@@ -562,18 +574,22 @@ def _test_tracin_identity_regression(self, mode) -> None:
562574 )
563575
564576 train_scores = tracin .influence (train_inputs , train_labels )
565- train_scores_hack = tracin_hack .influence (train_inputs , train_labels )
566- assertTensorAlmostEqual (self , train_scores , train_scores_hack )
577+ train_scores_tracin_sample_wise_trick = (
578+ tracin_sample_wise_trick .influence (train_inputs , train_labels )
579+ )
580+ assertTensorAlmostEqual (
581+ self , train_scores , train_scores_tracin_sample_wise_trick
582+ )
567583
568584
569585class _TestTracInIdentityRegressionCheckIdx (_TestTracInIdentityRegression ):
570586 def test_tracin_identity_regression_check_idx (self ):
571587 self ._test_tracin_identity_regression ("check_idx" )
572588
573589
574- class _TestTracInIdentityRegressionCheckAutogradHacks (_TestTracInIdentityRegression ):
575- def test_tracin_identity_regression_check_autograd_hacks (self ):
576- self ._test_tracin_identity_regression ("check_autograd_hacks " )
590+ class _TestTracInIdentityRegressionCheckSampleWiseTrick (_TestTracInIdentityRegression ):
591+ def test_tracin_identity_regression_check_sample_wise_trick (self ):
592+ self ._test_tracin_identity_regression ("sample_wise_trick " )
577593
578594
579595class _TestTracInRandomProjectionRegression :
0 commit comments