-
Notifications
You must be signed in to change notification settings - Fork 67
Description
Describe the issue
When running predictions on multi-block datasets (especially with cross-validation), I noticed significant slowdowns (~10s) in predict.mixo_pls. Profiling revealed that WeightedPredict calculation was the bottleneck (as mentioned in #236). I revisited the weighted vote part of predict.mixo_pls and ran additional correctness and benchmark tests.
Currently, the code uses:
out$WeightedPredict = array(unlist(lapply(temp.all, function(x){
apply(x, c(1,2), function(z){
temp = aggregate(rowMeans(object$weights),list(z),sum)
ind = which(temp[,2]== max (temp[,2]))# if two max, then NA
if(length(ind) == 1)
{
res = temp[ind, 1]
} else {
res = NA
}
res
})})), dim(Y.hat[[1]]), dimnames = list(rownames(newdata[[1]]), colnames(Y), paste0("dim", c(1:min(ncomp[-object$indY])))))
As discussed earlier, this construct:
- repeatedly computes
rowMeans(object$weights)even though it is constant, - calls
aggregate()on floating point values (almost no aggregation occurs), - results in very slow predictions on larger datasets (especially regression with CV/permutation).
Expected behavior
This block can be replaced by:
out$WeightedPredict = array(unlist(lapply(temp.all, function(x){
x[, , which.max(rowMeans(object$weights))]
})), dim(Y.hat[[1]]), dimnames = list(rownames(newdata[[1]]), colnames(Y), paste0("dim", c(1:min(ncomp[-object$indY])))))
out$temp.all.double <- is.double(unlist(temp.all)) # For testing purposes, always TRUE
Screenshots
I tested 4 affected models (Multi-Block PLS, Multi-Block sPLS, Multi-Block PLS-DA, Multi-Block sPLS-DA), each using random matrices generated by rnorm and diagonal matrices to simulate sparse omics data.
In each model, the value of temp.all is a floating point number, and the resulting WeightedPredict and WeightedPredict.class (for DA) are no different from the original function.
I also benchmarked each model, and for regression tasks, the performance improvement was more significant.
| Modified (ms) | Unmodified (ms) | |
|---|---|---|
| Multi-Block PLS | 67.11455 | 13010.12190 |
| Multi-Block sPLS | 65.91465 | 6566.60865 |
| Multi-Block PLS-DA | 97.25735 | 247.86015 |
| Multi-Block sPLS-DA | 97.06685 | 248.61410 |
In addition, the modified function also passed the mixOmics test suite.
Output of sessionInfo() (if it is an issue):
# session info
Platform: x86_64-w64-mingw32/x64
Running under: Windows 11 x64 (build 22631)
Matrix products: default
locale:
[1] LC_COLLATE=Chinese (Simplified)_China.utf8
[2] LC_CTYPE=Chinese (Simplified)_China.utf8
[3] LC_MONETARY=Chinese (Simplified)_China.utf8
[4] LC_NUMERIC=C
[5] LC_TIME=Chinese (Simplified)_China.utf8
time zone: Etc/GMT-8
tzcode source: internal
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] microbenchmark_1.4.10 mixOmics_6.32.0 ggplot2_3.5.2
[4] lattice_0.22-6 MASS_7.3-60.2
loaded via a namespace (and not attached):
[1] Matrix_1.7-0 gtable_0.3.6 dplyr_1.1.4
[4] compiler_4.4.1 tidyselect_1.2.1 Rcpp_1.0.12
[7] ellipse_0.5.0 stringr_1.5.1 parallel_4.4.1
[10] gridExtra_2.3 tidyr_1.3.1 rARPACK_0.11-0
[13] scales_1.4.0 BiocParallel_1.38.0 R6_2.6.1
[16] plyr_1.8.9 generics_0.1.3 igraph_2.0.3
[19] ggrepel_0.9.5 tibble_3.3.0 pillar_1.11.0
[22] RColorBrewer_1.1-3 rlang_1.1.4 stringi_1.8.4
[25] cli_3.6.3 withr_3.0.2 magrittr_2.0.3
[28] grid_4.4.1 lifecycle_1.0.4 vctrs_0.6.5
[31] RSpectra_0.16-2 glue_1.7.0 farver_2.1.2
[34] corpcor_1.6.10 codetools_0.2-20 purrr_1.0.2
[37] reshape2_1.4.4 matrixStats_1.3.0 tools_4.4.1
[40] pkgconfig_2.0.3
# package version
[1] "C:/Users/18269/AppData/Local/R/win-library/4.4/mixOmics"
package * version date (UTC) lib source
mixOmics * 6.32.0 2025-09-23 [1] Github (mixOmicsTeam/mixOmics@ad47493)