Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e6a2aea
sage/response parameter implementation
julienbj Jun 5, 2025
bb1d111
Add computation of SAGE-values
julienbj Jun 5, 2025
9bb7342
Add computation of SAGE-values
julienbj Jun 5, 2025
d449c2f
Merge branch 'julie/sage' of https://github.com/NorskRegnesentral/sha…
julienbj Jun 5, 2025
5613715
.
julienbj Jun 5, 2025
7341fcb
Name changes, structural fixes and cleanup
julienbj Jun 6, 2025
1fde310
Add documentation
julienbj Jun 10, 2025
5c3b030
Add documentation
julienbj Jun 10, 2025
f54d546
Merge branch 'julie/sage' of https://github.com/NorskRegnesentral/sha…
julienbj Jun 10, 2025
01c71f6
Add tests for sage
julienbj Jun 10, 2025
2ab36ba
.
julienbj Jun 10, 2025
7b2b376
Add plot functionality for sage
julienbj Jun 11, 2025
48ceca9
Add documentation
julienbj Jun 11, 2025
e2a5d2c
update documentation
martinju Jun 12, 2025
fb388aa
update test files
martinju Jun 12, 2025
af68f0f
lints
martinju Jun 12, 2025
28d2675
remove library(xgboost)
martinju Jun 12, 2025
c7ecd7b
Update sage-tests + change of error message
julienbj Jun 12, 2025
11b0e52
Add loss_func + minor plot touch-ups
julienbj Jun 12, 2025
8545d0e
Add documentation
julienbj Jun 12, 2025
881d328
styler/lint
julienbj Jun 12, 2025
31e7b33
Changes to sage value function, and improvements to loss-function det…
julienbj Jun 24, 2025
585a61c
Fixes to plotting of sage values
julienbj Jun 24, 2025
80abc9b
Improvements and additions to SAGE-testing
julienbj Jun 24, 2025
7e1a8b8
lint
julienbj Jun 24, 2025
80688b8
Changes to loss_func compatibility + documentation
julienbj Jun 25, 2025
a2a4f48
Merge remote-tracking branch 'origin/master' into julie/sage
julienbj Jun 25, 2025
f47e82f
lint
julienbj Jun 25, 2025
898689e
updates to tests
julienbj Jun 26, 2025
34632cd
(Hopefully) fix environment issues with loss_func
julienbj Jun 27, 2025
127e56a
Add computation of SAGE values
julienbj Jul 4, 2025
af3280c
Merge branch 'julie/sage' of https://github.com/NorskRegnesentral/sha…
julienbj Jul 4, 2025
dbb2787
Merge branch 'julie/sage' of https://github.com/NorskRegnesentral/sha…
julienbj Jul 4, 2025
a06edc0
.
julienbj Jul 4, 2025
7032a95
.
julienbj Jul 4, 2025
bb4738e
Merge branch 'julie/sage' of https://github.com/NorskRegnesentral/sha…
julienbj Jul 4, 2025
9c30eff
Add plotting for shaprpy
julienbj Jul 4, 2025
a8e85ef
minor cleanup
julienbj Jul 7, 2025
29ae2c1
cleanup
julienbj Aug 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions R/check_convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ check_convergence <- function(internal) {
n_sampled_coalitions <- internal$iter_list[[iter]]$n_sampled_coalitions
exact <- internal$iter_list[[iter]]$exact

shap_names <- internal$parameters$shap_names
shap_names_with_none <- c("none", shap_names)
shapley_names <- internal$parameters$shapley_names
shapley_names_with_none <- c("none", shapley_names)

dt_shapley_est <- internal$iter_list[[iter]]$dt_shapley_est
dt_shapley_sd <- internal$iter_list[[iter]]$dt_shapley_sd
Expand All @@ -28,7 +28,7 @@ check_convergence <- function(internal) {
cli::cli_abort("The column names of the dt_shapley_est and dt_shapley_df are not equal.")
}

max_sd <- dt_shapley_sd[, max(.SD, na.rm = TRUE), .SDcols = shap_names_with_none, by = .I]$V1 # Max per prediction
max_sd <- dt_shapley_sd[, max(.SD, na.rm = TRUE), .SDcols = shapley_names_with_none, by = .I]$V1 # Max per prediction
max_sd0 <- max_sd * sqrt(n_sampled_coalitions) # Scales UP the sd as it scales at this rate

dt_shapley_est0 <- copy(dt_shapley_est)
Expand All @@ -41,8 +41,8 @@ check_convergence <- function(internal) {
} else {
converged_exact <- FALSE
if (!is.null(convergence_tol)) {
dt_shapley_est0[, maxval := max(.SD, na.rm = TRUE), .SDcols = shap_names, by = .I]
dt_shapley_est0[, minval := min(.SD, na.rm = TRUE), .SDcols = shap_names, by = .I]
dt_shapley_est0[, maxval := max(.SD, na.rm = TRUE), .SDcols = shapley_names, by = .I]
dt_shapley_est0[, minval := min(.SD, na.rm = TRUE), .SDcols = shapley_names, by = .I]
dt_shapley_est0[, max_sd0 := max_sd0]
dt_shapley_est0[, req_samples := (max_sd0 / ((maxval - minval) * convergence_tol))^2]
dt_shapley_est0[, conv_measure := max_sd0 / ((maxval - minval) * sqrt(n_sampled_coalitions))]
Expand Down
97 changes: 69 additions & 28 deletions R/compute_estimates.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,47 +127,56 @@ postprocess_vS_list <- function(vS_list, internal) {
compute_shapley <- function(internal, dt_vS) {
is_groupwise <- internal$parameters$is_groupwise
type <- internal$parameters$type
sage <- internal$parameters$sage

iter <- length(internal$iter_list)

W <- internal$iter_list[[iter]]$W

shap_names <- internal$parameters$shap_names
shapley_names <- internal$parameters$shapley_names

response <- internal$data$response
loss_func <- internal$parameters$loss_func

# If multiple horizons with explain_forecast are used, we only distribute value to those used at each horizon
if (type == "forecast") {
id_coalition_mapper_dt <- internal$iter_list[[iter]]$id_coalition_mapper_dt
horizon <- internal$parameters$horizon
cols_per_horizon <- internal$objects$cols_per_horizon
shap_names <- internal$parameters$shap_names
# shapley_names <- internal$parameters$shapley_names
W_list <- internal$objects$W_list

kshap_list <- list()
kshapley_list <- list()
for (i in seq_len(horizon)) {
W0 <- W_list[[i]]

dt_vS0 <- merge(dt_vS, id_coalition_mapper_dt[horizon == i], by = "id_coalition", all.y = TRUE)
data.table::setorder(dt_vS0, horizon_id_coalition)
these_vS0_cols <- grep(paste0("p_hat", i, "_"), names(dt_vS0))

kshap0 <- t(W0 %*% as.matrix(dt_vS0[, these_vS0_cols, with = FALSE]))
kshap_list[[i]] <- data.table::as.data.table(kshap0)
kshapley0 <- t(W0 %*% as.matrix(dt_vS0[, these_vS0_cols, with = FALSE]))
kshapley_list[[i]] <- data.table::as.data.table(kshapley0)

if (!is_groupwise) {
names(kshap_list[[i]]) <- c("none", cols_per_horizon[[i]])
names(kshapley_list[[i]]) <- c("none", cols_per_horizon[[i]])
} else {
names(kshap_list[[i]]) <- c("none", shap_names)
names(kshapley_list[[i]]) <- c("none", shapley_names)
}
}

dt_kshap <- cbind(internal$parameters$output_labels, rbindlist(kshap_list, fill = TRUE))
dt_kshapley <- cbind(internal$parameters$output_labels, rbindlist(kshapley_list, fill = TRUE))
} else if (sage) {
vS_SAGE <- -apply(t(dt_vS[, -1]), 2, function(pred_col) loss_func(response, pred_col))
kshapley <- t(W %*% as.matrix(vS_SAGE))
dt_kshapley <- data.table::as.data.table(kshapley)
colnames(dt_kshapley) <- c("none", shapley_names)
} else {
kshap <- t(W %*% as.matrix(dt_vS[, -"id_coalition"]))
dt_kshap <- data.table::as.data.table(kshap)
colnames(dt_kshap) <- c("none", shap_names)
kshapley <- t(W %*% as.matrix(dt_vS[, -"id_coalition"]))
dt_kshapley <- data.table::as.data.table(kshapley)
colnames(dt_kshapley) <- c("none", shapley_names)
}

return(dt_kshap)
return(dt_kshapley)
}

#' @keywords internal
Expand All @@ -184,32 +193,48 @@ bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100) {
X <- X_list[[i]]
if (is_groupwise) {
n_shapley_values <- internal$parameters$n_shapley_values
shap_names <- internal$parameters$shap_names
shapley_names <- internal$parameters$shapley_names
} else {
n_shapley_values <- length(internal$parameters$horizon_features[[i]])
shap_names <- internal$parameters$horizon_features[[i]]
shapley_names <- internal$parameters$horizon_features[[i]]
}
dt_cols <- c(1, seq_len(n_explain) + (i - 1) * n_explain + 1)
dt_vS_this <- dt_vS[, dt_cols, with = FALSE]
n_coal_each_size <- choose(n_shapley_values, seq(n_shapley_values - 1))
result[[i]] <-
bootstrap_shapley_inner(X, n_shapley_values, shap_names, internal, dt_vS_this, n_coal_each_size, n_boot_samps)
bootstrap_shapley_inner(
X,
n_shapley_values,
shapley_names,
internal,
dt_vS_this,
n_coal_each_size,
n_boot_samps
)
}
result <- cbind(internal$parameters$output_labels, rbindlist(result, fill = TRUE))
} else {
X <- internal$iter_list[[iter]]$X
n_shapley_values <- internal$parameters$n_shapley_values
shap_names <- internal$parameters$shap_names
shapley_names <- internal$parameters$shapley_names
n_coal_each_size <- internal$parameters$n_coal_each_size
result <- bootstrap_shapley_inner(X, n_shapley_values, shap_names, internal, dt_vS, n_coal_each_size, n_boot_samps)
result <- bootstrap_shapley_inner(
X,
n_shapley_values,
shapley_names,
internal,
dt_vS,
n_coal_each_size,
n_boot_samps
)
}
return(result)
}

#' @keywords internal
bootstrap_shapley_inner <- function(X,
n_shapley_values,
shap_names,
shapley_names,
internal,
dt_vS,
n_coal_each_size,
Expand All @@ -222,6 +247,9 @@ bootstrap_shapley_inner <- function(X,
semi_deterministic_sampling <- internal$parameters$extra_computation_args$semi_deterministic_sampling
shapley_reweight <- internal$parameters$extra_computation_args$kernelSHAP_reweighting

sage <- internal$parameters$sage
loss_func <- internal$parameters$loss_func

if (type == "forecast") {
# For forecast we set it to zero as all coalitions except empty and grand can be sampled
max_fixed_coal_size <- 0
Expand All @@ -231,8 +259,11 @@ bootstrap_shapley_inner <- function(X,
}

X_org <- copy(X)

boot_sd_array <- array(NA, dim = c(n_explain, n_shapley_values + 1, n_boot_samps))
if (sage) {
boot_sd_array <- array(NA, dim = c(1, n_shapley_values + 1, n_boot_samps))
} else {
boot_sd_array <- array(NA, dim = c(n_explain, n_shapley_values + 1, n_boot_samps))
}

# Split X_org into the deterministic coalitions and the sampled coalitions
X_keep <- X_org[is.na(sample_freq), .(id_coalition, coalitions, coalition_size, N, shapley_weight)]
Expand Down Expand Up @@ -344,18 +375,28 @@ bootstrap_shapley_inner <- function(X,
normalize_W_weights = TRUE
)

kshap_boot <- t(W_boot %*% as.matrix(dt_vS[id_coalition %in% X_boot[
boot_id == i,
id_coalition
], -"id_coalition"]))
if (sage) {
response <- internal$data$response
loss_func <- internal$parameters$loss_func

vS_SAGE <- -apply(t(dt_vS[, -1]), 2, function(pred_col) loss_func(response, pred_col))
ksage_boot <- t(W_boot %*% as.matrix(vS_SAGE[X_boot[boot_id == i, id_coalition]]))

boot_sd_array[, , i] <- copy(kshap_boot)
boot_sd_array[, , i] <- copy(ksage_boot)
} else {
kshapley_boot <- t(W_boot %*% as.matrix(dt_vS[id_coalition %in% X_boot[
boot_id == i,
id_coalition
], -"id_coalition"]))

boot_sd_array[, , i] <- copy(kshapley_boot)
}
}

std_dev_mat <- apply(boot_sd_array, c(1, 2), sd)

dt_kshap_boot_sd <- data.table::as.data.table(std_dev_mat)
colnames(dt_kshap_boot_sd) <- c("none", shap_names)
dt_kshapley_boot_sd <- data.table::as.data.table(std_dev_mat)
colnames(dt_kshapley_boot_sd) <- c("none", shapley_names)

return(dt_kshap_boot_sd)
return(dt_kshapley_boot_sd)
}
55 changes: 49 additions & 6 deletions R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
#' If `NULL` (default), the argument is set to `TRUE` if there are more than 5 features/groups, and `FALSE` otherwise.
#' If eventually `TRUE`, the Shapley values are estimated iteratively in an iterative manner.
#' This provides sufficiently accurate Shapley value estimates faster.
#' First an initial number of coalitions is sampled, then bootsrapping is used to estimate the variance of the Shapley
#' First an initial number of coalitions is sampled, then bootstrapping is used to estimate the variance of the Shapley
#' values.
#' A convergence criterion is used to determine if the variances of the Shapley values are sufficiently small.
#' If the variances are too high, we estimate the number of required samples to reach convergence, and thereby add more
Expand Down Expand Up @@ -134,7 +134,7 @@
#' @param asymmetric Logical.
#' Not applicable for (regular) non-causal or asymmetric explanations.
#' If `FALSE` (default), `explain` computes regular symmetric Shapley values,
#' If `TRUE`, then `explain` compute asymmetric Shapley values based on the (partial) causal ordering
#' If `TRUE`, then `explain` computes asymmetric Shapley values based on the (partial) causal ordering
#' given by `causal_ordering`. That is, `explain` only uses the feature combinations/coalitions that
#' respect the causal ordering when computing the asymmetric Shapley values. If `asymmetric` is `TRUE` and
#' `confounding` is `NULL` (default), then `explain` computes asymmetric conditional Shapley values as specified in
Expand Down Expand Up @@ -170,8 +170,24 @@
#' `asymmetric`. The `approach` cannot be `regression_separate` and `regression_surrogate` as the
#' regression-based approaches are not applicable to the causal Shapley value methodology.
#'
#' @param ... Further arguments passed to specific approaches, see below.
#' @param sage Logical.
#' If `FALSE` (default), Shapley value explanations for individual predictions are computed.
#' If `TRUE`, Shapley value explanations of the global model loss (SAGE) are computed.
#' A single set of Shapley values are then computed over the observations provided to `x_explain`.
#' See details for further information.
#'
#' @param response Numerical vector.
#' Not applicable unless the `sage` parameter is set to `TRUE`.
#' `response` is used in computations of the SAGE values.
#'
#' @param loss_func Function.
#' Not applicable unless the `sage` parameter is set to `TRUE`.
#' Should be a function of two parameters, whereof the first will be the true value of the response,
#' and the second will be the models prediction.
#' If `NULL` (default), the loss-function will be set to logistic loss in case of
#' binary response vectors, and MSE loss otherwise.
#'
#' @param ... Further arguments passed to specific approaches, see below.
#'
#' @inheritDotParams setup_approach.categorical
#' @inheritDotParams setup_approach.copula
Expand All @@ -184,7 +200,7 @@
#' @inheritDotParams setup_approach.timeseries
#' @inheritDotParams setup_approach.vaeac
#'
#' @details The `shapr` package implements kernelSHAP estimation of dependence-aware Shapley values with
#' @details The `shapr` package implements kernelSHAP estimation of dependence-aware Shapley values explanations with
#' eight different Monte Carlo-based approaches for estimating the conditional distributions of the data.
#' These are all introduced in the
#' \href{https://norskregnesentral.github.io/shapr/articles/general_usage.html}{general usage vignette}.
Expand Down Expand Up @@ -222,6 +238,16 @@
#' Heskes et al. (2020)} as a way to explain the total effect of features
#' on the prediction, taking into account their causal relationships, by adapting the sampling procedure in `shapr`.
#'
#' When `sage = TRUE`, Shapley value explanations of the global model loss (SAGE) are computed.
#' A single set of Shapley values are then computed over the observations provided to `x_explain`,
#' and the output under `shapley_values_est` will then contain the SAGE values, while
#' `shapley_values_sd` will contain the standard deviation for the SAGE values.
#' The computation of the SAGE values is based on
#' \href{https://proceedings.neurips.cc/paper/2020/file/c7bf0b7c1a86d5eb3be2c722cf2cf746-Paper.pdf}{
#' Covert et al. (2020)}, sampling from conditional distributions rather than the marginal sampling described
#' by Covert et. al.
#' The SHAP values for the individual predictions can be found under `internal$output$shap_values_est` in all cases.
#'
#' The package allows for parallelized computation with progress updates through the tightly connected
#' [future::future] and [progressr::progressr] packages.
#' See the examples below.
Expand All @@ -231,15 +257,21 @@
#' This combined batch computing of the v(S) values, enables fast and accurate estimation of the Shapley values
#' in a memory friendly manner.
#'
#' The package can also be used for computation of SAGE values as described by
#' \href{https://proceedings.neurips.cc/paper/2020/file/c7bf0b7c1a86d5eb3be2c722cf2cf746-Paper.pdf}{
#' Covert et al. (2020)}.
#'
#' @return Object of class `c("shapr", "list")`. Contains the following items:
#' \describe{
#' \item{`shapley_values_est`}{data.table with the estimated Shapley values with explained observation in the rows and
#' features along the columns.
#' The column `none` is the prediction not devoted to any of the features (given by the argument `phi0`)}
#' The column `none` is the prediction not devoted to any of the features (given by the argument `phi0`).
#' If `sage = TRUE` in [explain()], the column will contain a single row with the estimated SAGE values.}
#' \item{`shapley_values_sd`}{data.table with the standard deviation of the Shapley values reflecting the uncertainty.
#' Note that this only reflects the coalition sampling part of the kernelSHAP procedure, and is therefore by
#' definition 0 when all coalitions is used.
#' Only present when `extra_computation_args$compute_sd=TRUE`, which is the default when `iterative = TRUE`}
#' Only present when `extra_computation_args$compute_sd=TRUE`, which is the default when `iterative = TRUE`.
#' If `sage = TRUE` in [explain()], the column will contain a single row with the estimated sd for the SAGE values.}
#' \item{`internal`}{List with the different parameters, data, functions and other output used internally.}
#' \item{`pred_explain`}{Numeric vector with the predictions for the explained observations}
#' \item{`MSEv`}{List with the values of the MSEv evaluation criterion for the approach. See the
Expand Down Expand Up @@ -467,6 +499,9 @@ explain <- function(model,
extra_computation_args = list(),
iterative_args = list(),
output_args = list(),
sage = FALSE,
response = NULL,
loss_func = NULL,
...) { # ... is further arguments passed to specific approaches


Expand Down Expand Up @@ -501,6 +536,9 @@ explain <- function(model,
confounding = confounding,
output_args = output_args,
extra_computation_args = extra_computation_args,
sage = sage,
response = response,
loss_func = loss_func,
...
)

Expand Down Expand Up @@ -630,6 +668,11 @@ testing_cleanup <- function(output) {
output$internal$objects$regression.surrogate_model <- NULL
}

#Removing loss-function
if (output$internal$paramteres$sage) {
output$internal$parameters$loss_func <- NULL
}

# Delete the saving_path
output$internal$parameters$output_args$saving_path <- NULL
output$saving_path <- NULL
Expand Down
2 changes: 1 addition & 1 deletion R/explain_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ get_data_forecast <- function(y, xreg, train_idx, explain_idx, explain_y_lags, e
xreg = xreg,
group = reg_fcast$group,
horizon_group = reg_fcast$horizon_group,
shap_names = names(data_lag$group),
shapley_names = names(data_lag$group),
n_endo = ncol(data_lag$lagged),
x_train = cbind(
data.table::as.data.table(data_lag$lagged[train_idx, , drop = FALSE]),
Expand Down
21 changes: 21 additions & 0 deletions R/finalize_explanation.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ finalize_explanation <- function(internal) {
type <- internal$parameters$type
dt_vS <- internal$output$dt_vS

sage <- internal$parameters$sage

# Extracting iter (and deleting the last temporary empty list of iter_list)
iter <- length(internal$iter_list) - 1
internal$iter_list[[iter + 1]] <- NULL
Expand Down Expand Up @@ -43,6 +45,24 @@ finalize_explanation <- function(internal) {
# Extract iterative results in a simplified format
iterative_results <- get_iter_results(internal$iter_list)

# Compute SAGE-values for internal
if (sage) {
response <- internal$data$response

dt_shapley_est$explain_id <- NA
dt_shapley_sd$explain_id <- NA

internal$output$sage_values_est <- dt_shapley_est

W <- internal$objects$W
kshap <- t(W %*% as.matrix(dt_vS[, -"id_coalition"]))
dt_kshap <- data.table::as.data.table(kshap)
colnames(dt_kshap) <- c("none", internal$parameters$shapley_names)
internal$output$shap_values_est <- dt_kshap
} else {
internal$output$shap_values_est <- dt_shapley_est
}

output <- list(
shapley_values_est = dt_shapley_est,
shapley_values_sd = dt_shapley_sd,
Expand All @@ -52,6 +72,7 @@ finalize_explanation <- function(internal) {
saving_path = internal$parameters$output_args$saving_path,
internal = internal
)

attr(output, "class") <- c("shapr", "list")

return(output)
Expand Down
Loading
Loading