11#!/usr/bin/env python3
22
33import glob
4+ import warnings
45from abc import abstractmethod
56from os .path import join
67from typing import Any , Callable , Iterator , List , Optional , Union , Tuple , NamedTuple
@@ -385,7 +386,23 @@ def __init__(
385386 be computed for all layers. Otherwise, they will only be computed
386387 for the layers specified in `layers`.
387388 Default: None
388- loss_fn (Callable, optional): The loss function applied to model.
389+ loss_fn (Callable, optional): The loss function applied to model. There
390+ are two options for the return type of `loss_fn`. First, `loss_fn`
391+ can be a "per-example" loss function - returns a 1D Tensor of
392+ losses for each example in a batch. `nn.BCELoss(reduction="none")`
393+ would be an "per-example" loss function. Second, `loss_fn` can be
394+ a "reduction" loss function that reduces the per-example losses,
395+ in a batch, and returns a single scalar Tensor. For this option,
396+ the reduction must be the *sum* or the *mean* of the per-example
397+ losses. For instance, `nn.BCELoss(reduction="sum")` is acceptable.
398+ Note for the first option, the `sample_wise_grads_per_batch`
399+ argument must be False, and for the second option,
400+ `sample_wise_grads_per_batch` must be True. Also note that for
401+ the second option, if `loss_fn` has no "reduction" attribute,
402+ the implementation assumes that the reduction is the *sum* of the
403+ per-example losses. If this is not the case, i.e. the reduction
404+ is the *mean*, please set the "reduction" attribute of `loss_fn`
405+ to "mean", i.e. `loss_fn.reduction = "mean"`.
389406 Default: None
390407 batch_size (int or None, optional): Batch size of the DataLoader created to
391408 iterate through `influence_src_dataset`, if it is a Dataset.
@@ -404,10 +421,16 @@ def __init__(
404421 inefficient. We offer an implementation of batch-wise gradient
405422 computations w.r.t. to model parameters which is computationally
406423 more efficient. This implementation can be enabled by setting the
407- `sample_wise_grad_per_batch` argument to `True`. Note that our
424+ `sample_wise_grad_per_batch` argument to `True`, and should be
425+ enabled if and only if the `loss_fn` argument is a "reduction" loss
426+ function. For example, `nn.BCELoss(reduction="sum")` would be a
427+ valid `loss_fn` if this implementation is enabled (see
428+ documentation for `loss_fn` for more details). Note that our
408429 current implementation enables batch-wise gradient computations
409430 only for a limited number of PyTorch nn.Modules: Conv2D and Linear.
410- This list will be expanded in the near future.
431+ This list will be expanded in the near future. Therefore, please
432+ do not enable this implementation if gradients will be computed
433+ for other kinds of layers.
411434 Default: False
412435 """
413436
@@ -423,14 +446,47 @@ def __init__(
423446
424447 self .sample_wise_grads_per_batch = sample_wise_grads_per_batch
425448
426- if (
427- self .sample_wise_grads_per_batch
428- and isinstance (loss_fn , Module ) # TODO: allow loss_fn to be Callable
429- and hasattr (loss_fn , "reduction" )
430- ):
431- self .reduction_type = str (loss_fn .reduction )
449+ # If we are able to access the reduction used by `loss_fn`, we check whether
450+ # the reduction is compatible with `sample_wise_grads_per_batch`
451+ if isinstance (loss_fn , Module ) and hasattr (
452+ loss_fn , "reduction"
453+ ): # TODO: allow loss_fn to be Callable
454+ if self .sample_wise_grads_per_batch :
455+ assert loss_fn .reduction in ["sum" , "mean" ], (
456+ 'reduction for `loss_fn` must be "sum" or "mean" when '
457+ "`sample_wise_grads_per_batch` is True"
458+ )
459+ self .reduction_type = str (loss_fn .reduction )
460+ else :
461+ assert loss_fn .reduction == "none" , (
462+ 'reduction for `loss_fn` must be "none" when '
463+ "`sample_wise_grads_per_batch` is False"
464+ )
432465 else :
433- self .reduction_type = "sum"
466+ # if we are unable to access the reduction used by `loss_fn`, we warn
467+ # the user about the assumptions we are making regarding the reduction
468+ # used by `loss_fn`
469+ if self .sample_wise_grads_per_batch :
470+ warnings .warn (
471+ 'Since `loss_fn` has no "reduction" attribute, and '
472+ "`sample_wise_grads_per_batch` is True, the implementation assumes "
473+ 'that `loss_fn` is a "reduction" loss function that reduces the '
474+ "per-example losses by taking their *sum*. If `loss_fn` "
475+ "instead reduces the per-example losses by taking their mean, "
476+ 'please set the reduction attribute of `loss_fn` to "mean", i.e. '
477+ '`loss_fn.reduction = "mean"`. Note that if '
478+ "`sample_wise_grads_per_batch` is True, the implementation "
479+ "assumes the reduction is either a sum or mean reduction."
480+ )
481+ self .reduction_type = "sum"
482+ else :
483+ warnings .warn (
484+ 'Since `loss_fn` has no "reduction" attribute, and '
485+ "`sample_wise_grads_per_batch` is False, the implementation "
486+ 'assumes that `loss_fn` is a "per-example" loss function (see '
487+ "documentation for `loss_fn` for details). Please ensure that "
488+ "this is the case."
489+ )
434490
435491 r"""
436492 TODO: Either restore model state after done (would have to place functionality
0 commit comments