diff --git a/.github/workflows/check-litgen.yml b/.github/workflows/check-litgen.yml new file mode 100644 index 0000000000..531bb393cd --- /dev/null +++ b/.github/workflows/check-litgen.yml @@ -0,0 +1,79 @@ +name: Checks - litgen + +permissions: + contents: read + +on: + pull_request: + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref_name }} + cancel-in-progress: true + +jobs: + litgen_check: + runs-on: ubuntu-latest + container: + image: ghcr.io/llnl/sundials_spack_cache:llvm-17.0.4-h4lflucc3v2vage45opbo2didtcuigsn.spack + steps: + - name: Install dependencies with apt + run: | + apt update + apt install -y git python3 python3-pip + + - name: Install black + run: pip install black + + - name: Print black version + run: black --version + + - name: Install cmake-format + run: pip install cmakelang + + - name: Print cmake-format version + run: cmake-format --version + + - name: Install fprettify + run: pip install fprettify + + - name: Print fprettify version + run: fprettify --version + + - name: Print clang-format version + run: clang-format --version + + - name: Install sundials generator+litgen + run: | + pip install black isort pyyaml litgen@git+https://github.com/sundials-codes/litgen.git + + - name: Check out repository code + uses: actions/checkout@v5 + with: + submodules: true + + - name: Add safe directory + run: git config --global --add safe.directory "$GITHUB_WORKSPACE" + + - name: Run generator on code + run: | + cd bindings/sundials4py + python3 sundials4py-generate/generate.py . + + - name: Run formatter on code + run: ./scripts/format.sh bindings/sundials4py + + - name: Run git diff to see if anything changed + run: /usr/bin/git diff --exit-code + + - name: Run git diff if we failed + if: failure() + run: /usr/bin/git diff > litgen.patch + + - name: Archive diff as a patch if we failed + uses: actions/upload-artifact@v5 + if: failure() + with: + name: litgen.patch + path: | + ${{ github.workspace }}/litgen.patch diff --git a/.github/workflows/sundials4py-wheels.yml b/.github/workflows/sundials4py-wheels.yml new file mode 100644 index 0000000000..550298886a --- /dev/null +++ b/.github/workflows/sundials4py-wheels.yml @@ -0,0 +1,120 @@ +name: sundials4py wheels + +permissions: + contents: read + +on: + push: + branches: + - main + - develop + pull_request: + merge_group: + workflow_dispatch: + release: + types: + - published + +jobs: + build_sdist: + name: Build SDist + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + with: + submodules: true + + - name: Build SDist + run: pipx run build --sdist + + - name: Check metadata + run: pipx run twine check dist/* + + - uses: actions/upload-artifact@v5 + with: + name: dist-sdist + path: dist/*.tar.gz + + + build_wheels: + name: Wheels on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, macos-15-intel, windows-latest] + + steps: + - name: Set macOS deployment target + if: contains(matrix.os, 'macos') + run: echo "MACOSX_DEPLOYMENT_TARGET=$(sw_vers -productVersion | cut -d '.' -f 1-2)" >> $GITHUB_ENV + + - uses: actions/checkout@v5 + with: + submodules: true + + - uses: pypa/cibuildwheel@v2.22 + env: + CIBW_TEST_COMMAND: "pytest {project}/bindings/sundials4py/test" + + - name: Upload wheels + uses: actions/upload-artifact@v5 + with: + path: wheelhouse/*.whl + name: dist-${{ matrix.os }} + + build_wheels_extra: + name: Build extra wheels on ${{ matrix.os }} (${{ matrix.precision }}/${{ matrix.index }}) + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + precision: [single, double] + index: [32, 64] + + steps: + - name: Set macOS deployment target + if: contains(matrix.os, 'macos') + run: echo "MACOSX_DEPLOYMENT_TARGET=$(sw_vers -productVersion | cut -d '.' -f 1-2)" >> $GITHUB_ENV + + - uses: actions/checkout@v5 + with: + submodules: true + + - uses: pypa/cibuildwheel@v2.22 + env: + CMAKE_ARGS: "-DSUNDIALS_PRECISION=${{ matrix.precision }} -DSUNDIALS_INDEX_SIZE=${{ matrix.index }}" + CIBW_TEST_COMMAND: "pytest {project}/bindings/sundials4py/test" + + - name: Rename wheels to include config + run: | + for whl in wheelhouse/*.whl; do + mv "$whl" "$(echo $whl | sed 's/\.whl$/-${{ matrix.precision }}-${{ matrix.index }}.whl/')" + done + shell: bash + + - name: Upload wheels + uses: actions/upload-artifact@v5 + with: + path: wheelhouse/*.whl + name: dist-${{ matrix.os }} + + upload: + name: Upload if release + needs: [build_wheels, build_sdist] + runs-on: ubuntu-latest + if: github.event_name == 'release' && github.event.action == 'published' + + steps: + - uses: actions/setup-python@v5 + - uses: actions/download-artifact@v4 + with: + path: dist + pattern: dist-* + merge-multiple: true + + - uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.pypi_password }} diff --git a/.gitignore b/.gitignore index 9d1c6fe6e6..1d8e49e3af 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ /.pydevproject .vscode compile_commands.json +.venv .clangd # custom test environment setup script @@ -69,6 +70,7 @@ compile_commands.json /doc/build/ /doc/*/build/ /doc/shared/__pycache__ +/doc/shared/Python/sundials4py-*-functions.rst # PDFs of user guides and example docs /doc/*/*_guide.pdf @@ -82,3 +84,7 @@ uberenv_libs # tools /tools/suntools/__pycache__ + +# python +.pytest_cache +__pycache__ diff --git a/.gitmodules b/.gitmodules index e915f32aa6..91609780fd 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "external/sundials-addon-example"] path = external/sundials-addon-example url = https://github.com/sundials-codes/sundials-addon-example.git +[submodule "bindings/sundials4py/sundials4py-generate"] + path = bindings/sundials4py/sundials4py-generate + url = https://github.com/sundials-codes/sundials4py-generate.git diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 5e5cd86b40..81cd62a113 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -7,9 +7,9 @@ version: 2 # Set the version of Python and other tools you might need build: - os: ubuntu-20.04 + os: ubuntu-24.04 tools: - python: "3.9" + python: "3.12" apt_packages: - graphviz diff --git a/CHANGELOG.md b/CHANGELOG.md index e7e6173e9d..c34a8b1fde 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ ### Major Features +SUNDIALS now has official Python interfaces! With this release, we are shipping a **beta version** of +the sundials4py Python module (created with nanobind and litgen). sundials4py provides explicit +interfaces to most features of SUNDIALS. + ### New Features and Enhancements The functions `CVodeGetUserDataB` and `IDAGetUserDataB` were added to CVODES diff --git a/CMakeLists.txt b/CMakeLists.txt index 6fd1d6b26c..ea5ed8443e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -243,6 +243,10 @@ if(SUNDIALS_ENABLE_EXTERNAL_ADDONS) add_subdirectory(external) endif() +if(SUNDIALS_ENABLE_PYTHON) + add_subdirectory(bindings/sundials4py) +endif() + # =============================================================== # Install configuration header files and license file. # =============================================================== diff --git a/bindings/sundials4py/CMakeLists.txt b/bindings/sundials4py/CMakeLists.txt new file mode 100644 index 0000000000..a290646f12 --- /dev/null +++ b/bindings/sundials4py/CMakeLists.txt @@ -0,0 +1,139 @@ +# --------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# --------------------------------------------------------------- + +# Warn if the user invokes CMake directly +if(NOT SKBUILD) + message( + WARNING + "\ + This CMake file is meant to be executed using 'scikit-build-core'. + Running it directly will almost certainly not produce the desired + result. If you are a user trying to install this package, use the + command below, which will install all necessary build dependencies, + compile the package in an isolated environment, and then install it. + ===================================================================== + $ pip install . + ===================================================================== + If you are a software developer, and this is your own package, then + it is usually much more efficient to install the build dependencies + in your environment once and use the following command that avoids + a costly creation of a new virtual environment at every compilation: + ===================================================================== + $ pip install nanobind scikit-build-core[pyproject] + $ pip install --no-build-isolation -ve . + ===================================================================== + You may optionally add -Ceditable.rebuild=true to auto-rebuild when + the package is imported. Otherwise, you need to rerun the above + after editing C++ files.") +endif() + +# nanobind needs the Python Interpreter and Development components +find_package( + Python 3.12 + COMPONENTS Interpreter Development.Module + OPTIONAL_COMPONENTS Development.SABIModule + REQUIRED) + +# Determine location of nanobind cmake config file +execute_process( + COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE nanobind_ROOT) + +# nanobind must already be installed as a Python module (e.g., with pip) +find_package(nanobind CONFIG REQUIRED) + +# Add the source files for the bindings +set(sundials_SOURCES + arkode/arkode_arkstep.cpp + arkode/arkode_erkstep.cpp + arkode/arkode_forcingstep.cpp + arkode/arkode_lsrkstep.cpp + arkode/arkode_mristep.cpp + arkode/arkode_splittingstep.cpp + arkode/arkode_sprkstep.cpp + arkode/arkode.cpp + cvodes/cvodes.cpp + idas/idas.cpp + kinsol/kinsol.cpp + nvector/nvector_serial.cpp + nvector/nvector_manyvector.cpp + sundials4py.cpp + sunadaptcontroller/sunadaptcontroller_imexgus.cpp + sunadaptcontroller/sunadaptcontroller_mrihtol.cpp + sunadaptcontroller/sunadaptcontroller_soderlind.cpp + sunadjointcheckpointscheme/sunadjointcheckpointscheme_fixed.cpp + sundials/sundials_adaptcontroller.cpp + sundials/sundials_adjointcheckpointscheme.cpp + sundials/sundials_adjointstepper.cpp + sundials/sundials_context.cpp + sundials/sundials_core.cpp + sundials/sundials_domeigestimator.cpp + sundials/sundials_linearsolver.cpp + sundials/sundials_logger.cpp + sundials/sundials_matrix.cpp + sundials/sundials_memory.cpp + sundials/sundials_nonlinearsolver.cpp + sundials/sundials_nvector.cpp + sundials/sundials_profiler.cpp + sundials/sundials_stepper.cpp + sundomeigest/sundomeigest_power.cpp + sunlinsol/sunlinsol_band.cpp + sunlinsol/sunlinsol_dense.cpp + sunlinsol/sunlinsol_pcg.cpp + sunlinsol/sunlinsol_spbcgs.cpp + sunlinsol/sunlinsol_spfgmr.cpp + sunlinsol/sunlinsol_spgmr.cpp + sunlinsol/sunlinsol_sptfqmr.cpp + sunmatrix/sunmatrix_band.cpp + sunmatrix/sunmatrix_dense.cpp + sunmatrix/sunmatrix_sparse.cpp + sunmemory/sunmemory_system.cpp + sunnonlinsol/sunnonlinsol_fixedpoint.cpp + sunnonlinsol/sunnonlinsol_newton.cpp) + +# Create the Python sundials library +nanobind_add_module(sundials4py STABLE_ABI NB_STATIC ${sundials_SOURCES}) + +# Include private header locations +target_include_directories( + sundials4py + PRIVATE ${CMAKE_CURRENT_LIST_DIR}/include ${SUNDIALS_SOURCE_DIR}/src + ${SUNDIALS_SOURCE_DIR}/src/arkode ${SUNDIALS_SOURCE_DIR}/src/cvodes + ${SUNDIALS_SOURCE_DIR}/src/idas ${SUNDIALS_SOURCE_DIR}/src/kinsol) + +# Link against sundials libraries +target_link_libraries( + sundials4py + PRIVATE sundials_arkode + sundials_cvodes + sundials_idas + sundials_kinsol + sundials_nvecserial + sundials_sunlinsolspgmr + sundials_sunlinsoldense + sundials_sunlinsolband + sundials_sunlinsolspbcgs + sundials_sunlinsolspfgmr + sundials_sunlinsolsptfqmr + sundials_sunlinsolpcg + sundials_sunmatrixband + sundials_sunmatrixdense + sundials_sunmatrixsparse + sundials_sundomeigestpower + sundials_core) + +# Install directive for scikit-build-core +install(TARGETS sundials4py LIBRARY DESTINATION .) diff --git a/bindings/sundials4py/arkode/arkode.cpp b/bindings/sundials4py/arkode/arkode.cpp new file mode 100644 index 0000000000..4db08507a8 --- /dev/null +++ b/bindings/sundials4py/arkode/arkode.cpp @@ -0,0 +1,244 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "arkode/arkode_impl.h" +#include "arkode_usersupplied.hpp" +#include "sundials_adjointcheckpointscheme_impl.h" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +// Forward declarations of functions defined in other translation units +void bind_arkode_erkstep(nb::module_& m); +void bind_arkode_arkstep(nb::module_& m); +void bind_arkode_sprkstep(nb::module_& m); +void bind_arkode_lsrkstep(nb::module_& m); +void bind_arkode_mristep(nb::module_& m); +void bind_arkode_forcingstep(nb::module_& m); +void bind_arkode_splittingstep(nb::module_& m); + +// ARKODE callback binding macros +#define BIND_ARKODE_CALLBACK(NAME, FN_TYPE, MEMBER, WRAPPER, ...) \ + m.def( \ + #NAME, \ + [](void* ark_mem, std::function> fn) \ + { \ + auto fn_table = get_arkode_fn_table(ark_mem); \ + fn_table->MEMBER = nb::cast(fn); \ + if (fn) { return NAME(ark_mem, &WRAPPER); } \ + else { return NAME(ark_mem, nullptr); } \ + }, \ + __VA_ARGS__) + +#define BIND_ARKODE_CALLBACK2(NAME, FN_TYPE1, MEMBER1, WRAPPER1, FN_TYPE2, \ + MEMBER2, WRAPPER2, ...) \ + m.def( \ + #NAME, \ + [](void* ark_mem, std::function> fn1, \ + std::function> fn2) \ + { \ + auto fn_table = get_arkode_fn_table(ark_mem); \ + fn_table->MEMBER1 = nb::cast(fn1); \ + fn_table->MEMBER2 = nb::cast(fn2); \ + if (fn1) { return NAME(ark_mem, &WRAPPER1, &WRAPPER2); } \ + else { return NAME(ark_mem, nullptr, &WRAPPER2); } \ + }, \ + __VA_ARGS__) + +void bind_arkode(nb::module_& m) +{ +#include "arkode_generated.hpp" + + ///////////////////////////////////////////////////////////////////////////// + // Interface view classes for ARKODE level objects + ///////////////////////////////////////////////////////////////////////////// + + nb::class_(m, "ARKodeView") + .def("get", nb::overload_cast<>(&ARKodeView::get, nb::const_), + nb::rv_policy::reference); + + ///////////////////////////////////////////////////////////////////////////// + // ARKODE user-supplied function setters + ///////////////////////////////////////////////////////////////////////////// + + m.def("ARKodeRootInit", + [](void* ark_mem, int nrtfn, + std::function> fn) + { + auto fn_table = get_arkode_fn_table(ark_mem); + fn_table->rootfn = nb::cast(fn); + return ARKodeRootInit(ark_mem, nrtfn, &arkode_rootfn_wrapper); + }); + + BIND_ARKODE_CALLBACK(ARKodeWFtolerances, ARKEwtFn, ewtn, arkode_ewtfn_wrapper, + nb::arg("arkode_mem"), nb::arg("efun").none()); + + BIND_ARKODE_CALLBACK(ARKodeResFtolerance, ARKRwtFn, rwtn, arkode_rwtfn_wrapper, + nb::arg("arkode_mem"), nb::arg("efun").none()); + + m.def( + "ARKodeResize", + [](void* ark_mem, N_Vector y_new, sunrealtype h_scale, sunrealtype t0, + std::function> fn) + { + auto fn_table = get_arkode_fn_table(ark_mem); + fn_table->vecresizefn = nb::cast(fn); + return ARKodeResize(ark_mem, y_new, h_scale, t0, + arkode_vecresizefn_wrapper, ark_mem); + }, + nb::arg("arkode_mem"), nb::arg("y_new"), nb::arg("h_scale"), nb::arg("t0"), + nb::arg("resize_fn").none()); + + BIND_ARKODE_CALLBACK2(ARKodeSetRelaxFn, ARKRelaxFn, relaxfn, + arkode_relaxfn_wrapper, ARKRelaxJacFn, relaxjacfn, + arkode_relaxjacfn_wrapper, nb::arg("arkode_mem"), + nb::arg("rfn").none(), nb::arg("rjacfn").none()); + + BIND_ARKODE_CALLBACK(ARKodeSetPostprocessStepFn, ARKPostProcessFn, + postprocessstepfn, arkode_postprocessstepfn_wrapper, + nb::arg("arkode_mem"), nb::arg("postprocessstep").none()); + + BIND_ARKODE_CALLBACK(ARKodeSetPostprocessStageFn, ARKPostProcessFn, + postprocessstagefn, arkode_postprocessstagefn_wrapper, + nb::arg("arkode_mem"), nb::arg("postprocessstage").none()); + + BIND_ARKODE_CALLBACK(ARKodeSetStagePredictFn, ARKStagePredictFn, + stagepredictfn, arkode_stagepredictfn_wrapper, + nb::arg("arkode_mem"), nb::arg("stagepredict").none()); + + BIND_ARKODE_CALLBACK(ARKodeSetNlsRhsFn, ARKRhsFn, nlsfi, + arkode_nlsrhsfn_wrapper, nb::arg("arkode_mem"), + nb::arg("nls_fi").none()); + + BIND_ARKODE_CALLBACK(ARKodeSetJacFn, ARKLsJacFn, lsjacfn, + arkode_lsjacfn_wrapper, nb::arg("arkode_mem"), + nb::arg("jac").none()); + + BIND_ARKODE_CALLBACK(ARKodeSetMassFn, ARKLsMassFn, lsmassfn, + arkode_lsmassfn_wrapper, nb::arg("arkode_mem"), + nb::arg("mass").none()); + + BIND_ARKODE_CALLBACK2(ARKodeSetPreconditioner, ARKLsPrecSetupFn, + lsprecsetupfn, arkode_lsprecsetupfn_wrapper, + ARKLsPrecSolveFn, lsprecsolvefn, + arkode_lsprecsolvefn_wrapper, nb::arg("arkode_mem"), + nb::arg("psetup").none(), nb::arg("psolve").none()); + + BIND_ARKODE_CALLBACK2(ARKodeSetMassPreconditioner, ARKLsMassPrecSetupFn, + lsmassprecsetupfn, arkode_lsmassprecsetupfn_wrapper, + ARKLsMassPrecSolveFn, lsmassprecsolvefn, + arkode_lsmassprecsolvefn_wrapper, nb::arg("arkode_mem"), + nb::arg("psetup").none(), nb::arg("psolve").none()); + + BIND_ARKODE_CALLBACK2(ARKodeSetJacTimes, ARKLsJacTimesSetupFn, + lsjactimessetupfn, arkode_lsjactimessetupfn_wrapper, + ARKLsJacTimesVecFn, lsjactimesvecfn, + arkode_lsjactimesvecfn_wrapper, nb::arg("arkode_mem"), + nb::arg("jtsetup").none(), nb::arg("jtimes").none()); + + BIND_ARKODE_CALLBACK(ARKodeSetJacTimesRhsFn, ARKRhsFn, lsjacrhsfn, + arkode_lsjacrhsfn_wrapper, nb::arg("arkode_mem"), + nb::arg("jtimesRhsFn").none()); + + m.def( + "ARKodeSetMassTimes", + [](void* ark_mem, + std::function> msetup, + std::function> mtimes) + { + auto fn_table = get_arkode_fn_table(ark_mem); + fn_table->lsmasstimessetupfn = nb::cast(msetup); + fn_table->lsmasstimesvecfn = nb::cast(mtimes); + return ARKodeSetMassTimes(ark_mem, &arkode_lsmasstimessetupfn_wrapper, + &arkode_lsmasstimesvecfn_wrapper, nullptr); + }, + nb::arg("arkode_mem"), nb::arg("msetup").none(), nb::arg("mtimes").none()); + + BIND_ARKODE_CALLBACK(ARKodeSetLinSysFn, ARKLsLinSysFn, lslinsysfn, + arkode_lslinsysfn_wrapper, nb::arg("arkode_mem"), + nb::arg("linsys").none()); + + // ARKodeSetMassTimes doesn't fit the BIND_ARKODE_CALLBACK macro pattern(s) + // due to the 4th argument for user data, so we just write it out explicitly. + m.def( + "ARKodeSetMassTimes", + [](void* ark_mem, + std::function> msetup, + std::function> mtimes) + { + auto fn_table = get_arkode_fn_table(ark_mem); + fn_table->lsmasstimessetupfn = nb::cast(msetup); + fn_table->lsmasstimesvecfn = nb::cast(mtimes); + return ARKodeSetMassTimes(ark_mem, &arkode_lsmasstimessetupfn_wrapper, + &arkode_lsmasstimesvecfn_wrapper, nullptr); + }, + nb::arg("ark_mem"), nb::arg("msetup").none(), nb::arg("mtimes").none()); + + ///////////////////////////////////////////////////////////////////////////// + // Additional functions that litgen cannot generate + ///////////////////////////////////////////////////////////////////////////// + + m.def( + "ARKodeSetOptions", + [](void* ark_mem, const std::string& arkid, const std::string& file_name, + int argc, const std::vector& args) + { + std::vector argv; + + for (const auto& arg : args) + { + // We need a non-const char*, so we use data() and an explicit cast. + // This is safe as long as the underlying std::string is not modified. + argv.push_back(const_cast(arg.data())); + } + + return ARKodeSetOptions(ark_mem, arkid.empty() ? nullptr : arkid.c_str(), + file_name.empty() ? nullptr : file_name.c_str(), + argc, argv.data()); + }, + nb::arg("ark_mem"), nb::arg("arkid"), nb::arg("file_name"), nb::arg("argc"), + nb::arg("args")); + + // This function has optional arguments which litgen cannot deal with because they are followed by non-optional arguments. + m.def("ARKodeSetMassLinearSolver", ARKodeSetMassLinearSolver, + nb::arg("arkode_mem"), nb::arg("LS"), nb::arg("M").none(), + nb::arg("time_dep")); + + bind_arkode_arkstep(m); + bind_arkode_erkstep(m); + bind_arkode_sprkstep(m); + bind_arkode_lsrkstep(m); + bind_arkode_mristep(m); + bind_arkode_forcingstep(m); + bind_arkode_splittingstep(m); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/arkode/arkode_arkstep.cpp b/bindings/sundials4py/arkode/arkode_arkstep.cpp new file mode 100644 index 0000000000..c478d118bc --- /dev/null +++ b/bindings/sundials4py/arkode/arkode_arkstep.cpp @@ -0,0 +1,112 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include +#include + +#include "arkode_impl.h" +#include "arkode_usersupplied.hpp" +#include "sundials_adjointstepper_impl.h" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_arkode_arkstep(nb::module_& m) +{ +#include "arkode_arkstep_generated.hpp" + + m.def( + "ARKStepCreate", + [](std::function> fe, + std::function> fi, sunrealtype t0, + N_Vector y0, SUNContext sunctx) + { + auto fe_wrapper = fe ? arkstep_fe_wrapper : nullptr; + auto fi_wrapper = fi ? arkstep_fi_wrapper : nullptr; + + void* ark_mem = ARKStepCreate(fe_wrapper, fi_wrapper, t0, y0, sunctx); + if (ark_mem == nullptr) + { + throw sundials4py::error_returned("Failed to create ARKODE memory"); + } + + // Create the user-supplied function table to store the Python user functions + auto fn_table = arkode_user_supplied_fn_table_alloc(); + + // Smuggle the user-supplied function table into callback wrappers through the user_data pointer + static_cast(ark_mem)->python = fn_table; + int ark_status = ARKodeSetUserData(ark_mem, ark_mem); + if (ark_status != ARK_SUCCESS) + { + free(fn_table); + throw sundials4py::error_returned( + "Failed to set user data in ARKODE memory"); + } + + // Finally, set the RHS functions + if (fe) { fn_table->arkstep_fe = nb::cast(fe); } + if (fi) { fn_table->arkstep_fi = nb::cast(fi); } + + return std::make_shared(ark_mem); + }, // .none() must be added to functions that accept nullptr as a valid argument + nb::arg("fe").none(), nb::arg("fi").none(), nb::arg("t0"), nb::arg("y0"), + nb::arg("sunctx"), nb::keep_alive<0, 5>()); + + m.def( + "ARKStepCreateAdjointStepper", + [](void* arkode_mem, std::function> adj_fe, + std::function> adj_fi, sunrealtype tf, + N_Vector sf, SUNContext sunctx) -> std::tuple + { + auto fe_wrapper = adj_fe ? arkstep_adjfe_wrapper : nullptr; + auto fi_wrapper = adj_fi ? arkstep_adjfi_wrapper : nullptr; + + SUNAdjointStepper adj_stepper = nullptr; + int ark_status = ARKStepCreateAdjointStepper(arkode_mem, fe_wrapper, + fi_wrapper, tf, sf, sunctx, + &adj_stepper); + if (ark_status != ARK_SUCCESS) + { + throw sundials4py::error_returned( + "Failed to create adjoint stepper in py-sundials memory"); + } + + // Finally, set the RHS functions + void* user_data = nullptr; + ark_status = ARKodeGetUserData(arkode_mem, &user_data); + if (ark_status != ARK_SUCCESS) + { + throw sundials4py::error_returned("Failed to extract ARKODE user data"); + } + + auto fn_table = get_arkode_fn_table(arkode_mem); + + if (adj_fe) { fn_table->arkstep_adjfe = nb::cast(adj_fe); } + if (adj_fi) { fn_table->arkstep_adjfi = nb::cast(adj_fi); } + + return std::make_tuple(ark_status, adj_stepper); + }, + nb::arg("arkode_mem"), nb::arg("adj_fe").none(), nb::arg("adj_fi").none(), + nb::arg("tf"), nb::arg("sf"), nb::arg("sunctx")); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/arkode/arkode_arkstep_generated.hpp b/bindings/sundials4py/arkode/arkode_arkstep_generated.hpp new file mode 100644 index 0000000000..cd2a1a4775 --- /dev/null +++ b/bindings/sundials4py/arkode/arkode_arkstep_generated.hpp @@ -0,0 +1,85 @@ +// #ifndef _ARKSTEP_H +// +// #ifdef __cplusplus +// #endif +// + +m.def("ARKStepSetExplicit", ARKStepSetExplicit, nb::arg("arkode_mem")); + +m.def("ARKStepSetImplicit", ARKStepSetImplicit, nb::arg("arkode_mem")); + +m.def("ARKStepSetImEx", ARKStepSetImEx, nb::arg("arkode_mem")); + +m.def("ARKStepSetTables", ARKStepSetTables, nb::arg("arkode_mem"), nb::arg("q"), + nb::arg("p"), nb::arg("Bi"), nb::arg("Be")); + +m.def("ARKStepSetTableNum", ARKStepSetTableNum, nb::arg("arkode_mem"), + nb::arg("itable"), nb::arg("etable")); + +m.def("ARKStepSetTableName", ARKStepSetTableName, nb::arg("arkode_mem"), + nb::arg("itable"), nb::arg("etable")); + +m.def( + "ARKStepGetCurrentButcherTables", + [](void* arkode_mem) -> std::tuple + { + auto ARKStepGetCurrentButcherTables_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + ARKodeButcherTable Bi_adapt_modifiable; + ARKodeButcherTable Be_adapt_modifiable; + + int r = ARKStepGetCurrentButcherTables(arkode_mem, &Bi_adapt_modifiable, + &Be_adapt_modifiable); + return std::make_tuple(r, Bi_adapt_modifiable, Be_adapt_modifiable); + }; + + return ARKStepGetCurrentButcherTables_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem"), + " Optional output functions\n\n nb::rv_policy::reference", + nb::rv_policy::reference); + +m.def( + "ARKStepGetTimestepperStats", + [](void* arkode_mem) -> std::tuple + { + auto ARKStepGetTimestepperStats_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) + -> std::tuple + { + long expsteps_adapt_modifiable; + long accsteps_adapt_modifiable; + long step_attempts_adapt_modifiable; + long nfe_evals_adapt_modifiable; + long nfi_evals_adapt_modifiable; + long nlinsetups_adapt_modifiable; + long netfails_adapt_modifiable; + + int r = ARKStepGetTimestepperStats(arkode_mem, &expsteps_adapt_modifiable, + &accsteps_adapt_modifiable, + &step_attempts_adapt_modifiable, + &nfe_evals_adapt_modifiable, + &nfi_evals_adapt_modifiable, + &nlinsetups_adapt_modifiable, + &netfails_adapt_modifiable); + return std::make_tuple(r, expsteps_adapt_modifiable, + accsteps_adapt_modifiable, + step_attempts_adapt_modifiable, + nfe_evals_adapt_modifiable, + nfi_evals_adapt_modifiable, + nlinsetups_adapt_modifiable, + netfails_adapt_modifiable); + }; + + return ARKStepGetTimestepperStats_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/arkode/arkode_erkstep.cpp b/bindings/sundials4py/arkode/arkode_erkstep.cpp new file mode 100644 index 0000000000..fe8b4f3475 --- /dev/null +++ b/bindings/sundials4py/arkode/arkode_erkstep.cpp @@ -0,0 +1,98 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include +#include + +#include "arkode_impl.h" +#include "arkode_usersupplied.hpp" +#include "sundials_adjointstepper_impl.h" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_arkode_erkstep(nb::module_& m) +{ +#include "arkode_erkstep_generated.hpp" + + m.def( + "ERKStepCreate", + [](std::function> rhs, sunrealtype t0, + N_Vector y0, SUNContext sunctx) + { + if (!rhs) { throw sundials4py::illegal_value("rhs was null"); } + + void* ark_mem = ERKStepCreate(erkstep_f_wrapper, t0, y0, sunctx); + if (ark_mem == nullptr) + { + throw sundials4py::error_returned("Failed to create ARKODE memory"); + } + + // Create the user-supplied function table to store the Python user functions + auto fn_table = arkode_user_supplied_fn_table_alloc(); + + // Smuggle the user-supplied function table into callback wrappers through the user_data pointer + static_cast(ark_mem)->python = fn_table; + int ark_status = ARKodeSetUserData(ark_mem, ark_mem); + if (ark_status != ARK_SUCCESS) + { + free(fn_table); + throw sundials4py::error_returned( + "Failed to set user data in ARKODE memory"); + } + + // Finally, set the RHS function + fn_table->erkstep_f = nb::cast(rhs); + + return std::make_shared(ark_mem); + }, + nb::arg("rhs"), nb::arg("t0"), nb::arg("y0"), nb::arg("sunctx"), + nb::keep_alive<0, 4>()); + + m.def( + "ERKStepCreateAdjointStepper", + [](void* arkode_mem, + std::function> adj_f, sunrealtype tf, + N_Vector sf, SUNContext sunctx) -> std::tuple + { + if (!adj_f) { throw sundials4py::illegal_value("adj_f was null"); } + + SUNAdjointStepper adj_stepper = nullptr; + int ark_status = ERKStepCreateAdjointStepper(arkode_mem, + erkstep_adjf_wrapper, tf, sf, + sunctx, &adj_stepper); + if (ark_status != ARK_SUCCESS) + { + throw sundials4py::error_returned("Failed to create adjoint stepper"); + } + + auto fn_table = get_arkode_fn_table(arkode_mem); + + fn_table->erkstep_adjf = nb::cast(adj_f); + + return std::make_tuple(ark_status, adj_stepper); + }, + nb::arg("arkode_mem"), nb::arg("adj_f").none(), nb::arg("tf"), + nb::arg("sf"), nb::arg("sunctx")); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/arkode/arkode_erkstep_generated.hpp b/bindings/sundials4py/arkode/arkode_erkstep_generated.hpp new file mode 100644 index 0000000000..019ac44004 --- /dev/null +++ b/bindings/sundials4py/arkode/arkode_erkstep_generated.hpp @@ -0,0 +1,68 @@ +// #ifndef _ERKSTEP_H +// +// #ifdef __cplusplus +// #endif +// + +m.def("ERKStepSetTable", ERKStepSetTable, nb::arg("arkode_mem"), nb::arg("B")); + +m.def("ERKStepSetTableNum", ERKStepSetTableNum, nb::arg("arkode_mem"), + nb::arg("etable")); + +m.def("ERKStepSetTableName", ERKStepSetTableName, nb::arg("arkode_mem"), + nb::arg("etable")); + +m.def( + "ERKStepGetCurrentButcherTable", + [](void* arkode_mem) -> std::tuple + { + auto ERKStepGetCurrentButcherTable_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + ARKodeButcherTable B_adapt_modifiable; + + int r = ERKStepGetCurrentButcherTable(arkode_mem, &B_adapt_modifiable); + return std::make_tuple(r, B_adapt_modifiable); + }; + + return ERKStepGetCurrentButcherTable_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem"), + " Optional output functions\n\n nb::rv_policy::reference", + nb::rv_policy::reference); + +m.def( + "ERKStepGetTimestepperStats", + [](void* arkode_mem) -> std::tuple + { + auto ERKStepGetTimestepperStats_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long expsteps_adapt_modifiable; + long accsteps_adapt_modifiable; + long step_attempts_adapt_modifiable; + long nfevals_adapt_modifiable; + long netfails_adapt_modifiable; + + int r = ERKStepGetTimestepperStats(arkode_mem, &expsteps_adapt_modifiable, + &accsteps_adapt_modifiable, + &step_attempts_adapt_modifiable, + &nfevals_adapt_modifiable, + &netfails_adapt_modifiable); + return std::make_tuple(r, expsteps_adapt_modifiable, + accsteps_adapt_modifiable, + step_attempts_adapt_modifiable, + nfevals_adapt_modifiable, netfails_adapt_modifiable); + }; + + return ERKStepGetTimestepperStats_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem"), "Grouped optional output functions"); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/arkode/arkode_forcingstep.cpp b/bindings/sundials4py/arkode/arkode_forcingstep.cpp new file mode 100644 index 0000000000..71a40e5bbc --- /dev/null +++ b/bindings/sundials4py/arkode/arkode_forcingstep.cpp @@ -0,0 +1,48 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include +#include +#include + +#include "arkode_mristep_impl.h" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_arkode_forcingstep(nb::module_& m) +{ +#include "arkode_forcingstep_generated.hpp" + + m.def( + "ForcingStepCreate", + [](SUNStepper stepper1, SUNStepper stepper2, sunrealtype t0, N_Vector y0, + SUNContext sunctx) + { + return std::make_shared( + ForcingStepCreate(stepper1, stepper2, t0, y0, sunctx)); + }, + nb::arg("stepper1"), nb::arg("stepper2"), nb::arg("t0"), nb::arg("y0"), + nb::arg("sunctx"), nb::keep_alive<0, 5>()); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/arkode/arkode_forcingstep_generated.hpp b/bindings/sundials4py/arkode/arkode_forcingstep_generated.hpp new file mode 100644 index 0000000000..ca5d33a7c3 --- /dev/null +++ b/bindings/sundials4py/arkode/arkode_forcingstep_generated.hpp @@ -0,0 +1,33 @@ +// #ifndef ARKODE_FORCINGINGSTEP_H_ +// +// #ifdef __cplusplus +// #endif +// + +m.def("ForcingStepReInit", ForcingStepReInit, nb::arg("arkode_mem"), + nb::arg("stepper1"), nb::arg("stepper2"), nb::arg("t0"), nb::arg("y0")); + +m.def( + "ForcingStepGetNumEvolves", + [](void* arkode_mem, int partition) -> std::tuple + { + auto ForcingStepGetNumEvolves_adapt_modifiable_immutable_to_return = + [](void* arkode_mem, int partition) -> std::tuple + { + long evolves_adapt_modifiable; + + int r = ForcingStepGetNumEvolves(arkode_mem, partition, + &evolves_adapt_modifiable); + return std::make_tuple(r, evolves_adapt_modifiable); + }; + + return ForcingStepGetNumEvolves_adapt_modifiable_immutable_to_return(arkode_mem, + partition); + }, + nb::arg("arkode_mem"), nb::arg("partition")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/arkode/arkode_generated.hpp b/bindings/sundials4py/arkode/arkode_generated.hpp new file mode 100644 index 0000000000..47b9a33b90 --- /dev/null +++ b/bindings/sundials4py/arkode/arkode_generated.hpp @@ -0,0 +1,1922 @@ +// #ifndef _ARKODE_H +// +// #ifdef __cplusplus +// #endif +// +m.attr("ARK_NORMAL") = 1; +m.attr("ARK_ONE_STEP") = 2; +m.attr("ARK_ADAPT_CUSTOM") = -1; +m.attr("ARK_ADAPT_PID") = 0; +m.attr("ARK_ADAPT_PI") = 1; +m.attr("ARK_ADAPT_I") = 2; +m.attr("ARK_ADAPT_EXP_GUS") = 3; +m.attr("ARK_ADAPT_IMP_GUS") = 4; +m.attr("ARK_ADAPT_IMEX_GUS") = 5; +m.attr("ARK_FULLRHS_START") = 0; +m.attr("ARK_FULLRHS_END") = 1; +m.attr("ARK_FULLRHS_OTHER") = 2; +m.attr("ARK_INTERP_MAX_DEGREE") = 5; +m.attr("ARK_INTERP_NONE") = -1; +m.attr("ARK_INTERP_HERMITE") = 0; +m.attr("ARK_INTERP_LAGRANGE") = 1; +m.attr("ARK_SUCCESS") = 0; +m.attr("ARK_TSTOP_RETURN") = 1; +m.attr("ARK_ROOT_RETURN") = 2; +m.attr("ARK_WARNING") = 99; +m.attr("ARK_TOO_MUCH_WORK") = -1; +m.attr("ARK_TOO_MUCH_ACC") = -2; +m.attr("ARK_ERR_FAILURE") = -3; +m.attr("ARK_CONV_FAILURE") = -4; +m.attr("ARK_LINIT_FAIL") = -5; +m.attr("ARK_LSETUP_FAIL") = -6; +m.attr("ARK_LSOLVE_FAIL") = -7; +m.attr("ARK_RHSFUNC_FAIL") = -8; +m.attr("ARK_FIRST_RHSFUNC_ERR") = -9; +m.attr("ARK_REPTD_RHSFUNC_ERR") = -10; +m.attr("ARK_UNREC_RHSFUNC_ERR") = -11; +m.attr("ARK_RTFUNC_FAIL") = -12; +m.attr("ARK_LFREE_FAIL") = -13; +m.attr("ARK_MASSINIT_FAIL") = -14; +m.attr("ARK_MASSSETUP_FAIL") = -15; +m.attr("ARK_MASSSOLVE_FAIL") = -16; +m.attr("ARK_MASSFREE_FAIL") = -17; +m.attr("ARK_MASSMULT_FAIL") = -18; +m.attr("ARK_CONSTR_FAIL") = -19; +m.attr("ARK_MEM_FAIL") = -20; +m.attr("ARK_MEM_NULL") = -21; +m.attr("ARK_ILL_INPUT") = -22; +m.attr("ARK_NO_MALLOC") = -23; +m.attr("ARK_BAD_K") = -24; +m.attr("ARK_BAD_T") = -25; +m.attr("ARK_BAD_DKY") = -26; +m.attr("ARK_TOO_CLOSE") = -27; +m.attr("ARK_VECTOROP_ERR") = -28; +m.attr("ARK_NLS_INIT_FAIL") = -29; +m.attr("ARK_NLS_SETUP_FAIL") = -30; +m.attr("ARK_NLS_SETUP_RECVR") = -31; +m.attr("ARK_NLS_OP_ERR") = -32; +m.attr("ARK_INNERSTEP_ATTACH_ERR") = -33; +m.attr("ARK_INNERSTEP_FAIL") = -34; +m.attr("ARK_OUTERTOINNER_FAIL") = -35; +m.attr("ARK_INNERTOOUTER_FAIL") = -36; +m.attr("ARK_POSTPROCESS_FAIL") = -37; +m.attr("ARK_POSTPROCESS_STEP_FAIL") = -37; +m.attr("ARK_POSTPROCESS_STAGE_FAIL") = -38; +m.attr("ARK_USER_PREDICT_FAIL") = -39; +m.attr("ARK_INTERP_FAIL") = -40; +m.attr("ARK_INVALID_TABLE") = -41; +m.attr("ARK_CONTEXT_ERR") = -42; +m.attr("ARK_RELAX_FAIL") = -43; +m.attr("ARK_RELAX_MEM_NULL") = -44; +m.attr("ARK_RELAX_FUNC_FAIL") = -45; +m.attr("ARK_RELAX_JAC_FAIL") = -46; +m.attr("ARK_CONTROLLER_ERR") = -47; +m.attr("ARK_STEPPER_UNSUPPORTED") = -48; +m.attr("ARK_DOMEIG_FAIL") = -49; +m.attr("ARK_MAX_STAGE_LIMIT_FAIL") = -50; +m.attr("ARK_SUNSTEPPER_ERR") = -51; +m.attr("ARK_STEP_DIRECTION_ERR") = -52; +m.attr("ARK_ADJ_CHECKPOINT_FAIL") = -53; +m.attr("ARK_ADJ_RECOMPUTE_FAIL") = -54; +m.attr("ARK_SUNADJSTEPPER_ERR") = -55; +m.attr("ARK_DEE_FAIL") = -56; +m.attr("ARK_UNRECOGNIZED_ERROR") = -99; + +auto pyEnumARKRelaxSolver = + nb::enum_(m, "ARKRelaxSolver", nb::is_arithmetic(), + " --------------------------\n * Relaxation Solver " + "Options\n * --------------------------") + .value("ARK_RELAX_BRENT", ARK_RELAX_BRENT, "") + .value("ARK_RELAX_NEWTON", ARK_RELAX_NEWTON, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +auto pyEnumARKAccumError = + nb::enum_(m, "ARKAccumError", nb::is_arithmetic(), "") + .value("ARK_ACCUMERROR_NONE", ARK_ACCUMERROR_NONE, "") + .value("ARK_ACCUMERROR_MAX", ARK_ACCUMERROR_MAX, "") + .value("ARK_ACCUMERROR_SUM", ARK_ACCUMERROR_SUM, "") + .value("ARK_ACCUMERROR_AVG", ARK_ACCUMERROR_AVG, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +m.def("ARKodeReset", ARKodeReset, nb::arg("arkode_mem"), nb::arg("tR"), + nb::arg("yR")); + +m.def( + "ARKodeCreateMRIStepInnerStepper", + [](void* arkode_mem) + -> std::tuple>> + { + auto ARKodeCreateMRIStepInnerStepper_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + MRIStepInnerStepper stepper_adapt_modifiable; + + int r = ARKodeCreateMRIStepInnerStepper(arkode_mem, + &stepper_adapt_modifiable); + return std::make_tuple(r, stepper_adapt_modifiable); + }; + auto ARKodeCreateMRIStepInnerStepper_adapt_return_type_to_shared_ptr = + [&ARKodeCreateMRIStepInnerStepper_adapt_modifiable_immutable_to_return]( + void* arkode_mem) + -> std::tuple>> + { + auto lambda_result = + ARKodeCreateMRIStepInnerStepper_adapt_modifiable_immutable_to_return( + arkode_mem); + + return std::make_tuple(std::get<0>(lambda_result), + our_make_shared, + MRIStepInnerStepperDeleter>( + std::get<1>(lambda_result))); + }; + + return ARKodeCreateMRIStepInnerStepper_adapt_return_type_to_shared_ptr( + arkode_mem); + }, + nb::arg("arkode_mem"), "Utility to wrap ARKODE as an MRIStepInnerStepper", + nb::rv_policy::reference); + +m.def("ARKodeSStolerances", ARKodeSStolerances, nb::arg("arkode_mem"), + nb::arg("reltol"), nb::arg("abstol")); + +m.def("ARKodeSVtolerances", ARKodeSVtolerances, nb::arg("arkode_mem"), + nb::arg("reltol"), nb::arg("abstol")); + +m.def("ARKodeResStolerance", ARKodeResStolerance, nb::arg("arkode_mem"), + nb::arg("rabstol")); + +m.def("ARKodeResVtolerance", ARKodeResVtolerance, nb::arg("arkode_mem"), + nb::arg("rabstol")); + +m.def( + "ARKodeSetRootDirection", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeSetRootDirection_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + int rootdir_adapt_modifiable; + + int r = ARKodeSetRootDirection(arkode_mem, &rootdir_adapt_modifiable); + return std::make_tuple(r, rootdir_adapt_modifiable); + }; + + return ARKodeSetRootDirection_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def("ARKodeSetNoInactiveRootWarn", ARKodeSetNoInactiveRootWarn, + nb::arg("arkode_mem")); + +m.def("ARKodeSetDefaults", ARKodeSetDefaults, nb::arg("arkode_mem")); + +m.def("ARKodeSetOrder", ARKodeSetOrder, nb::arg("arkode_mem"), nb::arg("maxord")); + +m.def("ARKodeSetInterpolantType", ARKodeSetInterpolantType, + nb::arg("arkode_mem"), nb::arg("itype")); + +m.def("ARKodeSetInterpolantDegree", ARKodeSetInterpolantDegree, + nb::arg("arkode_mem"), nb::arg("degree")); + +m.def("ARKodeSetMaxNumSteps", ARKodeSetMaxNumSteps, nb::arg("arkode_mem"), + nb::arg("mxsteps")); + +m.def("ARKodeSetInterpolateStopTime", ARKodeSetInterpolateStopTime, + nb::arg("arkode_mem"), nb::arg("interp")); + +m.def("ARKodeSetStopTime", ARKodeSetStopTime, nb::arg("arkode_mem"), + nb::arg("tstop")); + +m.def("ARKodeClearStopTime", ARKodeClearStopTime, nb::arg("arkode_mem")); + +m.def("ARKodeSetFixedStep", ARKodeSetFixedStep, nb::arg("arkode_mem"), + nb::arg("hfixed")); + +m.def("ARKodeSetStepDirection", ARKodeSetStepDirection, nb::arg("arkode_mem"), + nb::arg("stepdir")); + +m.def("ARKodeSetNonlinearSolver", ARKodeSetNonlinearSolver, + nb::arg("arkode_mem"), nb::arg("NLS")); + +m.def("ARKodeSetLinear", ARKodeSetLinear, nb::arg("arkode_mem"), + nb::arg("timedepend")); + +m.def("ARKodeSetNonlinear", ARKodeSetNonlinear, nb::arg("arkode_mem")); + +m.def("ARKodeSetAutonomous", ARKodeSetAutonomous, nb::arg("arkode_mem"), + nb::arg("autonomous")); + +m.def("ARKodeSetDeduceImplicitRhs", ARKodeSetDeduceImplicitRhs, + nb::arg("arkode_mem"), nb::arg("deduce")); + +m.def("ARKodeSetNonlinCRDown", ARKodeSetNonlinCRDown, nb::arg("arkode_mem"), + nb::arg("crdown")); + +m.def("ARKodeSetNonlinRDiv", ARKodeSetNonlinRDiv, nb::arg("arkode_mem"), + nb::arg("rdiv")); + +m.def("ARKodeSetDeltaGammaMax", ARKodeSetDeltaGammaMax, nb::arg("arkode_mem"), + nb::arg("dgmax")); + +m.def("ARKodeSetLSetupFrequency", ARKodeSetLSetupFrequency, + nb::arg("arkode_mem"), nb::arg("msbp")); + +m.def("ARKodeSetPredictorMethod", ARKodeSetPredictorMethod, + nb::arg("arkode_mem"), nb::arg("method")); + +m.def("ARKodeSetMaxNonlinIters", ARKodeSetMaxNonlinIters, nb::arg("arkode_mem"), + nb::arg("maxcor")); + +m.def("ARKodeSetMaxConvFails", ARKodeSetMaxConvFails, nb::arg("arkode_mem"), + nb::arg("maxncf")); + +m.def("ARKodeSetNonlinConvCoef", ARKodeSetNonlinConvCoef, nb::arg("arkode_mem"), + nb::arg("nlscoef")); + +m.def("ARKodeSetAdaptController", ARKodeSetAdaptController, + nb::arg("arkode_mem"), nb::arg("C")); + +m.def("ARKodeSetAdaptControllerByName", ARKodeSetAdaptControllerByName, + nb::arg("arkode_mem"), nb::arg("cname")); + +m.def("ARKodeSetAdaptivityAdjustment", ARKodeSetAdaptivityAdjustment, + nb::arg("arkode_mem"), nb::arg("adjust")); + +m.def("ARKodeSetCFLFraction", ARKodeSetCFLFraction, nb::arg("arkode_mem"), + nb::arg("cfl_frac")); + +m.def("ARKodeSetErrorBias", ARKodeSetErrorBias, nb::arg("arkode_mem"), + nb::arg("bias")); + +m.def("ARKodeSetSafetyFactor", ARKodeSetSafetyFactor, nb::arg("arkode_mem"), + nb::arg("safety")); + +m.def("ARKodeSetMaxGrowth", ARKodeSetMaxGrowth, nb::arg("arkode_mem"), + nb::arg("mx_growth")); + +m.def("ARKodeSetMinReduction", ARKodeSetMinReduction, nb::arg("arkode_mem"), + nb::arg("eta_min")); + +m.def("ARKodeSetFixedStepBounds", ARKodeSetFixedStepBounds, + nb::arg("arkode_mem"), nb::arg("lb"), nb::arg("ub")); + +m.def("ARKodeSetMaxFirstGrowth", ARKodeSetMaxFirstGrowth, nb::arg("arkode_mem"), + nb::arg("etamx1")); + +m.def("ARKodeSetMaxEFailGrowth", ARKodeSetMaxEFailGrowth, nb::arg("arkode_mem"), + nb::arg("etamxf")); + +m.def("ARKodeSetSmallNumEFails", ARKodeSetSmallNumEFails, nb::arg("arkode_mem"), + nb::arg("small_nef")); + +m.def("ARKodeSetMaxCFailGrowth", ARKodeSetMaxCFailGrowth, nb::arg("arkode_mem"), + nb::arg("etacf")); + +m.def("ARKodeSetMaxErrTestFails", ARKodeSetMaxErrTestFails, + nb::arg("arkode_mem"), nb::arg("maxnef")); + +m.def("ARKodeSetConstraints", ARKodeSetConstraints, nb::arg("arkode_mem"), + nb::arg("constraints")); + +m.def("ARKodeSetMaxHnilWarns", ARKodeSetMaxHnilWarns, nb::arg("arkode_mem"), + nb::arg("mxhnil")); + +m.def("ARKodeSetInitStep", ARKodeSetInitStep, nb::arg("arkode_mem"), + nb::arg("hin")); + +m.def("ARKodeSetMinStep", ARKodeSetMinStep, nb::arg("arkode_mem"), + nb::arg("hmin")); + +m.def("ARKodeSetMaxStep", ARKodeSetMaxStep, nb::arg("arkode_mem"), + nb::arg("hmax")); + +m.def("ARKodeSetMaxNumConstrFails", ARKodeSetMaxNumConstrFails, + nb::arg("arkode_mem"), nb::arg("maxfails")); + +m.def("ARKodeSetAdjointCheckpointScheme", ARKodeSetAdjointCheckpointScheme, + nb::arg("arkode_mem"), nb::arg("checkpoint_scheme")); + +m.def("ARKodeSetAdjointCheckpointIndex", ARKodeSetAdjointCheckpointIndex, + nb::arg("arkode_mem"), nb::arg("step_index")); + +m.def("ARKodeSetUseCompensatedSums", ARKodeSetUseCompensatedSums, + nb::arg("arkode_mem"), nb::arg("onoff")); + +m.def("ARKodeSetAccumulatedErrorType", ARKodeSetAccumulatedErrorType, + nb::arg("arkode_mem"), nb::arg("accum_type")); + +m.def("ARKodeResetAccumulatedError", ARKodeResetAccumulatedError, + nb::arg("arkode_mem")); + +m.def( + "ARKodeEvolve", + [](void* arkode_mem, sunrealtype tout, N_Vector yout, + int itask) -> std::tuple + { + auto ARKodeEvolve_adapt_modifiable_immutable_to_return = + [](void* arkode_mem, sunrealtype tout, N_Vector yout, + int itask) -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + int r = ARKodeEvolve(arkode_mem, tout, yout, &tret_adapt_modifiable, itask); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return ARKodeEvolve_adapt_modifiable_immutable_to_return(arkode_mem, tout, + yout, itask); + }, + nb::arg("arkode_mem"), nb::arg("tout"), nb::arg("yout"), nb::arg("itask"), + "Integrate the ODE over an interval in t"); + +m.def("ARKodeGetDky", ARKodeGetDky, nb::arg("arkode_mem"), nb::arg("t"), + nb::arg("k"), nb::arg("dky"), + "Computes the kth derivative of the y function at time t"); + +m.def("ARKodeComputeState", ARKodeComputeState, nb::arg("arkode_mem"), + nb::arg("zcor"), nb::arg("z"), + "Utility function to update/compute y based on zcor"); + +m.def( + "ARKodeGetNumRhsEvals", + [](void* arkode_mem, int partition_index) -> std::tuple + { + auto ARKodeGetNumRhsEvals_adapt_modifiable_immutable_to_return = + [](void* arkode_mem, int partition_index) -> std::tuple + { + long num_rhs_evals_adapt_modifiable; + + int r = ARKodeGetNumRhsEvals(arkode_mem, partition_index, + &num_rhs_evals_adapt_modifiable); + return std::make_tuple(r, num_rhs_evals_adapt_modifiable); + }; + + return ARKodeGetNumRhsEvals_adapt_modifiable_immutable_to_return(arkode_mem, + partition_index); + }, + nb::arg("arkode_mem"), nb::arg("partition_index")); + +m.def( + "ARKodeGetNumStepAttempts", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumStepAttempts_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long step_attempts_adapt_modifiable; + + int r = ARKodeGetNumStepAttempts(arkode_mem, + &step_attempts_adapt_modifiable); + return std::make_tuple(r, step_attempts_adapt_modifiable); + }; + + return ARKodeGetNumStepAttempts_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumSteps", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumSteps_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nsteps_adapt_modifiable; + + int r = ARKodeGetNumSteps(arkode_mem, &nsteps_adapt_modifiable); + return std::make_tuple(r, nsteps_adapt_modifiable); + }; + + return ARKodeGetNumSteps_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetLastStep", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetLastStep_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + sunrealtype hlast_adapt_modifiable; + + int r = ARKodeGetLastStep(arkode_mem, &hlast_adapt_modifiable); + return std::make_tuple(r, hlast_adapt_modifiable); + }; + + return ARKodeGetLastStep_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetCurrentStep", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetCurrentStep_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + sunrealtype hcur_adapt_modifiable; + + int r = ARKodeGetCurrentStep(arkode_mem, &hcur_adapt_modifiable); + return std::make_tuple(r, hcur_adapt_modifiable); + }; + + return ARKodeGetCurrentStep_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetStepDirection", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetStepDirection_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + sunrealtype stepdir_adapt_modifiable; + + int r = ARKodeGetStepDirection(arkode_mem, &stepdir_adapt_modifiable); + return std::make_tuple(r, stepdir_adapt_modifiable); + }; + + return ARKodeGetStepDirection_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def("ARKodeGetErrWeights", ARKodeGetErrWeights, nb::arg("arkode_mem"), + nb::arg("eweight")); + +m.def( + "ARKodeGetNumGEvals", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumGEvals_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long ngevals_adapt_modifiable; + + int r = ARKodeGetNumGEvals(arkode_mem, &ngevals_adapt_modifiable); + return std::make_tuple(r, ngevals_adapt_modifiable); + }; + + return ARKodeGetNumGEvals_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetRootInfo", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetRootInfo_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + int rootsfound_adapt_modifiable; + + int r = ARKodeGetRootInfo(arkode_mem, &rootsfound_adapt_modifiable); + return std::make_tuple(r, rootsfound_adapt_modifiable); + }; + + return ARKodeGetRootInfo_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def("ARKodePrintAllStats", ARKodePrintAllStats, nb::arg("arkode_mem"), + nb::arg("outfile"), nb::arg("fmt")); + +m.def("ARKodeGetReturnFlagName", ARKodeGetReturnFlagName, nb::arg("flag")); + +m.def("ARKodeWriteParameters", ARKodeWriteParameters, nb::arg("arkode_mem"), + nb::arg("fp")); + +m.def( + "ARKodeGetNumExpSteps", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumExpSteps_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long expsteps_adapt_modifiable; + + int r = ARKodeGetNumExpSteps(arkode_mem, &expsteps_adapt_modifiable); + return std::make_tuple(r, expsteps_adapt_modifiable); + }; + + return ARKodeGetNumExpSteps_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumAccSteps", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumAccSteps_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long accsteps_adapt_modifiable; + + int r = ARKodeGetNumAccSteps(arkode_mem, &accsteps_adapt_modifiable); + return std::make_tuple(r, accsteps_adapt_modifiable); + }; + + return ARKodeGetNumAccSteps_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumErrTestFails", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumErrTestFails_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long netfails_adapt_modifiable; + + int r = ARKodeGetNumErrTestFails(arkode_mem, &netfails_adapt_modifiable); + return std::make_tuple(r, netfails_adapt_modifiable); + }; + + return ARKodeGetNumErrTestFails_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def("ARKodeGetEstLocalErrors", ARKodeGetEstLocalErrors, nb::arg("arkode_mem"), + nb::arg("ele")); + +m.def( + "ARKodeGetActualInitStep", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetActualInitStep_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + sunrealtype hinused_adapt_modifiable; + + int r = ARKodeGetActualInitStep(arkode_mem, &hinused_adapt_modifiable); + return std::make_tuple(r, hinused_adapt_modifiable); + }; + + return ARKodeGetActualInitStep_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetTolScaleFactor", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetTolScaleFactor_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + sunrealtype tolsfac_adapt_modifiable; + + int r = ARKodeGetTolScaleFactor(arkode_mem, &tolsfac_adapt_modifiable); + return std::make_tuple(r, tolsfac_adapt_modifiable); + }; + + return ARKodeGetTolScaleFactor_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumConstrFails", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumConstrFails_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nconstrfails_adapt_modifiable; + + int r = ARKodeGetNumConstrFails(arkode_mem, &nconstrfails_adapt_modifiable); + return std::make_tuple(r, nconstrfails_adapt_modifiable); + }; + + return ARKodeGetNumConstrFails_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetStepStats", + [](void* arkode_mem) + -> std::tuple + { + auto ARKodeGetStepStats_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) + -> std::tuple + { + long nsteps_adapt_modifiable; + sunrealtype hinused_adapt_modifiable; + sunrealtype hlast_adapt_modifiable; + sunrealtype hcur_adapt_modifiable; + sunrealtype tcur_adapt_modifiable; + + int r = ARKodeGetStepStats(arkode_mem, &nsteps_adapt_modifiable, + &hinused_adapt_modifiable, + &hlast_adapt_modifiable, + &hcur_adapt_modifiable, &tcur_adapt_modifiable); + return std::make_tuple(r, nsteps_adapt_modifiable, + hinused_adapt_modifiable, hlast_adapt_modifiable, + hcur_adapt_modifiable, tcur_adapt_modifiable); + }; + + return ARKodeGetStepStats_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetAccumulatedError", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetAccumulatedError_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + sunrealtype accum_error_adapt_modifiable; + + int r = ARKodeGetAccumulatedError(arkode_mem, + &accum_error_adapt_modifiable); + return std::make_tuple(r, accum_error_adapt_modifiable); + }; + + return ARKodeGetAccumulatedError_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumLinSolvSetups", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumLinSolvSetups_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nlinsetups_adapt_modifiable; + + int r = ARKodeGetNumLinSolvSetups(arkode_mem, &nlinsetups_adapt_modifiable); + return std::make_tuple(r, nlinsetups_adapt_modifiable); + }; + + return ARKodeGetNumLinSolvSetups_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetCurrentTime", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetCurrentTime_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + sunrealtype tcur_adapt_modifiable; + + int r = ARKodeGetCurrentTime(arkode_mem, &tcur_adapt_modifiable); + return std::make_tuple(r, tcur_adapt_modifiable); + }; + + return ARKodeGetCurrentTime_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetCurrentState", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetCurrentState_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + N_Vector state_adapt_modifiable; + + int r = ARKodeGetCurrentState(arkode_mem, &state_adapt_modifiable); + return std::make_tuple(r, state_adapt_modifiable); + }; + + return ARKodeGetCurrentState_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem"), "nb::rv_policy::reference", nb::rv_policy::reference); + +m.def( + "ARKodeGetCurrentGamma", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetCurrentGamma_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + sunrealtype gamma_adapt_modifiable; + + int r = ARKodeGetCurrentGamma(arkode_mem, &gamma_adapt_modifiable); + return std::make_tuple(r, gamma_adapt_modifiable); + }; + + return ARKodeGetCurrentGamma_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumNonlinSolvIters", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumNonlinSolvIters_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nniters_adapt_modifiable; + + int r = ARKodeGetNumNonlinSolvIters(arkode_mem, &nniters_adapt_modifiable); + return std::make_tuple(r, nniters_adapt_modifiable); + }; + + return ARKodeGetNumNonlinSolvIters_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumNonlinSolvConvFails", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumNonlinSolvConvFails_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nnfails_adapt_modifiable; + + int r = ARKodeGetNumNonlinSolvConvFails(arkode_mem, + &nnfails_adapt_modifiable); + return std::make_tuple(r, nnfails_adapt_modifiable); + }; + + return ARKodeGetNumNonlinSolvConvFails_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNonlinSolvStats", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNonlinSolvStats_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nniters_adapt_modifiable; + long nnfails_adapt_modifiable; + + int r = ARKodeGetNonlinSolvStats(arkode_mem, &nniters_adapt_modifiable, + &nnfails_adapt_modifiable); + return std::make_tuple(r, nniters_adapt_modifiable, + nnfails_adapt_modifiable); + }; + + return ARKodeGetNonlinSolvStats_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumStepSolveFails", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumStepSolveFails_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nncfails_adapt_modifiable; + + int r = ARKodeGetNumStepSolveFails(arkode_mem, &nncfails_adapt_modifiable); + return std::make_tuple(r, nncfails_adapt_modifiable); + }; + + return ARKodeGetNumStepSolveFails_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetJac", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetJac_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + SUNMatrix J_adapt_modifiable; + + int r = ARKodeGetJac(arkode_mem, &J_adapt_modifiable); + return std::make_tuple(r, J_adapt_modifiable); + }; + + return ARKodeGetJac_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem"), "nb::rv_policy::reference", nb::rv_policy::reference); + +m.def( + "ARKodeGetJacTime", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetJacTime_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + sunrealtype t_J_adapt_modifiable; + + int r = ARKodeGetJacTime(arkode_mem, &t_J_adapt_modifiable); + return std::make_tuple(r, t_J_adapt_modifiable); + }; + + return ARKodeGetJacTime_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetJacNumSteps", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetJacNumSteps_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nst_J_adapt_modifiable; + + int r = ARKodeGetJacNumSteps(arkode_mem, &nst_J_adapt_modifiable); + return std::make_tuple(r, nst_J_adapt_modifiable); + }; + + return ARKodeGetJacNumSteps_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumJacEvals", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumJacEvals_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long njevals_adapt_modifiable; + + int r = ARKodeGetNumJacEvals(arkode_mem, &njevals_adapt_modifiable); + return std::make_tuple(r, njevals_adapt_modifiable); + }; + + return ARKodeGetNumJacEvals_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumPrecEvals", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumPrecEvals_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long npevals_adapt_modifiable; + + int r = ARKodeGetNumPrecEvals(arkode_mem, &npevals_adapt_modifiable); + return std::make_tuple(r, npevals_adapt_modifiable); + }; + + return ARKodeGetNumPrecEvals_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumPrecSolves", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumPrecSolves_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long npsolves_adapt_modifiable; + + int r = ARKodeGetNumPrecSolves(arkode_mem, &npsolves_adapt_modifiable); + return std::make_tuple(r, npsolves_adapt_modifiable); + }; + + return ARKodeGetNumPrecSolves_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumLinIters", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumLinIters_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nliters_adapt_modifiable; + + int r = ARKodeGetNumLinIters(arkode_mem, &nliters_adapt_modifiable); + return std::make_tuple(r, nliters_adapt_modifiable); + }; + + return ARKodeGetNumLinIters_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumLinConvFails", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumLinConvFails_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nlcfails_adapt_modifiable; + + int r = ARKodeGetNumLinConvFails(arkode_mem, &nlcfails_adapt_modifiable); + return std::make_tuple(r, nlcfails_adapt_modifiable); + }; + + return ARKodeGetNumLinConvFails_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumJTSetupEvals", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumJTSetupEvals_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long njtsetups_adapt_modifiable; + + int r = ARKodeGetNumJTSetupEvals(arkode_mem, &njtsetups_adapt_modifiable); + return std::make_tuple(r, njtsetups_adapt_modifiable); + }; + + return ARKodeGetNumJTSetupEvals_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumJtimesEvals", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumJtimesEvals_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long njvevals_adapt_modifiable; + + int r = ARKodeGetNumJtimesEvals(arkode_mem, &njvevals_adapt_modifiable); + return std::make_tuple(r, njvevals_adapt_modifiable); + }; + + return ARKodeGetNumJtimesEvals_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumLinRhsEvals", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumLinRhsEvals_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nfevalsLS_adapt_modifiable; + + int r = ARKodeGetNumLinRhsEvals(arkode_mem, &nfevalsLS_adapt_modifiable); + return std::make_tuple(r, nfevalsLS_adapt_modifiable); + }; + + return ARKodeGetNumLinRhsEvals_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetLastLinFlag", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetLastLinFlag_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long flag_adapt_modifiable; + + int r = ARKodeGetLastLinFlag(arkode_mem, &flag_adapt_modifiable); + return std::make_tuple(r, flag_adapt_modifiable); + }; + + return ARKodeGetLastLinFlag_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def("ARKodeGetLinReturnFlagName", ARKodeGetLinReturnFlagName, nb::arg("flag")); + +m.def( + "ARKodeGetCurrentMassMatrix", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetCurrentMassMatrix_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + SUNMatrix M_adapt_modifiable; + + int r = ARKodeGetCurrentMassMatrix(arkode_mem, &M_adapt_modifiable); + return std::make_tuple(r, M_adapt_modifiable); + }; + + return ARKodeGetCurrentMassMatrix_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem"), " Optional output functions (non-identity mass matrices)\n\n nb::rv_policy::reference", + nb::rv_policy::reference); + +m.def("ARKodeGetResWeights", ARKodeGetResWeights, nb::arg("arkode_mem"), + nb::arg("rweight")); + +m.def( + "ARKodeGetNumMassSetups", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumMassSetups_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nmsetups_adapt_modifiable; + + int r = ARKodeGetNumMassSetups(arkode_mem, &nmsetups_adapt_modifiable); + return std::make_tuple(r, nmsetups_adapt_modifiable); + }; + + return ARKodeGetNumMassSetups_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumMassMultSetups", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumMassMultSetups_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nmvsetups_adapt_modifiable; + + int r = ARKodeGetNumMassMultSetups(arkode_mem, &nmvsetups_adapt_modifiable); + return std::make_tuple(r, nmvsetups_adapt_modifiable); + }; + + return ARKodeGetNumMassMultSetups_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumMassMult", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumMassMult_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nmvevals_adapt_modifiable; + + int r = ARKodeGetNumMassMult(arkode_mem, &nmvevals_adapt_modifiable); + return std::make_tuple(r, nmvevals_adapt_modifiable); + }; + + return ARKodeGetNumMassMult_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumMassSolves", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumMassSolves_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nmsolves_adapt_modifiable; + + int r = ARKodeGetNumMassSolves(arkode_mem, &nmsolves_adapt_modifiable); + return std::make_tuple(r, nmsolves_adapt_modifiable); + }; + + return ARKodeGetNumMassSolves_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumMassPrecEvals", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumMassPrecEvals_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nmpevals_adapt_modifiable; + + int r = ARKodeGetNumMassPrecEvals(arkode_mem, &nmpevals_adapt_modifiable); + return std::make_tuple(r, nmpevals_adapt_modifiable); + }; + + return ARKodeGetNumMassPrecEvals_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumMassPrecSolves", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumMassPrecSolves_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nmpsolves_adapt_modifiable; + + int r = ARKodeGetNumMassPrecSolves(arkode_mem, &nmpsolves_adapt_modifiable); + return std::make_tuple(r, nmpsolves_adapt_modifiable); + }; + + return ARKodeGetNumMassPrecSolves_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumMassIters", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumMassIters_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nmiters_adapt_modifiable; + + int r = ARKodeGetNumMassIters(arkode_mem, &nmiters_adapt_modifiable); + return std::make_tuple(r, nmiters_adapt_modifiable); + }; + + return ARKodeGetNumMassIters_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumMassConvFails", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumMassConvFails_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nmcfails_adapt_modifiable; + + int r = ARKodeGetNumMassConvFails(arkode_mem, &nmcfails_adapt_modifiable); + return std::make_tuple(r, nmcfails_adapt_modifiable); + }; + + return ARKodeGetNumMassConvFails_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumMTSetups", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumMTSetups_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nmtsetups_adapt_modifiable; + + int r = ARKodeGetNumMTSetups(arkode_mem, &nmtsetups_adapt_modifiable); + return std::make_tuple(r, nmtsetups_adapt_modifiable); + }; + + return ARKodeGetNumMTSetups_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetLastMassFlag", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetLastMassFlag_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long flag_adapt_modifiable; + + int r = ARKodeGetLastMassFlag(arkode_mem, &flag_adapt_modifiable); + return std::make_tuple(r, flag_adapt_modifiable); + }; + + return ARKodeGetLastMassFlag_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def("ARKodePrintMem", ARKodePrintMem, nb::arg("arkode_mem"), nb::arg("outfile"), + "Output the ARKODE memory structure (useful when debugging)"); + +m.def("ARKodeSetRelaxEtaFail", ARKodeSetRelaxEtaFail, nb::arg("arkode_mem"), + nb::arg("eta_rf")); + +m.def("ARKodeSetRelaxLowerBound", ARKodeSetRelaxLowerBound, + nb::arg("arkode_mem"), nb::arg("lower")); + +m.def("ARKodeSetRelaxMaxFails", ARKodeSetRelaxMaxFails, nb::arg("arkode_mem"), + nb::arg("max_fails")); + +m.def("ARKodeSetRelaxMaxIters", ARKodeSetRelaxMaxIters, nb::arg("arkode_mem"), + nb::arg("max_iters")); + +m.def("ARKodeSetRelaxSolver", ARKodeSetRelaxSolver, nb::arg("arkode_mem"), + nb::arg("solver")); + +m.def("ARKodeSetRelaxResTol", ARKodeSetRelaxResTol, nb::arg("arkode_mem"), + nb::arg("res_tol")); + +m.def("ARKodeSetRelaxTol", ARKodeSetRelaxTol, nb::arg("arkode_mem"), + nb::arg("rel_tol"), nb::arg("abs_tol")); + +m.def("ARKodeSetRelaxUpperBound", ARKodeSetRelaxUpperBound, + nb::arg("arkode_mem"), nb::arg("upper")); + +m.def( + "ARKodeGetNumRelaxFnEvals", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumRelaxFnEvals_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long r_evals_adapt_modifiable; + + int r = ARKodeGetNumRelaxFnEvals(arkode_mem, &r_evals_adapt_modifiable); + return std::make_tuple(r, r_evals_adapt_modifiable); + }; + + return ARKodeGetNumRelaxFnEvals_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumRelaxJacEvals", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumRelaxJacEvals_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long J_evals_adapt_modifiable; + + int r = ARKodeGetNumRelaxJacEvals(arkode_mem, &J_evals_adapt_modifiable); + return std::make_tuple(r, J_evals_adapt_modifiable); + }; + + return ARKodeGetNumRelaxJacEvals_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumRelaxFails", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumRelaxFails_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long relax_fails_adapt_modifiable; + + int r = ARKodeGetNumRelaxFails(arkode_mem, &relax_fails_adapt_modifiable); + return std::make_tuple(r, relax_fails_adapt_modifiable); + }; + + return ARKodeGetNumRelaxFails_adapt_modifiable_immutable_to_return(arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumRelaxBoundFails", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumRelaxBoundFails_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long fails_adapt_modifiable; + + int r = ARKodeGetNumRelaxBoundFails(arkode_mem, &fails_adapt_modifiable); + return std::make_tuple(r, fails_adapt_modifiable); + }; + + return ARKodeGetNumRelaxBoundFails_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumRelaxSolveFails", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumRelaxSolveFails_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long fails_adapt_modifiable; + + int r = ARKodeGetNumRelaxSolveFails(arkode_mem, &fails_adapt_modifiable); + return std::make_tuple(r, fails_adapt_modifiable); + }; + + return ARKodeGetNumRelaxSolveFails_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeGetNumRelaxSolveIters", + [](void* arkode_mem) -> std::tuple + { + auto ARKodeGetNumRelaxSolveIters_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long iters_adapt_modifiable; + + int r = ARKodeGetNumRelaxSolveIters(arkode_mem, &iters_adapt_modifiable); + return std::make_tuple(r, iters_adapt_modifiable); + }; + + return ARKodeGetNumRelaxSolveIters_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "ARKodeCreateSUNStepper", + [](void* arkode_mem) + -> std::tuple>> + { + auto ARKodeCreateSUNStepper_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + SUNStepper stepper_adapt_modifiable; + + int r = ARKodeCreateSUNStepper(arkode_mem, &stepper_adapt_modifiable); + return std::make_tuple(r, stepper_adapt_modifiable); + }; + auto ARKodeCreateSUNStepper_adapt_return_type_to_shared_ptr = + [&ARKodeCreateSUNStepper_adapt_modifiable_immutable_to_return]( + void* arkode_mem) + -> std::tuple>> + { + auto lambda_result = + ARKodeCreateSUNStepper_adapt_modifiable_immutable_to_return(arkode_mem); + + return std::make_tuple(std::get<0>(lambda_result), + our_make_shared, + SUNStepperDeleter>( + std::get<1>(lambda_result))); + }; + + return ARKodeCreateSUNStepper_adapt_return_type_to_shared_ptr(arkode_mem); + }, + nb::arg("arkode_mem"), "SUNStepper functions", nb::rv_policy::reference); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// +// #ifndef _ARKLS_H +// +// #ifdef __cplusplus +// #endif +// +m.attr("ARKLS_SUCCESS") = 0; +m.attr("ARKLS_MEM_NULL") = -1; +m.attr("ARKLS_LMEM_NULL") = -2; +m.attr("ARKLS_ILL_INPUT") = -3; +m.attr("ARKLS_MEM_FAIL") = -4; +m.attr("ARKLS_PMEM_NULL") = -5; +m.attr("ARKLS_MASSMEM_NULL") = -6; +m.attr("ARKLS_JACFUNC_UNRECVR") = -7; +m.attr("ARKLS_JACFUNC_RECVR") = -8; +m.attr("ARKLS_MASSFUNC_UNRECVR") = -9; +m.attr("ARKLS_MASSFUNC_RECVR") = -10; +m.attr("ARKLS_SUNMAT_FAIL") = -11; +m.attr("ARKLS_SUNLS_FAIL") = -12; + +m.def( + "ARKodeSetLinearSolver", + [](void* arkode_mem, SUNLinearSolver LS, + std::optional A = std::nullopt) -> int + { + auto ARKodeSetLinearSolver_adapt_optional_arg_with_default_null = + [](void* arkode_mem, SUNLinearSolver LS, + std::optional A = std::nullopt) -> int + { + SUNMatrix A_adapt_default_null = nullptr; + if (A.has_value()) A_adapt_default_null = A.value(); + + auto lambda_result = ARKodeSetLinearSolver(arkode_mem, LS, + A_adapt_default_null); + return lambda_result; + }; + + return ARKodeSetLinearSolver_adapt_optional_arg_with_default_null(arkode_mem, + LS, A); + }, + nb::arg("arkode_mem"), nb::arg("LS"), nb::arg("A").none() = nb::none()); + +m.def("ARKodeSetJacEvalFrequency", ARKodeSetJacEvalFrequency, + nb::arg("arkode_mem"), nb::arg("msbj")); + +m.def("ARKodeSetLinearSolutionScaling", ARKodeSetLinearSolutionScaling, + nb::arg("arkode_mem"), nb::arg("onoff")); + +m.def("ARKodeSetEpsLin", ARKodeSetEpsLin, nb::arg("arkode_mem"), + nb::arg("eplifac")); + +m.def("ARKodeSetMassEpsLin", ARKodeSetMassEpsLin, nb::arg("arkode_mem"), + nb::arg("eplifac")); + +m.def("ARKodeSetLSNormFactor", ARKodeSetLSNormFactor, nb::arg("arkode_mem"), + nb::arg("nrmfac")); + +m.def("ARKodeSetMassLSNormFactor", ARKodeSetMassLSNormFactor, + nb::arg("arkode_mem"), nb::arg("nrmfac")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// +// #ifndef _ARKODE_BUTCHER_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClassARKodeButcherTableMem = + nb::class_(m, + "ARKodeButcherTableMem", "---------------------------------------------------------------\n Types : struct ARKodeButcherTableMem, ARKodeButcherTable\n ---------------------------------------------------------------") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "ARKodeButcherTable_Create", + [](int s, int q, int p, sundials4py::Array1d c_1d, sundials4py::Array1d A_1d, + sundials4py::Array1d b_1d, sundials4py::Array1d d_1d) + -> std::shared_ptr> + { + auto ARKodeButcherTable_Create_adapt_arr_ptr_to_std_vector = + [](int s, int q, int p, sundials4py::Array1d c_1d, + sundials4py::Array1d A_1d, sundials4py::Array1d b_1d, + sundials4py::Array1d d_1d) -> ARKodeButcherTable + { + sunrealtype* c_1d_ptr = reinterpret_cast(c_1d.data()); + sunrealtype* A_1d_ptr = reinterpret_cast(A_1d.data()); + sunrealtype* b_1d_ptr = reinterpret_cast(b_1d.data()); + sunrealtype* d_1d_ptr = reinterpret_cast(d_1d.data()); + + auto lambda_result = ARKodeButcherTable_Create(s, q, p, c_1d_ptr, A_1d_ptr, + b_1d_ptr, d_1d_ptr); + return lambda_result; + }; + auto ARKodeButcherTable_Create_adapt_return_type_to_shared_ptr = + [&ARKodeButcherTable_Create_adapt_arr_ptr_to_std_vector](int s, int q, + int p, + sundials4py::Array1d c_1d, + sundials4py::Array1d A_1d, + sundials4py::Array1d b_1d, + sundials4py::Array1d d_1d) + -> std::shared_ptr> + { + auto lambda_result = + ARKodeButcherTable_Create_adapt_arr_ptr_to_std_vector(s, q, p, c_1d, + A_1d, b_1d, d_1d); + + return our_make_shared, + ARKodeButcherTableDeleter>(lambda_result); + }; + + return ARKodeButcherTable_Create_adapt_return_type_to_shared_ptr(s, q, p, + c_1d, A_1d, + b_1d, d_1d); + }, + nb::arg("s"), nb::arg("q"), nb::arg("p"), nb::arg("c_1d"), nb::arg("A_1d"), + nb::arg("b_1d"), nb::arg("d_1d")); + +m.def( + "ARKodeButcherTable_Copy", + [](ARKodeButcherTable B) + -> std::shared_ptr> + { + auto ARKodeButcherTable_Copy_adapt_return_type_to_shared_ptr = + [](ARKodeButcherTable B) + -> std::shared_ptr> + { + auto lambda_result = ARKodeButcherTable_Copy(B); + + return our_make_shared, + ARKodeButcherTableDeleter>(lambda_result); + }; + + return ARKodeButcherTable_Copy_adapt_return_type_to_shared_ptr(B); + }, + nb::arg("B")); + +m.def("ARKodeButcherTable_Write", ARKodeButcherTable_Write, nb::arg("B"), + nb::arg("outfile")); + +m.def("ARKodeButcherTable_IsStifflyAccurate", + ARKodeButcherTable_IsStifflyAccurate, nb::arg("B")); + +m.def( + "ARKodeButcherTable_CheckOrder", + [](ARKodeButcherTable B, FILE* outfile) -> std::tuple + { + auto ARKodeButcherTable_CheckOrder_adapt_modifiable_immutable_to_return = + [](ARKodeButcherTable B, FILE* outfile) -> std::tuple + { + int q_adapt_modifiable; + int p_adapt_modifiable; + + int r = ARKodeButcherTable_CheckOrder(B, &q_adapt_modifiable, + &p_adapt_modifiable, outfile); + return std::make_tuple(r, q_adapt_modifiable, p_adapt_modifiable); + }; + + return ARKodeButcherTable_CheckOrder_adapt_modifiable_immutable_to_return(B, + outfile); + }, + nb::arg("B"), nb::arg("outfile")); + +m.def( + "ARKodeButcherTable_CheckARKOrder", + [](ARKodeButcherTable B1, ARKodeButcherTable B2, + FILE* outfile) -> std::tuple + { + auto ARKodeButcherTable_CheckARKOrder_adapt_modifiable_immutable_to_return = + [](ARKodeButcherTable B1, ARKodeButcherTable B2, + FILE* outfile) -> std::tuple + { + int q_adapt_modifiable; + int p_adapt_modifiable; + + int r = ARKodeButcherTable_CheckARKOrder(B1, B2, &q_adapt_modifiable, + &p_adapt_modifiable, outfile); + return std::make_tuple(r, q_adapt_modifiable, p_adapt_modifiable); + }; + + return ARKodeButcherTable_CheckARKOrder_adapt_modifiable_immutable_to_return(B1, + B2, + outfile); + }, + nb::arg("B1"), nb::arg("B2"), nb::arg("outfile")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// +// #ifndef _ARKODE_ERK_TABLES_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyEnumARKODE_ERKTableID = + nb::enum_(m, "ARKODE_ERKTableID", nb::is_arithmetic(), "") + .value("ARKODE_ERK_NONE", ARKODE_ERK_NONE, "ensure enum is signed int") + .value("ARKODE_HEUN_EULER_2_1_2", ARKODE_HEUN_EULER_2_1_2, "") + .value("ARKODE_MIN_ERK_NUM", ARKODE_MIN_ERK_NUM, "") + .value("ARKODE_BOGACKI_SHAMPINE_4_2_3", ARKODE_BOGACKI_SHAMPINE_4_2_3, "") + .value("ARKODE_ARK324L2SA_ERK_4_2_3", ARKODE_ARK324L2SA_ERK_4_2_3, "") + .value("ARKODE_ZONNEVELD_5_3_4", ARKODE_ZONNEVELD_5_3_4, "") + .value("ARKODE_ARK436L2SA_ERK_6_3_4", ARKODE_ARK436L2SA_ERK_6_3_4, "") + .value("ARKODE_SAYFY_ABURUB_6_3_4", ARKODE_SAYFY_ABURUB_6_3_4, "") + .value("ARKODE_CASH_KARP_6_4_5", ARKODE_CASH_KARP_6_4_5, "") + .value("ARKODE_FEHLBERG_6_4_5", ARKODE_FEHLBERG_6_4_5, "") + .value("ARKODE_DORMAND_PRINCE_7_4_5", ARKODE_DORMAND_PRINCE_7_4_5, "") + .value("ARKODE_ARK548L2SA_ERK_8_4_5", ARKODE_ARK548L2SA_ERK_8_4_5, "") + .value("ARKODE_VERNER_8_5_6", ARKODE_VERNER_8_5_6, "") + .value("ARKODE_FEHLBERG_13_7_8", ARKODE_FEHLBERG_13_7_8, "") + .value("ARKODE_KNOTH_WOLKE_3_3", ARKODE_KNOTH_WOLKE_3_3, "") + .value("ARKODE_ARK437L2SA_ERK_7_3_4", ARKODE_ARK437L2SA_ERK_7_3_4, "") + .value("ARKODE_ARK548L2SAb_ERK_8_4_5", ARKODE_ARK548L2SAb_ERK_8_4_5, "") + .value("ARKODE_ARK2_ERK_3_1_2", ARKODE_ARK2_ERK_3_1_2, "") + .value("ARKODE_SOFRONIOU_SPALETTA_5_3_4", ARKODE_SOFRONIOU_SPALETTA_5_3_4, "") + .value("ARKODE_SHU_OSHER_3_2_3", ARKODE_SHU_OSHER_3_2_3, "") + .value("ARKODE_VERNER_9_5_6", ARKODE_VERNER_9_5_6, "") + .value("ARKODE_VERNER_10_6_7", ARKODE_VERNER_10_6_7, "") + .value("ARKODE_VERNER_13_7_8", ARKODE_VERNER_13_7_8, "") + .value("ARKODE_VERNER_16_8_9", ARKODE_VERNER_16_8_9, "") + .value("ARKODE_FORWARD_EULER_1_1", ARKODE_FORWARD_EULER_1_1, "") + .value("ARKODE_RALSTON_EULER_2_1_2", ARKODE_RALSTON_EULER_2_1_2, "") + .value("ARKODE_EXPLICIT_MIDPOINT_EULER_2_1_2", + ARKODE_EXPLICIT_MIDPOINT_EULER_2_1_2, "") + .value("ARKODE_RALSTON_3_1_2", ARKODE_RALSTON_3_1_2, "") + .value("ARKODE_TSITOURAS_7_4_5", ARKODE_TSITOURAS_7_4_5, "") + .value("ARKODE_MAX_ERK_NUM", ARKODE_MAX_ERK_NUM, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +m.def( + "ARKodeButcherTable_LoadERK", + [](ARKODE_ERKTableID emethod) + -> std::shared_ptr> + { + auto ARKodeButcherTable_LoadERK_adapt_return_type_to_shared_ptr = + [](ARKODE_ERKTableID emethod) + -> std::shared_ptr> + { + auto lambda_result = ARKodeButcherTable_LoadERK(emethod); + + return our_make_shared, + ARKodeButcherTableDeleter>(lambda_result); + }; + + return ARKodeButcherTable_LoadERK_adapt_return_type_to_shared_ptr(emethod); + }, + nb::arg("emethod"), "Accessor routine to load built-in ERK table"); + +m.def( + "ARKodeButcherTable_LoadERKByName", + [](const char* emethod) + -> std::shared_ptr> + { + auto ARKodeButcherTable_LoadERKByName_adapt_return_type_to_shared_ptr = + [](const char* emethod) + -> std::shared_ptr> + { + auto lambda_result = ARKodeButcherTable_LoadERKByName(emethod); + + return our_make_shared, + ARKodeButcherTableDeleter>(lambda_result); + }; + + return ARKodeButcherTable_LoadERKByName_adapt_return_type_to_shared_ptr( + emethod); + }, + nb::arg("emethod")); + +m.def("ARKodeButcherTable_ERKIDToName", ARKodeButcherTable_ERKIDToName, + nb::arg("emethod")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// +// #ifndef _ARKODE_DIRK_TABLES_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyEnumARKODE_DIRKTableID = + nb::enum_(m, "ARKODE_DIRKTableID", nb::is_arithmetic(), "") + .value("ARKODE_DIRK_NONE", ARKODE_DIRK_NONE, "ensure enum is signed int") + .value("ARKODE_SDIRK_2_1_2", ARKODE_SDIRK_2_1_2, "") + .value("ARKODE_MIN_DIRK_NUM", ARKODE_MIN_DIRK_NUM, "") + .value("ARKODE_BILLINGTON_3_3_2", ARKODE_BILLINGTON_3_3_2, "") + .value("ARKODE_TRBDF2_3_3_2", ARKODE_TRBDF2_3_3_2, "") + .value("ARKODE_KVAERNO_4_2_3", ARKODE_KVAERNO_4_2_3, "") + .value("ARKODE_ARK324L2SA_DIRK_4_2_3", ARKODE_ARK324L2SA_DIRK_4_2_3, "") + .value("ARKODE_CASH_5_2_4", ARKODE_CASH_5_2_4, "") + .value("ARKODE_CASH_5_3_4", ARKODE_CASH_5_3_4, "") + .value("ARKODE_SDIRK_5_3_4", ARKODE_SDIRK_5_3_4, "") + .value("ARKODE_KVAERNO_5_3_4", ARKODE_KVAERNO_5_3_4, "") + .value("ARKODE_ARK436L2SA_DIRK_6_3_4", ARKODE_ARK436L2SA_DIRK_6_3_4, "") + .value("ARKODE_KVAERNO_7_4_5", ARKODE_KVAERNO_7_4_5, "") + .value("ARKODE_ARK548L2SA_DIRK_8_4_5", ARKODE_ARK548L2SA_DIRK_8_4_5, "") + .value("ARKODE_ARK437L2SA_DIRK_7_3_4", ARKODE_ARK437L2SA_DIRK_7_3_4, "") + .value("ARKODE_ARK548L2SAb_DIRK_8_4_5", ARKODE_ARK548L2SAb_DIRK_8_4_5, "") + .value("ARKODE_ESDIRK324L2SA_4_2_3", ARKODE_ESDIRK324L2SA_4_2_3, "") + .value("ARKODE_ESDIRK325L2SA_5_2_3", ARKODE_ESDIRK325L2SA_5_2_3, "") + .value("ARKODE_ESDIRK32I5L2SA_5_2_3", ARKODE_ESDIRK32I5L2SA_5_2_3, "") + .value("ARKODE_ESDIRK436L2SA_6_3_4", ARKODE_ESDIRK436L2SA_6_3_4, "") + .value("ARKODE_ESDIRK43I6L2SA_6_3_4", ARKODE_ESDIRK43I6L2SA_6_3_4, "") + .value("ARKODE_QESDIRK436L2SA_6_3_4", ARKODE_QESDIRK436L2SA_6_3_4, "") + .value("ARKODE_ESDIRK437L2SA_7_3_4", ARKODE_ESDIRK437L2SA_7_3_4, "") + .value("ARKODE_ESDIRK547L2SA_7_4_5", ARKODE_ESDIRK547L2SA_7_4_5, "") + .value("ARKODE_ESDIRK547L2SA2_7_4_5", ARKODE_ESDIRK547L2SA2_7_4_5, "") + .value("ARKODE_ARK2_DIRK_3_1_2", ARKODE_ARK2_DIRK_3_1_2, "") + .value("ARKODE_BACKWARD_EULER_1_1", ARKODE_BACKWARD_EULER_1_1, "") + .value("ARKODE_IMPLICIT_MIDPOINT_1_2", ARKODE_IMPLICIT_MIDPOINT_1_2, "") + .value("ARKODE_IMPLICIT_TRAPEZOIDAL_2_2", ARKODE_IMPLICIT_TRAPEZOIDAL_2_2, "") + .value("ARKODE_MAX_DIRK_NUM", ARKODE_MAX_DIRK_NUM, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +m.def( + "ARKodeButcherTable_LoadDIRK", + [](ARKODE_DIRKTableID imethod) + -> std::shared_ptr> + { + auto ARKodeButcherTable_LoadDIRK_adapt_return_type_to_shared_ptr = + [](ARKODE_DIRKTableID imethod) + -> std::shared_ptr> + { + auto lambda_result = ARKodeButcherTable_LoadDIRK(imethod); + + return our_make_shared, + ARKodeButcherTableDeleter>(lambda_result); + }; + + return ARKodeButcherTable_LoadDIRK_adapt_return_type_to_shared_ptr(imethod); + }, + nb::arg("imethod"), "Accessor routine to load built-in DIRK table"); + +m.def( + "ARKodeButcherTable_LoadDIRKByName", + [](const char* imethod) + -> std::shared_ptr> + { + auto ARKodeButcherTable_LoadDIRKByName_adapt_return_type_to_shared_ptr = + [](const char* imethod) + -> std::shared_ptr> + { + auto lambda_result = ARKodeButcherTable_LoadDIRKByName(imethod); + + return our_make_shared, + ARKodeButcherTableDeleter>(lambda_result); + }; + + return ARKodeButcherTable_LoadDIRKByName_adapt_return_type_to_shared_ptr( + imethod); + }, + nb::arg("imethod"), "Accessor routine to load built-in DIRK table"); + +m.def("ARKodeButcherTable_DIRKIDToName", ARKodeButcherTable_DIRKIDToName, + nb::arg("imethod")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// +// #ifndef _ARKODE_SPRKTABLE_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyEnumARKODE_SPRKMethodID = + nb::enum_(m, "ARKODE_SPRKMethodID", nb::is_arithmetic(), "") + .value("ARKODE_SPRK_NONE", ARKODE_SPRK_NONE, "ensure enum is signed int") + .value("ARKODE_SPRK_EULER_1_1", ARKODE_SPRK_EULER_1_1, "") + .value("ARKODE_MIN_SPRK_NUM", ARKODE_MIN_SPRK_NUM, "") + .value("ARKODE_SPRK_LEAPFROG_2_2", ARKODE_SPRK_LEAPFROG_2_2, "") + .value("ARKODE_SPRK_PSEUDO_LEAPFROG_2_2", ARKODE_SPRK_PSEUDO_LEAPFROG_2_2, "") + .value("ARKODE_SPRK_RUTH_3_3", ARKODE_SPRK_RUTH_3_3, "") + .value("ARKODE_SPRK_MCLACHLAN_2_2", ARKODE_SPRK_MCLACHLAN_2_2, "") + .value("ARKODE_SPRK_MCLACHLAN_3_3", ARKODE_SPRK_MCLACHLAN_3_3, "") + .value("ARKODE_SPRK_CANDY_ROZMUS_4_4", ARKODE_SPRK_CANDY_ROZMUS_4_4, "") + .value("ARKODE_SPRK_MCLACHLAN_4_4", ARKODE_SPRK_MCLACHLAN_4_4, "") + .value("ARKODE_SPRK_MCLACHLAN_5_6", ARKODE_SPRK_MCLACHLAN_5_6, "") + .value("ARKODE_SPRK_YOSHIDA_6_8", ARKODE_SPRK_YOSHIDA_6_8, "") + .value("ARKODE_SPRK_SUZUKI_UMENO_8_16", ARKODE_SPRK_SUZUKI_UMENO_8_16, "") + .value("ARKODE_SPRK_SOFRONIOU_10_36", ARKODE_SPRK_SOFRONIOU_10_36, "") + .value("ARKODE_MAX_SPRK_NUM", ARKODE_MAX_SPRK_NUM, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +auto pyClassARKodeSPRKTableMem = + nb::class_(m, "ARKodeSPRKTableMem", "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "ARKodeSPRKTable_Create", + [](int s, int q, sundials4py::Array1d a_1d, sundials4py::Array1d ahat_1d) + -> std::shared_ptr> + { + auto ARKodeSPRKTable_Create_adapt_arr_ptr_to_std_vector = + [](int s, int q, sundials4py::Array1d a_1d, + sundials4py::Array1d ahat_1d) -> ARKodeSPRKTable + { + sunrealtype* a_1d_ptr = reinterpret_cast(a_1d.data()); + sunrealtype* ahat_1d_ptr = reinterpret_cast(ahat_1d.data()); + + auto lambda_result = ARKodeSPRKTable_Create(s, q, a_1d_ptr, ahat_1d_ptr); + return lambda_result; + }; + auto ARKodeSPRKTable_Create_adapt_return_type_to_shared_ptr = + [&ARKodeSPRKTable_Create_adapt_arr_ptr_to_std_vector](int s, int q, + sundials4py::Array1d a_1d, + sundials4py::Array1d ahat_1d) + -> std::shared_ptr> + { + auto lambda_result = + ARKodeSPRKTable_Create_adapt_arr_ptr_to_std_vector(s, q, a_1d, ahat_1d); + + return our_make_shared, + ARKodeSPRKTableDeleter>(lambda_result); + }; + + return ARKodeSPRKTable_Create_adapt_return_type_to_shared_ptr(s, q, a_1d, + ahat_1d); + }, + nb::arg("s"), nb::arg("q"), nb::arg("a_1d"), nb::arg("ahat_1d")); + +m.def( + "ARKodeSPRKTable_Load", + [](ARKODE_SPRKMethodID id) + -> std::shared_ptr> + { + auto ARKodeSPRKTable_Load_adapt_return_type_to_shared_ptr = + [](ARKODE_SPRKMethodID id) + -> std::shared_ptr> + { + auto lambda_result = ARKodeSPRKTable_Load(id); + + return our_make_shared, + ARKodeSPRKTableDeleter>(lambda_result); + }; + + return ARKodeSPRKTable_Load_adapt_return_type_to_shared_ptr(id); + }, + nb::arg("id")); + +m.def( + "ARKodeSPRKTable_LoadByName", + [](const char* method) -> std::shared_ptr> + { + auto ARKodeSPRKTable_LoadByName_adapt_return_type_to_shared_ptr = + [](const char* method) + -> std::shared_ptr> + { + auto lambda_result = ARKodeSPRKTable_LoadByName(method); + + return our_make_shared, + ARKodeSPRKTableDeleter>(lambda_result); + }; + + return ARKodeSPRKTable_LoadByName_adapt_return_type_to_shared_ptr(method); + }, + nb::arg("method")); + +m.def( + "ARKodeSPRKTable_Copy", + [](ARKodeSPRKTable that_sprk_storage) + -> std::shared_ptr> + { + auto ARKodeSPRKTable_Copy_adapt_return_type_to_shared_ptr = + [](ARKodeSPRKTable that_sprk_storage) + -> std::shared_ptr> + { + auto lambda_result = ARKodeSPRKTable_Copy(that_sprk_storage); + + return our_make_shared, + ARKodeSPRKTableDeleter>(lambda_result); + }; + + return ARKodeSPRKTable_Copy_adapt_return_type_to_shared_ptr(that_sprk_storage); + }, + nb::arg("that_sprk_storage")); + +m.def("ARKodeSPRKTable_Write", ARKodeSPRKTable_Write, nb::arg("sprk_table"), + nb::arg("outfile")); + +m.def( + "ARKodeSPRKTable_ToButcher", + [](ARKodeSPRKTable sprk_storage) + -> std::tuple>, + std::shared_ptr>> + { + auto ARKodeSPRKTable_ToButcher_adapt_modifiable_immutable_to_return = + [](ARKodeSPRKTable sprk_storage) + -> std::tuple + { + ARKodeButcherTable a_ptr_adapt_modifiable; + ARKodeButcherTable b_ptr_adapt_modifiable; + + int r = ARKodeSPRKTable_ToButcher(sprk_storage, &a_ptr_adapt_modifiable, + &b_ptr_adapt_modifiable); + return std::make_tuple(r, a_ptr_adapt_modifiable, b_ptr_adapt_modifiable); + }; + auto ARKodeSPRKTable_ToButcher_adapt_return_type_to_shared_ptr = + [&ARKodeSPRKTable_ToButcher_adapt_modifiable_immutable_to_return]( + ARKodeSPRKTable sprk_storage) + -> std::tuple>, + std::shared_ptr>> + { + auto lambda_result = + ARKodeSPRKTable_ToButcher_adapt_modifiable_immutable_to_return( + sprk_storage); + + return std::make_tuple(std::get<0>(lambda_result), + our_make_shared, + ARKodeButcherTableDeleter>( + std::get<1>(lambda_result)), + our_make_shared, + ARKodeButcherTableDeleter>( + std::get<2>(lambda_result))); + }; + + return ARKodeSPRKTable_ToButcher_adapt_return_type_to_shared_ptr(sprk_storage); + }, + nb::arg("sprk_storage"), nb::rv_policy::reference); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/arkode/arkode_lsrkstep.cpp b/bindings/sundials4py/arkode/arkode_lsrkstep.cpp new file mode 100644 index 0000000000..7f8989214c --- /dev/null +++ b/bindings/sundials4py/arkode/arkode_lsrkstep.cpp @@ -0,0 +1,109 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include +#include + +#include "arkode_impl.h" +#include "arkode_usersupplied.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_arkode_lsrkstep(nb::module_& m) +{ +#include "arkode_lsrkstep_generated.hpp" + + m.def( + "LSRKStepCreateSTS", + [](std::function> rhs, sunrealtype t0, + N_Vector y0, SUNContext sunctx) + { + if (!rhs) { throw sundials4py::illegal_value("rhs was null"); } + + void* ark_mem = LSRKStepCreateSTS(lsrkstep_f_wrapper, t0, y0, sunctx); + if (ark_mem == nullptr) + { + throw sundials4py::error_returned("Failed to create LSRKStep memory"); + } + + auto fn_table = arkode_user_supplied_fn_table_alloc(); + + static_cast(ark_mem)->python = fn_table; + + int ark_status = ARKodeSetUserData(ark_mem, ark_mem); + if (ark_status != ARK_SUCCESS) + { + free(fn_table); + throw sundials4py::error_returned( + "Failed to set user data in ARKODE memory"); + } + + fn_table->lsrkstep_f = nb::cast(rhs); + + return std::make_shared(ark_mem); + }, + nb::arg("rhs"), nb::arg("t0"), nb::arg("y0"), nb::arg("sunctx"), + nb::keep_alive<0, 4>()); + + m.def( + "LSRKStepCreateSSP", + [](std::function> rhs, sunrealtype t0, + N_Vector y0, SUNContext sunctx) + { + if (!rhs) { throw sundials4py::illegal_value("rhs was null"); } + + void* ark_mem = LSRKStepCreateSSP(lsrkstep_f_wrapper, t0, y0, sunctx); + if (ark_mem == nullptr) + { + throw sundials4py::error_returned("Failed to create LSRKStep memory"); + } + + auto fn_table = arkode_user_supplied_fn_table_alloc(); + + static_cast(ark_mem)->python = fn_table; + + int ark_status = ARKodeSetUserData(ark_mem, ark_mem); + if (ark_status != ARK_SUCCESS) + { + free(fn_table); + throw sundials4py::error_returned( + "Failed to set user data in ARKODE memory"); + } + + fn_table->lsrkstep_f = nb::cast(rhs); + + return std::make_shared(ark_mem); + }, + nb::arg("rhs"), nb::arg("t0"), nb::arg("y0"), nb::arg("sunctx"), + nb::keep_alive<0, 4>()); + + m.def("LSRKStepSetDomEigFn", + [](void* ark_mem, std::function> fn) + { + auto fn_table = get_arkode_fn_table(ark_mem); + fn_table->lsrkstep_domeig = nb::cast(fn); + return LSRKStepSetDomEigFn(ark_mem, &lsrkstep_domeig_wrapper); + }); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/arkode/arkode_lsrkstep_generated.hpp b/bindings/sundials4py/arkode/arkode_lsrkstep_generated.hpp new file mode 100644 index 0000000000..252f93ff92 --- /dev/null +++ b/bindings/sundials4py/arkode/arkode_lsrkstep_generated.hpp @@ -0,0 +1,135 @@ +// #ifndef _LSRKSTEP_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyEnumARKODE_LSRKMethodType = + nb::enum_(m, "ARKODE_LSRKMethodType", + nb::is_arithmetic(), "") + .value("ARKODE_LSRK_RKC_2", ARKODE_LSRK_RKC_2, "") + .value("ARKODE_LSRK_RKL_2", ARKODE_LSRK_RKL_2, "") + .value("ARKODE_LSRK_SSP_S_2", ARKODE_LSRK_SSP_S_2, "") + .value("ARKODE_LSRK_SSP_S_3", ARKODE_LSRK_SSP_S_3, "") + .value("ARKODE_LSRK_SSP_10_4", ARKODE_LSRK_SSP_10_4, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +m.def("LSRKStepSetSTSMethod", LSRKStepSetSTSMethod, nb::arg("arkode_mem"), + nb::arg("method")); + +m.def("LSRKStepSetSSPMethod", LSRKStepSetSSPMethod, nb::arg("arkode_mem"), + nb::arg("method")); + +m.def("LSRKStepSetSTSMethodByName", LSRKStepSetSTSMethodByName, + nb::arg("arkode_mem"), nb::arg("emethod")); + +m.def("LSRKStepSetSSPMethodByName", LSRKStepSetSSPMethodByName, + nb::arg("arkode_mem"), nb::arg("emethod")); + +m.def("LSRKStepSetDomEigEstimator", LSRKStepSetDomEigEstimator, + nb::arg("arkode_mem"), nb::arg("DEE")); + +m.def("LSRKStepSetDomEigFrequency", LSRKStepSetDomEigFrequency, + nb::arg("arkode_mem"), nb::arg("nsteps")); + +m.def("LSRKStepSetMaxNumStages", LSRKStepSetMaxNumStages, nb::arg("arkode_mem"), + nb::arg("stage_max_limit")); + +m.def("LSRKStepSetDomEigSafetyFactor", LSRKStepSetDomEigSafetyFactor, + nb::arg("arkode_mem"), nb::arg("dom_eig_safety")); + +m.def("LSRKStepSetNumDomEigEstInitPreprocessIters", + LSRKStepSetNumDomEigEstInitPreprocessIters, nb::arg("arkode_mem"), + nb::arg("num_iters")); + +m.def("LSRKStepSetNumDomEigEstPreprocessIters", + LSRKStepSetNumDomEigEstPreprocessIters, nb::arg("arkode_mem"), + nb::arg("num_iters")); + +m.def("LSRKStepSetNumSSPStages", LSRKStepSetNumSSPStages, nb::arg("arkode_mem"), + nb::arg("num_of_stages")); + +m.def( + "LSRKStepGetNumDomEigUpdates", + [](void* arkode_mem) -> std::tuple + { + auto LSRKStepGetNumDomEigUpdates_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long dom_eig_num_evals_adapt_modifiable; + + int r = LSRKStepGetNumDomEigUpdates(arkode_mem, + &dom_eig_num_evals_adapt_modifiable); + return std::make_tuple(r, dom_eig_num_evals_adapt_modifiable); + }; + + return LSRKStepGetNumDomEigUpdates_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "LSRKStepGetMaxNumStages", + [](void* arkode_mem) -> std::tuple + { + auto LSRKStepGetMaxNumStages_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + int stage_max_adapt_modifiable; + + int r = LSRKStepGetMaxNumStages(arkode_mem, &stage_max_adapt_modifiable); + return std::make_tuple(r, stage_max_adapt_modifiable); + }; + + return LSRKStepGetMaxNumStages_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "LSRKStepGetNumDomEigEstRhsEvals", + [](void* arkode_mem) -> std::tuple + { + auto LSRKStepGetNumDomEigEstRhsEvals_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long nfeDQ_adapt_modifiable; + + int r = LSRKStepGetNumDomEigEstRhsEvals(arkode_mem, + &nfeDQ_adapt_modifiable); + return std::make_tuple(r, nfeDQ_adapt_modifiable); + }; + + return LSRKStepGetNumDomEigEstRhsEvals_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "LSRKStepGetNumDomEigEstIters", + [](void* arkode_mem) -> std::tuple + { + auto LSRKStepGetNumDomEigEstIters_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long num_iters_adapt_modifiable; + + int r = LSRKStepGetNumDomEigEstIters(arkode_mem, + &num_iters_adapt_modifiable); + return std::make_tuple(r, num_iters_adapt_modifiable); + }; + + return LSRKStepGetNumDomEigEstIters_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/arkode/arkode_mristep.cpp b/bindings/sundials4py/arkode/arkode_mristep.cpp new file mode 100644 index 0000000000..079619c986 --- /dev/null +++ b/bindings/sundials4py/arkode/arkode_mristep.cpp @@ -0,0 +1,146 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#include "arkode/arkode.h" +#include "sundials4py.hpp" + +#include +#include +#include + +#include "arkode_mristep_impl.h" +#include "arkode_usersupplied.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_arkode_mristep(nb::module_& m) +{ +#include "arkode_mristep_generated.hpp" + + // _MRIStepInnerStepper is a opaque/private class forward declared in a public header but + // defined in a source file elsewhere. As such, we need to declare it here since its + // not picked up in any header files by the generator. + nb::class_<_MRIStepInnerStepper>(m, "_MRIStepInnerStepper"); + + m.def("MRIStepInnerStepper_Create", + [](SUNContext sunctx) + { + MRIStepInnerStepper stepper = nullptr; + + int status = MRIStepInnerStepper_Create(sunctx, &stepper); + auto fn_table = mristepinnerstepper_user_supplied_fn_table_alloc(); + stepper->python = static_cast(fn_table); + + return std::make_tuple(status, + our_make_shared< + std::remove_pointer_t, + MRIStepInnerStepperDeleter>(stepper)); + }); + + m.def("MRIStepInnerStepper_CreateFromSUNStepper", + [](SUNStepper stepper) + { + MRIStepInnerStepper inner_stepper = nullptr; + + int status = MRIStepInnerStepper_CreateFromSUNStepper(stepper, + &inner_stepper); + + return std::make_tuple(status, + our_make_shared< + std::remove_pointer_t, + MRIStepInnerStepperDeleter>(inner_stepper)); + }); + + // m.def("MRIStepInnerStepper_GetForcingData", + // [](MRIStepInnerStepper stepper) -> std::tuple, int> { + + // sunrealtype tshift = 0.0; + // sunrealtype tscale = 0.0; + // N_Vector* forcing_1d = nullptr; + // int nforcing = 0; + + // int status = MRIStepInnerStepper_GetForcingData(stepper, &tshift, &tscale, &forcing_1d, &nforcing); + + // std::vector forcing(nforcing); + // // TODO(CJB): for some reason this causes a segfault unless you clone + // // for (int i = 0; i < nforcing; i++) { + // // // forcing[i] = N_VClone(forcing_1d[i]); + // // forcing[i] = forcing_1d[i]; + // // } + + // return std::make_tuple(status, tshift, tscale, forcing, nforcing); + // }); + + m.def("ARKodeCreateMRIStepInnerStepper", + [](void* inner_arkode_mem) + { + MRIStepInnerStepper stepper = nullptr; + + int status = ARKodeCreateMRIStepInnerStepper(inner_arkode_mem, + &stepper); + + return std::make_tuple(status, + our_make_shared< + std::remove_pointer_t, + MRIStepInnerStepperDeleter>(stepper)); + }); + + m.def( + "MRIStepCreate", + [](std::function> fse, + std::function> fsi, sunrealtype t0, + N_Vector y0, MRIStepInnerStepper stepper, SUNContext sunctx) + { + auto fse_wrapper = fse ? mristep_fse_wrapper : nullptr; + auto fsi_wrapper = fsi ? mristep_fsi_wrapper : nullptr; + + void* ark_mem = MRIStepCreate(fse_wrapper, fsi_wrapper, t0, y0, stepper, + sunctx); + if (ark_mem == nullptr) + { + throw sundials4py::error_returned("Failed to create ARKODE memory"); + } + + // Create the user-supplied function table to store the Python user functions + auto fn_table = arkode_user_supplied_fn_table_alloc(); + + // Smuggle the user-supplied function table into callback wrappers through the user_data pointer + static_cast(ark_mem)->python = fn_table; + int ark_status = ARKodeSetUserData(ark_mem, ark_mem); + if (ark_status != ARK_SUCCESS) + { + free(fn_table); + throw sundials4py::error_returned( + "Failed to set user data in ARKODE memory"); + } + + // Finally, set the RHS function + fn_table->mristep_fse = nb::cast(fse); + fn_table->mristep_fsi = nb::cast(fsi); + + return std::make_shared(ark_mem); + }, + nb::arg("fse").none(), nb::arg("fsi").none(), nb::arg("t0"), nb::arg("y0"), + nb::arg("inner_stepper"), nb::arg("sunctx"), nb::keep_alive<0, 6>()); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/arkode/arkode_mristep_generated.hpp b/bindings/sundials4py/arkode/arkode_mristep_generated.hpp new file mode 100644 index 0000000000..83f8d214bb --- /dev/null +++ b/bindings/sundials4py/arkode/arkode_mristep_generated.hpp @@ -0,0 +1,252 @@ +// #ifndef _MRISTEP_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyEnumMRISTEP_METHOD_TYPE = + nb::enum_(m, "MRISTEP_METHOD_TYPE", nb::is_arithmetic(), + "MRIStep method types") + .value("MRISTEP_EXPLICIT", MRISTEP_EXPLICIT, "") + .value("MRISTEP_IMPLICIT", MRISTEP_IMPLICIT, "") + .value("MRISTEP_IMEX", MRISTEP_IMEX, "") + .value("MRISTEP_MERK", MRISTEP_MERK, "") + .value("MRISTEP_SR", MRISTEP_SR, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +auto pyEnumARKODE_MRITableID = + nb::enum_(m, "ARKODE_MRITableID", nb::is_arithmetic(), + "MRI coupling table IDs") + .value("ARKODE_MRI_NONE", ARKODE_MRI_NONE, "ensure enum is signed int") + .value("ARKODE_MIS_KW3", ARKODE_MIS_KW3, "") + .value("ARKODE_MIN_MRI_NUM", ARKODE_MIN_MRI_NUM, "") + .value("ARKODE_MRI_GARK_ERK33a", ARKODE_MRI_GARK_ERK33a, "") + .value("ARKODE_MRI_GARK_ERK45a", ARKODE_MRI_GARK_ERK45a, "") + .value("ARKODE_MRI_GARK_IRK21a", ARKODE_MRI_GARK_IRK21a, "") + .value("ARKODE_MRI_GARK_ESDIRK34a", ARKODE_MRI_GARK_ESDIRK34a, "") + .value("ARKODE_MRI_GARK_ESDIRK46a", ARKODE_MRI_GARK_ESDIRK46a, "") + .value("ARKODE_IMEX_MRI_GARK3a", ARKODE_IMEX_MRI_GARK3a, "") + .value("ARKODE_IMEX_MRI_GARK3b", ARKODE_IMEX_MRI_GARK3b, "") + .value("ARKODE_IMEX_MRI_GARK4", ARKODE_IMEX_MRI_GARK4, "") + .value("ARKODE_MRI_GARK_FORWARD_EULER", ARKODE_MRI_GARK_FORWARD_EULER, "") + .value("ARKODE_MRI_GARK_RALSTON2", ARKODE_MRI_GARK_RALSTON2, "") + .value("ARKODE_MRI_GARK_ERK22a", ARKODE_MRI_GARK_ERK22a, "") + .value("ARKODE_MRI_GARK_ERK22b", ARKODE_MRI_GARK_ERK22b, "") + .value("ARKODE_MRI_GARK_RALSTON3", ARKODE_MRI_GARK_RALSTON3, "") + .value("ARKODE_MRI_GARK_BACKWARD_EULER", ARKODE_MRI_GARK_BACKWARD_EULER, "") + .value("ARKODE_MRI_GARK_IMPLICIT_MIDPOINT", + ARKODE_MRI_GARK_IMPLICIT_MIDPOINT, "") + .value("ARKODE_IMEX_MRI_GARK_EULER", ARKODE_IMEX_MRI_GARK_EULER, "") + .value("ARKODE_IMEX_MRI_GARK_TRAPEZOIDAL", ARKODE_IMEX_MRI_GARK_TRAPEZOIDAL, + "") + .value("ARKODE_IMEX_MRI_GARK_MIDPOINT", ARKODE_IMEX_MRI_GARK_MIDPOINT, "") + .value("ARKODE_MERK21", ARKODE_MERK21, "") + .value("ARKODE_MERK32", ARKODE_MERK32, "") + .value("ARKODE_MERK43", ARKODE_MERK43, "") + .value("ARKODE_MERK54", ARKODE_MERK54, "") + .value("ARKODE_IMEX_MRI_SR21", ARKODE_IMEX_MRI_SR21, "") + .value("ARKODE_IMEX_MRI_SR32", ARKODE_IMEX_MRI_SR32, "") + .value("ARKODE_IMEX_MRI_SR43", ARKODE_IMEX_MRI_SR43, "") + .value("ARKODE_MAX_MRI_NUM", ARKODE_MAX_MRI_NUM, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +auto pyClassMRIStepCouplingMem = + nb::class_(m, + "MRIStepCouplingMem", "---------------------------------------------------------------\n MRI coupling data structure and associated utility routines\n ---------------------------------------------------------------") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "MRIStepCoupling_LoadTable", + [](ARKODE_MRITableID method) + -> std::shared_ptr> + { + auto MRIStepCoupling_LoadTable_adapt_return_type_to_shared_ptr = + [](ARKODE_MRITableID method) + -> std::shared_ptr> + { + auto lambda_result = MRIStepCoupling_LoadTable(method); + + return our_make_shared, + MRIStepCouplingDeleter>(lambda_result); + }; + + return MRIStepCoupling_LoadTable_adapt_return_type_to_shared_ptr(method); + }, + nb::arg("method"), "Accessor routine to load built-in MRI table"); + +m.def( + "MRIStepCoupling_LoadTableByName", + [](const char* method) -> std::shared_ptr> + { + auto MRIStepCoupling_LoadTableByName_adapt_return_type_to_shared_ptr = + [](const char* method) + -> std::shared_ptr> + { + auto lambda_result = MRIStepCoupling_LoadTableByName(method); + + return our_make_shared, + MRIStepCouplingDeleter>(lambda_result); + }; + + return MRIStepCoupling_LoadTableByName_adapt_return_type_to_shared_ptr(method); + }, + nb::arg("method"), "Accessor routine to load built-in MRI table from string"); + +m.def( + "MRIStepCoupling_Create", + [](int nmat, int stages, int q, int p, sundials4py::Array1d W_1d, + sundials4py::Array1d G_1d, sundials4py::Array1d c_1d) + -> std::shared_ptr> + { + auto MRIStepCoupling_Create_adapt_arr_ptr_to_std_vector = + [](int nmat, int stages, int q, int p, sundials4py::Array1d W_1d, + sundials4py::Array1d G_1d, sundials4py::Array1d c_1d) -> MRIStepCoupling + { + sunrealtype* W_1d_ptr = reinterpret_cast(W_1d.data()); + sunrealtype* G_1d_ptr = reinterpret_cast(G_1d.data()); + sunrealtype* c_1d_ptr = reinterpret_cast(c_1d.data()); + + auto lambda_result = MRIStepCoupling_Create(nmat, stages, q, p, W_1d_ptr, + G_1d_ptr, c_1d_ptr); + return lambda_result; + }; + auto MRIStepCoupling_Create_adapt_return_type_to_shared_ptr = + [&MRIStepCoupling_Create_adapt_arr_ptr_to_std_vector](int nmat, int stages, + int q, int p, + sundials4py::Array1d W_1d, + sundials4py::Array1d G_1d, + sundials4py::Array1d c_1d) + -> std::shared_ptr> + { + auto lambda_result = + MRIStepCoupling_Create_adapt_arr_ptr_to_std_vector(nmat, stages, q, p, + W_1d, G_1d, c_1d); + + return our_make_shared, + MRIStepCouplingDeleter>(lambda_result); + }; + + return MRIStepCoupling_Create_adapt_return_type_to_shared_ptr(nmat, stages, + q, p, W_1d, + G_1d, c_1d); + }, + nb::arg("nmat"), nb::arg("stages"), nb::arg("q"), nb::arg("p"), + nb::arg("W_1d"), nb::arg("G_1d"), nb::arg("c_1d")); + +m.def( + "MRIStepCoupling_MIStoMRI", + [](ARKodeButcherTable B, int q, + int p) -> std::shared_ptr> + { + auto MRIStepCoupling_MIStoMRI_adapt_return_type_to_shared_ptr = + [](ARKodeButcherTable B, int q, + int p) -> std::shared_ptr> + { + auto lambda_result = MRIStepCoupling_MIStoMRI(B, q, p); + + return our_make_shared, + MRIStepCouplingDeleter>(lambda_result); + }; + + return MRIStepCoupling_MIStoMRI_adapt_return_type_to_shared_ptr(B, q, p); + }, + nb::arg("B"), nb::arg("q"), nb::arg("p")); + +m.def( + "MRIStepCoupling_Copy", + [](MRIStepCoupling MRIC) -> std::shared_ptr> + { + auto MRIStepCoupling_Copy_adapt_return_type_to_shared_ptr = + [](MRIStepCoupling MRIC) + -> std::shared_ptr> + { + auto lambda_result = MRIStepCoupling_Copy(MRIC); + + return our_make_shared, + MRIStepCouplingDeleter>(lambda_result); + }; + + return MRIStepCoupling_Copy_adapt_return_type_to_shared_ptr(MRIC); + }, + nb::arg("MRIC")); + +m.def("MRIStepCoupling_Write", MRIStepCoupling_Write, nb::arg("MRIC"), + nb::arg("outfile")); + +m.def("MRIStepSetCoupling", MRIStepSetCoupling, nb::arg("arkode_mem"), + nb::arg("MRIC")); + +m.def( + "MRIStepGetCurrentCoupling", + [](void* arkode_mem) -> std::tuple + { + auto MRIStepGetCurrentCoupling_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + MRIStepCoupling MRIC_adapt_modifiable; + + int r = MRIStepGetCurrentCoupling(arkode_mem, &MRIC_adapt_modifiable); + return std::make_tuple(r, MRIC_adapt_modifiable); + }; + + return MRIStepGetCurrentCoupling_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem"), + " Optional output functions\n\n nb::rv_policy::reference", + nb::rv_policy::reference); + +m.def( + "MRIStepGetLastInnerStepFlag", + [](void* arkode_mem) -> std::tuple + { + auto MRIStepGetLastInnerStepFlag_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + int flag_adapt_modifiable; + + int r = MRIStepGetLastInnerStepFlag(arkode_mem, &flag_adapt_modifiable); + return std::make_tuple(r, flag_adapt_modifiable); + }; + + return MRIStepGetLastInnerStepFlag_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def( + "MRIStepGetNumInnerStepperFails", + [](void* arkode_mem) -> std::tuple + { + auto MRIStepGetNumInnerStepperFails_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + long inner_fails_adapt_modifiable; + + int r = MRIStepGetNumInnerStepperFails(arkode_mem, + &inner_fails_adapt_modifiable); + return std::make_tuple(r, inner_fails_adapt_modifiable); + }; + + return MRIStepGetNumInnerStepperFails_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem")); + +m.def("MRIStepInnerStepper_AddForcing", MRIStepInnerStepper_AddForcing, + nb::arg("stepper"), nb::arg("t"), nb::arg("f")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/arkode/arkode_splittingstep.cpp b/bindings/sundials4py/arkode/arkode_splittingstep.cpp new file mode 100644 index 0000000000..e1d242553b --- /dev/null +++ b/bindings/sundials4py/arkode/arkode_splittingstep.cpp @@ -0,0 +1,55 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include +#include + +#include "arkode_mristep_impl.h" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_arkode_splittingstep(nb::module_& m) +{ +#include "arkode_splittingstep_generated.hpp" + + m.def( + "SplittingStepCreate", + [](std::vector steppers, int partitions, sunrealtype t0, + N_Vector y0, SUNContext sunctx) + { + return std::make_shared( + SplittingStepCreate(steppers.data(), partitions, t0, y0, sunctx)); + }, + nb::arg("steppers"), nb::arg("partitions"), nb::arg("t0"), nb::arg("y0"), + nb::arg("sunctx"), nb::keep_alive<0, 5>()); + + m.def("SplittingStepReInit", + [](void* arkode_mem, std::vector steppers, int partitions, + sunrealtype t0, N_Vector y0) -> int + { + return SplittingStepReInit(arkode_mem, steppers.data(), partitions, + t0, y0); + }); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/arkode/arkode_splittingstep_generated.hpp b/bindings/sundials4py/arkode/arkode_splittingstep_generated.hpp new file mode 100644 index 0000000000..8483a4667c --- /dev/null +++ b/bindings/sundials4py/arkode/arkode_splittingstep_generated.hpp @@ -0,0 +1,320 @@ +// #ifndef ARKODE_SPLITTINGSTEP_H_ +// +// #ifdef __cplusplus +// #endif +// + +auto pyClassSplittingStepCoefficientsMem = + nb::class_(m, + "SplittingStepCoefficientsMem", "---------------------------------------------------------------\n Types : struct SplittingStepCoefficientsMem, SplittingStepCoefficients\n ---------------------------------------------------------------") + .def(nb::init<>()) // implicit default constructor + ; + +auto pyEnumARKODE_SplittingCoefficientsID = + nb::enum_(m, "ARKODE_SplittingCoefficientsID", + nb::is_arithmetic(), " Splitting names use the convention\n * ARKODE_SPLITTING____") + .value("ARKODE_SPLITTING_NONE", ARKODE_SPLITTING_NONE, + "ensure enum is signed int") + .value("ARKODE_SPLITTING_LIE_TROTTER_1_1_2", + ARKODE_SPLITTING_LIE_TROTTER_1_1_2, "") + .value("ARKODE_MIN_SPLITTING_NUM", ARKODE_MIN_SPLITTING_NUM, "") + .value("ARKODE_SPLITTING_STRANG_2_2_2", ARKODE_SPLITTING_STRANG_2_2_2, "") + .value("ARKODE_SPLITTING_BEST_2_2_2", ARKODE_SPLITTING_BEST_2_2_2, "") + .value("ARKODE_SPLITTING_SUZUKI_3_3_2", ARKODE_SPLITTING_SUZUKI_3_3_2, "") + .value("ARKODE_SPLITTING_RUTH_3_3_2", ARKODE_SPLITTING_RUTH_3_3_2, "") + .value("ARKODE_SPLITTING_YOSHIDA_4_4_2", ARKODE_SPLITTING_YOSHIDA_4_4_2, "") + .value("ARKODE_SPLITTING_YOSHIDA_8_6_2", ARKODE_SPLITTING_YOSHIDA_8_6_2, "") + .value("ARKODE_MAX_SPLITTING_NUM", ARKODE_MAX_SPLITTING_NUM, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +m.def( + "SplittingStepCoefficients_Create", + [](int sequential_methods, int stages, int partitions, int order, + sundials4py::Array1d alpha_1d, sundials4py::Array1d beta_1d) + -> std::shared_ptr> + { + auto SplittingStepCoefficients_Create_adapt_arr_ptr_to_std_vector = + [](int sequential_methods, int stages, int partitions, int order, + sundials4py::Array1d alpha_1d, + sundials4py::Array1d beta_1d) -> SplittingStepCoefficients + { + sunrealtype* alpha_1d_ptr = reinterpret_cast(alpha_1d.data()); + sunrealtype* beta_1d_ptr = reinterpret_cast(beta_1d.data()); + + auto lambda_result = + SplittingStepCoefficients_Create(sequential_methods, stages, partitions, + order, alpha_1d_ptr, beta_1d_ptr); + return lambda_result; + }; + auto SplittingStepCoefficients_Create_adapt_return_type_to_shared_ptr = + [&SplittingStepCoefficients_Create_adapt_arr_ptr_to_std_vector](int sequential_methods, + int stages, + int partitions, + int order, + sundials4py::Array1d + alpha_1d, + sundials4py::Array1d + beta_1d) + -> std::shared_ptr> + { + auto lambda_result = + SplittingStepCoefficients_Create_adapt_arr_ptr_to_std_vector(sequential_methods, + stages, + partitions, + order, + alpha_1d, + beta_1d); + + return our_make_shared, + SplittingStepCoefficientsDeleter>(lambda_result); + }; + + return SplittingStepCoefficients_Create_adapt_return_type_to_shared_ptr(sequential_methods, + stages, + partitions, + order, + alpha_1d, + beta_1d); + }, + nb::arg("sequential_methods"), nb::arg("stages"), nb::arg("partitions"), + nb::arg("order"), nb::arg("alpha_1d"), nb::arg("beta_1d")); + +m.def( + "SplittingStepCoefficients_Copy", + [](SplittingStepCoefficients coefficients) + -> std::shared_ptr> + { + auto SplittingStepCoefficients_Copy_adapt_return_type_to_shared_ptr = + [](SplittingStepCoefficients coefficients) + -> std::shared_ptr> + { + auto lambda_result = SplittingStepCoefficients_Copy(coefficients); + + return our_make_shared, + SplittingStepCoefficientsDeleter>(lambda_result); + }; + + return SplittingStepCoefficients_Copy_adapt_return_type_to_shared_ptr( + coefficients); + }, + nb::arg("coefficients")); + +m.def("SplittingStepCoefficients_Write", SplittingStepCoefficients_Write, + nb::arg("coefficients"), nb::arg("outfile")); + +m.def( + "SplittingStepCoefficients_LoadCoefficients", + [](ARKODE_SplittingCoefficientsID id) + -> std::shared_ptr> + { + auto SplittingStepCoefficients_LoadCoefficients_adapt_return_type_to_shared_ptr = + [](ARKODE_SplittingCoefficientsID id) + -> std::shared_ptr> + { + auto lambda_result = SplittingStepCoefficients_LoadCoefficients(id); + + return our_make_shared, + SplittingStepCoefficientsDeleter>(lambda_result); + }; + + return SplittingStepCoefficients_LoadCoefficients_adapt_return_type_to_shared_ptr( + id); + }, + nb::arg("id")); + +m.def( + "SplittingStepCoefficients_LoadCoefficientsByName", + [](const char* name) + -> std::shared_ptr> + { + auto SplittingStepCoefficients_LoadCoefficientsByName_adapt_return_type_to_shared_ptr = + [](const char* name) + -> std::shared_ptr> + { + auto lambda_result = SplittingStepCoefficients_LoadCoefficientsByName(name); + + return our_make_shared, + SplittingStepCoefficientsDeleter>(lambda_result); + }; + + return SplittingStepCoefficients_LoadCoefficientsByName_adapt_return_type_to_shared_ptr( + name); + }, + nb::arg("name")); + +m.def("SplittingStepCoefficients_IDToName", SplittingStepCoefficients_IDToName, + nb::arg("id")); + +m.def( + "SplittingStepCoefficients_LieTrotter", + [](int partitions) + -> std::shared_ptr> + { + auto SplittingStepCoefficients_LieTrotter_adapt_return_type_to_shared_ptr = + [](int partitions) + -> std::shared_ptr> + { + auto lambda_result = SplittingStepCoefficients_LieTrotter(partitions); + + return our_make_shared, + SplittingStepCoefficientsDeleter>(lambda_result); + }; + + return SplittingStepCoefficients_LieTrotter_adapt_return_type_to_shared_ptr( + partitions); + }, + nb::arg("partitions")); + +m.def( + "SplittingStepCoefficients_Strang", + [](int partitions) + -> std::shared_ptr> + { + auto SplittingStepCoefficients_Strang_adapt_return_type_to_shared_ptr = + [](int partitions) + -> std::shared_ptr> + { + auto lambda_result = SplittingStepCoefficients_Strang(partitions); + + return our_make_shared, + SplittingStepCoefficientsDeleter>(lambda_result); + }; + + return SplittingStepCoefficients_Strang_adapt_return_type_to_shared_ptr( + partitions); + }, + nb::arg("partitions")); + +m.def( + "SplittingStepCoefficients_Parallel", + [](int partitions) + -> std::shared_ptr> + { + auto SplittingStepCoefficients_Parallel_adapt_return_type_to_shared_ptr = + [](int partitions) + -> std::shared_ptr> + { + auto lambda_result = SplittingStepCoefficients_Parallel(partitions); + + return our_make_shared, + SplittingStepCoefficientsDeleter>(lambda_result); + }; + + return SplittingStepCoefficients_Parallel_adapt_return_type_to_shared_ptr( + partitions); + }, + nb::arg("partitions")); + +m.def( + "SplittingStepCoefficients_SymmetricParallel", + [](int partitions) + -> std::shared_ptr> + { + auto SplittingStepCoefficients_SymmetricParallel_adapt_return_type_to_shared_ptr = + [](int partitions) + -> std::shared_ptr> + { + auto lambda_result = SplittingStepCoefficients_SymmetricParallel(partitions); + + return our_make_shared, + SplittingStepCoefficientsDeleter>(lambda_result); + }; + + return SplittingStepCoefficients_SymmetricParallel_adapt_return_type_to_shared_ptr( + partitions); + }, + nb::arg("partitions")); + +m.def( + "SplittingStepCoefficients_ThirdOrderSuzuki", + [](int partitions) + -> std::shared_ptr> + { + auto SplittingStepCoefficients_ThirdOrderSuzuki_adapt_return_type_to_shared_ptr = + [](int partitions) + -> std::shared_ptr> + { + auto lambda_result = SplittingStepCoefficients_ThirdOrderSuzuki(partitions); + + return our_make_shared, + SplittingStepCoefficientsDeleter>(lambda_result); + }; + + return SplittingStepCoefficients_ThirdOrderSuzuki_adapt_return_type_to_shared_ptr( + partitions); + }, + nb::arg("partitions")); + +m.def( + "SplittingStepCoefficients_TripleJump", + [](int partitions, int order) + -> std::shared_ptr> + { + auto SplittingStepCoefficients_TripleJump_adapt_return_type_to_shared_ptr = + [](int partitions, int order) + -> std::shared_ptr> + { + auto lambda_result = SplittingStepCoefficients_TripleJump(partitions, + order); + + return our_make_shared, + SplittingStepCoefficientsDeleter>(lambda_result); + }; + + return SplittingStepCoefficients_TripleJump_adapt_return_type_to_shared_ptr(partitions, + order); + }, + nb::arg("partitions"), nb::arg("order")); + +m.def( + "SplittingStepCoefficients_SuzukiFractal", + [](int partitions, int order) + -> std::shared_ptr> + { + auto SplittingStepCoefficients_SuzukiFractal_adapt_return_type_to_shared_ptr = + [](int partitions, int order) + -> std::shared_ptr> + { + auto lambda_result = SplittingStepCoefficients_SuzukiFractal(partitions, + order); + + return our_make_shared, + SplittingStepCoefficientsDeleter>(lambda_result); + }; + + return SplittingStepCoefficients_SuzukiFractal_adapt_return_type_to_shared_ptr(partitions, + order); + }, + nb::arg("partitions"), nb::arg("order")); + +m.def("SplittingStepSetCoefficients", SplittingStepSetCoefficients, + nb::arg("arkode_mem"), nb::arg("coefficients")); + +m.def( + "SplittingStepGetNumEvolves", + [](void* arkode_mem, int partition) -> std::tuple + { + auto SplittingStepGetNumEvolves_adapt_modifiable_immutable_to_return = + [](void* arkode_mem, int partition) -> std::tuple + { + long evolves_adapt_modifiable; + + int r = SplittingStepGetNumEvolves(arkode_mem, partition, + &evolves_adapt_modifiable); + return std::make_tuple(r, evolves_adapt_modifiable); + }; + + return SplittingStepGetNumEvolves_adapt_modifiable_immutable_to_return(arkode_mem, + partition); + }, + nb::arg("arkode_mem"), nb::arg("partition")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/arkode/arkode_sprkstep.cpp b/bindings/sundials4py/arkode/arkode_sprkstep.cpp new file mode 100644 index 0000000000..c2988c8cb9 --- /dev/null +++ b/bindings/sundials4py/arkode/arkode_sprkstep.cpp @@ -0,0 +1,73 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include +#include + +#include "arkode_usersupplied.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_arkode_sprkstep(nb::module_& m) +{ +#include "arkode_sprkstep_generated.hpp" + + m.def( + "SPRKStepCreate", + [](std::function> f1, + std::function> f2, sunrealtype t0, + N_Vector y0, SUNContext sunctx) + { + if (!f1 && !f2) + { + throw sundials4py::illegal_value("f1 and f2 cannot be null"); + } + + void* ark_mem = SPRKStepCreate(sprkstep_f1_wrapper, sprkstep_f2_wrapper, + t0, y0, sunctx); + if (ark_mem == nullptr) + { + throw sundials4py::error_returned("Failed to create SPRKStep memory"); + } + + auto fn_table = arkode_user_supplied_fn_table_alloc(); + fn_table->sprkstep_f1 = nb::cast(f1); + fn_table->sprkstep_f2 = nb::cast(f2); + + static_cast(ark_mem)->python = fn_table; + + int ark_status = ARKodeSetUserData(ark_mem, ark_mem); + if (ark_status != ARK_SUCCESS) + { + free(fn_table); + throw sundials4py::error_returned( + "Failed to set user data in SPRKStep memory"); + } + + return std::make_shared(ark_mem); + }, + nb::arg("f1"), nb::arg("f2"), nb::arg("t0"), nb::arg("y0"), + nb::arg("sunctx"), nb::keep_alive<0, 5>()); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/arkode/arkode_sprkstep_generated.hpp b/bindings/sundials4py/arkode/arkode_sprkstep_generated.hpp new file mode 100644 index 0000000000..793d18738a --- /dev/null +++ b/bindings/sundials4py/arkode/arkode_sprkstep_generated.hpp @@ -0,0 +1,38 @@ +// #ifndef _ARKODE_SPRKSTEP_H +// +// #ifdef __cplusplus +// #endif +// + +m.def("SPRKStepSetMethod", SPRKStepSetMethod, nb::arg("arkode_mem"), + nb::arg("sprk_storage")); + +m.def("SPRKStepSetMethodName", SPRKStepSetMethodName, nb::arg("arkode_mem"), + nb::arg("method")); + +m.def( + "SPRKStepGetCurrentMethod", + [](void* arkode_mem) -> std::tuple + { + auto SPRKStepGetCurrentMethod_adapt_modifiable_immutable_to_return = + [](void* arkode_mem) -> std::tuple + { + ARKodeSPRKTable sprk_storage_adapt_modifiable; + + int r = SPRKStepGetCurrentMethod(arkode_mem, + &sprk_storage_adapt_modifiable); + return std::make_tuple(r, sprk_storage_adapt_modifiable); + }; + + return SPRKStepGetCurrentMethod_adapt_modifiable_immutable_to_return( + arkode_mem); + }, + nb::arg("arkode_mem"), + " Optional output functions\n\n nb::rv_policy::reference", + nb::rv_policy::reference); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/arkode/arkode_usersupplied.hpp b/bindings/sundials4py/arkode/arkode_usersupplied.hpp new file mode 100644 index 0000000000..010337c979 --- /dev/null +++ b/bindings/sundials4py/arkode/arkode_usersupplied.hpp @@ -0,0 +1,609 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#ifndef _SUNDIALS4PY_ARKODE_USERSUPPLIED_HPP +#define _SUNDIALS4PY_ARKODE_USERSUPPLIED_HPP + +#include +#include +#include +#include +#include +#include + +#include + +#include "arkode_mristep_impl.h" + +#include "sundials/sundials_nvector.h" +#include "sundials4py_helpers.hpp" + +/////////////////////////////////////////////////////////////////////////////// +// ARKODE user-supplied function table +// Every integrator-level user-supplied function must be in this table. +// The user-supplied function table is passed to ARKODE as user_data. +/////////////////////////////////////////////////////////////////////////////// + +struct arkode_user_supplied_fn_table +{ + // common user-supplied function pointers + nb::object rootfn; + nb::object ewtn; + nb::object rwtn; + nb::object vecresizefn; + nb::object postprocessstepfn; + nb::object postprocessstagefn; + nb::object stagepredictfn; + nb::object relaxfn; + nb::object relaxjacfn; + nb::object nlsfi; + + // arkode_ls user-supplied function pointers + nb::object lsjacfn; + nb::object lsmassfn; + nb::object lsprecsetupfn; + nb::object lsprecsolvefn; + nb::object lsjactimessetupfn; + nb::object lsjactimesvecfn; + nb::object lslinsysfn; + nb::object lsmasstimessetupfn; + nb::object lsmasstimesvecfn; + nb::object lsmassprecsetupfn; + nb::object lsmassprecsolvefn; + nb::object lsjacrhsfn; + + // erkstep-specific user-supplied function pointers + nb::object erkstep_f; + nb::object erkstep_adjf; + + // arkstep-specific user-supplied function pointers + nb::object arkstep_fe; + nb::object arkstep_fi; + nb::object arkstep_adjfe; + nb::object arkstep_adjfi; + + // sprkstep-specific user-supplied function pointers + nb::object sprkstep_f1; + nb::object sprkstep_f2; + + // lsrkstep-specific user-supplied function pointers + nb::object lsrkstep_f; + nb::object lsrkstep_domeig; + + // mristep-specific user-supplied function pointers + nb::object mristep_fse; + nb::object mristep_fsi; + nb::object mristep_preinnerfn; + nb::object mristep_postinnerfn; +}; + +struct mristepinnerstepper_user_supplied_fn_table +{ + nb::object mristepinner_evolvefn; + nb::object mristepinner_fullrhsfn; + nb::object mristepinner_resetfn; + nb::object mristepinner_getaccumulatederrorfn; + nb::object mristepinner_resetaccumulatederrorfn; + nb::object mristepinner_setrtolfn; +}; + +/////////////////////////////////////////////////////////////////////////////// +// ARKODE user-supplied functions +/////////////////////////////////////////////////////////////////////////////// + +inline arkode_user_supplied_fn_table* arkode_user_supplied_fn_table_alloc() +{ + // We must use malloc since ARKodeFree calls free + auto fn_table = static_cast( + std::malloc(sizeof(arkode_user_supplied_fn_table))); + + // Zero out the memory + std::memset(fn_table, 0, sizeof(arkode_user_supplied_fn_table)); + + return fn_table; +} + +inline arkode_user_supplied_fn_table* get_arkode_fn_table(void* ark_mem) +{ + auto mem = static_cast(ark_mem); + auto fn_table = static_cast(mem->python); + if (!fn_table) + { + throw sundials4py::null_function_table( + "Failed to get Python function table from ARKODE memory"); + } + return fn_table; +} + +inline mristepinnerstepper_user_supplied_fn_table* mristepinnerstepper_user_supplied_fn_table_alloc() +{ + // We must use malloc since ARKodeFree calls free + auto fn_table = static_cast( + std::malloc(sizeof(mristepinnerstepper_user_supplied_fn_table))); + + // Zero out the memory + std::memset(fn_table, 0, sizeof(mristepinnerstepper_user_supplied_fn_table)); + + return fn_table; +} + +using ARKRootStdFn = int(sunrealtype t, N_Vector y, sundials4py::Array1d gout, + void* user_data); + +inline int arkode_rootfn_wrapper(sunrealtype t, N_Vector y, + sunrealtype* gout_1d, void* user_data) +{ + auto fn_table = get_arkode_fn_table(user_data); + auto fn = nb::cast>(fn_table->rootfn); + auto nrtfn = static_cast(user_data)->root_mem->nrtfn; + + sundials4py::Array1d gout(gout_1d, {static_cast(nrtfn)}, + nb::find(gout_1d)); + + int status = fn(t, y, gout, nullptr); + + return status; +} + +template +inline int arkode_ewtfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, ARKodeMem, + 1>(&arkode_user_supplied_fn_table::ewtn, std::forward(args)...); +} + +template +inline int arkode_rwtfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, ARKodeMem, + 1>(&arkode_user_supplied_fn_table::rwtn, std::forward(args)...); +} + +template +inline int arkode_vecresizefn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, + ARKodeMem, 1>(&arkode_user_supplied_fn_table::vecresizefn, + std::forward(args)...); +} + +template +inline int arkode_postprocessstepfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, + ARKodeMem, 1>(&arkode_user_supplied_fn_table::postprocessstepfn, + std::forward(args)...); +} + +template +inline int arkode_postprocessstagefn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, + ARKodeMem, 1>(&arkode_user_supplied_fn_table::postprocessstagefn, + std::forward(args)...); +} + +template +inline int arkode_stagepredictfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, + ARKodeMem, 1>(&arkode_user_supplied_fn_table::stagepredictfn, + std::forward(args)...); +} + +template +inline int arkode_nlsrhsfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, ARKodeMem, + 1>(&arkode_user_supplied_fn_table::nlsfi, std::forward(args)...); +} + +using ARKRelaxStdFn = std::tuple(N_Vector y, void* user_data); + +inline int arkode_relaxfn_wrapper(N_Vector y, sunrealtype* r, void* user_data) +{ + auto fn_table = static_cast(user_data); + auto fn = nb::cast>(fn_table->relaxfn); + + auto result = fn(y, nullptr); + + *r = std::get<1>(result); + + return std::get<0>(result); +} + +template +inline int arkode_relaxjacfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, ARKodeMem, + 1>(&arkode_user_supplied_fn_table::relaxjacfn, std::forward(args)...); +} + +template +inline int arkode_lsjacfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, ARKodeMem, + 4>(&arkode_user_supplied_fn_table::lsjacfn, std::forward(args)...); +} + +template +inline int arkode_lsmassfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, ARKodeMem, + 4>(&arkode_user_supplied_fn_table::lsmassfn, std::forward(args)...); +} + +using ARKLsPrecSetupStdFn = std::tuple( + sunrealtype t, N_Vector y, N_Vector fy, sunbooleantype jok, sunrealtype gamma, + void* user_data); + +inline int arkode_lsprecsetupfn_wrapper(sunrealtype t, N_Vector y, N_Vector fy, + sunbooleantype jok, + sunbooleantype* jcurPtr, + sunrealtype gamma, void* user_data) +{ + auto fn_table = static_cast(user_data); + auto fn = nb::cast>(fn_table->lsprecsetupfn); + + auto result = fn(t, y, fy, jok, gamma, nullptr); + + *jcurPtr = std::get<1>(result); + + return std::get<0>(result); +} + +template +inline int arkode_lsprecsolvefn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, + ARKodeMem, 1>(&arkode_user_supplied_fn_table::lsprecsolvefn, + std::forward(args)...); +} + +template +inline int arkode_lsjactimessetupfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, + ARKodeMem, 1>(&arkode_user_supplied_fn_table::lsjactimessetupfn, + std::forward(args)...); +} + +template +inline int arkode_lsjactimesvecfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, + ARKodeMem, 2>(&arkode_user_supplied_fn_table::lsjactimesvecfn, + std::forward(args)...); +} + +using ARKLsLinSysStdFn = std::tuple( + sunrealtype t, N_Vector y, N_Vector fy, SUNMatrix A, SUNMatrix M, + sunbooleantype jok, sunrealtype gamma, void* user_data, N_Vector tmp1, + N_Vector tmp2, N_Vector tmp3); + +inline int arkode_lslinsysfn_wrapper(sunrealtype t, N_Vector y, N_Vector fy, + SUNMatrix A, SUNMatrix M, + sunbooleantype jok, sunbooleantype* jcurPtr, + sunrealtype gamma, void* user_data, + N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) +{ + auto fn_table = static_cast(user_data); + auto fn = nb::cast>(fn_table->lslinsysfn); + + auto result = fn(t, y, fy, A, M, jok, gamma, nullptr, tmp1, tmp2, tmp3); + + *jcurPtr = std::get<1>(result); + + return std::get<0>(result); +} + +template +inline int arkode_lsmasstimessetupfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, + ARKodeMem, 1>(&arkode_user_supplied_fn_table::lsmasstimessetupfn, + std::forward(args)...); +} + +template +inline int arkode_lsmasstimesvecfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, + ARKodeMem, 1>(&arkode_user_supplied_fn_table::lsmasstimesvecfn, + std::forward(args)...); +} + +template +inline int arkode_lsmassprecsetupfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, + ARKodeMem, 1>(&arkode_user_supplied_fn_table::lsmassprecsetupfn, + std::forward(args)...); +} + +template +inline int arkode_lsmassprecsolvefn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, + ARKodeMem, 1>(&arkode_user_supplied_fn_table::lsmassprecsolvefn, + std::forward(args)...); +} + +template +inline int arkode_lsjacrhsfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, ARKodeMem, + 1>(&arkode_user_supplied_fn_table::lsjacrhsfn, std::forward(args)...); +} + +/////////////////////////////////////////////////////////////////////////////// +// ERKStep user-supplied functions +/////////////////////////////////////////////////////////////////////////////// + +template +inline int erkstep_f_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, ARKodeMem, + 1>(&arkode_user_supplied_fn_table::erkstep_f, std::forward(args)...); +} + +template +inline int erkstep_adjf_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, ARKodeMem, + 1>(&arkode_user_supplied_fn_table::erkstep_adjf, std::forward(args)...); +} + +/////////////////////////////////////////////////////////////////////////////// +// ARKStep user-supplied functions +/////////////////////////////////////////////////////////////////////////////// + +inline int arkstep_fe_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, + void* user_data) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, ARKodeMem, + 1>(&arkode_user_supplied_fn_table::arkstep_fe, t, y, ydot, user_data); +} + +inline int arkstep_fi_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, + void* user_data) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, ARKodeMem, + 1>(&arkode_user_supplied_fn_table::arkstep_fi, t, y, ydot, user_data); +} + +inline int arkstep_adjfe_wrapper(sunrealtype t, N_Vector y, N_Vector sens, + N_Vector sens_dot, void* user_data) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, + ARKodeMem, 1>(&arkode_user_supplied_fn_table::arkstep_adjfe, t, y, sens, + sens_dot, user_data); +} + +inline int arkstep_adjfi_wrapper(sunrealtype t, N_Vector y, N_Vector sens, + N_Vector sens_dot, void* user_data) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, + ARKodeMem, 1>(&arkode_user_supplied_fn_table::arkstep_adjfi, t, y, sens, + sens_dot, user_data); +} + +/////////////////////////////////////////////////////////////////////////////// +// SPRKStep user-supplied functions +/////////////////////////////////////////////////////////////////////////////// + +template +inline int sprkstep_f1_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, ARKodeMem, + 1>(&arkode_user_supplied_fn_table::sprkstep_f1, std::forward(args)...); +} + +template +inline int sprkstep_f2_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, ARKodeMem, + 1>(&arkode_user_supplied_fn_table::sprkstep_f2, std::forward(args)...); +} + +/////////////////////////////////////////////////////////////////////////////// +// LSRKStep user-supplied functions +/////////////////////////////////////////////////////////////////////////////// + +template +inline int lsrkstep_f_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, ARKodeMem, + 1>(&arkode_user_supplied_fn_table::lsrkstep_f, std::forward(args)...); +} + +using ARKDomEigStdFn = std::tuple( + sunrealtype t, N_Vector y, N_Vector fn, void* user_data, N_Vector temp1, + N_Vector temp2, N_Vector temp3); + +inline int lsrkstep_domeig_wrapper(sunrealtype t, N_Vector y, N_Vector fn, + sunrealtype* lambdaR, sunrealtype* lambdaI, + void* user_data, N_Vector temp1, + N_Vector temp2, N_Vector temp3) +{ + auto fn_table = get_arkode_fn_table(user_data); + auto callback = + nb::cast>(fn_table->lsrkstep_domeig); + + auto result = callback(t, y, fn, nullptr, temp1, temp2, temp3); + + *lambdaR = std::get<1>(result); + *lambdaI = std::get<2>(result); + + return std::get<0>(result); +} + +/////////////////////////////////////////////////////////////////////////////// +// MRIStep user-supplied functions +/////////////////////////////////////////////////////////////////////////////// + +inline int mristep_fse_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, + void* user_data) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, ARKodeMem, + 1>(&arkode_user_supplied_fn_table::mristep_fse, t, y, ydot, user_data); +} + +inline int mristep_fsi_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, + void* user_data) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, ARKodeMem, + 1>(&arkode_user_supplied_fn_table::mristep_fsi, t, y, ydot, user_data); +} + +using MRIStepPreInnerStdFn = int(sunrealtype t, std::vector f, + int nvecs, void* user_data); + +inline int mristep_preinnerfn_wrapper(sunrealtype t, N_Vector* f_1d, int nvecs, + void* user_data) +{ + auto fn_table = static_cast(user_data); + auto fn = + nb::cast>(fn_table->mristep_preinnerfn); + + std::vector f(f_1d, f_1d + nvecs); + + return fn(t, f, nvecs, nullptr); +} + +template +inline int mristep_postinnerfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, + ARKodeMem, 1>(&arkode_user_supplied_fn_table::mristep_postinnerfn, + std::forward(args)...); +} + +inline int mristepinner_evolvefn_wrapper(MRIStepInnerStepper stepper, + sunrealtype t0, sunrealtype tout, + N_Vector y) +{ + void* user_data = nullptr; + MRIStepInnerStepper_GetContent(stepper, &user_data); + + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, mristepinnerstepper_user_supplied_fn_table, + MRIStepInnerStepper>(&mristepinnerstepper_user_supplied_fn_table::mristepinner_evolvefn, + stepper, t0, tout, y); +} + +inline int mristepinner_fullrhsfn_wrapper(MRIStepInnerStepper stepper, + sunrealtype t, N_Vector y, N_Vector f, + int mode) +{ + void* user_data = nullptr; + MRIStepInnerStepper_GetContent(stepper, &user_data); + + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, mristepinnerstepper_user_supplied_fn_table, + MRIStepInnerStepper>(&mristepinnerstepper_user_supplied_fn_table::mristepinner_fullrhsfn, + stepper, t, y, f, mode); +} + +inline int mristepinner_resetfn_wrapper(MRIStepInnerStepper stepper, + sunrealtype tR, N_Vector yR) +{ + void* user_data = nullptr; + MRIStepInnerStepper_GetContent(stepper, &user_data); + + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, mristepinnerstepper_user_supplied_fn_table, + MRIStepInnerStepper>(&mristepinnerstepper_user_supplied_fn_table::mristepinner_resetfn, + stepper, tR, yR); +} + +using MRIStepInnerGetAccumulatedErrorStdFn = + std::tuple(MRIStepInnerStepper stepper); + +inline int mristepinner_getaccumulatederrorfn_wrapper(MRIStepInnerStepper stepper, + sunrealtype* accum_error) +{ + void* user_data = nullptr; + MRIStepInnerStepper_GetContent(stepper, &user_data); + + auto fn_table = + static_cast(user_data); + auto fn = nb::cast>( + fn_table->mristepinner_getaccumulatederrorfn); + + auto result = fn(stepper); + + *accum_error = std::get<1>(result); + + return std::get<0>(result); +} + +inline int mristepinner_resetaccumulatederrorfn_wrapper(MRIStepInnerStepper stepper) +{ + void* user_data = nullptr; + MRIStepInnerStepper_GetContent(stepper, &user_data); + + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, + mristepinnerstepper_user_supplied_fn_table, + MRIStepInnerStepper>(&mristepinnerstepper_user_supplied_fn_table::mristepinner_resetaccumulatederrorfn, + stepper); +} + +inline int mristepinner_setrtolfn_wrapper(MRIStepInnerStepper stepper, + sunrealtype rtol) +{ + void* user_data = nullptr; + MRIStepInnerStepper_GetContent(stepper, &user_data); + + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, mristepinnerstepper_user_supplied_fn_table, + MRIStepInnerStepper>(&mristepinnerstepper_user_supplied_fn_table::mristepinner_setrtolfn, + stepper, rtol); +} + +#endif diff --git a/bindings/sundials4py/arkode/generate.yaml b/bindings/sundials4py/arkode/generate.yaml new file mode 100644 index 0000000000..6622cb2061 --- /dev/null +++ b/bindings/sundials4py/arkode/generate.yaml @@ -0,0 +1,149 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2002-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This YAML file is used to instruct the generate.py script on +# how to generate some of the nanobind code needed for the +# ARKODE module of sundials4py. +# ----------------------------------------------------------------- + +modules: + all: + fn_exclude_by_name__regex: + - "Free" # Free and destroy functions should not need to be called as all objects on the Python side are RAII + - "Destroy" + - "Space" # Space functions are deprecated, so don't expose them in Python + # Due to the need to convert between sys.argv and C argv, we need to do custom wrappers of these + - "SetOptions" + macro_define_include_by_name__regex: + - "^SUN_" + - "^ARK_" + - "^ARKLS_" + - "^ARKODE_" + fn_params_optional_with_default_null: + "SetLinearSolver": + - "A" + arkode: + path: arkode/arkode_generated.hpp + headers: + - ../../include/arkode/arkode.h + - ../../include/arkode/arkode_ls.h + - ../../include/arkode/arkode_butcher.h + - ../../include/arkode/arkode_butcher_erk.h + - ../../include/arkode/arkode_butcher_dirk.h + - ../../include/arkode/arkode_sprk.h + sundials_pointer_types: + - "ARKodeButcherTable" + - "ARKodeSPRKTable" + fn_exclude_by_name__regex: + # we dont interface the Alloc functions as they are not needed from Python + - "^ARKodeButcherTable_Alloc$" + - "^ARKodeSPRKTable_Alloc$" + # we use user_data for sneaking in python contexts, so we don't interface these + - "^ARKodeGetUserData$" + - "^ARKodeSetUserData$" + # generator cannot handle setting of function pointers, so we do something custom + - "^ARKodeSet.*Fn$" + - "^ARKodeSetWFtolerances$" + - "^ARKodeSet.*Preconditioner$" + - "^ARKodeSet.*Times$" + - "^ARKodeWFtolerances$" + - "^ARKodeResFtolerance$" + - "^ARKodeRootInit$" + - "^ARKodeResize$" + # generator cannot handle functions with optional (i.e. NULLable) parameters that is not + # followed by only optional parameters, so we have to do something custom + - "^ARKodeSetMassLinearSolver$" + # TODO(CJB): interface this (in the future?) + # generator cannot yet handle mixing pointer outputs and ** in the same function + - "^ARKodeGetNonlinearSystemData$" + arkstep: + path: arkode/arkode_arkstep_generated.hpp + headers: + - ../../include/arkode/arkode_arkstep.h + fn_exclude_by_name__regex: + # we dont bind the ARKStepCreate function as we need to do something custom + - "^ARKStepCreate$" + - "^ARKStepCreateAdjointStepper$" + - "^ARKStepReInit$" + # reinit also requires custom handling due to function pointers + - "^ARKStepReInit$" + erkstep: + path: arkode/arkode_erkstep_generated.hpp + headers: + - ../../include/arkode/arkode_erkstep.h + fn_exclude_by_name__regex: + # we dont bind the ERKStepCreate function as we need to do something custom + - "^ERKStepCreate$" + - "^ERKStepCreateAdjointStepper$" + # reinit also requires custom handling due to function pointers + - "^ERKStepReInit$" + sprkstep: + path: arkode/arkode_sprkstep_generated.hpp + headers: + - ../../include/arkode/arkode_sprkstep.h + fn_exclude_by_name__regex: + - "^SPRKStepCreate$" + # reinit also requires custom handling due to function pointers + - "^SPRKStepReInit$" + lsrkstep: + path: arkode/arkode_lsrkstep_generated.hpp + headers: + - ../../include/arkode/arkode_lsrkstep.h + fn_exclude_by_name__regex: + # these functions take function pointers, so we have to do something custom + - "^LSRKStepCreate.*$" + - "^LSRKStepReInit.*$" + - "^LSRKStepSetDomEigFn$" + mristep: + path: arkode/arkode_mristep_generated.hpp + headers: + - ../../include/arkode/arkode_mristep.h + sundials_pointer_types: + - "MRIStepCoupling" + - "MRIStepInnerStepper" + fn_exclude_by_name__regex: + # we don't allow alloc to be used from Python + - "^MRIStepCoupling_Alloc$" + # we do custom handling of Create routines + - "^MRIStepCreate$" + - "^MRIStepInnerStepper_Create$" + - "^MRIStepInnerStepper_CreateFromSUNStepper$" + # reinit also requires custom handling due to function pointers + - "^MRIStepReInit$" + # We steal the MRIStepInnerStepper content for the callback table, so don't interface the set/get + - "^MRIStepInnerStepper_GetContent$" + - "^MRIStepInnerStepper_SetContent$" + # we have to do custom things to handle setting function pointers + - "^MRIStepInnerStepper_Set.*Fn$" + - "^MRIStepSetPreInnerFn$" + - "^MRIStepSetPostInnerFn$" + # This function takes a pointer to a pointer which the generator cannot yet handle, so we do custom handling + - "^MRIStepInnerStepper_GetForcingData$" + forcingstep: + path: arkode/arkode_forcingstep_generated.hpp + headers: + - ../../include/arkode/arkode_forcingstep.h + fn_exclude_by_name__regex: + - "^ForcingStepCreate$" + splittingstep: + path: arkode/arkode_splittingstep_generated.hpp + headers: + - ../../include/arkode/arkode_splittingstep.h + sundials_pointer_types: + - "SplittingStepCoefficients" + fn_exclude_by_name__regex: + # we don't allow alloc to be used from Python + - "^SplittingStepCoefficients_Alloc$" + # we do custom handling of the stepper create/reinit + - "^SplittingStepCreate$" + - "^SplittingStepReInit$" diff --git a/bindings/sundials4py/cvodes/cvodes.cpp b/bindings/sundials4py/cvodes/cvodes.cpp new file mode 100644 index 0000000000..0eef362a4c --- /dev/null +++ b/bindings/sundials4py/cvodes/cvodes.cpp @@ -0,0 +1,345 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include + +#include +#include +#include +#include + +#include "cvodes/cvodes_impl.h" +#include "cvodes_usersupplied.hpp" + +#include "sundials_adjointcheckpointscheme_impl.h" + +namespace sundials4py { + +using namespace sundials::experimental; + +#define BIND_CVODE_CALLBACK(NAME, FN_TYPE, MEMBER, WRAPPER, ...) \ + m.def( \ + #NAME, \ + [](void* cv_mem, std::function> fn) \ + { \ + auto fn_table = get_cvode_fn_table(cv_mem); \ + fn_table->MEMBER = nb::cast(fn); \ + if (fn) { return NAME(cv_mem, &WRAPPER); } \ + else { return NAME(cv_mem, nullptr); } \ + }, \ + __VA_ARGS__) + +#define BIND_CVODE_CALLBACK2(NAME, FN_TYPE1, MEMBER1, WRAPPER1, FN_TYPE2, \ + MEMBER2, WRAPPER2, ...) \ + m.def( \ + #NAME, \ + [](void* cv_mem, std::function> fn1, \ + std::function> fn2) \ + { \ + auto fn_table = get_cvode_fn_table(cv_mem); \ + fn_table->MEMBER1 = nb::cast(fn1); \ + fn_table->MEMBER2 = nb::cast(fn2); \ + if (fn1) { return NAME(cv_mem, WRAPPER1, WRAPPER2); } \ + else { return NAME(cv_mem, nullptr, WRAPPER2); } \ + }, \ + __VA_ARGS__) + +#define BIND_CVODEB_CALLBACK(NAME, FN_TYPE, MEMBER, WRAPPER, ...) \ + m.def( \ + #NAME, \ + [](void* cv_mem, int which, std::function> fn) \ + { \ + void* user_data = nullptr; \ + auto fn_table = get_cvodea_fn_table(cv_mem, which); \ + fn_table->MEMBER = nb::cast(fn); \ + if (fn) { return NAME(cv_mem, which, &WRAPPER); } \ + else { return NAME(cv_mem, which, nullptr); } \ + }, \ + __VA_ARGS__) + +#define BIND_CVODEB_CALLBACK2(NAME, FN_TYPE1, MEMBER1, WRAPPER1, FN_TYPE2, \ + MEMBER2, WRAPPER2, ...) \ + m.def( \ + #NAME, \ + [](void* cv_mem, int which, \ + std::function> fn1, \ + std::function> fn2) \ + { \ + void* user_data = nullptr; \ + auto fn_table = get_cvodea_fn_table(cv_mem, which); \ + fn_table->MEMBER1 = nb::cast(fn1); \ + fn_table->MEMBER2 = nb::cast(fn2); \ + if (fn1) { return NAME(cv_mem, which, WRAPPER1, WRAPPER2); } \ + else { return NAME(cv_mem, which, nullptr, WRAPPER2); } \ + }, \ + __VA_ARGS__) + +void bind_cvodes(nb::module_& m) +{ +#include "cvodes_generated.hpp" + + nb::class_(m, "CVodeView") + .def("get", nb::overload_cast<>(&CVodeView::get, nb::const_), + nb::rv_policy::reference); + + m.def( + "CVodeSetOptions", + [](void* cv_mem, const std::string& cvid, const std::string& file_name, + int argc, const std::vector& args) + { + std::vector argv; + + for (const auto& arg : args) + { + // We need a non-const char*, so we use data() and an explicit cast. + // This is safe as long as the underlying std::string is not modified. + argv.push_back(const_cast(arg.data())); + } + + return CVodeSetOptions(cv_mem, cvid.empty() ? nullptr : cvid.c_str(), + file_name.empty() ? nullptr : file_name.c_str(), + argc, argv.data()); + }, + nb::arg("cv_mem"), nb::arg("cvid"), nb::arg("file_name"), nb::arg("argc"), + nb::arg("args")); + + m.def( + "CVodeCreate", + [](int lmm, SUNContext sunctx) + { return std::make_shared(CVodeCreate(lmm, sunctx)); }, + nb::arg("lmm"), nb::arg("sunctx"), nb::keep_alive<0, 2>()); + + m.def("CVodeInit", + [](void* cv_mem, std::function> rhs, + sunrealtype t0, N_Vector y0) + { + int cv_status = CVodeInit(cv_mem, cvode_f_wrapper, t0, y0); + + // Create the user-supplied function table to store the Python user functions + auto fn_table = cvode_user_supplied_fn_table_alloc(); + + static_cast(cv_mem)->python = fn_table; + + // Smuggle the user-supplied function table into callback wrappers through the user_data pointer + cv_status = CVodeSetUserData(cv_mem, cv_mem); + if (cv_status != CV_SUCCESS) + { + free(fn_table); + throw sundials4py::error_returned( + "Failed to set user data in CVODE memory"); + } + + // Finally, set the RHS function + fn_table->f = nb::cast(rhs); + + return cv_status; + }); + + m.def("CVodeRootInit", + [](void* cv_mem, int nrtfn, + std::function> fn) + { + auto fn_table = get_cvode_fn_table(cv_mem); + fn_table->rootfn = nb::cast(fn); + return CVodeRootInit(cv_mem, nrtfn, &cvode_rootfn_wrapper); + }); + + m.def("CVodeQuadInit", + [](void* cv_mem, std::function> fQ, + N_Vector yQ0) + { + auto fn_table = get_cvode_fn_table(cv_mem); + fn_table->fQ = nb::cast(fQ); + return CVodeQuadInit(cv_mem, &cvode_fQ_wrapper, yQ0); + }); + + BIND_CVODE_CALLBACK(CVodeWFtolerances, CVEwtFn, ewtn, cvode_ewtfn_wrapper, + nb::arg("cvode_mem"), nb::arg("efun").none()); + + BIND_CVODE_CALLBACK(CVodeSetNlsRhsFn, CVRhsFn, fNLS, cvode_nlsrhsfn_wrapper, + nb::arg("cvode_mem"), nb::arg("f").none()); + + BIND_CVODE_CALLBACK(CVodeSetJacFn, CVLsJacFn, lsjacfn, cvode_lsjacfn_wrapper, + nb::arg("cvode_mem"), nb::arg("jac").none()); + + BIND_CVODE_CALLBACK2(CVodeSetPreconditioner, CVLsPrecSetupFn, lsprecsetupfn, + cvode_lsprecsetupfn_wrapper, CVLsPrecSolveFn, + lsprecsolvefn, cvode_lsprecsolvefn_wrapper, + nb::arg("cvode_mem"), nb::arg("pset").none(), + nb::arg("psolve").none()); + + BIND_CVODE_CALLBACK2(CVodeSetJacTimes, CVLsJacTimesSetupFn, lsjactimessetupfn, + cvode_lsjactimessetupfn_wrapper, CVLsJacTimesVecFn, + lsjactimesvecfn, cvode_lsjactimesvecfn_wrapper, + nb::arg("cvode_mem"), nb::arg("jtsetup").none(), + nb::arg("jtimes").none()); + + BIND_CVODE_CALLBACK(CVodeSetLinSysFn, CVLsLinSysFn, lslinsysfn, + cvode_lslinsysfn_wrapper, nb::arg("cvode_mem"), + nb::arg("linsys").none()); + + BIND_CVODE_CALLBACK(CVodeSetJacTimesRhsFn, CVLsJacTimesVecFn, lsjacrhsfn, + cvode_lsjacrhsfn_wrapper, nb::arg("cvode_mem"), + nb::arg("jtimesRhsFn").none()); + + BIND_CVODE_CALLBACK(CVodeSetProjFn, CVProjFn, projfn, cvode_projfn_wrapper, + nb::arg("cvode_mem"), nb::arg("pfun").none()); + + m.def("CVodeQuadSensInit", + [](void* cv_mem, std::function fQS, + std::vector yQS0) + { + auto fn_table = get_cvode_fn_table(cv_mem); + fn_table->fQS = nb::cast(fQS); + return CVodeQuadSensInit(cv_mem, cvode_fQS_wrapper, yQS0.data()); + }); + + m.def("CVodeSensInit", + [](void* cv_mem, int Ns, int ism, std::function fS, + std::vector yS0) + { + auto fn_table = get_cvode_fn_table(cv_mem); + fn_table->fS = nb::cast(fS); + return CVodeSensInit(cv_mem, Ns, ism, cvode_fS_wrapper, yS0.data()); + }); + + m.def("CVodeSensInit1", + [](void* cv_mem, int Ns, int ism, + std::function> fS1, + std::vector yS0) + { + auto fn_table = get_cvode_fn_table(cv_mem); + fn_table->fS1 = nb::cast(fS1); + return CVodeSensInit1(cv_mem, Ns, ism, cvode_fS1_wrapper, yS0.data()); + }); + + /// + // CVODES Adjoint Bindings + /// + + m.def("CVodeInitB", + [](void* cv_mem, int which, + std::function> fB, sunrealtype tB0, + N_Vector yB0) + { + int cv_status = CVodeInitB(cv_mem, which, cvode_fB_wrapper, tB0, yB0); + + auto fn_table = cvodea_user_supplied_fn_table_alloc(); + auto cvb_mem = + static_cast(CVodeGetAdjCVodeBmem(cv_mem, which)); + cvb_mem->python = fn_table; + + cv_status = CVodeSetUserDataB(cv_mem, which, cvb_mem); + if (cv_status != CV_SUCCESS) + { + free(fn_table); + throw sundials4py::error_returned( + "Failed to set user data in CVODE memory"); + } + + fn_table->fB = nb::cast(fB); + return cv_status; + }); + + m.def("CVodeQuadInitB", + [](void* cv_mem, int which, + std::function> fQB, N_Vector yQBO) + { + auto fn_table = get_cvodea_fn_table(cv_mem, which); + fn_table->fQB = nb::cast(fQB); + return CVodeQuadInitB(cv_mem, which, cvode_fQB_wrapper, yQBO); + }); + + BIND_CVODEB_CALLBACK(CVodeSetJacFnB, CVLsJacFnB, lsjacfnB, + cvode_lsjacfnB_wrapper, nb::arg("cv_mem"), + nb::arg("which"), nb::arg("jacB").none()); + + BIND_CVODEB_CALLBACK2(CVodeSetPreconditionerB, CVLsPrecSetupFnB, lsprecsetupfnB, + cvode_lsprecsetupfnB_wrapper, CVLsPrecSolveFnB, + lsprecsolvefnB, cvode_lsprecsolvefnB_wrapper, + nb::arg("cv_mem"), nb::arg("which"), + nb::arg("psetB").none(), nb::arg("psolveB").none()); + + BIND_CVODEB_CALLBACK2(CVodeSetJacTimesB, CVLsJacTimesSetupFnB, + lsjactimessetupfnB, cvode_lsjactimessetupfnB_wrapper, + CVLsJacTimesVecFnB, lsjactimesvecfnB, + cvode_lsjactimesvecfnB_wrapper, nb::arg("cv_mem"), + nb::arg("which"), nb::arg("jsetupB").none(), + nb::arg("jtimesB").none()); + + BIND_CVODEB_CALLBACK(CVodeSetLinSysFnB, CVLsLinSysFnB, lslinsysfnB, + cvode_lslinsysfnB_wrapper, nb::arg("cv_mem"), + nb::arg("which"), nb::arg("linsysB").none()); + + m.def("CVodeInitBS", + [](void* cv_mem, int which, std::function fBS, + sunrealtype tB0, N_Vector yB0) + { + int cv_status = CVodeInitBS(cv_mem, which, cvode_fBS_wrapper, tB0, yB0); + + auto fn_table = cvodea_user_supplied_fn_table_alloc(); + auto cvb_mem = + static_cast(CVodeGetAdjCVodeBmem(cv_mem, which)); + cvb_mem->python = fn_table; + + cv_status = CVodeSetUserDataB(cv_mem, which, cvb_mem); + if (cv_status != CV_SUCCESS) + { + free(fn_table); + throw sundials4py::error_returned( + "Failed to set user data in CVODE memory"); + } + + fn_table->fBS = nb::cast(fBS); + return cv_status; + }); + + m.def("CVodeQuadInitBS", + [](void* cv_mem, int which, std::function fQBS, + N_Vector yQBO) + { + auto fn_table = get_cvodea_fn_table(cv_mem, which); + fn_table->fQBS = nb::cast(fQBS); + return CVodeQuadInitBS(cv_mem, which, cvode_fQBS_wrapper, yQBO); + }); + + BIND_CVODEB_CALLBACK(CVodeSetJacFnBS, CVLsJacStdFnBS, lsjacfnBS, + cvode_lsjacfnBS_wrapper, nb::arg("cv_mem"), + nb::arg("which"), nb::arg("jacBS").none()); + + BIND_CVODEB_CALLBACK2(CVodeSetPreconditionerBS, CVLsPrecSetupStdFnBS, + lsprecsetupfnBS, cvode_lsprecsetupfnBS_wrapper, + CVLsPrecSolveStdFnBS, lsprecsolvefnBS, + cvode_lsprecsolvefnBS_wrapper, nb::arg("cv_mem"), + nb::arg("which"), nb::arg("psetBS").none(), + nb::arg("psolveBS").none()); + + BIND_CVODEB_CALLBACK2(CVodeSetJacTimesBS, CVLsJacTimesSetupStdFnBS, + lsjactimessetupfnBS, cvode_lsjactimessetupfnBS_wrapper, + CVLsJacTimesVecStdFnBS, lsjactimesvecfnBS, + cvode_lsjactimesvecfnBS_wrapper, nb::arg("cv_mem"), + nb::arg("which"), nb::arg("jsetupBS").none(), + nb::arg("jtimesBS").none()); + + BIND_CVODEB_CALLBACK(CVodeSetLinSysFnBS, CVLsLinSysStdFnBS, lslinsysfnBS, + cvode_lslinsysfnBS_wrapper, nb::arg("cv_mem"), + nb::arg("which"), nb::arg("linsysBS").none()); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/cvodes/cvodes_generated.hpp b/bindings/sundials4py/cvodes/cvodes_generated.hpp new file mode 100644 index 0000000000..70a5acad18 --- /dev/null +++ b/bindings/sundials4py/cvodes/cvodes_generated.hpp @@ -0,0 +1,2035 @@ +// #ifndef _CVODES_H +// +// #ifdef __cplusplus +// #endif +// +m.attr("CV_ADAMS") = 1; +m.attr("CV_BDF") = 2; +m.attr("CV_NORMAL") = 1; +m.attr("CV_ONE_STEP") = 2; +m.attr("CV_SIMULTANEOUS") = 1; +m.attr("CV_STAGGERED") = 2; +m.attr("CV_STAGGERED1") = 3; +m.attr("CV_CENTERED") = 1; +m.attr("CV_FORWARD") = 2; +m.attr("CV_HERMITE") = 1; +m.attr("CV_POLYNOMIAL") = 2; +m.attr("CV_SUCCESS") = 0; +m.attr("CV_TSTOP_RETURN") = 1; +m.attr("CV_ROOT_RETURN") = 2; +m.attr("CV_WARNING") = 99; +m.attr("CV_TOO_MUCH_WORK") = -1; +m.attr("CV_TOO_MUCH_ACC") = -2; +m.attr("CV_ERR_FAILURE") = -3; +m.attr("CV_CONV_FAILURE") = -4; +m.attr("CV_LINIT_FAIL") = -5; +m.attr("CV_LSETUP_FAIL") = -6; +m.attr("CV_LSOLVE_FAIL") = -7; +m.attr("CV_RHSFUNC_FAIL") = -8; +m.attr("CV_FIRST_RHSFUNC_ERR") = -9; +m.attr("CV_REPTD_RHSFUNC_ERR") = -10; +m.attr("CV_UNREC_RHSFUNC_ERR") = -11; +m.attr("CV_RTFUNC_FAIL") = -12; +m.attr("CV_NLS_INIT_FAIL") = -13; +m.attr("CV_NLS_SETUP_FAIL") = -14; +m.attr("CV_CONSTR_FAIL") = -15; +m.attr("CV_NLS_FAIL") = -16; +m.attr("CV_MEM_FAIL") = -20; +m.attr("CV_MEM_NULL") = -21; +m.attr("CV_ILL_INPUT") = -22; +m.attr("CV_NO_MALLOC") = -23; +m.attr("CV_BAD_K") = -24; +m.attr("CV_BAD_T") = -25; +m.attr("CV_BAD_DKY") = -26; +m.attr("CV_TOO_CLOSE") = -27; +m.attr("CV_VECTOROP_ERR") = -28; +m.attr("CV_NO_QUAD") = -30; +m.attr("CV_QRHSFUNC_FAIL") = -31; +m.attr("CV_FIRST_QRHSFUNC_ERR") = -32; +m.attr("CV_REPTD_QRHSFUNC_ERR") = -33; +m.attr("CV_UNREC_QRHSFUNC_ERR") = -34; +m.attr("CV_NO_SENS") = -40; +m.attr("CV_SRHSFUNC_FAIL") = -41; +m.attr("CV_FIRST_SRHSFUNC_ERR") = -42; +m.attr("CV_REPTD_SRHSFUNC_ERR") = -43; +m.attr("CV_UNREC_SRHSFUNC_ERR") = -44; +m.attr("CV_BAD_IS") = -45; +m.attr("CV_NO_QUADSENS") = -50; +m.attr("CV_QSRHSFUNC_FAIL") = -51; +m.attr("CV_FIRST_QSRHSFUNC_ERR") = -52; +m.attr("CV_REPTD_QSRHSFUNC_ERR") = -53; +m.attr("CV_UNREC_QSRHSFUNC_ERR") = -54; +m.attr("CV_CONTEXT_ERR") = -55; +m.attr("CV_PROJ_MEM_NULL") = -56; +m.attr("CV_PROJFUNC_FAIL") = -57; +m.attr("CV_REPTD_PROJFUNC_ERR") = -58; +m.attr("CV_BAD_TINTERP") = -59; +m.attr("CV_UNRECOGNIZED_ERR") = -99; +m.attr("CV_NO_ADJ") = -101; +m.attr("CV_NO_FWD") = -102; +m.attr("CV_NO_BCK") = -103; +m.attr("CV_BAD_TB0") = -104; +m.attr("CV_REIFWD_FAIL") = -105; +m.attr("CV_FWD_FAIL") = -106; +m.attr("CV_GETY_BADT") = -107; + +m.def("CVodeReInit", CVodeReInit, nb::arg("cvode_mem"), nb::arg("t0"), + nb::arg("y0")); + +m.def( + "CVodeResizeHistory", + [](void* cvode_mem, sundials4py::Array1d t_hist_1d, + std::vector y_hist_1d, std::vector f_hist_1d, + int num_y_hist, int num_f_hist) -> int + { + auto CVodeResizeHistory_adapt_arr_ptr_to_std_vector = + [](void* cvode_mem, sundials4py::Array1d t_hist_1d, + std::vector y_hist_1d, std::vector f_hist_1d, + int num_y_hist, int num_f_hist) -> int + { + sunrealtype* t_hist_1d_ptr = + reinterpret_cast(t_hist_1d.data()); + N_Vector* y_hist_1d_ptr = reinterpret_cast( + y_hist_1d.empty() ? nullptr : y_hist_1d.data()); + N_Vector* f_hist_1d_ptr = reinterpret_cast( + f_hist_1d.empty() ? nullptr : f_hist_1d.data()); + + auto lambda_result = CVodeResizeHistory(cvode_mem, t_hist_1d_ptr, + y_hist_1d_ptr, f_hist_1d_ptr, + num_y_hist, num_f_hist); + return lambda_result; + }; + + return CVodeResizeHistory_adapt_arr_ptr_to_std_vector(cvode_mem, t_hist_1d, + y_hist_1d, f_hist_1d, + num_y_hist, num_f_hist); + }, + nb::arg("cvode_mem"), nb::arg("t_hist_1d"), nb::arg("y_hist_1d"), + nb::arg("f_hist_1d"), nb::arg("num_y_hist"), nb::arg("num_f_hist")); + +m.def("CVodeSStolerances", CVodeSStolerances, nb::arg("cvode_mem"), + nb::arg("reltol"), nb::arg("abstol")); + +m.def("CVodeSVtolerances", CVodeSVtolerances, nb::arg("cvode_mem"), + nb::arg("reltol"), nb::arg("abstol")); + +m.def("CVodeSetConstraints", CVodeSetConstraints, nb::arg("cvode_mem"), + nb::arg("constraints")); + +m.def("CVodeSetDeltaGammaMaxLSetup", CVodeSetDeltaGammaMaxLSetup, + nb::arg("cvode_mem"), nb::arg("dgmax_lsetup")); + +m.def("CVodeSetInitStep", CVodeSetInitStep, nb::arg("cvode_mem"), nb::arg("hin")); + +m.def("CVodeSetLSetupFrequency", CVodeSetLSetupFrequency, nb::arg("cvode_mem"), + nb::arg("msbp")); + +m.def("CVodeSetMaxConvFails", CVodeSetMaxConvFails, nb::arg("cvode_mem"), + nb::arg("maxncf")); + +m.def("CVodeSetMaxErrTestFails", CVodeSetMaxErrTestFails, nb::arg("cvode_mem"), + nb::arg("maxnef")); + +m.def("CVodeSetMaxHnilWarns", CVodeSetMaxHnilWarns, nb::arg("cvode_mem"), + nb::arg("mxhnil")); + +m.def("CVodeSetMaxNonlinIters", CVodeSetMaxNonlinIters, nb::arg("cvode_mem"), + nb::arg("maxcor")); + +m.def("CVodeSetMaxNumSteps", CVodeSetMaxNumSteps, nb::arg("cvode_mem"), + nb::arg("mxsteps")); + +m.def("CVodeSetMaxOrd", CVodeSetMaxOrd, nb::arg("cvode_mem"), nb::arg("maxord")); + +m.def("CVodeSetMaxStep", CVodeSetMaxStep, nb::arg("cvode_mem"), nb::arg("hmax")); + +m.def("CVodeSetMinStep", CVodeSetMinStep, nb::arg("cvode_mem"), nb::arg("hmin")); + +m.def("CVodeSetMonitorFrequency", CVodeSetMonitorFrequency, + nb::arg("cvode_mem"), nb::arg("nst")); + +m.def("CVodeSetNonlinConvCoef", CVodeSetNonlinConvCoef, nb::arg("cvode_mem"), + nb::arg("nlscoef")); + +m.def("CVodeSetNonlinearSolver", CVodeSetNonlinearSolver, nb::arg("cvode_mem"), + nb::arg("NLS")); + +m.def("CVodeSetStabLimDet", CVodeSetStabLimDet, nb::arg("cvode_mem"), + nb::arg("stldet")); + +m.def("CVodeSetStopTime", CVodeSetStopTime, nb::arg("cvode_mem"), + nb::arg("tstop")); + +m.def("CVodeSetInterpolateStopTime", CVodeSetInterpolateStopTime, + nb::arg("cvode_mem"), nb::arg("interp")); + +m.def("CVodeClearStopTime", CVodeClearStopTime, nb::arg("cvode_mem")); + +m.def("CVodeSetEtaFixedStepBounds", CVodeSetEtaFixedStepBounds, + nb::arg("cvode_mem"), nb::arg("eta_min_fx"), nb::arg("eta_max_fx")); + +m.def("CVodeSetEtaMaxFirstStep", CVodeSetEtaMaxFirstStep, nb::arg("cvode_mem"), + nb::arg("eta_max_fs")); + +m.def("CVodeSetEtaMaxEarlyStep", CVodeSetEtaMaxEarlyStep, nb::arg("cvode_mem"), + nb::arg("eta_max_es")); + +m.def("CVodeSetNumStepsEtaMaxEarlyStep", CVodeSetNumStepsEtaMaxEarlyStep, + nb::arg("cvode_mem"), nb::arg("small_nst")); + +m.def("CVodeSetEtaMax", CVodeSetEtaMax, nb::arg("cvode_mem"), + nb::arg("eta_max_gs")); + +m.def("CVodeSetEtaMin", CVodeSetEtaMin, nb::arg("cvode_mem"), nb::arg("eta_min")); + +m.def("CVodeSetEtaMinErrFail", CVodeSetEtaMinErrFail, nb::arg("cvode_mem"), + nb::arg("eta_min_ef")); + +m.def("CVodeSetEtaMaxErrFail", CVodeSetEtaMaxErrFail, nb::arg("cvode_mem"), + nb::arg("eta_max_ef")); + +m.def("CVodeSetNumFailsEtaMaxErrFail", CVodeSetNumFailsEtaMaxErrFail, + nb::arg("cvode_mem"), nb::arg("small_nef")); + +m.def("CVodeSetEtaConvFail", CVodeSetEtaConvFail, nb::arg("cvode_mem"), + nb::arg("eta_cf")); + +m.def( + "CVodeSetRootDirection", + [](void* cvode_mem) -> std::tuple + { + auto CVodeSetRootDirection_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + int rootdir_adapt_modifiable; + + int r = CVodeSetRootDirection(cvode_mem, &rootdir_adapt_modifiable); + return std::make_tuple(r, rootdir_adapt_modifiable); + }; + + return CVodeSetRootDirection_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def("CVodeSetNoInactiveRootWarn", CVodeSetNoInactiveRootWarn, + nb::arg("cvode_mem")); + +m.def( + "CVode", + [](void* cvode_mem, sunrealtype tout, N_Vector yout, + int itask) -> std::tuple + { + auto CVode_adapt_modifiable_immutable_to_return = + [](void* cvode_mem, sunrealtype tout, N_Vector yout, + int itask) -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + int r = CVode(cvode_mem, tout, yout, &tret_adapt_modifiable, itask); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return CVode_adapt_modifiable_immutable_to_return(cvode_mem, tout, yout, + itask); + }, + nb::arg("cvode_mem"), nb::arg("tout"), nb::arg("yout"), nb::arg("itask"), + "Solver function"); + +m.def("CVodeComputeState", CVodeComputeState, nb::arg("cvode_mem"), + nb::arg("ycor"), nb::arg("y")); + +m.def( + "CVodeComputeStateSens", + [](void* cvode_mem, std::vector yScor_1d, + std::vector yS_1d) -> int + { + auto CVodeComputeStateSens_adapt_arr_ptr_to_std_vector = + [](void* cvode_mem, std::vector yScor_1d, + std::vector yS_1d) -> int + { + N_Vector* yScor_1d_ptr = reinterpret_cast( + yScor_1d.empty() ? nullptr : yScor_1d.data()); + N_Vector* yS_1d_ptr = + reinterpret_cast(yS_1d.empty() ? nullptr : yS_1d.data()); + + auto lambda_result = CVodeComputeStateSens(cvode_mem, yScor_1d_ptr, + yS_1d_ptr); + return lambda_result; + }; + + return CVodeComputeStateSens_adapt_arr_ptr_to_std_vector(cvode_mem, + yScor_1d, yS_1d); + }, + nb::arg("cvode_mem"), nb::arg("yScor_1d"), nb::arg("yS_1d")); + +m.def("CVodeComputeStateSens1", CVodeComputeStateSens1, nb::arg("cvode_mem"), + nb::arg("idx"), nb::arg("yScor1"), nb::arg("yS1")); + +m.def("CVodeGetDky", CVodeGetDky, nb::arg("cvode_mem"), nb::arg("t"), + nb::arg("k"), nb::arg("dky"), "Dense output function"); + +m.def( + "CVodeGetNumSteps", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumSteps_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nsteps_adapt_modifiable; + + int r = CVodeGetNumSteps(cvode_mem, &nsteps_adapt_modifiable); + return std::make_tuple(r, nsteps_adapt_modifiable); + }; + + return CVodeGetNumSteps_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumRhsEvals", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumRhsEvals_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nfevals_adapt_modifiable; + + int r = CVodeGetNumRhsEvals(cvode_mem, &nfevals_adapt_modifiable); + return std::make_tuple(r, nfevals_adapt_modifiable); + }; + + return CVodeGetNumRhsEvals_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumLinSolvSetups", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumLinSolvSetups_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nlinsetups_adapt_modifiable; + + int r = CVodeGetNumLinSolvSetups(cvode_mem, &nlinsetups_adapt_modifiable); + return std::make_tuple(r, nlinsetups_adapt_modifiable); + }; + + return CVodeGetNumLinSolvSetups_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumErrTestFails", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumErrTestFails_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long netfails_adapt_modifiable; + + int r = CVodeGetNumErrTestFails(cvode_mem, &netfails_adapt_modifiable); + return std::make_tuple(r, netfails_adapt_modifiable); + }; + + return CVodeGetNumErrTestFails_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetLastOrder", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetLastOrder_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + int qlast_adapt_modifiable; + + int r = CVodeGetLastOrder(cvode_mem, &qlast_adapt_modifiable); + return std::make_tuple(r, qlast_adapt_modifiable); + }; + + return CVodeGetLastOrder_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetCurrentOrder", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetCurrentOrder_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + int qcur_adapt_modifiable; + + int r = CVodeGetCurrentOrder(cvode_mem, &qcur_adapt_modifiable); + return std::make_tuple(r, qcur_adapt_modifiable); + }; + + return CVodeGetCurrentOrder_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetCurrentGamma", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetCurrentGamma_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + sunrealtype gamma_adapt_modifiable; + + int r = CVodeGetCurrentGamma(cvode_mem, &gamma_adapt_modifiable); + return std::make_tuple(r, gamma_adapt_modifiable); + }; + + return CVodeGetCurrentGamma_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumStabLimOrderReds", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumStabLimOrderReds_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nslred_adapt_modifiable; + + int r = CVodeGetNumStabLimOrderReds(cvode_mem, &nslred_adapt_modifiable); + return std::make_tuple(r, nslred_adapt_modifiable); + }; + + return CVodeGetNumStabLimOrderReds_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetActualInitStep", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetActualInitStep_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + sunrealtype hinused_adapt_modifiable; + + int r = CVodeGetActualInitStep(cvode_mem, &hinused_adapt_modifiable); + return std::make_tuple(r, hinused_adapt_modifiable); + }; + + return CVodeGetActualInitStep_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetLastStep", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetLastStep_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + sunrealtype hlast_adapt_modifiable; + + int r = CVodeGetLastStep(cvode_mem, &hlast_adapt_modifiable); + return std::make_tuple(r, hlast_adapt_modifiable); + }; + + return CVodeGetLastStep_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetCurrentStep", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetCurrentStep_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + sunrealtype hcur_adapt_modifiable; + + int r = CVodeGetCurrentStep(cvode_mem, &hcur_adapt_modifiable); + return std::make_tuple(r, hcur_adapt_modifiable); + }; + + return CVodeGetCurrentStep_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetCurrentState", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetCurrentState_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + N_Vector y_adapt_modifiable; + + int r = CVodeGetCurrentState(cvode_mem, &y_adapt_modifiable); + return std::make_tuple(r, y_adapt_modifiable); + }; + + return CVodeGetCurrentState_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem"), "nb::rv_policy::reference", nb::rv_policy::reference); + +m.def( + "CVodeGetCurrentSensSolveIndex", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetCurrentSensSolveIndex_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + int index_adapt_modifiable; + + int r = CVodeGetCurrentSensSolveIndex(cvode_mem, &index_adapt_modifiable); + return std::make_tuple(r, index_adapt_modifiable); + }; + + return CVodeGetCurrentSensSolveIndex_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetCurrentTime", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetCurrentTime_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + sunrealtype tcur_adapt_modifiable; + + int r = CVodeGetCurrentTime(cvode_mem, &tcur_adapt_modifiable); + return std::make_tuple(r, tcur_adapt_modifiable); + }; + + return CVodeGetCurrentTime_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetTolScaleFactor", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetTolScaleFactor_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + sunrealtype tolsfac_adapt_modifiable; + + int r = CVodeGetTolScaleFactor(cvode_mem, &tolsfac_adapt_modifiable); + return std::make_tuple(r, tolsfac_adapt_modifiable); + }; + + return CVodeGetTolScaleFactor_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def("CVodeGetErrWeights", CVodeGetErrWeights, nb::arg("cvode_mem"), + nb::arg("eweight")); + +m.def("CVodeGetEstLocalErrors", CVodeGetEstLocalErrors, nb::arg("cvode_mem"), + nb::arg("ele")); + +m.def( + "CVodeGetNumGEvals", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumGEvals_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long ngevals_adapt_modifiable; + + int r = CVodeGetNumGEvals(cvode_mem, &ngevals_adapt_modifiable); + return std::make_tuple(r, ngevals_adapt_modifiable); + }; + + return CVodeGetNumGEvals_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetRootInfo", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetRootInfo_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + int rootsfound_adapt_modifiable; + + int r = CVodeGetRootInfo(cvode_mem, &rootsfound_adapt_modifiable); + return std::make_tuple(r, rootsfound_adapt_modifiable); + }; + + return CVodeGetRootInfo_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetIntegratorStats", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetIntegratorStats_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) + -> std::tuple + { + long nsteps_adapt_modifiable; + long nfevals_adapt_modifiable; + long nlinsetups_adapt_modifiable; + long netfails_adapt_modifiable; + int qlast_adapt_modifiable; + int qcur_adapt_modifiable; + sunrealtype hinused_adapt_modifiable; + sunrealtype hlast_adapt_modifiable; + sunrealtype hcur_adapt_modifiable; + sunrealtype tcur_adapt_modifiable; + + int r = + CVodeGetIntegratorStats(cvode_mem, &nsteps_adapt_modifiable, + &nfevals_adapt_modifiable, + &nlinsetups_adapt_modifiable, + &netfails_adapt_modifiable, + &qlast_adapt_modifiable, &qcur_adapt_modifiable, + &hinused_adapt_modifiable, + &hlast_adapt_modifiable, &hcur_adapt_modifiable, + &tcur_adapt_modifiable); + return std::make_tuple(r, nsteps_adapt_modifiable, nfevals_adapt_modifiable, + nlinsetups_adapt_modifiable, + netfails_adapt_modifiable, qlast_adapt_modifiable, + qcur_adapt_modifiable, hinused_adapt_modifiable, + hlast_adapt_modifiable, hcur_adapt_modifiable, + tcur_adapt_modifiable); + }; + + return CVodeGetIntegratorStats_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumNonlinSolvIters", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumNonlinSolvIters_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nniters_adapt_modifiable; + + int r = CVodeGetNumNonlinSolvIters(cvode_mem, &nniters_adapt_modifiable); + return std::make_tuple(r, nniters_adapt_modifiable); + }; + + return CVodeGetNumNonlinSolvIters_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumNonlinSolvConvFails", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumNonlinSolvConvFails_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nnfails_adapt_modifiable; + + int r = CVodeGetNumNonlinSolvConvFails(cvode_mem, + &nnfails_adapt_modifiable); + return std::make_tuple(r, nnfails_adapt_modifiable); + }; + + return CVodeGetNumNonlinSolvConvFails_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNonlinSolvStats", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNonlinSolvStats_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nniters_adapt_modifiable; + long nnfails_adapt_modifiable; + + int r = CVodeGetNonlinSolvStats(cvode_mem, &nniters_adapt_modifiable, + &nnfails_adapt_modifiable); + return std::make_tuple(r, nniters_adapt_modifiable, + nnfails_adapt_modifiable); + }; + + return CVodeGetNonlinSolvStats_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumStepSolveFails", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumStepSolveFails_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nncfails_adapt_modifiable; + + int r = CVodeGetNumStepSolveFails(cvode_mem, &nncfails_adapt_modifiable); + return std::make_tuple(r, nncfails_adapt_modifiable); + }; + + return CVodeGetNumStepSolveFails_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def("CVodePrintAllStats", CVodePrintAllStats, nb::arg("cvode_mem"), + nb::arg("outfile"), nb::arg("fmt")); + +m.def("CVodeGetReturnFlagName", CVodeGetReturnFlagName, nb::arg("flag")); + +m.def("CVodeQuadReInit", CVodeQuadReInit, nb::arg("cvode_mem"), nb::arg("yQ0")); + +m.def("CVodeQuadSStolerances", CVodeQuadSStolerances, nb::arg("cvode_mem"), + nb::arg("reltolQ"), nb::arg("abstolQ")); + +m.def("CVodeQuadSVtolerances", CVodeQuadSVtolerances, nb::arg("cvode_mem"), + nb::arg("reltolQ"), nb::arg("abstolQ")); + +m.def("CVodeSetQuadErrCon", CVodeSetQuadErrCon, nb::arg("cvode_mem"), + nb::arg("errconQ"), "Optional input specification functions"); + +m.def( + "CVodeGetQuad", + [](void* cvode_mem, N_Vector yQout) -> std::tuple + { + auto CVodeGetQuad_adapt_modifiable_immutable_to_return = + [](void* cvode_mem, N_Vector yQout) -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + int r = CVodeGetQuad(cvode_mem, &tret_adapt_modifiable, yQout); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return CVodeGetQuad_adapt_modifiable_immutable_to_return(cvode_mem, yQout); + }, + nb::arg("cvode_mem"), nb::arg("yQout")); + +m.def("CVodeGetQuadDky", CVodeGetQuadDky, nb::arg("cvode_mem"), nb::arg("t"), + nb::arg("k"), nb::arg("dky")); + +m.def( + "CVodeGetQuadNumRhsEvals", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetQuadNumRhsEvals_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nfQevals_adapt_modifiable; + + int r = CVodeGetQuadNumRhsEvals(cvode_mem, &nfQevals_adapt_modifiable); + return std::make_tuple(r, nfQevals_adapt_modifiable); + }; + + return CVodeGetQuadNumRhsEvals_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetQuadNumErrTestFails", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetQuadNumErrTestFails_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nQetfails_adapt_modifiable; + + int r = CVodeGetQuadNumErrTestFails(cvode_mem, &nQetfails_adapt_modifiable); + return std::make_tuple(r, nQetfails_adapt_modifiable); + }; + + return CVodeGetQuadNumErrTestFails_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def("CVodeGetQuadErrWeights", CVodeGetQuadErrWeights, nb::arg("cvode_mem"), + nb::arg("eQweight")); + +m.def( + "CVodeGetQuadStats", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetQuadStats_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nfQevals_adapt_modifiable; + long nQetfails_adapt_modifiable; + + int r = CVodeGetQuadStats(cvode_mem, &nfQevals_adapt_modifiable, + &nQetfails_adapt_modifiable); + return std::make_tuple(r, nfQevals_adapt_modifiable, + nQetfails_adapt_modifiable); + }; + + return CVodeGetQuadStats_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeSensReInit", + [](void* cvode_mem, int ism, std::vector yS0_1d) -> int + { + auto CVodeSensReInit_adapt_arr_ptr_to_std_vector = + [](void* cvode_mem, int ism, std::vector yS0_1d) -> int + { + N_Vector* yS0_1d_ptr = + reinterpret_cast(yS0_1d.empty() ? nullptr : yS0_1d.data()); + + auto lambda_result = CVodeSensReInit(cvode_mem, ism, yS0_1d_ptr); + return lambda_result; + }; + + return CVodeSensReInit_adapt_arr_ptr_to_std_vector(cvode_mem, ism, yS0_1d); + }, + nb::arg("cvode_mem"), nb::arg("ism"), nb::arg("yS0_1d")); + +m.def( + "CVodeSensSStolerances", + [](void* cvode_mem, sunrealtype reltolS, sundials4py::Array1d abstolS_1d) -> int + { + auto CVodeSensSStolerances_adapt_arr_ptr_to_std_vector = + [](void* cvode_mem, sunrealtype reltolS, + sundials4py::Array1d abstolS_1d) -> int + { + sunrealtype* abstolS_1d_ptr = + reinterpret_cast(abstolS_1d.data()); + + auto lambda_result = CVodeSensSStolerances(cvode_mem, reltolS, + abstolS_1d_ptr); + return lambda_result; + }; + + return CVodeSensSStolerances_adapt_arr_ptr_to_std_vector(cvode_mem, reltolS, + abstolS_1d); + }, + nb::arg("cvode_mem"), nb::arg("reltolS"), nb::arg("abstolS_1d")); + +m.def( + "CVodeSensSVtolerances", + [](void* cvode_mem, sunrealtype reltolS, std::vector abstolS_1d) -> int + { + auto CVodeSensSVtolerances_adapt_arr_ptr_to_std_vector = + [](void* cvode_mem, sunrealtype reltolS, + std::vector abstolS_1d) -> int + { + N_Vector* abstolS_1d_ptr = reinterpret_cast( + abstolS_1d.empty() ? nullptr : abstolS_1d.data()); + + auto lambda_result = CVodeSensSVtolerances(cvode_mem, reltolS, + abstolS_1d_ptr); + return lambda_result; + }; + + return CVodeSensSVtolerances_adapt_arr_ptr_to_std_vector(cvode_mem, reltolS, + abstolS_1d); + }, + nb::arg("cvode_mem"), nb::arg("reltolS"), nb::arg("abstolS_1d")); + +m.def("CVodeSensEEtolerances", CVodeSensEEtolerances, nb::arg("cvode_mem")); + +m.def("CVodeSetSensDQMethod", CVodeSetSensDQMethod, nb::arg("cvode_mem"), + nb::arg("DQtype"), nb::arg("DQrhomax")); + +m.def("CVodeSetSensErrCon", CVodeSetSensErrCon, nb::arg("cvode_mem"), + nb::arg("errconS")); + +m.def("CVodeSetSensMaxNonlinIters", CVodeSetSensMaxNonlinIters, + nb::arg("cvode_mem"), nb::arg("maxcorS")); + +m.def( + "CVodeSetSensParams", + [](void* cvode_mem, sundials4py::Array1d p_1d, sundials4py::Array1d pbar_1d, + std::vector plist_1d) -> int + { + auto CVodeSetSensParams_adapt_arr_ptr_to_std_vector = + [](void* cvode_mem, sundials4py::Array1d p_1d, + sundials4py::Array1d pbar_1d, std::vector plist_1d) -> int + { + sunrealtype* p_1d_ptr = reinterpret_cast(p_1d.data()); + sunrealtype* pbar_1d_ptr = reinterpret_cast(pbar_1d.data()); + int* plist_1d_ptr = + reinterpret_cast(plist_1d.empty() ? nullptr : plist_1d.data()); + + auto lambda_result = CVodeSetSensParams(cvode_mem, p_1d_ptr, pbar_1d_ptr, + plist_1d_ptr); + return lambda_result; + }; + + return CVodeSetSensParams_adapt_arr_ptr_to_std_vector(cvode_mem, p_1d, + pbar_1d, plist_1d); + }, + nb::arg("cvode_mem"), nb::arg("p_1d"), nb::arg("pbar_1d"), nb::arg("plist_1d")); + +m.def("CVodeSetNonlinearSolverSensSim", CVodeSetNonlinearSolverSensSim, + nb::arg("cvode_mem"), nb::arg("NLS")); + +m.def("CVodeSetNonlinearSolverSensStg", CVodeSetNonlinearSolverSensStg, + nb::arg("cvode_mem"), nb::arg("NLS")); + +m.def("CVodeSetNonlinearSolverSensStg1", CVodeSetNonlinearSolverSensStg1, + nb::arg("cvode_mem"), nb::arg("NLS")); + +m.def("CVodeSensToggleOff", CVodeSensToggleOff, nb::arg("cvode_mem"), + "Enable/disable sensitivities"); + +m.def( + "CVodeGetSens", + [](void* cvode_mem, + std::vector ySout_1d) -> std::tuple + { + auto CVodeGetSens_adapt_arr_ptr_to_std_vector = + [](void* cvode_mem, sunrealtype* tret, std::vector ySout_1d) -> int + { + N_Vector* ySout_1d_ptr = reinterpret_cast( + ySout_1d.empty() ? nullptr : ySout_1d.data()); + + auto lambda_result = CVodeGetSens(cvode_mem, tret, ySout_1d_ptr); + return lambda_result; + }; + auto CVodeGetSens_adapt_modifiable_immutable_to_return = + [&CVodeGetSens_adapt_arr_ptr_to_std_vector](void* cvode_mem, + std::vector ySout_1d) + -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + int r = CVodeGetSens_adapt_arr_ptr_to_std_vector(cvode_mem, + &tret_adapt_modifiable, + ySout_1d); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return CVodeGetSens_adapt_modifiable_immutable_to_return(cvode_mem, ySout_1d); + }, + nb::arg("cvode_mem"), nb::arg("ySout_1d")); + +m.def( + "CVodeGetSens1", + [](void* cvode_mem, int is, N_Vector ySout) -> std::tuple + { + auto CVodeGetSens1_adapt_modifiable_immutable_to_return = + [](void* cvode_mem, int is, N_Vector ySout) -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + int r = CVodeGetSens1(cvode_mem, &tret_adapt_modifiable, is, ySout); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return CVodeGetSens1_adapt_modifiable_immutable_to_return(cvode_mem, is, + ySout); + }, + nb::arg("cvode_mem"), nb::arg("is_"), nb::arg("ySout")); + +m.def( + "CVodeGetSensDky", + [](void* cvode_mem, sunrealtype t, int k, std::vector dkyA_1d) -> int + { + auto CVodeGetSensDky_adapt_arr_ptr_to_std_vector = + [](void* cvode_mem, sunrealtype t, int k, + std::vector dkyA_1d) -> int + { + N_Vector* dkyA_1d_ptr = + reinterpret_cast(dkyA_1d.empty() ? nullptr : dkyA_1d.data()); + + auto lambda_result = CVodeGetSensDky(cvode_mem, t, k, dkyA_1d_ptr); + return lambda_result; + }; + + return CVodeGetSensDky_adapt_arr_ptr_to_std_vector(cvode_mem, t, k, dkyA_1d); + }, + nb::arg("cvode_mem"), nb::arg("t"), nb::arg("k"), nb::arg("dkyA_1d")); + +m.def("CVodeGetSensDky1", CVodeGetSensDky1, nb::arg("cvode_mem"), nb::arg("t"), + nb::arg("k"), nb::arg("is_"), nb::arg("dky")); + +m.def( + "CVodeGetSensNumRhsEvals", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetSensNumRhsEvals_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nfSevals_adapt_modifiable; + + int r = CVodeGetSensNumRhsEvals(cvode_mem, &nfSevals_adapt_modifiable); + return std::make_tuple(r, nfSevals_adapt_modifiable); + }; + + return CVodeGetSensNumRhsEvals_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumRhsEvalsSens", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumRhsEvalsSens_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nfevalsS_adapt_modifiable; + + int r = CVodeGetNumRhsEvalsSens(cvode_mem, &nfevalsS_adapt_modifiable); + return std::make_tuple(r, nfevalsS_adapt_modifiable); + }; + + return CVodeGetNumRhsEvalsSens_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetSensNumErrTestFails", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetSensNumErrTestFails_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nSetfails_adapt_modifiable; + + int r = CVodeGetSensNumErrTestFails(cvode_mem, &nSetfails_adapt_modifiable); + return std::make_tuple(r, nSetfails_adapt_modifiable); + }; + + return CVodeGetSensNumErrTestFails_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetSensNumLinSolvSetups", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetSensNumLinSolvSetups_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nlinsetupsS_adapt_modifiable; + + int r = CVodeGetSensNumLinSolvSetups(cvode_mem, + &nlinsetupsS_adapt_modifiable); + return std::make_tuple(r, nlinsetupsS_adapt_modifiable); + }; + + return CVodeGetSensNumLinSolvSetups_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetSensErrWeights", + [](void* cvode_mem, std::vector eSweight_1d) -> int + { + auto CVodeGetSensErrWeights_adapt_arr_ptr_to_std_vector = + [](void* cvode_mem, std::vector eSweight_1d) -> int + { + N_Vector* eSweight_1d_ptr = reinterpret_cast( + eSweight_1d.empty() ? nullptr : eSweight_1d.data()); + + auto lambda_result = CVodeGetSensErrWeights(cvode_mem, eSweight_1d_ptr); + return lambda_result; + }; + + return CVodeGetSensErrWeights_adapt_arr_ptr_to_std_vector(cvode_mem, + eSweight_1d); + }, + nb::arg("cvode_mem"), nb::arg("eSweight_1d")); + +m.def( + "CVodeGetSensStats", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetSensStats_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nfSevals_adapt_modifiable; + long nfevalsS_adapt_modifiable; + long nSetfails_adapt_modifiable; + long nlinsetupsS_adapt_modifiable; + + int r = CVodeGetSensStats(cvode_mem, &nfSevals_adapt_modifiable, + &nfevalsS_adapt_modifiable, + &nSetfails_adapt_modifiable, + &nlinsetupsS_adapt_modifiable); + return std::make_tuple(r, nfSevals_adapt_modifiable, + nfevalsS_adapt_modifiable, + nSetfails_adapt_modifiable, + nlinsetupsS_adapt_modifiable); + }; + + return CVodeGetSensStats_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetSensNumNonlinSolvIters", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetSensNumNonlinSolvIters_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nSniters_adapt_modifiable; + + int r = CVodeGetSensNumNonlinSolvIters(cvode_mem, + &nSniters_adapt_modifiable); + return std::make_tuple(r, nSniters_adapt_modifiable); + }; + + return CVodeGetSensNumNonlinSolvIters_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetSensNumNonlinSolvConvFails", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetSensNumNonlinSolvConvFails_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nSnfails_adapt_modifiable; + + int r = CVodeGetSensNumNonlinSolvConvFails(cvode_mem, + &nSnfails_adapt_modifiable); + return std::make_tuple(r, nSnfails_adapt_modifiable); + }; + + return CVodeGetSensNumNonlinSolvConvFails_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetSensNonlinSolvStats", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetSensNonlinSolvStats_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nSniters_adapt_modifiable; + long nSnfails_adapt_modifiable; + + int r = CVodeGetSensNonlinSolvStats(cvode_mem, &nSniters_adapt_modifiable, + &nSnfails_adapt_modifiable); + return std::make_tuple(r, nSniters_adapt_modifiable, + nSnfails_adapt_modifiable); + }; + + return CVodeGetSensNonlinSolvStats_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumStepSensSolveFails", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumStepSensSolveFails_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nSncfails_adapt_modifiable; + + int r = CVodeGetNumStepSensSolveFails(cvode_mem, + &nSncfails_adapt_modifiable); + return std::make_tuple(r, nSncfails_adapt_modifiable); + }; + + return CVodeGetNumStepSensSolveFails_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetStgrSensNumNonlinSolvIters", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetStgrSensNumNonlinSolvIters_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nSTGR1niters_adapt_modifiable; + + int r = CVodeGetStgrSensNumNonlinSolvIters(cvode_mem, + &nSTGR1niters_adapt_modifiable); + return std::make_tuple(r, nSTGR1niters_adapt_modifiable); + }; + + return CVodeGetStgrSensNumNonlinSolvIters_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetStgrSensNumNonlinSolvConvFails", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetStgrSensNumNonlinSolvConvFails_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nSTGR1nfails_adapt_modifiable; + + int r = + CVodeGetStgrSensNumNonlinSolvConvFails(cvode_mem, + &nSTGR1nfails_adapt_modifiable); + return std::make_tuple(r, nSTGR1nfails_adapt_modifiable); + }; + + return CVodeGetStgrSensNumNonlinSolvConvFails_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetStgrSensNonlinSolvStats", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetStgrSensNonlinSolvStats_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nSTGR1niters_adapt_modifiable; + long nSTGR1nfails_adapt_modifiable; + + int r = CVodeGetStgrSensNonlinSolvStats(cvode_mem, + &nSTGR1niters_adapt_modifiable, + &nSTGR1nfails_adapt_modifiable); + return std::make_tuple(r, nSTGR1niters_adapt_modifiable, + nSTGR1nfails_adapt_modifiable); + }; + + return CVodeGetStgrSensNonlinSolvStats_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumStepStgrSensSolveFails", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumStepStgrSensSolveFails_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nSTGR1ncfails_adapt_modifiable; + + int r = CVodeGetNumStepStgrSensSolveFails(cvode_mem, + &nSTGR1ncfails_adapt_modifiable); + return std::make_tuple(r, nSTGR1ncfails_adapt_modifiable); + }; + + return CVodeGetNumStepStgrSensSolveFails_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeQuadSensReInit", + [](void* cvode_mem, std::vector yQS0_1d) -> int + { + auto CVodeQuadSensReInit_adapt_arr_ptr_to_std_vector = + [](void* cvode_mem, std::vector yQS0_1d) -> int + { + N_Vector* yQS0_1d_ptr = + reinterpret_cast(yQS0_1d.empty() ? nullptr : yQS0_1d.data()); + + auto lambda_result = CVodeQuadSensReInit(cvode_mem, yQS0_1d_ptr); + return lambda_result; + }; + + return CVodeQuadSensReInit_adapt_arr_ptr_to_std_vector(cvode_mem, yQS0_1d); + }, + nb::arg("cvode_mem"), nb::arg("yQS0_1d")); + +m.def( + "CVodeQuadSensSStolerances", + [](void* cvode_mem, sunrealtype reltolQS, sundials4py::Array1d abstolQS_1d) -> int + { + auto CVodeQuadSensSStolerances_adapt_arr_ptr_to_std_vector = + [](void* cvode_mem, sunrealtype reltolQS, + sundials4py::Array1d abstolQS_1d) -> int + { + sunrealtype* abstolQS_1d_ptr = + reinterpret_cast(abstolQS_1d.data()); + + auto lambda_result = CVodeQuadSensSStolerances(cvode_mem, reltolQS, + abstolQS_1d_ptr); + return lambda_result; + }; + + return CVodeQuadSensSStolerances_adapt_arr_ptr_to_std_vector(cvode_mem, + reltolQS, + abstolQS_1d); + }, + nb::arg("cvode_mem"), nb::arg("reltolQS"), nb::arg("abstolQS_1d")); + +m.def( + "CVodeQuadSensSVtolerances", + [](void* cvode_mem, sunrealtype reltolQS, std::vector abstolQS_1d) -> int + { + auto CVodeQuadSensSVtolerances_adapt_arr_ptr_to_std_vector = + [](void* cvode_mem, sunrealtype reltolQS, + std::vector abstolQS_1d) -> int + { + N_Vector* abstolQS_1d_ptr = reinterpret_cast( + abstolQS_1d.empty() ? nullptr : abstolQS_1d.data()); + + auto lambda_result = CVodeQuadSensSVtolerances(cvode_mem, reltolQS, + abstolQS_1d_ptr); + return lambda_result; + }; + + return CVodeQuadSensSVtolerances_adapt_arr_ptr_to_std_vector(cvode_mem, + reltolQS, + abstolQS_1d); + }, + nb::arg("cvode_mem"), nb::arg("reltolQS"), nb::arg("abstolQS_1d")); + +m.def("CVodeQuadSensEEtolerances", CVodeQuadSensEEtolerances, + nb::arg("cvode_mem")); + +m.def("CVodeSetQuadSensErrCon", CVodeSetQuadSensErrCon, nb::arg("cvode_mem"), + nb::arg("errconQS"), "Optional input specification functions"); + +m.def( + "CVodeGetQuadSens", + [](void* cvode_mem, + std::vector yQSout_1d) -> std::tuple + { + auto CVodeGetQuadSens_adapt_arr_ptr_to_std_vector = + [](void* cvode_mem, sunrealtype* tret, std::vector yQSout_1d) -> int + { + N_Vector* yQSout_1d_ptr = reinterpret_cast( + yQSout_1d.empty() ? nullptr : yQSout_1d.data()); + + auto lambda_result = CVodeGetQuadSens(cvode_mem, tret, yQSout_1d_ptr); + return lambda_result; + }; + auto CVodeGetQuadSens_adapt_modifiable_immutable_to_return = + [&CVodeGetQuadSens_adapt_arr_ptr_to_std_vector](void* cvode_mem, + std::vector yQSout_1d) + -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + int r = CVodeGetQuadSens_adapt_arr_ptr_to_std_vector(cvode_mem, + &tret_adapt_modifiable, + yQSout_1d); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return CVodeGetQuadSens_adapt_modifiable_immutable_to_return(cvode_mem, + yQSout_1d); + }, + nb::arg("cvode_mem"), nb::arg("yQSout_1d")); + +m.def( + "CVodeGetQuadSens1", + [](void* cvode_mem, int is, N_Vector yQSout) -> std::tuple + { + auto CVodeGetQuadSens1_adapt_modifiable_immutable_to_return = + [](void* cvode_mem, int is, N_Vector yQSout) -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + int r = CVodeGetQuadSens1(cvode_mem, &tret_adapt_modifiable, is, yQSout); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return CVodeGetQuadSens1_adapt_modifiable_immutable_to_return(cvode_mem, is, + yQSout); + }, + nb::arg("cvode_mem"), nb::arg("is_"), nb::arg("yQSout")); + +m.def( + "CVodeGetQuadSensDky", + [](void* cvode_mem, sunrealtype t, int k, + std::vector dkyQS_all_1d) -> int + { + auto CVodeGetQuadSensDky_adapt_arr_ptr_to_std_vector = + [](void* cvode_mem, sunrealtype t, int k, + std::vector dkyQS_all_1d) -> int + { + N_Vector* dkyQS_all_1d_ptr = reinterpret_cast( + dkyQS_all_1d.empty() ? nullptr : dkyQS_all_1d.data()); + + auto lambda_result = CVodeGetQuadSensDky(cvode_mem, t, k, dkyQS_all_1d_ptr); + return lambda_result; + }; + + return CVodeGetQuadSensDky_adapt_arr_ptr_to_std_vector(cvode_mem, t, k, + dkyQS_all_1d); + }, + nb::arg("cvode_mem"), nb::arg("t"), nb::arg("k"), nb::arg("dkyQS_all_1d")); + +m.def("CVodeGetQuadSensDky1", CVodeGetQuadSensDky1, nb::arg("cvode_mem"), + nb::arg("t"), nb::arg("k"), nb::arg("is_"), nb::arg("dkyQS")); + +m.def( + "CVodeGetQuadSensNumRhsEvals", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetQuadSensNumRhsEvals_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nfQSevals_adapt_modifiable; + + int r = CVodeGetQuadSensNumRhsEvals(cvode_mem, &nfQSevals_adapt_modifiable); + return std::make_tuple(r, nfQSevals_adapt_modifiable); + }; + + return CVodeGetQuadSensNumRhsEvals_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetQuadSensNumErrTestFails", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetQuadSensNumErrTestFails_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nQSetfails_adapt_modifiable; + + int r = CVodeGetQuadSensNumErrTestFails(cvode_mem, + &nQSetfails_adapt_modifiable); + return std::make_tuple(r, nQSetfails_adapt_modifiable); + }; + + return CVodeGetQuadSensNumErrTestFails_adapt_modifiable_immutable_to_return( + cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetQuadSensErrWeights", + [](void* cvode_mem, std::vector eQSweight_1d) -> int + { + auto CVodeGetQuadSensErrWeights_adapt_arr_ptr_to_std_vector = + [](void* cvode_mem, std::vector eQSweight_1d) -> int + { + N_Vector* eQSweight_1d_ptr = reinterpret_cast( + eQSweight_1d.empty() ? nullptr : eQSweight_1d.data()); + + auto lambda_result = CVodeGetQuadSensErrWeights(cvode_mem, + eQSweight_1d_ptr); + return lambda_result; + }; + + return CVodeGetQuadSensErrWeights_adapt_arr_ptr_to_std_vector(cvode_mem, + eQSweight_1d); + }, + nb::arg("cvode_mem"), nb::arg("eQSweight_1d")); + +m.def( + "CVodeGetQuadSensStats", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetQuadSensStats_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nfQSevals_adapt_modifiable; + long nQSetfails_adapt_modifiable; + + int r = CVodeGetQuadSensStats(cvode_mem, &nfQSevals_adapt_modifiable, + &nQSetfails_adapt_modifiable); + return std::make_tuple(r, nfQSevals_adapt_modifiable, + nQSetfails_adapt_modifiable); + }; + + return CVodeGetQuadSensStats_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def("CVodeAdjInit", CVodeAdjInit, nb::arg("cvode_mem"), nb::arg("steps"), + nb::arg("interp")); + +m.def("CVodeAdjReInit", CVodeAdjReInit, nb::arg("cvode_mem")); + +m.def( + "CVodeCreateB", + [](void* cvode_mem, int lmmB) -> std::tuple + { + auto CVodeCreateB_adapt_modifiable_immutable_to_return = + [](void* cvode_mem, int lmmB) -> std::tuple + { + int which_adapt_modifiable; + + int r = CVodeCreateB(cvode_mem, lmmB, &which_adapt_modifiable); + return std::make_tuple(r, which_adapt_modifiable); + }; + + return CVodeCreateB_adapt_modifiable_immutable_to_return(cvode_mem, lmmB); + }, + nb::arg("cvode_mem"), nb::arg("lmmB")); + +m.def("CVodeReInitB", CVodeReInitB, nb::arg("cvode_mem"), nb::arg("which"), + nb::arg("tB0"), nb::arg("yB0")); + +m.def("CVodeSStolerancesB", CVodeSStolerancesB, nb::arg("cvode_mem"), + nb::arg("which"), nb::arg("reltolB"), nb::arg("abstolB")); + +m.def("CVodeSVtolerancesB", CVodeSVtolerancesB, nb::arg("cvode_mem"), + nb::arg("which"), nb::arg("reltolB"), nb::arg("abstolB")); + +m.def("CVodeQuadReInitB", CVodeQuadReInitB, nb::arg("cvode_mem"), + nb::arg("which"), nb::arg("yQB0")); + +m.def("CVodeQuadSStolerancesB", CVodeQuadSStolerancesB, nb::arg("cvode_mem"), + nb::arg("which"), nb::arg("reltolQB"), nb::arg("abstolQB")); + +m.def("CVodeQuadSVtolerancesB", CVodeQuadSVtolerancesB, nb::arg("cvode_mem"), + nb::arg("which"), nb::arg("reltolQB"), nb::arg("abstolQB")); + +m.def( + "CVodeF", + [](void* cvode_mem, sunrealtype tout, N_Vector yout, + int itask) -> std::tuple + { + auto CVodeF_adapt_modifiable_immutable_to_return = + [](void* cvode_mem, sunrealtype tout, N_Vector yout, + int itask) -> std::tuple + { + sunrealtype tret_adapt_modifiable; + int ncheckPtr_adapt_modifiable; + + int r = CVodeF(cvode_mem, tout, yout, &tret_adapt_modifiable, itask, + &ncheckPtr_adapt_modifiable); + return std::make_tuple(r, tret_adapt_modifiable, + ncheckPtr_adapt_modifiable); + }; + + return CVodeF_adapt_modifiable_immutable_to_return(cvode_mem, tout, yout, + itask); + }, + nb::arg("cvode_mem"), nb::arg("tout"), nb::arg("yout"), nb::arg("itask")); + +m.def("CVodeB", CVodeB, nb::arg("cvode_mem"), nb::arg("tBout"), + nb::arg("itaskB")); + +m.def("CVodeSetAdjNoSensi", CVodeSetAdjNoSensi, nb::arg("cvode_mem")); + +m.def("CVodeSetMaxOrdB", CVodeSetMaxOrdB, nb::arg("cvode_mem"), + nb::arg("which"), nb::arg("maxordB")); + +m.def("CVodeSetMaxNumStepsB", CVodeSetMaxNumStepsB, nb::arg("cvode_mem"), + nb::arg("which"), nb::arg("mxstepsB")); + +m.def("CVodeSetStabLimDetB", CVodeSetStabLimDetB, nb::arg("cvode_mem"), + nb::arg("which"), nb::arg("stldetB")); + +m.def("CVodeSetInitStepB", CVodeSetInitStepB, nb::arg("cvode_mem"), + nb::arg("which"), nb::arg("hinB")); + +m.def("CVodeSetMinStepB", CVodeSetMinStepB, nb::arg("cvode_mem"), + nb::arg("which"), nb::arg("hminB")); + +m.def("CVodeSetMaxStepB", CVodeSetMaxStepB, nb::arg("cvode_mem"), + nb::arg("which"), nb::arg("hmaxB")); + +m.def("CVodeSetConstraintsB", CVodeSetConstraintsB, nb::arg("cvode_mem"), + nb::arg("which"), nb::arg("constraintsB")); + +m.def("CVodeSetQuadErrConB", CVodeSetQuadErrConB, nb::arg("cvode_mem"), + nb::arg("which"), nb::arg("errconQB")); + +m.def("CVodeSetNonlinearSolverB", CVodeSetNonlinearSolverB, + nb::arg("cvode_mem"), nb::arg("which"), nb::arg("NLS")); + +m.def( + "CVodeGetB", + [](void* cvode_mem, int which, N_Vector yB) -> std::tuple + { + auto CVodeGetB_adapt_modifiable_immutable_to_return = + [](void* cvode_mem, int which, N_Vector yB) -> std::tuple + { + sunrealtype tBret_adapt_modifiable; + + int r = CVodeGetB(cvode_mem, which, &tBret_adapt_modifiable, yB); + return std::make_tuple(r, tBret_adapt_modifiable); + }; + + return CVodeGetB_adapt_modifiable_immutable_to_return(cvode_mem, which, yB); + }, + nb::arg("cvode_mem"), nb::arg("which"), nb::arg("yB")); + +m.def( + "CVodeGetQuadB", + [](void* cvode_mem, int which, N_Vector qB) -> std::tuple + { + auto CVodeGetQuadB_adapt_modifiable_immutable_to_return = + [](void* cvode_mem, int which, N_Vector qB) -> std::tuple + { + sunrealtype tBret_adapt_modifiable; + + int r = CVodeGetQuadB(cvode_mem, which, &tBret_adapt_modifiable, qB); + return std::make_tuple(r, tBret_adapt_modifiable); + }; + + return CVodeGetQuadB_adapt_modifiable_immutable_to_return(cvode_mem, which, + qB); + }, + nb::arg("cvode_mem"), nb::arg("which"), nb::arg("qB")); + +m.def("CVodeGetAdjCVodeBmem", CVodeGetAdjCVodeBmem, nb::arg("cvode_mem"), + nb::arg("which")); + +m.def("CVodeGetAdjY", CVodeGetAdjY, nb::arg("cvode_mem"), nb::arg("t"), + nb::arg("y")); + +m.def("CVodeGetAdjCheckPointsInfo", CVodeGetAdjCheckPointsInfo, + nb::arg("cvode_mem"), nb::arg("ckpnt")); + +m.def( + "CVodeGetAdjDataPointHermite", + [](void* cvode_mem, int which, N_Vector y, + N_Vector yd) -> std::tuple + { + auto CVodeGetAdjDataPointHermite_adapt_modifiable_immutable_to_return = + [](void* cvode_mem, int which, N_Vector y, + N_Vector yd) -> std::tuple + { + sunrealtype t_adapt_modifiable; + + int r = CVodeGetAdjDataPointHermite(cvode_mem, which, &t_adapt_modifiable, + y, yd); + return std::make_tuple(r, t_adapt_modifiable); + }; + + return CVodeGetAdjDataPointHermite_adapt_modifiable_immutable_to_return(cvode_mem, + which, + y, + yd); + }, + nb::arg("cvode_mem"), nb::arg("which"), nb::arg("y"), nb::arg("yd")); + +m.def( + "CVodeGetAdjDataPointPolynomial", + [](void* cvode_mem, int which, N_Vector y) -> std::tuple + { + auto CVodeGetAdjDataPointPolynomial_adapt_modifiable_immutable_to_return = + [](void* cvode_mem, int which, + N_Vector y) -> std::tuple + { + sunrealtype t_adapt_modifiable; + int order_adapt_modifiable; + + int r = CVodeGetAdjDataPointPolynomial(cvode_mem, which, + &t_adapt_modifiable, + &order_adapt_modifiable, y); + return std::make_tuple(r, t_adapt_modifiable, order_adapt_modifiable); + }; + + return CVodeGetAdjDataPointPolynomial_adapt_modifiable_immutable_to_return(cvode_mem, + which, + y); + }, + nb::arg("cvode_mem"), nb::arg("which"), nb::arg("y")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// +// #ifndef _CVSLS_H +// +// #ifdef __cplusplus +// #endif +// +m.attr("CVLS_SUCCESS") = 0; +m.attr("CVLS_MEM_NULL") = -1; +m.attr("CVLS_LMEM_NULL") = -2; +m.attr("CVLS_ILL_INPUT") = -3; +m.attr("CVLS_MEM_FAIL") = -4; +m.attr("CVLS_PMEM_NULL") = -5; +m.attr("CVLS_JACFUNC_UNRECVR") = -6; +m.attr("CVLS_JACFUNC_RECVR") = -7; +m.attr("CVLS_SUNMAT_FAIL") = -8; +m.attr("CVLS_SUNLS_FAIL") = -9; +m.attr("CVLS_NO_ADJ") = -101; +m.attr("CVLS_LMEMB_NULL") = -102; + +m.def( + "CVodeSetLinearSolver", + [](void* cvode_mem, SUNLinearSolver LS, + std::optional A = std::nullopt) -> int + { + auto CVodeSetLinearSolver_adapt_optional_arg_with_default_null = + [](void* cvode_mem, SUNLinearSolver LS, + std::optional A = std::nullopt) -> int + { + SUNMatrix A_adapt_default_null = nullptr; + if (A.has_value()) A_adapt_default_null = A.value(); + + auto lambda_result = CVodeSetLinearSolver(cvode_mem, LS, + A_adapt_default_null); + return lambda_result; + }; + + return CVodeSetLinearSolver_adapt_optional_arg_with_default_null(cvode_mem, + LS, A); + }, + nb::arg("cvode_mem"), nb::arg("LS"), nb::arg("A").none() = nb::none()); + +m.def("CVodeSetJacEvalFrequency", CVodeSetJacEvalFrequency, + nb::arg("cvode_mem"), nb::arg("msbj")); + +m.def("CVodeSetLinearSolutionScaling", CVodeSetLinearSolutionScaling, + nb::arg("cvode_mem"), nb::arg("onoff")); + +m.def("CVodeSetDeltaGammaMaxBadJac", CVodeSetDeltaGammaMaxBadJac, + nb::arg("cvode_mem"), nb::arg("dgmax_jbad")); + +m.def("CVodeSetEpsLin", CVodeSetEpsLin, nb::arg("cvode_mem"), nb::arg("eplifac")); + +m.def("CVodeSetLSNormFactor", CVodeSetLSNormFactor, nb::arg("arkode_mem"), + nb::arg("nrmfac")); + +m.def( + "CVodeGetJac", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetJac_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + SUNMatrix J_adapt_modifiable; + + int r = CVodeGetJac(cvode_mem, &J_adapt_modifiable); + return std::make_tuple(r, J_adapt_modifiable); + }; + + return CVodeGetJac_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem"), "nb::rv_policy::reference", nb::rv_policy::reference); + +m.def( + "CVodeGetJacTime", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetJacTime_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + sunrealtype t_J_adapt_modifiable; + + int r = CVodeGetJacTime(cvode_mem, &t_J_adapt_modifiable); + return std::make_tuple(r, t_J_adapt_modifiable); + }; + + return CVodeGetJacTime_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetJacNumSteps", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetJacNumSteps_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nst_J_adapt_modifiable; + + int r = CVodeGetJacNumSteps(cvode_mem, &nst_J_adapt_modifiable); + return std::make_tuple(r, nst_J_adapt_modifiable); + }; + + return CVodeGetJacNumSteps_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumJacEvals", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumJacEvals_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long njevals_adapt_modifiable; + + int r = CVodeGetNumJacEvals(cvode_mem, &njevals_adapt_modifiable); + return std::make_tuple(r, njevals_adapt_modifiable); + }; + + return CVodeGetNumJacEvals_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumPrecEvals", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumPrecEvals_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long npevals_adapt_modifiable; + + int r = CVodeGetNumPrecEvals(cvode_mem, &npevals_adapt_modifiable); + return std::make_tuple(r, npevals_adapt_modifiable); + }; + + return CVodeGetNumPrecEvals_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumPrecSolves", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumPrecSolves_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long npsolves_adapt_modifiable; + + int r = CVodeGetNumPrecSolves(cvode_mem, &npsolves_adapt_modifiable); + return std::make_tuple(r, npsolves_adapt_modifiable); + }; + + return CVodeGetNumPrecSolves_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumLinIters", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumLinIters_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nliters_adapt_modifiable; + + int r = CVodeGetNumLinIters(cvode_mem, &nliters_adapt_modifiable); + return std::make_tuple(r, nliters_adapt_modifiable); + }; + + return CVodeGetNumLinIters_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumLinConvFails", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumLinConvFails_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nlcfails_adapt_modifiable; + + int r = CVodeGetNumLinConvFails(cvode_mem, &nlcfails_adapt_modifiable); + return std::make_tuple(r, nlcfails_adapt_modifiable); + }; + + return CVodeGetNumLinConvFails_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumJTSetupEvals", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumJTSetupEvals_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long njtsetups_adapt_modifiable; + + int r = CVodeGetNumJTSetupEvals(cvode_mem, &njtsetups_adapt_modifiable); + return std::make_tuple(r, njtsetups_adapt_modifiable); + }; + + return CVodeGetNumJTSetupEvals_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumJtimesEvals", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumJtimesEvals_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long njvevals_adapt_modifiable; + + int r = CVodeGetNumJtimesEvals(cvode_mem, &njvevals_adapt_modifiable); + return std::make_tuple(r, njvevals_adapt_modifiable); + }; + + return CVodeGetNumJtimesEvals_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumLinRhsEvals", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumLinRhsEvals_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nfevalsLS_adapt_modifiable; + + int r = CVodeGetNumLinRhsEvals(cvode_mem, &nfevalsLS_adapt_modifiable); + return std::make_tuple(r, nfevalsLS_adapt_modifiable); + }; + + return CVodeGetNumLinRhsEvals_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetLinSolveStats", + [](void* cvode_mem) + -> std::tuple + { + auto CVodeGetLinSolveStats_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) + -> std::tuple + { + long njevals_adapt_modifiable; + long nfevalsLS_adapt_modifiable; + long nliters_adapt_modifiable; + long nlcfails_adapt_modifiable; + long npevals_adapt_modifiable; + long npsolves_adapt_modifiable; + long njtsetups_adapt_modifiable; + long njtimes_adapt_modifiable; + + int r = CVodeGetLinSolveStats(cvode_mem, &njevals_adapt_modifiable, + &nfevalsLS_adapt_modifiable, + &nliters_adapt_modifiable, + &nlcfails_adapt_modifiable, + &npevals_adapt_modifiable, + &npsolves_adapt_modifiable, + &njtsetups_adapt_modifiable, + &njtimes_adapt_modifiable); + return std::make_tuple(r, njevals_adapt_modifiable, + nfevalsLS_adapt_modifiable, + nliters_adapt_modifiable, nlcfails_adapt_modifiable, + npevals_adapt_modifiable, npsolves_adapt_modifiable, + njtsetups_adapt_modifiable, + njtimes_adapt_modifiable); + }; + + return CVodeGetLinSolveStats_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetLastLinFlag", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetLastLinFlag_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long flag_adapt_modifiable; + + int r = CVodeGetLastLinFlag(cvode_mem, &flag_adapt_modifiable); + return std::make_tuple(r, flag_adapt_modifiable); + }; + + return CVodeGetLastLinFlag_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def("CVodeGetLinReturnFlagName", CVodeGetLinReturnFlagName, nb::arg("flag")); + +m.def( + "CVodeSetLinearSolverB", + [](void* cvode_mem, int which, SUNLinearSolver LS, + std::optional A = std::nullopt) -> int + { + auto CVodeSetLinearSolverB_adapt_optional_arg_with_default_null = + [](void* cvode_mem, int which, SUNLinearSolver LS, + std::optional A = std::nullopt) -> int + { + SUNMatrix A_adapt_default_null = nullptr; + if (A.has_value()) A_adapt_default_null = A.value(); + + auto lambda_result = CVodeSetLinearSolverB(cvode_mem, which, LS, + A_adapt_default_null); + return lambda_result; + }; + + return CVodeSetLinearSolverB_adapt_optional_arg_with_default_null(cvode_mem, + which, LS, + A); + }, + nb::arg("cvode_mem"), nb::arg("which"), nb::arg("LS"), + nb::arg("A").none() = nb::none()); + +m.def("CVodeSetEpsLinB", CVodeSetEpsLinB, nb::arg("cvode_mem"), + nb::arg("which"), nb::arg("eplifacB")); + +m.def("CVodeSetLSNormFactorB", CVodeSetLSNormFactorB, nb::arg("arkode_mem"), + nb::arg("which"), nb::arg("nrmfacB")); + +m.def("CVodeSetLinearSolutionScalingB", CVodeSetLinearSolutionScalingB, + nb::arg("cvode_mem"), nb::arg("which"), nb::arg("onoffB")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// +// #ifndef _CVPROJ_H +// +// #ifdef __cplusplus +// #endif +// + +m.def("CVodeSetProjErrEst", CVodeSetProjErrEst, nb::arg("cvode_mem"), + nb::arg("onoff")); + +m.def("CVodeSetProjFrequency", CVodeSetProjFrequency, nb::arg("cvode_mem"), + nb::arg("proj_freq")); + +m.def("CVodeSetMaxNumProjFails", CVodeSetMaxNumProjFails, nb::arg("cvode_mem"), + nb::arg("max_fails")); + +m.def("CVodeSetEpsProj", CVodeSetEpsProj, nb::arg("cvode_mem"), nb::arg("eps")); + +m.def("CVodeSetProjFailEta", CVodeSetProjFailEta, nb::arg("cvode_mem"), + nb::arg("eta")); + +m.def( + "CVodeGetNumProjEvals", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumProjEvals_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nproj_adapt_modifiable; + + int r = CVodeGetNumProjEvals(cvode_mem, &nproj_adapt_modifiable); + return std::make_tuple(r, nproj_adapt_modifiable); + }; + + return CVodeGetNumProjEvals_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); + +m.def( + "CVodeGetNumProjFails", + [](void* cvode_mem) -> std::tuple + { + auto CVodeGetNumProjFails_adapt_modifiable_immutable_to_return = + [](void* cvode_mem) -> std::tuple + { + long nprf_adapt_modifiable; + + int r = CVodeGetNumProjFails(cvode_mem, &nprf_adapt_modifiable); + return std::make_tuple(r, nprf_adapt_modifiable); + }; + + return CVodeGetNumProjFails_adapt_modifiable_immutable_to_return(cvode_mem); + }, + nb::arg("cvode_mem")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/cvodes/cvodes_usersupplied.hpp b/bindings/sundials4py/cvodes/cvodes_usersupplied.hpp new file mode 100644 index 0000000000..4b116d9c8b --- /dev/null +++ b/bindings/sundials4py/cvodes/cvodes_usersupplied.hpp @@ -0,0 +1,575 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#ifndef _SUNDIALS4PY_CVODE_USERSUPPLIED_HPP +#define _SUNDIALS4PY_CVODE_USERSUPPLIED_HPP + +#include +#include + +#include + +#include "sundials/sundials_types.h" +#include "sundials4py_helpers.hpp" + +/////////////////////////////////////////////////////////////////////////////// +// CVODE user-supplied function table +// Every integrator-level user-supplied function must be in this table. +// The user-supplied function table is passed to CVODE as user_data. +/////////////////////////////////////////////////////////////////////////////// + +struct cvode_user_supplied_fn_table +{ + // user-supplied function pointers + nb::object f, rootfn, ewtn, rwtn, fNLS, projfn; + + // cvode_ls user-supplied function pointers + nb::object lsjacfn, lsprecsetupfn, lsprecsolvefn, lsjactimessetupfn, + lsjactimesvecfn, lslinsysfn, lsjacrhsfn; + + // cvode quadrature user-supplied function pointers + nanobind::object fQ, fQS; + + // cvode FSA user-supplied function pointers + nanobind::object fS, fS1; +}; + +// Helper to extract CVodeMem and function table +inline cvode_user_supplied_fn_table* get_cvode_fn_table(void* cv_mem) +{ + auto mem = static_cast(cv_mem); + auto fn_table = static_cast(mem->python); + if (!fn_table) + { + throw sundials4py::null_function_table( + "Failed to get Python function table from CVODE memory"); + } + return fn_table; +} + +/////////////////////////////////////////////////////////////////////////////// +// CVODE user-supplied functions +/////////////////////////////////////////////////////////////////////////////// + +inline cvode_user_supplied_fn_table* cvode_user_supplied_fn_table_alloc() +{ + // We must use malloc since CVODEFree calls free + auto fn_table = static_cast( + std::malloc(sizeof(cvode_user_supplied_fn_table))); + + // Zero out the memory + std::memset(fn_table, 0, sizeof(cvode_user_supplied_fn_table)); + + return fn_table; +} + +template +inline int cvode_f_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvode_user_supplied_fn_table, CVodeMem, + 1>(&cvode_user_supplied_fn_table::f, std::forward(args)...); +} + +using CVRootStdFn = int(sunrealtype t, N_Vector y, sundials4py::Array1d gout, + void* user_data); + +inline int cvode_rootfn_wrapper(sunrealtype t, N_Vector y, sunrealtype* gout_1d, + void* user_data) +{ + auto cv_mem = static_cast(user_data); + auto fn_table = get_cvode_fn_table(user_data); + auto fn = nb::cast>(fn_table->rootfn); + + sundials4py::Array1d gout(gout_1d, + {static_cast(cv_mem->cv_nrtfn)}, + nb::find(gout_1d)); + + return fn(t, y, gout, nullptr); +} + +template +inline int cvode_ewtfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvode_user_supplied_fn_table, CVodeMem, + 1>(&cvode_user_supplied_fn_table::ewtn, std::forward(args)...); +} + +template +inline int cvode_nlsrhsfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvode_user_supplied_fn_table, CVodeMem, + 1>(&cvode_user_supplied_fn_table::fNLS, std::forward(args)...); +} + +template +inline int cvode_lsjacfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvode_user_supplied_fn_table, CVodeMem, + 4>(&cvode_user_supplied_fn_table::lsjacfn, std::forward(args)...); +} + +using CVLsPrecSetupStdFn = std::tuple( + sunrealtype t, N_Vector y, N_Vector fy, sunbooleantype jok, sunrealtype gamma, + void* user_data); + +inline int cvode_lsprecsetupfn_wrapper(sunrealtype t, N_Vector y, N_Vector fy, + sunbooleantype jok, + sunbooleantype* jcurPtr, + sunrealtype gamma, void* user_data) +{ + auto fn_table = get_cvode_fn_table(user_data); + auto fn = nb::cast>(fn_table->lsprecsetupfn); + + auto result = fn(t, y, fy, jok, gamma, nullptr); + + *jcurPtr = std::get<1>(result); + + return std::get<0>(result); +} + +template +inline int cvode_lsprecsolvefn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvode_user_supplied_fn_table, + CVodeMem, 1>(&cvode_user_supplied_fn_table::lsprecsolvefn, + std::forward(args)...); +} + +template +inline int cvode_lsjactimessetupfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvode_user_supplied_fn_table, + CVodeMem, 1>(&cvode_user_supplied_fn_table::lsjactimessetupfn, + std::forward(args)...); +} + +template +inline int cvode_lsjactimesvecfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvode_user_supplied_fn_table, + CVodeMem, 2>(&cvode_user_supplied_fn_table::lsjactimesvecfn, + std::forward(args)...); +} + +using CVLsLinSysStdFn = std::tuple( + sunrealtype t, N_Vector y, N_Vector fy, SUNMatrix M, sunbooleantype jok, + sunrealtype gamma, void* user_data, N_Vector tmp1, N_Vector tmp2, + N_Vector tmp3); + +inline int cvode_lslinsysfn_wrapper(sunrealtype t, N_Vector y, N_Vector fy, + SUNMatrix M, sunbooleantype jok, + sunbooleantype* jcur, sunrealtype gamma, + void* user_data, N_Vector tmp1, + N_Vector tmp2, N_Vector tmp3) +{ + auto fn_table = get_cvode_fn_table(user_data); + auto fn = nb::cast>(fn_table->lslinsysfn); + + auto result = fn(t, y, fy, M, jok, gamma, nullptr, tmp1, tmp2, tmp3); + + *jcur = std::get<1>(result); + + return std::get<0>(result); +} + +template +inline int cvode_lsjacrhsfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvode_user_supplied_fn_table, CVodeMem, + 1>(&cvode_user_supplied_fn_table::lsjacrhsfn, std::forward(args)...); +} + +template +inline int cvode_projfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvode_user_supplied_fn_table, CVodeMem, + 1>(&cvode_user_supplied_fn_table::projfn, std::forward(args)...); +} + +template +inline int cvode_fQ_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvode_user_supplied_fn_table, CVodeMem, + 1>(&cvode_user_supplied_fn_table::fQ, std::forward(args)...); +} + +using CVQuadSensRhsStdFn = int(int Ns, sunrealtype t, N_Vector y, + std::vector yS_1d, N_Vector yQdot, + std::vector yQSdot_1d, void* user_data, + N_Vector tmp, N_Vector tmpQ); + +inline int cvode_fQS_wrapper(int Ns, sunrealtype t, N_Vector y, N_Vector* yS_1d, + N_Vector yQdot, N_Vector* yQSdot_1d, + void* user_data, N_Vector tmp, N_Vector tmpQ) +{ + auto fn_table = get_cvode_fn_table(user_data); + auto fn = nb::cast>(fn_table->fQS); + + std::vector yS(yS_1d, yS_1d + Ns); + std::vector yQSdot(yQSdot_1d, yQSdot_1d + Ns); + + return fn(Ns, t, y, yS, yQdot, yQSdot, nullptr, tmp, tmpQ); +} + +using CVSensRhsStdFn = int(int Ns, sunrealtype t, N_Vector y, N_Vector ydot, + std::vector yS, std::vector ySdot, + void* user_data, N_Vector tmp1, N_Vector tmp2); + +inline int cvode_fS_wrapper(int Ns, sunrealtype t, N_Vector y, N_Vector ydot, + N_Vector* yS, N_Vector* ySdot, void* user_data, + N_Vector tmp1, N_Vector tmp2) +{ + auto fn_table = get_cvode_fn_table(user_data); + auto fn = nb::cast>(fn_table->fS); + + std::vector yS_1d(yS, yS + Ns); + std::vector ySdot_1d(ySdot, ySdot + Ns); + + return fn(Ns, t, y, ydot, yS_1d, ySdot_1d, nullptr, tmp1, tmp2); +} + +template +inline int cvode_fS1_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvode_user_supplied_fn_table, CVodeMem, + 3>(&cvode_user_supplied_fn_table::fS1, std::forward(args)...); +} + +/////////////////////////////////////////////////////////////////////////////// +// CVODE Adjoint user-supplied functions +/////////////////////////////////////////////////////////////////////////////// + +struct cvodea_user_supplied_fn_table +{ + // cvode adjoint user-supplied function pointers + nb::object fB, fBS, fQB, fQBS; + + // cvode_ls adjoint user-supplied function pointers + nb::object lsjacfnB, lsjacfnBS, lsprecsetupfnB, lsprecsetupfnBS, + lsprecsolvefnB, lsprecsolvefnBS, lsjactimessetupfnB, lsjactimessetupfnBS, + lsjactimesvecfnB, lsjactimesvecfnBS, lslinsysfnB, lslinsysfnBS; +}; + +inline cvodea_user_supplied_fn_table* cvodea_user_supplied_fn_table_alloc() +{ + // We must use malloc since CVODEFree calls free + auto fn_table = static_cast( + std::malloc(sizeof(cvodea_user_supplied_fn_table))); + + // Zero out the memory + std::memset(fn_table, 0, sizeof(cvodea_user_supplied_fn_table)); + + return fn_table; +} + +inline cvodea_user_supplied_fn_table* get_cvodea_fn_table(void* cv_mem) +{ + auto fn_table = static_cast( + static_cast(cv_mem)->python); + if (!fn_table) + { + throw sundials4py::null_function_table( + "Failed to get Python adjoint function table from CVODE memory"); + } + return fn_table; +} + +inline cvodea_user_supplied_fn_table* get_cvodea_fn_table(void* cv_mem, int which) +{ + auto cvb_mem = static_cast(CVodeGetAdjCVodeBmem(cv_mem, which)); + auto fn_table = static_cast(cvb_mem->python); + if (!fn_table) + { + throw sundials4py::null_function_table( + "Failed to get Python adjoint function table from CVODE memory"); + } + return fn_table; +} + +template +inline int cvode_fB_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvodea_user_supplied_fn_table, CVodeMem, + 1>(&cvodea_user_supplied_fn_table::fB, std::forward(args)...); +} + +template +inline int cvode_fQB_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvodea_user_supplied_fn_table, CVodeMem, + 1>(&cvodea_user_supplied_fn_table::fQB, std::forward(args)...); +} + +template +inline int cvode_lsjacfnB_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvodea_user_supplied_fn_table, CVodeMem, + 4>(&cvodea_user_supplied_fn_table::lsjacfnB, std::forward(args)...); +} + +using CVLsPrecSetupStdFnB = std::tuple( + sunrealtype t, N_Vector y, N_Vector yB, N_Vector fyB, sunbooleantype jokB, + sunrealtype gammaB, void* user_dataB); + +inline int cvode_lsprecsetupfnB_wrapper(sunrealtype t, N_Vector y, N_Vector yB, + N_Vector fyB, sunbooleantype jokB, + sunbooleantype* jcurPtrB, + sunrealtype gammaB, void* user_dataB) +{ + auto fn_table = get_cvodea_fn_table(user_dataB); + auto fn = nb::cast>(fn_table->lsprecsetupfnB); + + auto result = fn(t, y, yB, fyB, jokB, gammaB, nullptr); + + *jcurPtrB = std::get<1>(result); + + return std::get<0>(result); +} + +template +inline int cvode_lsprecsolvefnB_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvodea_user_supplied_fn_table, + CVodeMem, 1>(&cvodea_user_supplied_fn_table::lsprecsolvefnB, + std::forward(args)...); +} + +template +inline int cvode_lsjactimessetupfnB_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvodea_user_supplied_fn_table, + CVodeMem, 1>(&cvodea_user_supplied_fn_table::lsjactimessetupfnB, + std::forward(args)...); +} + +template +inline int cvode_lsjactimesvecfnB_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, cvodea_user_supplied_fn_table, + CVodeMem, 2>(&cvodea_user_supplied_fn_table::lsjactimesvecfnB, + std::forward(args)...); +} + +using CVLsLinSysStdFnB = std::tuple( + sunrealtype t, N_Vector y, N_Vector yB, N_Vector fyB, SUNMatrix AB, + sunbooleantype jokB, sunrealtype gammaB, void* user_dataB, N_Vector tmp1B, + N_Vector tmp2B, N_Vector tmp3B); + +inline int cvode_lslinsysfnB_wrapper(sunrealtype t, N_Vector y, N_Vector yB, + N_Vector fyB, SUNMatrix AB, + sunbooleantype jokB, sunbooleantype* jcurB, + sunrealtype gammaB, void* user_dataB, + N_Vector tmp1B, N_Vector tmp2B, + N_Vector tmp3B) +{ + auto fn_table = get_cvodea_fn_table(user_dataB); + auto fn = nb::cast>(fn_table->lslinsysfnB); + + auto result = fn(t, y, yB, fyB, AB, jokB, gammaB, nullptr, tmp1B, tmp2B, tmp3B); + + *jcurB = std::get<1>(result); + + return std::get<0>(result); +} + +using CVRhsStdFnBS = int(sunrealtype t, N_Vector y, std::vector yS_1d, + N_Vector yB, N_Vector yBdot, void* user_dataB); + +inline int cvode_fBS_wrapper(sunrealtype t, N_Vector y, N_Vector* yS_1d, + N_Vector yB, N_Vector yBdot, void* user_dataB) +{ + auto cv_mem = static_cast(user_dataB); + auto fn_table = get_cvodea_fn_table(user_dataB); + auto fn = nb::cast>(fn_table->fBS); + auto Ns = cv_mem->cv_Ns; + + std::vector yS(yS_1d, yS_1d + Ns); + + return fn(t, y, yS, yB, yBdot, nullptr); +} + +using CVQuadRhsStdFnBS = int(sunrealtype t, N_Vector y, + std::vector yS_1d, N_Vector yB, + N_Vector qBdot, void* user_dataB); + +inline int cvode_fQBS_wrapper(sunrealtype t, N_Vector y, N_Vector* yS_1d, + N_Vector yB, N_Vector qBdot, void* user_dataB) +{ + auto cv_mem = static_cast(user_dataB); + auto fn_table = get_cvodea_fn_table(user_dataB); + auto fn = nb::cast>(fn_table->fQBS); + auto Ns = cv_mem->cv_Ns; + + std::vector yS(yS_1d, yS_1d + Ns); + + return fn(t, y, yS, yB, qBdot, nullptr); +} + +using CVLsJacStdFnBS = int(sunrealtype t, N_Vector y, + std::vector yS_1d, N_Vector yB, + N_Vector fyB, SUNMatrix JB, void* user_dataB, + N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B); + +inline int cvode_lsjacfnBS_wrapper(sunrealtype t, N_Vector y, N_Vector* yS_1d, + N_Vector yB, N_Vector fyB, SUNMatrix JB, + void* user_dataB, N_Vector tmp1B, + N_Vector tmp2B, N_Vector tmp3B) +{ + auto cv_mem = static_cast(user_dataB); + auto fn_table = get_cvodea_fn_table(user_dataB); + auto fn = nb::cast>(fn_table->lsjacfnBS); + auto Ns = cv_mem->cv_Ns; + + std::vector yS(yS_1d, yS_1d + Ns); + + return fn(t, y, yS, yB, fyB, JB, nullptr, tmp1B, tmp2B, tmp3B); +} + +using CVLsPrecSetupStdFnBS = std::tuple( + sunrealtype t, N_Vector y, std::vector yS_1d, N_Vector yB, + N_Vector fyB, sunbooleantype jokB, sunrealtype gammaB, void* user_dataB); + +inline int cvode_lsprecsetupfnBS_wrapper(sunrealtype t, N_Vector y, + N_Vector* yS_1d, N_Vector yB, + N_Vector fyB, sunbooleantype jokB, + sunbooleantype* jcurPtrB, + sunrealtype gammaB, void* user_dataB) +{ + auto cv_mem = static_cast(user_dataB); + auto fn_table = get_cvodea_fn_table(user_dataB); + auto fn = + nb::cast>(fn_table->lsprecsetupfnBS); + auto Ns = cv_mem->cv_Ns; + + std::vector yS(yS_1d, yS_1d + Ns); + + auto result = fn(t, y, yS, yB, fyB, jokB, gammaB, nullptr); + + *jcurPtrB = std::get<1>(result); + + return std::get<0>(result); +} + +using CVLsPrecSolveStdFnBS = int(sunrealtype t, N_Vector y, + std::vector yS_1d, N_Vector yB, + N_Vector fyB, N_Vector rB, N_Vector zB, + sunrealtype gammaB, sunrealtype deltaB, + int lrB, void* user_dataB); + +inline int cvode_lsprecsolvefnBS_wrapper(sunrealtype t, N_Vector y, + N_Vector* yS_1d, N_Vector yB, + N_Vector fyB, N_Vector rB, N_Vector zB, + sunrealtype gammaB, sunrealtype deltaB, + int lrB, void* user_dataB) +{ + auto cv_mem = static_cast(user_dataB); + auto fn_table = get_cvodea_fn_table(user_dataB); + auto fn = + nb::cast>(fn_table->lsprecsolvefnBS); + auto Ns = cv_mem->cv_Ns; + + std::vector yS(yS_1d, yS_1d + Ns); + + return fn(t, y, yS, yB, fyB, rB, zB, gammaB, deltaB, lrB, nullptr); +} + +using CVLsJacTimesSetupStdFnBS = int(sunrealtype t, N_Vector y, + std::vector yS_1d, N_Vector yB, + N_Vector fyB, void* user_dataB); + +inline int cvode_lsjactimessetupfnBS_wrapper(sunrealtype t, N_Vector y, + N_Vector* yS_1d, N_Vector yB, + N_Vector fyB, void* user_dataB) +{ + auto cv_mem = static_cast(user_dataB); + auto fn_table = get_cvodea_fn_table(user_dataB); + auto fn = nb::cast>( + fn_table->lsjactimessetupfnBS); + auto Ns = cv_mem->cv_Ns; + + std::vector yS(yS_1d, yS_1d + Ns); + + return fn(t, y, yS, yB, fyB, nullptr); +} + +using CVLsJacTimesVecStdFnBS = int(N_Vector vB, N_Vector JvB, sunrealtype t, + N_Vector y, std::vector yS_1d, + N_Vector yB, N_Vector fyB, void* user_dataB, + N_Vector tmpB); + +inline int cvode_lsjactimesvecfnBS_wrapper(N_Vector vB, N_Vector JvB, + sunrealtype t, N_Vector y, + N_Vector* yS_1d, N_Vector yB, + N_Vector fyB, void* user_dataB, + N_Vector tmpB) +{ + auto cv_mem = static_cast(user_dataB); + auto fn_table = get_cvodea_fn_table(user_dataB); + auto fn = + nb::cast>(fn_table->lsjactimesvecfnBS); + auto Ns = cv_mem->cv_Ns; + + std::vector yS(yS_1d, yS_1d + Ns); + + return fn(vB, JvB, t, y, yS, yB, fyB, nullptr, tmpB); +} + +using CVLsLinSysStdFnBS = std::tuple( + sunrealtype t, N_Vector y, std::vector yS_1d, N_Vector yB, + N_Vector fyB, SUNMatrix AB, sunbooleantype jokB, sunrealtype gammaB, + void* user_dataB, N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B); + +inline int cvode_lslinsysfnBS_wrapper(sunrealtype t, N_Vector y, N_Vector* yS_1d, + N_Vector yB, N_Vector fyB, SUNMatrix AB, + sunbooleantype jokB, + sunbooleantype* jcurB, sunrealtype gammaB, + void* user_dataB, N_Vector tmp1B, + N_Vector tmp2B, N_Vector tmp3B) +{ + auto cv_mem = static_cast(user_dataB); + auto fn_table = get_cvodea_fn_table(user_dataB); + auto fn = nb::cast>(fn_table->lslinsysfnBS); + auto Ns = cv_mem->cv_Ns; + + std::vector yS(yS_1d, yS_1d + Ns); + + auto result = fn(t, y, yS, yB, fyB, AB, jokB, gammaB, nullptr, tmp1B, tmp2B, + tmp3B); + + *jcurB = std::get<1>(result); + + return std::get<0>(result); +} + +#endif diff --git a/bindings/sundials4py/cvodes/generate.yaml b/bindings/sundials4py/cvodes/generate.yaml new file mode 100644 index 0000000000..32ed131194 --- /dev/null +++ b/bindings/sundials4py/cvodes/generate.yaml @@ -0,0 +1,72 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2002-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This YAML file is used to instruct the generate.py script on +# how to generate some of the nanobind code needed for the +# CVODES module of sundials4py. +# ----------------------------------------------------------------- + +modules: + all: + fn_exclude_by_name__regex: + - "Free" # Free and destroy functions should not need to be called as all objects on the Python side are RAII + - "Destroy" + - "Space" # Space functions are deprecated, so dont expose them in Python + # Due to the need to convert between sys.argv and C argv, we need to do custom wrappers of these + - "SetOptions" + macro_define_include_by_name__regex: + - "^SUN_" + - "^CV_" + - "^CVLS_" + cvodes: + path: cvodes/cvodes_generated.hpp + headers: + - ../../include/cvodes/cvodes.h + - ../../include/cvodes/cvodes_ls.h + - ../../include/cvodes/cvodes_proj.h + # this option describes the functions which have optional pointer arguments, + # i.e., one where you could provide NULL + fn_params_optional_with_default_null: + "SetLinearSolver": + - "A" + fn_exclude_by_name__regex: + # We do custom handling of Create so we can wrap the void* in a CVodeView + - "^CVodeCreate$" + # we use user_data for sneaking in python contexts, so we don't interface these + - "^CVodeGetUserData$" + - "^CVodeSetUserData$" + - "^CVodeGetUserDataB$" + - "^CVodeSetUserDataB$" + # this function should be deprecated, so we don't interface it + - "^CVodeSetMonitorFn$" + # generator cannot handle setting of function pointers, so we do something custom + - "CVodeInit.*" + - "^CVodeSensInit$" + - "^CVodeSensInit1$" + - "CVodeQuadInit.*" + - "CVodeQuadSensInit.*" + - "CVodeSet.*Fn" + - "CVodeSet.*Preconditioner" + - "CVodeSetJacTimes.*" + - "^CVodeRootInit$" + - "^CVodeWFtolerances$" + # ** parameters are not yet supported by litgen + # - "^CVodeComputeStateSens$" + # - "^CVodeComputeStateSens1$" + # this function returns just a pointer address that is not useful from Python, so we don't interface it + - "^CVodeGetAdjCurrentCheckPoint" + # TODO(CJB): interface these (in the future?) + # generator cannot yet handle mixing pointer outputs and ** in the same function + - "^CVodeGetCurrentStateSens$" + - "^CVodeGetNonlinearSystemData$" + - "^CVodeGetNonlinearSystemDataSens$" \ No newline at end of file diff --git a/bindings/sundials4py/examples/arkode/ark_brusselator.py b/bindings/sundials4py/examples/arkode/ark_brusselator.py new file mode 100644 index 0000000000..8179ad661a --- /dev/null +++ b/bindings/sundials4py/examples/arkode/ark_brusselator.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This is a direct port of the C example, +# examples/arkode/C_serial/ark_brusselator.c, +# specifically with the parameters from Test 2. +# +# The following test simulates a brusselator problem from chemical +# kinetics. This is an ODE system with 3 components, Y = [u,v,w], +# satisfying the equations, +# du/dt = a - (w+1)*u + v*u^2 +# dv/dt = w*u - v*u^2 +# dw/dt = (b-w)/ep - w*u +# for t in the interval [0.0, 10.0], with initial conditions +# Y0 = [u0,v0,w0]. +# +# u0=1.2, v0=3.1, w0=3, a=1, b=3.5, ep=5.0e-6 +# Here, w experiences a fast initial transient, jumping 0.5 +# within a few steps. All values proceed smoothly until +# around t=6.5, when both u and v undergo a sharp transition, +# with u increasing from around 0.5 to 5 and v decreasing +# from around 6 to 1 in less than 0.5 time units. After this +# transition, both u and v continue to evolve somewhat +# rapidly for another 1.4 time units, and finish off smoothly. +# +# This program solves the problem with the DIRK method, using a +# Newton iteration with the SUNDIALS dense linear solver, and a +# user-supplied Jacobian routine. +# +# 100 outputs are printed at equal intervals, and run statistics +# are printed at the end. +# ----------------------------------------------------------------- + +import numpy as np +from sundials4py.core import * +from sundials4py.arkode import * + + +# Brusselator ODE problem class +class BrusselatorODE: + def __init__(self, u0, v0, w0, a, b, ep): + self.u0 = u0 + self.v0 = v0 + self.w0 = w0 + self.a = a + self.b = b + self.ep = ep + + def set_init_cond(self, yvec): + y = N_VGetArrayPointer(yvec) + y[0] = self.u0 + y[1] = self.v0 + y[2] = self.w0 + return 0 + + def f(self, t, yvec, ydotvec): + y = N_VGetArrayPointer(yvec) + ydot = N_VGetArrayPointer(ydotvec) + a, b, ep = self.a, self.b, self.ep + u, v, w = y[0], y[1], y[2] + ydot[0] = a - (w + 1.0) * u + v * u * u + ydot[1] = w * u - v * u * u + ydot[2] = (b - w) / ep - w * u + return 0 + + def jac(self, t, yvec, fyvec, J, tmp1, tmp2, tmp3): + y = N_VGetArrayPointer(yvec) + a, b, ep = self.a, self.b, self.ep + u, v, w = y[0], y[1], y[2] + Jdata = SUNDenseMatrix_Data(J) + Jdata[0, 0] = -(w + 1.0) + 2.0 * u * v + Jdata[0, 1] = u * u + Jdata[0, 2] = -u + Jdata[1, 0] = w - 2.0 * u * v + Jdata[1, 1] = -u * u + Jdata[1, 2] = u + Jdata[2, 0] = -w + Jdata[2, 1] = 0.0 + Jdata[2, 2] = -1.0 / ep - u + return 0 + + +def main(): + # Problem parameters for Test 2 + u0 = 1.2 + v0 = 3.1 + w0 = 3.0 + a = 1.0 + b = 3.5 + ep = 5.0e-6 + T0 = 0.0 + Tf = 10.0 + dTout = 1.0 + NEQ = 3 + Nt = int(np.ceil(Tf / dTout)) + reltol = 1.0e-6 + abstol = 1.0e-10 + + status, sunctx = SUNContext_Create(SUN_COMM_NULL) + y = N_VNew_Serial(NEQ, sunctx) + + # Create ODE problem instance and set initial conditions + ode_problem = BrusselatorODE(u0, v0, w0, a, b, ep) + ode_problem.set_init_cond(y) + + ark = ARKStepCreate( + None, # f_E (explicit) + lambda t, yvec, ydotvec, _: ode_problem.f(t, yvec, ydotvec), # f_I (implicit) + T0, + y, + sunctx, + ) + + status = ARKodeSStolerances(ark.get(), reltol, abstol) + assert status == ARK_SUCCESS + + status = ARKodeSetInterpolantType(ark.get(), ARK_INTERP_LAGRANGE) + assert status == ARK_SUCCESS + + status = ARKodeSetDeduceImplicitRhs(ark.get(), 1) + assert status == ARK_SUCCESS + + # Dense matrix and linear solver + A = SUNDenseMatrix(NEQ, NEQ, sunctx) + LS = SUNLinSol_Dense(y, A, sunctx) + + status = ARKodeSetLinearSolver(ark.get(), LS, A) + assert status == ARK_SUCCESS + + status = ARKodeSetJacFn( + ark.get(), + lambda t, yvec, fyvec, J, tmp1, tmp2, tmp3, _: ode_problem.jac( + t, yvec, fyvec, J, tmp1, tmp2, tmp3 + ), + ) + assert status == ARK_SUCCESS + + # Signal that this problem does not explicitly depend on time + status = ARKodeSetAutonomous(ark.get(), 1) + assert status == ARK_SUCCESS + + # Initial problem output + yarr = N_VGetArrayPointer(y) + print("\nBrusselator ODE test problem:") + print(f" initial conditions: u0 = {u0}, v0 = {v0}, w0 = {w0}") + print(f" problem parameters: a = {a}, b = {b}, ep = {ep}") + print(f" reltol = {reltol}, abstol = {abstol}\n") + print(" t u v w") + print(" -------------------------------------------") + print(f" {T0:10.6f} {yarr[0]:10.6f} {yarr[1]:10.6f} {yarr[2]:10.6f}") + + # Main time-stepping loop: calls ARKodeEvolve to perform the integration, + # then prints results. Stops when the final time has been reached. The + # solution is written out to a file. + with open("solution.txt", "w") as UFID: + UFID.write("# t u v w\n") + UFID.write(f" {T0:.16e} {yarr[0]:.16e} {yarr[1]:.16e} {yarr[2]:.16e}\n") + tout = T0 + dTout + for iout in range(Nt): + status, tret = ARKodeEvolve(ark.get(), tout, y, ARK_NORMAL) + yarr = N_VGetArrayPointer(y) + print(f" {tret:10.6f} {yarr[0]:10.6f} {yarr[1]:10.6f} {yarr[2]:10.6f}") + UFID.write(f" {tret:.16e} {yarr[0]:.16e} {yarr[1]:.16e} {yarr[2]:.16e}\n") + if status == ARK_SUCCESS: + tout += dTout + tout = min(tout, Tf) + else: + print("Solver failure, stopping integration") + break + print(" -------------------------------------------") + + # Print statistics (check status code) + status, nst = ARKodeGetNumSteps(ark.get()) + assert status == ARK_SUCCESS + status, nst_a = ARKodeGetNumStepAttempts(ark.get()) + assert status == ARK_SUCCESS + status, nfe = ARKodeGetNumRhsEvals(ark.get(), 0) + assert status == ARK_SUCCESS + status, nfi = ARKodeGetNumRhsEvals(ark.get(), 1) + assert status == ARK_SUCCESS + status, nsetups = ARKodeGetNumLinSolvSetups(ark.get()) + assert status == ARK_SUCCESS + status, netf = ARKodeGetNumErrTestFails(ark.get()) + assert status == ARK_SUCCESS + status, ncfn = ARKodeGetNumStepSolveFails(ark.get()) + assert status == ARK_SUCCESS + status, nni = ARKodeGetNumNonlinSolvIters(ark.get()) + assert status == ARK_SUCCESS + status, nnf = ARKodeGetNumNonlinSolvConvFails(ark.get()) + assert status == ARK_SUCCESS + status, nje = ARKodeGetNumJacEvals(ark.get()) + assert status == ARK_SUCCESS + status, nfeLS = ARKodeGetNumLinRhsEvals(ark.get()) + assert status == ARK_SUCCESS + + print("\nFinal Solver Statistics:") + print(f" Internal solver steps = {nst} (attempted = {nst_a})") + print(f" Total RHS evals: Fe = {nfe}, Fi = {nfi}") + print(f" Total linear solver setups = {nsetups}") + print(f" Total RHS evals for setting up the linear system = {nfeLS}") + print(f" Total number of Jacobian evaluations = {nje}") + print(f" Total number of Newton iterations = {nni}") + print(f" Total number of nonlinear solver convergence failures = {nnf}") + print(f" Total number of error test failures = {netf}") + print(f" Total number of failed steps from solver failure = {ncfn}") + + +def test_ark_brusselator(): + main() + + +if __name__ == "__main__": + main() diff --git a/bindings/sundials4py/examples/arkode/ark_heat1D.py b/bindings/sundials4py/examples/arkode/ark_heat1D.py new file mode 100644 index 0000000000..8826612b20 --- /dev/null +++ b/bindings/sundials4py/examples/arkode/ark_heat1D.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This example is a copy of examples/arkode/C_serial/ark_heat1D.c, +# but ported to Python to use sundials4py. +# +# The following test simulates a simple 1D heat equation, +# u_t = k*u_xx + f +# for t in [0, 10], x in [0, 1], with initial conditions +# u(0,x) = 0 +# Dirichlet boundary conditions, i.e. +# u_t(t,0) = u_t(t,1) = 0, +# and a point-source heating term, +# f = 1 for x=0.5. +# +# The spatial derivatives are computed using second-order +# centered differences, with the data distributed over N points +# on a uniform spatial grid. +# +# This program solves the problem with either an ERK or DIRK +# method. For the DIRK method, we use a Newton iteration with +# the SUNLinSol_PCG linear solver, and a user-supplied Jacobian-vector +# product routine. +# +# 100 outputs are printed at equal intervals, and run statistics +# are printed at the end. +# ----------------------------------------------------------------- + +import numpy as np +from sundials4py.core import * +from sundials4py.arkode import * + + +class Heat1DProblem: + def __init__(self, N, k): + self.N = N + self.k = k + self.dx = 1.0 / (N - 1) + + def set_init_cond(self, yvec): + y = N_VGetArrayPointer(yvec) + y[:] = 0.0 + return 0 + + def f(self, t, yvec, ydotvec): + N, k, dx = self.N, self.k, self.dx + Y = N_VGetArrayPointer(yvec) + Ydot = N_VGetArrayPointer(ydotvec) + Ydot[:] = 0.0 + c1 = k / dx / dx + c2 = -2.0 * k / dx / dx + # Vectorized Laplacian + Ydot[1:-1] = c1 * Y[:-2] + c2 * Y[1:-1] + c1 * Y[2:] + # Dirichlet BCs + Ydot[0] = 0.0 + Ydot[-1] = 0.0 + # Point source + isource = N // 2 + Ydot[isource] += 0.01 / dx + return 0 + + def jtv(self, vvec, Jvvec, t, yvec, fyvec, tmpvec): + N, k, dx = self.N, self.k, self.dx + V = N_VGetArrayPointer(vvec) + JV = N_VGetArrayPointer(Jvvec) + JV[:] = 0.0 + c1 = k / dx / dx + c2 = -2.0 * k / dx / dx + # Vectorized tridiagonal product + JV[1:-1] = c1 * V[:-2] + c2 * V[1:-1] + c1 * V[2:] + JV[0] = 0.0 + JV[-1] = 0.0 + return 0 + + +def main(): + # Problem parameters + N = 201 + k = 0.5 + T0 = 0.0 + Tf = 1.0 + Nt = 10 + reltol = 1e-6 + abstol = 1e-10 + + status, sunctx = SUNContext_Create(SUN_COMM_NULL) + y = N_VNew_Serial(N, sunctx) + + problem = Heat1DProblem(N, k) + problem.set_init_cond(y) + + # Call ARKStepCreate to initialize the ARK timestepper module and + # specify the right-hand side function in y'=f(t,y), the initial time + # T0, and the initial dependent variable vector y. Note: since this + # problem is fully implicit, we set f_E to NULL and f_I to f. */ + ark = ARKStepCreate( + None, # f_E (explicit) + lambda t, yvec, ydotvec, _: problem.f(t, yvec, ydotvec), # f_I (implicit) + T0, + y, + sunctx, + ) + + # Set routines + status = ARKodeSStolerances(ark.get(), reltol, abstol) + assert status == ARK_SUCCESS + + status = ARKodeSetMaxNumSteps(ark.get(), 10000) + assert status == ARK_SUCCESS + + status = ARKodeSetPredictorMethod(ark.get(), 1) + assert status == ARK_SUCCESS + + # PCG linear solver with no preconditioning, with up to N iterations + LS = SUNLinSol_PCG(y, 0, N, sunctx) + status = ARKodeSetLinearSolver(ark.get(), LS, None) + assert status == ARK_SUCCESS + + status = ARKodeSetJacTimes( + ark.get(), None, lambda v, Jv, t, y, fy, tmp, _: problem.jtv(v, Jv, t, y, fy, tmp) + ) + assert status == ARK_SUCCESS + + status = ARKodeSetLinear(ark.get(), 0) + assert status == ARK_SUCCESS + + # Output mesh + with open("heat_mesh.txt", "w") as FID: + for i in range(N): + FID.write(f" {problem.dx * i:.16e}\n") + + yarr = N_VGetArrayPointer(y) + with open("heat1D.txt", "w") as UFID: + # Output initial condition + UFID.write(" ".join(f"{val:.16e}" for val in yarr) + "\n") + + # Main time-stepping loop + t = T0 + dTout = (Tf - T0) / Nt + tout = T0 + dTout + print(" t ||u||_rms") + print(" -------------------------") + print(f" {t:10.6f} {np.sqrt(np.dot(yarr, yarr) / N):10.6f}") + for iout in range(Nt): + status, tret = ARKodeEvolve(ark.get(), tout, y, ARK_NORMAL) + yarr = N_VGetArrayPointer(y) + print(f" {tret:10.6f} {np.sqrt(np.dot(yarr, yarr) / N):10.6f}") + UFID.write(" ".join(f"{val:.16e}" for val in yarr) + "\n") + if status == ARK_SUCCESS: + tout += dTout + tout = min(tout, Tf) + else: + print("Solver failure, stopping integration") + break + print(" -------------------------") + + # Print statistics + status, nst = ARKodeGetNumSteps(ark.get()) + assert status == ARK_SUCCESS + status, nst_a = ARKodeGetNumStepAttempts(ark.get()) + assert status == ARK_SUCCESS + status, nfe = ARKodeGetNumRhsEvals(ark.get(), 0) + assert status == ARK_SUCCESS + status, nfi = ARKodeGetNumRhsEvals(ark.get(), 1) + assert status == ARK_SUCCESS + status, nsetups = ARKodeGetNumLinSolvSetups(ark.get()) + assert status == ARK_SUCCESS + status, nli = ARKodeGetNumLinIters(ark.get()) + assert status == ARK_SUCCESS + status, nJv = ARKodeGetNumJtimesEvals(ark.get()) + assert status == ARK_SUCCESS + status, nlcf = ARKodeGetNumLinConvFails(ark.get()) + assert status == ARK_SUCCESS + status, nni = ARKodeGetNumNonlinSolvIters(ark.get()) + assert status == ARK_SUCCESS + status, ncfn = ARKodeGetNumNonlinSolvConvFails(ark.get()) + assert status == ARK_SUCCESS + status, netf = ARKodeGetNumErrTestFails(ark.get()) + assert status == ARK_SUCCESS + + print("\nFinal Solver Statistics:") + print(f" Internal solver steps = {nst} (attempted = {nst_a})") + print(f" Total RHS evals: Fe = {nfe}, Fi = {nfi}") + print(f" Total linear solver setups = {nsetups}") + print(f" Total linear iterations = {nli}") + print(f" Total number of Jacobian-vector products = {nJv}") + print(f" Total number of linear solver convergence failures = {nlcf}") + print(f" Total number of Newton iterations = {nni}") + print(f" Total number of nonlinear solver convergence failures = {ncfn}") + print(f" Total number of error test failures = {netf}") + + +def test_ark_heat1D(): + main() + + +if __name__ == "__main__": + main() diff --git a/bindings/sundials4py/examples/arkode/ark_lotka_volterra_ASA.py b/bindings/sundials4py/examples/arkode/ark_lotka_volterra_ASA.py new file mode 100644 index 0000000000..8e020acbe7 --- /dev/null +++ b/bindings/sundials4py/examples/arkode/ark_lotka_volterra_ASA.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos +# ----------------------------------------------------------------- +# Python port of the SUNDIALS ARKODE Lotka-Volterra adjoint +# sensitivity example (ark_lotka_volterra_ASA.c) +# ----------------------------------------------------------------- +import numpy as np +import sundials4py.core as sun +import sundials4py.arkode as ark + + +class LotkaVolterraODE: + def __init__(self, p): + self.p = np.array(p, dtype=sun.sunrealtype) + self.NEQ = 2 + self.NP = 4 + + def set_init_cond(self, yvec): + y = sun.N_VGetArrayPointer(yvec) + y[0] = 1.0 + y[1] = 1.0 + return 0 + + def f(self, t, yvec, ydotvec): + p = self.p + y = sun.N_VGetArrayPointer(yvec) + ydot = sun.N_VGetArrayPointer(ydotvec) + ydot[0] = p[0] * y[0] - p[1] * y[0] * y[1] + ydot[1] = -p[2] * y[1] + p[3] * y[0] * y[1] + return 0 + + def vjp(self, vvec, Jvvec, t, yvec): + p = self.p + v = sun.N_VGetArrayPointer(vvec) + Jv = sun.N_VGetArrayPointer(Jvvec) + y = sun.N_VGetArrayPointer(yvec) + Jv[0] = (p[0] - p[1] * y[1]) * v[0] + p[3] * y[1] * v[1] + Jv[1] = -p[1] * y[0] * v[0] + (-p[2] + p[3] * y[0]) * v[1] + return 0 + + def parameter_vjp(self, vvec, Jvvec, t, yvec): + v = sun.N_VGetArrayPointer(vvec) + Jv = sun.N_VGetArrayPointer(Jvvec) + y = sun.N_VGetArrayPointer(yvec) + Jv[0] = y[0] * v[0] + Jv[1] = -y[0] * y[1] * v[0] + Jv[2] = -y[1] * v[1] + Jv[3] = y[0] * y[1] * v[1] + return 0 + + def dgdu(self, yvec): + y = sun.N_VGetArrayPointer(yvec) + return np.array([-1.0 + y[0], -1.0 + y[1]], dtype=sun.sunrealtype) + + def dgdp(self, yvec): + return np.zeros(self.NP, dtype=sun.sunrealtype) + + def adj_rhs(self, t, y, sens, sens_dot): + l = sun.N_VGetSubvector_ManyVector(sens, 0) + ldot = sun.N_VGetSubvector_ManyVector(sens_dot, 0) + nu = sun.N_VGetSubvector_ManyVector(sens_dot, 1) + self.vjp(l, ldot, t, y) + self.parameter_vjp(l, nu, t, y) + return 0 + + def quad_rhs(self, t, yvec, muvec, qBdotvec): + self.parameter_vjp(muvec, qBdotvec, t, yvec) + return 0 + + +def main(): + # Program args + tf = 10.0 + dt = 1e-3 + order = 4 + check_freq = 2 + keep_checks = True + + # Problem parameters + p = [1.5, 1.0, 3.0, 1.0] + t0 = 0.0 + reltol = 1e-10 + abstol = 1e-14 + ode = LotkaVolterraODE(p) + NEQ = ode.NEQ + NP = ode.NP + + # + # Create the initial conditions vector + # + status, sunctx = sun.SUNContext_Create(sun.SUN_COMM_NULL) + assert status == sun.SUN_SUCCESS + + y = sun.N_VNew_Serial(NEQ, sunctx) + ode.set_init_cond(y) + + # + # Create the ARKODE stepper that will be used for the forward evolution. + # + arkode = ark.ARKStepCreate(lambda t, y, ydot, _: ode.f(t, y, ydot), None, t0, y, sunctx) + status = ark.ARKodeSetOrder(arkode.get(), 4) + assert status == ark.ARK_SUCCESS + status = ark.ARKodeSStolerances(arkode.get(), reltol, abstol) + assert status == ark.ARK_SUCCESS + status = ark.ARKodeSetFixedStep(arkode.get(), dt) + assert status == ark.ARK_SUCCESS + + # Due to roundoff in the `t` accumulation within the integrator, + # the integrator may actually use nsteps + 1 time steps to reach tf + status = ark.ARKodeSetMaxNumSteps(arkode.get(), int((tf - t0) / dt) + 1) + assert status == ark.ARK_SUCCESS + + # # Enable checkpointing during the forward run + nsteps = int(np.ceil((tf - t0) / dt)) + ncheck = nsteps * order + mem_helper = sun.SUNMemoryHelper_Sys(sunctx) + status, checkpoint_scheme = sun.SUNAdjointCheckpointScheme_Create_Fixed( + sun.SUNDATAIOMODE_INMEM, mem_helper, check_freq, ncheck, keep_checks, sunctx + ) + assert status == sun.SUN_SUCCESS + status = ark.ARKodeSetAdjointCheckpointScheme(arkode.get(), checkpoint_scheme) + assert status == ark.ARK_SUCCESS + + # + # Compute the forward solution + # + + print("Initial condition:") + yarr = sun.N_VGetArrayPointer(y) + print(yarr) + + tret = t0 + status, tret = ark.ARKodeEvolve(arkode.get(), tf, y, ark.ARK_NORMAL) + assert status == ark.ARK_SUCCESS + + print("Forward Solution:") + print(sun.N_VGetArrayPointer(y)) + + print("ARKODE Stats for Forward Solution:") + status, file_ptr = sun.SUNFileOpen("stdout", "w+") + assert status == ark.ARK_SUCCESS + ark.ARKodePrintAllStats(arkode.get(), file_ptr, sun.SUN_OUTPUTFORMAT_TABLE) + + # + # Create the adjoint stepper + # + + # Adjoint terminal condition + uB = sun.N_VNew_Serial(NEQ, sunctx) + arr_uB = ode.dgdu(y) + uB_arr = sun.N_VGetArrayPointer(uB) + uB_arr[:] = arr_uB + qB = sun.N_VNew_Serial(NP, sunctx) + qB_arr = sun.N_VGetArrayPointer(qB) + qB_arr[:] = ode.dgdp(y) + + # Combine adjoint vectors into a ManyVector + sens = [uB, qB] + sf = sun.N_VNew_ManyVector(2, sens, sunctx) + print("Adjoint terminal condition:") + print(sun.N_VGetArrayPointer(uB)) + print(sun.N_VGetArrayPointer(qB)) + + # Create ARKStep adjoint stepper + status, adj_stepper = ark.ARKStepCreateAdjointStepper( + arkode.get(), + lambda t, yv, lv, ldotv, _: ode.adj_rhs(t, yv, lv, ldotv), + None, + tf, + sf, + sunctx, + ) + + # + # Now compute the adjoint solution + # + + status, tret = sun.SUNAdjointStepper_Evolve(adj_stepper, t0, sf) + assert status == ark.ARK_SUCCESS + + print("Adjoint Solution:") + print(sun.N_VGetArrayPointer(uB)) + print(sun.N_VGetArrayPointer(qB)) + + # print("\nARKStep Adjoint Stats:") + # ARKStepAdjointStepperPrintAllStats(adj_stepper.get(), None, 0) + + +def test_ark_lotka_volterra_ASA(): + main() + + +if __name__ == "__main__": + main() diff --git a/bindings/sundials4py/examples/cvodes/cvs_brusselator.py b/bindings/sundials4py/examples/cvodes/cvs_brusselator.py new file mode 100644 index 0000000000..d642ac8d45 --- /dev/null +++ b/bindings/sundials4py/examples/cvodes/cvs_brusselator.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This is a direct port of ark_brusselator.py to use CVODE. +# The following test simulates a brusselator problem from chemical +# kinetics. This is an ODE system with 3 components, Y = [u,v,w], +# satisfying the equations, +# du/dt = a - (w+1)*u + v*u^2 +# dv/dt = w*u - v*u^2 +# dw/dt = (b-w)/ep - w*u +# for t in the interval [0.0, 10.0], with initial conditions +# Y0 = [u0,v0,w0]. +# +# u0=1.2, v0=3.1, w0=3, a=1, b=3.5, ep=5.0e-6 +# Here, w experiences a fast initial transient, jumping 0.5 +# within a few steps. All values proceed smoothly until +# around t=6.5, when both u and v undergo a sharp transition, +# with u increasing from around 0.5 to 5 and v decreasing +# from around 6 to 1 in less than 0.5 time units. After this +# transition, both u and v continue to evolve somewhat +# rapidly for another 1.4 time units, and finish off smoothly. +# +# This program solves the problem with the DIRK method, using a +# Newton iteration with the SUNDIALS dense linear solver, and a +# user-supplied Jacobian routine. +# +# 100 outputs are printed at equal intervals, and run statistics +# are printed at the end. +# ----------------------------------------------------------------- + +import sys +import numpy as np +from sundials4py.core import * +from sundials4py.cvodes import * + + +class BrusselatorODE: + def __init__(self, u0, v0, w0, a, b, ep): + self.u0 = u0 + self.v0 = v0 + self.w0 = w0 + self.a = a + self.b = b + self.ep = ep + + def set_init_cond(self, yvec): + y = N_VGetArrayPointer(yvec) + y[0] = self.u0 + y[1] = self.v0 + y[2] = self.w0 + return 0 + + def f(self, t, yvec, ydotvec): + y = N_VGetArrayPointer(yvec) + ydot = N_VGetArrayPointer(ydotvec) + a, b, ep = self.a, self.b, self.ep + u, v, w = y[0], y[1], y[2] + ydot[0] = a - (w + 1.0) * u + v * u * u + ydot[1] = w * u - v * u * u + ydot[2] = (b - w) / ep - w * u + return 0 + + def jac(self, t, yvec, fyvec, J, tmp1, tmp2, tmp3): + y = N_VGetArrayPointer(yvec) + a, b, ep = self.a, self.b, self.ep + u, v, w = y[0], y[1], y[2] + Jdata = SUNDenseMatrix_Data(J) + Jdata[0, 0] = -(w + 1.0) + 2.0 * u * v + Jdata[0, 1] = u * u + Jdata[0, 2] = -u + Jdata[1, 0] = w - 2.0 * u * v + Jdata[1, 1] = -u * u + Jdata[1, 2] = u + Jdata[2, 0] = -w + Jdata[2, 1] = 0.0 + Jdata[2, 2] = -1.0 / ep - u + return 0 + + +def main(): + # Problem parameters for Test 2 + u0 = 1.2 + v0 = 3.1 + w0 = 3.0 + a = 1.0 + b = 3.5 + ep = 5.0e-6 + T0 = 0.0 + Tf = 10.0 + dTout = 1.0 + NEQ = 3 + Nt = int(np.ceil(Tf / dTout)) + reltol = 1.0e-6 + abstol = 1.0e-10 + + status, sunctx = SUNContext_Create(SUN_COMM_NULL) + y = N_VNew_Serial(NEQ, sunctx) + + ode_problem = BrusselatorODE(u0, v0, w0, a, b, ep) + ode_problem.set_init_cond(y) + + cvode = CVodeCreate(CV_BDF, sunctx) + status = CVodeInit(cvode.get(), lambda t, y, ydot, _: ode_problem.f(t, y, ydot), T0, y) + assert status == CV_SUCCESS + status = CVodeSStolerances(cvode.get(), reltol, abstol) + assert status == CV_SUCCESS + status = CVodeSetMaxNumSteps(cvode.get(), 10000) + assert status == CV_SUCCESS + + # Dense matrix and linear solver + A = SUNDenseMatrix(NEQ, NEQ, sunctx) + LS = SUNLinSol_Dense(y, A, sunctx) + status = CVodeSetLinearSolver(cvode.get(), LS, A) + assert status == CV_SUCCESS + status = CVodeSetJacFn( + cvode.get(), + lambda t, yvec, fyvec, J, tmp1, tmp2, tmp3, _: ode_problem.jac( + t, yvec, fyvec, J, tmp1, tmp2, tmp3 + ), + ) + assert status == CV_SUCCESS + + # Parse any command line arguments + status = CVodeSetOptions(cvode.get(), "", "", len(sys.argv), sys.argv) + assert status == CV_SUCCESS + + # Initial problem output + yarr = N_VGetArrayPointer(y) + print("\nBrusselator ODE test problem (CVODE):") + print(f" initial conditions: u0 = {u0}, v0 = {v0}, w0 = {w0}") + print(f" problem parameters: a = {a}, b = {b}, ep = {ep}") + print(f" reltol = {reltol}, abstol = {abstol}\n") + print(" t u v w") + print(" -------------------------------------------") + print(f" {T0:10.6f} {yarr[0]:10.6f} {yarr[1]:10.6f} {yarr[2]:10.6f}") + + # Main time-stepping loop + with open("cv_solution.txt", "w") as UFID: + UFID.write("# t u v w\n") + UFID.write(f" {T0:.16e} {yarr[0]:.16e} {yarr[1]:.16e} {yarr[2]:.16e}\n") + tout = T0 + dTout + for iout in range(Nt): + status, tret = CVode(cvode.get(), tout, y, CV_NORMAL) + yarr = N_VGetArrayPointer(y) + print(f" {tret:10.6f} {yarr[0]:10.6f} {yarr[1]:10.6f} {yarr[2]:10.6f}") + UFID.write(f" {tret:.16e} {yarr[0]:.16e} {yarr[1]:.16e} {yarr[2]:.16e}\n") + if status == CV_SUCCESS: + tout += dTout + tout = min(tout, Tf) + else: + print("Solver failure, stopping integration") + break + print(" -------------------------------------------") + + # Print statistics + status, nst = CVodeGetNumSteps(cvode.get()) + assert status == CV_SUCCESS + status, nfe = CVodeGetNumRhsEvals(cvode.get()) + assert status == CV_SUCCESS + status, nsetups = CVodeGetNumLinSolvSetups(cvode.get()) + assert status == CV_SUCCESS + status, nni = CVodeGetNumNonlinSolvIters(cvode.get()) + assert status == CV_SUCCESS + status, ncfn = CVodeGetNumNonlinSolvConvFails(cvode.get()) + assert status == CV_SUCCESS + status, nje = CVodeGetNumJacEvals(cvode.get()) + assert status == CV_SUCCESS + status, nfeLS = CVodeGetNumLinRhsEvals(cvode.get()) + assert status == CV_SUCCESS + + print("\nFinal Solver Statistics:") + print(f" Internal solver steps = {nst}") + print(f" Total RHS evals = {nfe}") + print(f" Total linear solver setups = {nsetups}") + print(f" Total number of Jacobian evaluations = {nje}") + print(f" Total number of Newton iterations = {nni}") + print(f" Total number of nonlinear solver convergence failures = {ncfn}") + print(f" Total RHS evals for setting up the linear system = {nfeLS}") + + +def test_cvs_brusselator(): + main() + + +if __name__ == "__main__": + main() diff --git a/bindings/sundials4py/examples/cvodes/cvs_lotkavolterra_ASA.py b/bindings/sundials4py/examples/cvodes/cvs_lotkavolterra_ASA.py new file mode 100644 index 0000000000..9d68055dc8 --- /dev/null +++ b/bindings/sundials4py/examples/cvodes/cvs_lotkavolterra_ASA.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import numpy as np +from sundials4py.core import * +from sundials4py.cvodes import * + + +class LotkaVolterraODE: + def __init__(self, p): + self.p = np.array(p, dtype=sunrealtype) + self.NEQ = 2 + self.NP = 4 + + def set_init_cond(self, yvec): + # Set initial condition u0 = [1.0, 1.0] + y = N_VGetArrayPointer(yvec) + y[0] = 1.0 + y[1] = 1.0 + return 0 + + def f(self, t, yvec, ydotvec): + # Lotka-Volterra ODE right-hand side + p = self.p + y = N_VGetArrayPointer(yvec) + ydot = N_VGetArrayPointer(ydotvec) + ydot[0] = p[0] * y[0] - p[1] * y[0] * y[1] + ydot[1] = -p[2] * y[1] + p[3] * y[0] * y[1] + return 0 + + def vjp(self, vvec, Jvvec, t, yvec): + # Jacobian-vector product v^T (df/du) + p = self.p + v = N_VGetArrayPointer(vvec) + Jv = N_VGetArrayPointer(Jvvec) + y = N_VGetArrayPointer(yvec) + Jv[0] = (p[0] - p[1] * y[1]) * v[0] + p[3] * y[1] * v[1] + Jv[1] = -p[1] * y[0] * v[0] + (-p[2] + p[3] * y[0]) * v[1] + return 0 + + def parameter_vjp(self, vvec, Jvvec, t, yvec): + # Parameter Jacobian-vector product v^T (df/dp) + v = N_VGetArrayPointer(vvec) + Jv = N_VGetArrayPointer(Jvvec) + y = N_VGetArrayPointer(yvec) + # Derivatives w.r.t. each parameter + Jv[0] = y[0] * v[0] + Jv[1] = -y[0] * y[1] * v[0] + Jv[2] = -y[1] * v[1] + Jv[3] = y[0] * y[1] * v[1] + return 0 + + def dgdu(self, yvec): + # Gradient of the cost function w.r.t. u + y = N_VGetArrayPointer(yvec) + # g(u) = 0.5 * ||1 - u||^2, so grad = u - 1 + return np.array([-1.0 + y[0], -1.0 + y[1]], dtype=sunrealtype) + + def adjoint_rhs(self, t, yvec, lvec, ldotvec): + # Adjoint ODE right-hand side: -mu^T (df/du) + self.vjp(lvec, ldotvec, t, yvec) + ldot = N_VGetArrayPointer(ldotvec) + ldot *= -1.0 + return 0 + + def quad_rhs(self, t, yvec, muvec, qBdotvec): + # Quadrature right-hand side: mu^T (df/dp) + self.parameter_vjp(muvec, qBdotvec, t, yvec) + return 0 + + +def main(): + # Problem parameters + p = [1.5, 1.0, 3.0, 1.0] + T0 = 0.0 + Tf = 10.0 + reltol = 1e-10 + abstol = 1e-14 + steps = 5 + ode = LotkaVolterraODE(p) + NEQ = ode.NEQ + NP = ode.NP + + status, sunctx = SUNContext_Create(SUN_COMM_NULL) + y = N_VNew_Serial(NEQ, sunctx) + + # Set initial condition + ode.set_init_cond(y) + + # Create CVODE solver and set up problem + cvode = CVodeCreate(CV_BDF, sunctx) + + # Initialize CVODE with ODE RHS + status = CVodeInit(cvode.get(), lambda t, yv, ydv, _: ode.f(t, yv, ydv), T0, y) + assert status == CV_SUCCESS + + # Set tolerances + status = CVodeSStolerances(cvode.get(), reltol, abstol) + assert status == CV_SUCCESS + + # Set max steps + status = CVodeSetMaxNumSteps(cvode.get(), 100000) + assert status == CV_SUCCESS + + # Set linear solver + ls = SUNLinSol_SPGMR(y, 0, 3, sunctx) + status = CVodeSetLinearSolver(cvode.get(), ls, None) + assert status == CV_SUCCESS + + # Enable adjoint sensitivity analysis + status = CVodeAdjInit(cvode.get(), steps, 1) # CV_HERMITE = 1 + assert status == CV_SUCCESS + + # Output problem setup + print("\nLotka-Volterra ODE test problem (CVODE, ASA):") + print(f" initial conditions: y0 = [1.0, 1.0]") + print(f" parameters: p = {p}") + print(f" reltol = {reltol}, abstol = {abstol}\n") + print(" t x y") + print(" ---------------------------------") + yarr = N_VGetArrayPointer(y) + print(f" {T0:10.6f} {yarr[0]:10.6f} {yarr[1]:10.6f}") + + # Forward integration + tout = Tf + t = 0.0 + status, tret, ncheck = CVodeF(cvode.get(), tout, y, CV_NORMAL) + yarr = N_VGetArrayPointer(y) + print(f" {tout:10.6f} {yarr[0]:10.6f} {yarr[1]:10.6f}") + print(" ---------------------------------") + + # Adjoint terminal condition + uB = N_VNew_Serial(NEQ, sunctx) + arr_uB = N_VGetArrayPointer(uB) + arr_uB[:] = ode.dgdu(y) + qB = N_VNew_Serial(NP, sunctx) + N_VConst(0.0, qB) + print("Adjoint terminal condition:") + print(arr_uB) + print(N_VGetArrayPointer(qB)) + + # Create the CVODES object for the backward problem + status, which = CVodeCreateB(cvode.get(), CV_BDF) + assert status == CV_SUCCESS + + # Initialize the CVODES solver for the backward problem + status = CVodeInitB( + cvode.get(), which, lambda t, yv, lv, ldotv, _: ode.adjoint_rhs(t, yv, lv, ldotv), Tf, uB + ) + assert status == CV_SUCCESS + + # Set the tolerances for the backward problem + status = CVodeSStolerancesB(cvode.get(), which, reltol, abstol) + assert status == CV_SUCCESS + + # Create the linear solver for the backward problem + lsb = SUNLinSol_SPGMR(uB, 0, 3, sunctx) + status = CVodeSetLinearSolverB(cvode.get(), which, lsb, None) + assert status == CV_SUCCESS + + # Call CVodeQuadInitB to allocate internal memory and initialize backward + # quadrature integration. This gives the sensitivities w.r.t. the parameters. + status = CVodeQuadInitB( + cvode.get(), which, lambda t, yv, mu, qBdot, _: ode.quad_rhs(t, yv, mu, qBdot), qB + ) + assert status == CV_SUCCESS + + # Call CVodeSetQuadErrCon to specify whether or not the quadrature variables + # are to be used in the step size control mechanism within CVODES. Call + # CVodeQuadSStolerances or CVodeQuadSVtolerances to specify the integration + # tolerances for the quadrature variables. + status = CVodeSetQuadErrConB(cvode.get(), which, 1) + assert status == CV_SUCCESS + + # Call CVodeQuadSStolerancesB to specify the scalar relative and absolute tolerances + # for the backward problem. + status = CVodeQuadSStolerancesB(cvode.get(), which, reltol, abstol) + assert status == CV_SUCCESS + + # Integrate the adjoint ODE + status = CVodeB(cvode.get(), T0, CV_NORMAL) + assert status >= 0 + t = 0.0 + + # Get the final adjoint solution + status, t = CVodeGetB(cvode.get(), which, uB) + assert status == CV_SUCCESS + + # Call CVodeGetQuadB to get the quadrature solution vector after a + # successful return from CVodeB. + status, t = CVodeGetQuadB(cvode.get(), which, qB) + assert status == CV_SUCCESS + + # dg/dp = -qB + arr_qB = N_VGetArrayPointer(qB) + arr_qB *= -1.0 + print(f"Adjoint Solution at t = {t}:") + print(N_VGetArrayPointer(uB)) + print(arr_qB) + + +def test_cvs_lotkavolterra_ASA(): + main() + + +if __name__ == "__main__": + main() diff --git a/bindings/sundials4py/examples/idas/idasSlCrank_dns.py b/bindings/sundials4py/examples/idas/idasSlCrank_dns.py new file mode 100644 index 0000000000..12a230cf63 --- /dev/null +++ b/bindings/sundials4py/examples/idas/idasSlCrank_dns.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# Python port of the SUNDIALS slider-crank DAE example in IDAS +# (idasSlCrank_dns.c). +# +# Simulation of a slider-crank mechanism modelled with 3 generalized +# coordinates: crank angle, connecting bar angle, and slider location. +# The mechanism moves under the action of a constant horizontal force +# applied to the connecting rod and a spring-damper connecting the crank +# and connecting rod. +# +# The equations of motion are formulated as a system of stabilized +# index-2 DAEs (Gear-Gupta-Leimkuhler formulation). +# IDAS also computes the average kinetic energy as the quadrature: +# G = int_t0^tend g(t,y,p) dt, +# where +# g(t,y,p) = 0.5*J1*v1^2 + 0.5*J2*v3^2 + 0.5*m2*v2^2 +# ----------------------------------------------------------------- + +import numpy as np +from sundials4py.core import * +from sundials4py.idas import * + + +class SliderCrankDAE: + def __init__(self, a=0.5, J1=1.0, m2=1.0, m1=1.0, J2=2.0, l0=1.0, F=1.0, k=1.0, c=1.0): + self.a = a + self.J1 = J1 + self.m2 = m2 + self.m1 = m1 + self.J2 = J2 + self.l0 = l0 + self.F = F + self.k = k + self.c = c + + def set_initial_conditions(self, yyvec, ypvec): + pi = 4.0 * np.arctan(1.0) + a = self.a + J1 = self.J1 + m2 = self.m2 + J2 = self.J2 + q = pi / 2.0 + p = np.arcsin(-a) + x = np.cos(p) + yy = N_VGetArrayPointer(yyvec) + yp = N_VGetArrayPointer(ypvec) + yy[:] = 0.0 + yp[:] = 0.0 + yy[0] = q + yy[1] = x + yy[2] = p + Q = self.force(yy) + yp[3] = Q[0] / J1 + yp[4] = Q[1] / m2 + yp[5] = Q[2] / J2 + return 0 + + def force(self, yy): + a, k, c, l0, F = self.a, self.k, self.c, self.l0, self.F + q, x, p = yy[0], yy[1], yy[2] + qd, xd, pd = yy[3], yy[4], yy[5] + s1, c1 = np.sin(q), np.cos(q) + s2, c2 = np.sin(p), np.cos(p) + s21 = s2 * c1 - c2 * s1 + c21 = c2 * c1 + s2 * s1 + l2 = x * x - x * (c2 + a * c1) + (1.0 + a * a) / 4.0 + a * c21 / 2.0 + ell = np.sqrt(l2) + ld = ( + 2.0 * x * xd + - xd * (c2 + a * c1) + + x * (s2 * pd + a * s1 * qd) + - a * s21 * (pd - qd) / 2.0 + ) / (2.0 * ell) + f = k * (ell - l0) + c * ld + fl = f / ell + Q = np.zeros(3) + Q[0] = -fl * a * (s21 / 2.0 + x * s1) / 2.0 + Q[1] = fl * (c2 / 2.0 - x + a * c1 / 2.0) + F + Q[2] = -fl * (x * s2 - a * s21 / 2.0) / 2.0 - F * s2 + return Q + + def residual(self, t, yyvec, ypvec, rvec): + a, J1, m2, J2 = self.a, self.J1, self.m2, self.J2 + yy = N_VGetArrayPointer(yyvec) + yp = N_VGetArrayPointer(ypvec) + rr = N_VGetArrayPointer(rvec) + q, x, p = yy[0], yy[1], yy[2] + qd, xd, pd = yy[3], yy[4], yy[5] + lam1, lam2 = yy[6], yy[7] + mu1, mu2 = yy[8], yy[9] + s1, c1 = np.sin(q), np.cos(q) + s2, c2 = np.sin(p), np.cos(p) + Q = self.force(yy) + rr[0] = yp[0] - qd + a * s1 * mu1 - a * c1 * mu2 + rr[1] = yp[1] - xd + mu1 + rr[2] = yp[2] - pd + s2 * mu1 - c2 * mu2 + rr[3] = J1 * yp[3] - Q[0] + a * s1 * lam1 - a * c1 * lam2 + rr[4] = m2 * yp[4] - Q[1] + lam1 + rr[5] = J2 * yp[5] - Q[2] + s2 * lam1 - c2 * lam2 + rr[6] = x - c2 - a * c1 + rr[7] = -s2 - a * s1 + rr[8] = a * s1 * qd + xd + s2 * pd + rr[9] = -a * c1 * qd - c2 * pd + return 0 + + def rhsQ(self, t, yyvec, ypvec, qdotvec): + J1, m2, J2 = self.J1, self.m2, self.J2 + yy = N_VGetArrayPointer(yyvec) + qdot = N_VGetArrayPointer(qdotvec) + v1, v2, v3 = yy[3], yy[4], yy[5] + qdot[0] = 0.5 * (J1 * v1 * v1 + m2 * v2 * v2 + J2 * v3 * v3) + return 0 + + +def main(): + # Problem parameters + RTOLF = 1e-6 + ATOLF = 1e-7 + RTOLQ = 1e-6 + ATOLQ = 1e-8 + TBEGIN = 0.0 + TEND = 10.0 + NEQ = 10 + NOUT = 25 + + # Create model + dae = SliderCrankDAE() + + # SUNDIALS context + status, sunctx = SUNContext_Create(SUN_COMM_NULL) + + # Create N_Vectors + id = N_VNew_Serial(NEQ, sunctx) + yy = N_VClone(id) + yp = N_VClone(id) + q = N_VNew_Serial(1, sunctx) + + # Consistent IC + id_arr = N_VGetArrayPointer(id) + id_arr[:6] = 1.0 + id_arr[6:] = 0.0 + dae.set_initial_conditions(yy, yp) + + # IDAS initialization + ida = IDACreate(sunctx) + status = IDAInit( + ida.get(), lambda t, yv, ypv, rv, _: dae.residual(t, yv, ypv, rv), TBEGIN, yy, yp + ) + assert status == IDA_SUCCESS + status = IDASStolerances(ida.get(), RTOLF, ATOLF) + assert status == IDA_SUCCESS + status = IDASetId(ida.get(), id) + assert status == IDA_SUCCESS + status = IDASetSuppressAlg(ida.get(), True) + assert status == IDA_SUCCESS + status = IDASetMaxNumSteps(ida.get(), 20000) + assert status == IDA_SUCCESS + + # Create dense SUNMatrix to use with dense linear solver + A = SUNDenseMatrix(NEQ, NEQ, sunctx) + + # Create dense linear solver + LS = SUNLinSol_Dense(yy, A, sunctx) + + # Attach the matrix and linear solver + status = IDASetLinearSolver(ida.get(), LS, A) + assert status == IDA_SUCCESS + + # Setup quadrature + N_VConst(0.0, q) + status = IDAQuadInit(ida.get(), lambda t, yv, ypv, qv, _: dae.rhsQ(t, yv, ypv, qv), q) + assert status == IDA_SUCCESS + status = IDASetQuadErrCon(ida.get(), 1) + assert status == IDA_SUCCESS + status = IDAQuadSStolerances(ida.get(), RTOLQ, ATOLQ) + assert status == IDA_SUCCESS + + # Output header + print("\nidasSlCrank_dns: Slider-Crank DAE serial example problem for IDAS") + print("Linear solver: DENSE, Jacobian is computed by IDAS.") + print(f"Tolerance parameters: rtol = {RTOLF} atol = {ATOLF}") + print("---------------------------------------------------------------------") + print(" t y1 y2 y3 | nst k h") + print("---------------------------------------------------------------------") + + # Time stepping loop (C example style) + yarr = N_VGetArrayPointer(yy) + tout = TEND / NOUT + tret = 0.0 + while True: + status, nst = IDAGetNumSteps(ida.get()) + status, kused = IDAGetLastOrder(ida.get()) + status, hused = IDAGetLastStep(ida.get()) + print( + f"{tret:5.2f} {yarr[0]:12.4e} {yarr[1]:12.4e} {yarr[2]:12.4e} | {nst:3d} {kused:1d} {hused:12.4e}" + ) + + status, tret = IDASolve(ida.get(), tout, yy, yp, IDA_NORMAL) + assert status >= 0 + + tout += TEND / NOUT + if tret > TEND: + status, nst = IDAGetNumSteps(ida.get()) + status, kused = IDAGetLastOrder(ida.get()) + status, hused = IDAGetLastStep(ida.get()) + print( + f"{tret:5.2f} {yarr[0]:12.4e} {yarr[1]:12.4e} {yarr[2]:12.4e} | {nst:3d} {kused:1d} {hused:12.4e}" + ) + break + + # Final statistics + status, nst = IDAGetNumSteps(ida.get()) + status, nre = IDAGetNumResEvals(ida.get()) + status, nje = IDAGetNumJacEvals(ida.get()) + status, nni = IDAGetNumNonlinSolvIters(ida.get()) + status, netf = IDAGetNumErrTestFails(ida.get()) + status, nnf = IDAGetNumNonlinSolvConvFails(ida.get()) + status, ncfn = IDAGetNumStepSolveFails(ida.get()) + status, nreLS = IDAGetNumLinResEvals(ida.get()) + + print("\nFinal Run Statistics: \n") + print(f"Number of steps = {nst}") + print(f"Number of residual evaluations = {nre + nreLS}") + print(f"Number of Jacobian evaluations = {nje}") + print(f"Number of nonlinear iterations = {nni}") + print(f"Number of error test failures = {netf}") + print(f"Number of nonlinear conv. failures = {nnf}") + print(f"Number of step solver failures = {ncfn}") + + status, tret = IDAGetQuad(ida.get(), q) + print("--------------------------------------------") + print(f" G = {N_VGetArrayPointer(q)[0]:24.16f}") + print("--------------------------------------------\n") + + +def test_idaSlCrank_dns(): + main() + + +if __name__ == "__main__": + main() diff --git a/bindings/sundials4py/examples/kinsol/kinLaplace_bnd.py b/bindings/sundials4py/examples/kinsol/kinLaplace_bnd.py new file mode 100644 index 0000000000..26f692059d --- /dev/null +++ b/bindings/sundials4py/examples/kinsol/kinLaplace_bnd.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This is a direct port of the C example, +# examples/kinsol/serial/kinLaplace_bnd.c +# +# This example solves a 2D elliptic PDE +# +# d^2 u / dx^2 + d^2 u / dy^2 = u^3 - u - 2.0 +# +# subject to homogeneous Dirichlet boundary conditions. +# The PDE is discretized on a uniform NX+2 by NY+2 grid with +# central differencing, and with boundary values eliminated, +# leaving a system of size NEQ = NX*NY. +# The nonlinear system is solved by KINSOL using the SUNBAND linear +# solver. +# ----------------------------------------------------------------- + +import numpy as np +from sundials4py.core import * +from sundials4py.kinsol import * + +# Problem Constants +NX = 31 +NY = 31 +NEQ = NX * NY +SKIP = 3 +FTOL = 1e-12 +ZERO = 0.0 +ONE = 1.0 +TWO = 2.0 + + +# Helper to get index in y for (i, j) +def idx(i, j): + return (j - 1) + (i - 1) * NY + + +class Laplace2D: + def __init__(self, NX, NY): + self.NX = NX + self.NY = NY + self.NEQ = NX * NY + self.dx = ONE / (NX + 1) + self.dy = ONE / (NY + 1) + self.hdc = ONE / (self.dx * self.dx) + self.vdc = ONE / (self.dy * self.dy) + + def set_init_cond(self, yvec): + y = N_VGetArrayPointer(yvec) + # Initial guess: zero everywhere + for i in range(1, self.NX + 1): + for j in range(1, self.NY + 1): + y[idx(i, j)] = ZERO + return 0 + + def func(self, uvec, fvec): + u = N_VGetArrayPointer(uvec) + f = N_VGetArrayPointer(fvec) + # Reshape to 2D for vectorized operations + u2d = np.zeros((self.NX + 2, self.NY + 2), dtype=sunrealtype) + # Fill interior points + for i in range(1, self.NX + 1): + for j in range(1, self.NY + 1): + u2d[i, j] = u[idx(i, j)] + # Vectorized finite difference + uij = u2d[1:-1, 1:-1] + udn = u2d[1:-1, 0:-2] + uup = u2d[1:-1, 2:] + ult = u2d[0:-2, 1:-1] + urt = u2d[2:, 1:-1] + hdiff = self.hdc * (ult - TWO * uij + urt) + vdiff = self.vdc * (uup - TWO * uij + udn) + f2d = hdiff + vdiff + uij - uij * uij * uij + 2.0 + # Write back to 1D f + for i in range(1, self.NX + 1): + for j in range(1, self.NY + 1): + f[idx(i, j)] = f2d[i - 1, j - 1] + return 0 + + def print_output(self, yvec): + y = N_VGetArrayPointer(yvec) + print(" ", end="") + for i in range(1, self.NX + 1, SKIP): + x = i * self.dx + print(f"{x:<8.5f} ", end="") + print("\n") + for j in range(1, self.NY + 1, SKIP): + yval = j * self.dy + print(f"{yval:<8.5f} ", end="") + for i in range(1, self.NX + 1, SKIP): + print(f"{y[idx(i, j)]:<8.5f} ", end="") + print() + + +def main(): + print("\n2D elliptic PDE on unit square") + print(" d^2 u / dx^2 + d^2 u / dy^2 = u^3 - u + 2.0") + print(" + homogeneous Dirichlet boundary conditions\n") + print("Solution method: Modified Newton with band linear solver") + print(f"Problem size: {NX} x {NY} = {NEQ}\n") + + status, sunctx = SUNContext_Create(SUN_COMM_NULL) + y = N_VNew_Serial(NEQ, sunctx) + scale = N_VNew_Serial(NEQ, sunctx) + + # Create problem instance and set initial guess + problem = Laplace2D(NX, NY) + problem.set_init_cond(y) + + # Initialize and allocate memory for KINSOL + kin = KINCreate(sunctx) + status = KINInit(kin.get(), lambda u, f, _: problem.func(u, f), y) + assert status == KIN_SUCCESS + + # Set function norm tolerance + status = KINSetFuncNormTol(kin.get(), FTOL) + assert status == KIN_SUCCESS + + # Create band matrix and linear solver + J = SUNBandMatrix(NEQ, NX, NX, sunctx) + LS = SUNLinSol_Band(y, J, sunctx) + status = KINSetLinearSolver(kin.get(), LS, J) + assert status == KIN_SUCCESS + + # Set Modified Newton parameters + status = KINSetMaxSetupCalls(kin.get(), 100) + assert status == KIN_SUCCESS + status = KINSetMaxSubSetupCalls(kin.get(), 1) + assert status == KIN_SUCCESS + + # No scaling used + N_VConst(ONE, scale) + + # Call KINSol to solve problem + status = KINSol(kin.get(), y, KIN_LINESEARCH, scale, scale) + assert status == KIN_SUCCESS + + # Get scaled norm of the system function + status, fnorm = KINGetFuncNorm(kin.get()) + assert status == KIN_SUCCESS + print(f"\nComputed solution (||F|| = {fnorm}):\n") + problem.print_output(y) + + # Print final statistics (faithful to C PrintFinalStats) + status, nni = KINGetNumNonlinSolvIters(kin.get()) + assert status == KIN_SUCCESS + + status, nfe = KINGetNumFuncEvals(kin.get()) + assert status == KIN_SUCCESS + + status, nbcfails = KINGetNumBetaCondFails(kin.get()) + assert status == KIN_SUCCESS + + status, nbacktr = KINGetNumBacktrackOps(kin.get()) + assert status == KIN_SUCCESS + + status, nje = KINGetNumJacEvals(kin.get()) + assert status == KIN_SUCCESS + + status, nfeD = KINGetNumLinFuncEvals(kin.get()) + assert status == KIN_SUCCESS + + print("\nFinal Statistics.. \n") + print(f"nni = {nni:6d} nfe = {nfe:6d}") + print(f"nbcfails = {nbcfails:6d} nbacktr = {nbacktr:6d}") + print(f"nje = {nje:6d} nfeB = {nfeD:6d}") + + +def test_kinLaplace_bnd(): + main() + + +if __name__ == "__main__": + main() diff --git a/bindings/sundials4py/idas/generate.yaml b/bindings/sundials4py/idas/generate.yaml new file mode 100644 index 0000000000..39cb8e79d9 --- /dev/null +++ b/bindings/sundials4py/idas/generate.yaml @@ -0,0 +1,68 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2002-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This YAML file is used to instruct the generate.py script on +# how to generate some of the nanobind code needed for the +# IDAS module of sundials4py. +# ----------------------------------------------------------------- + +modules: + all: + fn_exclude_by_name__regex: + - "Free" # Free and destroy functions should not need to be called as all objects on the Python side are RAII objects + - "Destroy" + - "Space" # Space functions are deprecated, so dont expose them in Python + # Due to the need to convert between sys.argv and C argv, we need to do custom wrappers of these + - "SetOptions" + macro_define_include_by_name__regex: + - "^SUN_" + - "^IDA_" + - "^IDALS_" + idas: + path: idas/idas_generated.hpp + headers: + - ../../include/idas/idas.h + - ../../include/idas/idas_ls.h + # this option describes the functions which have optional pointer arguments, + # i.e., one where you could provide NULL + fn_params_optional_with_default_null: + "SetLinearSolver": + - "A" + fn_exclude_by_name__regex: + # We do custom handling of Create so we can wrap the void* in a CVodeView + - "^IDACreate$" + # we use user_data for sneaking in python contexts, users can instead capture their states in a class + - "^IDAGetUserData$" + - "^IDASetUserData$" + - "^IDAGetUserDataB$" + - "^IDASetUserDataB$" + # this function should be deprecated, so we don't interface it + - "^IDASetMonitorFn$" + # generator cannot handle setting of function pointers, so we do something custom + - "IDAInit.*" + - "^IDASensInit$" + - "^IDASensInit1$" + - "IDAQuadInit.*" + - "IDAQuadSensInit.*" + - "IDASet.*Fn" + - "IDASet.*Preconditioner" + - "IDASetJacTimes.*" + - "^IDARootInit$" + - "^IDAWFtolerances$" + # TODO(CJB): interface these (in the future?) + # generator cannot yet handle mixing pointer outputs and ** in the same function + - "^IDAGetAdjCurrentCheckPoint$" + - "^IDAGetCurrentYSens$" + - "^IDAGetCurrentYpSens$" + - "^IDAGetNonlinearSystemData$" + - "^IDAGetNonlinearSystemDataSens$" \ No newline at end of file diff --git a/bindings/sundials4py/idas/idas.cpp b/bindings/sundials4py/idas/idas.cpp new file mode 100644 index 0000000000..b17560bb1a --- /dev/null +++ b/bindings/sundials4py/idas/idas.cpp @@ -0,0 +1,319 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include + +#include +#include +#include + +#include "idas/idas_impl.h" + +#include "idas_usersupplied.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +#define BIND_IDA_CALLBACK(NAME, FN_TYPE, MEMBER, WRAPPER, ...) \ + m.def( \ + #NAME, \ + [](void* ida_mem, std::function> fn) \ + { \ + auto fn_table = get_idas_fn_table(ida_mem); \ + fn_table->MEMBER = nb::cast(fn); \ + if (fn) { return NAME(ida_mem, &WRAPPER); } \ + else { return NAME(ida_mem, nullptr); } \ + }, \ + __VA_ARGS__) + +#define BIND_IDA_CALLBACK2(NAME, FN_TYPE1, MEMBER1, WRAPPER1, FN_TYPE2, \ + MEMBER2, WRAPPER2, ...) \ + m.def( \ + #NAME, \ + [](void* ida_mem, std::function> fn1, \ + std::function> fn2) \ + { \ + auto fn_table = get_idas_fn_table(ida_mem); \ + fn_table->MEMBER1 = nb::cast(fn1); \ + fn_table->MEMBER2 = nb::cast(fn2); \ + if (fn1) { return NAME(ida_mem, WRAPPER1, WRAPPER2); } \ + else { return NAME(ida_mem, nullptr, WRAPPER2); } \ + }, \ + __VA_ARGS__) + +#define BIND_IDAB_CALLBACK(NAME, FN_TYPE, MEMBER, WRAPPER, ...) \ + m.def( \ + #NAME, \ + [](void* ida_mem, int which, std::function> fn) \ + { \ + void* user_data = nullptr; \ + auto fn_table = get_idasa_fn_table(ida_mem, which); \ + fn_table->MEMBER = nb::cast(fn); \ + if (fn) { return NAME(ida_mem, which, &WRAPPER); } \ + else { return NAME(ida_mem, which, nullptr); } \ + }, \ + __VA_ARGS__) + +#define BIND_IDAB_CALLBACK2(NAME, FN_TYPE1, MEMBER1, WRAPPER1, FN_TYPE2, \ + MEMBER2, WRAPPER2, ...) \ + m.def( \ + #NAME, \ + [](void* ida_mem, int which, \ + std::function> fn1, \ + std::function> fn2) \ + { \ + void* user_data = nullptr; \ + auto fn_table = get_idasa_fn_table(ida_mem, which); \ + fn_table->MEMBER1 = nb::cast(fn1); \ + fn_table->MEMBER2 = nb::cast(fn2); \ + if (fn1) { return NAME(ida_mem, which, WRAPPER1, WRAPPER2); } \ + else { return NAME(ida_mem, which, nullptr, WRAPPER2); } \ + }, \ + __VA_ARGS__) + +namespace sundials4py { + +void bind_idas(nb::module_& m) +{ +#include "idas_generated.hpp" + + nb::class_(m, "IDAView") + .def("get", nb::overload_cast<>(&IDAView::get, nb::const_), + nb::rv_policy::reference); + + m.def( + "IDASetOptions", + [](void* ida_mem, const std::string& idaid, const std::string& file_name, + int argc, const std::vector& args) + { + std::vector argv; + + for (const auto& arg : args) + { + // We need a non-const char*, so we use data() and an explicit cast. + // This is safe as long as the underlying std::string is not modified. + argv.push_back(const_cast(arg.data())); + } + + return IDASetOptions(ida_mem, idaid.empty() ? nullptr : idaid.c_str(), + file_name.empty() ? nullptr : file_name.c_str(), + argc, argv.data()); + }, + nb::arg("ida_mem"), nb::arg("idaid"), nb::arg("file_name"), nb::arg("argc"), + nb::arg("args")); + + m.def( + "IDACreate", + [](SUNContext sunctx) + { return std::make_shared(IDACreate(sunctx)); }, + nb::arg("sunctx"), nb::keep_alive<0, 1>()); + + m.def("IDAInit", + [](void* ida_mem, std::function> res, + sunrealtype t0, N_Vector yy0, N_Vector yp0) + { + int ida_status = IDAInit(ida_mem, idas_res_wrapper, t0, yy0, yp0); + + auto fn_table = idas_user_supplied_fn_table_alloc(); + static_cast(ida_mem)->python = fn_table; + + ida_status = IDASetUserData(ida_mem, ida_mem); + if (ida_status != IDA_SUCCESS) + { + free(fn_table); + throw sundials4py::error_returned( + "Failed to set user data in IDAS memory"); + } + + fn_table->res = nb::cast(res); + + return ida_status; + }); + + m.def("IDARootInit", + [](void* ida_mem, int nrtfn, + std::function> fn) + { + auto fn_table = get_idas_fn_table(ida_mem); + fn_table->rootfn = nb::cast(fn); + return IDARootInit(ida_mem, nrtfn, &idas_rootfn_wrapper); + }); + + m.def("IDAQuadInit", + [](void* ida_mem, + std::function> resQ, N_Vector yQ0) + { + auto fn_table = get_idas_fn_table(ida_mem); + fn_table->resQ = nb::cast(resQ); + return IDAQuadInit(ida_mem, &idas_resQ_wrapper, yQ0); + }); + + BIND_IDA_CALLBACK(IDAWFtolerances, IDAEwtFn, ewtn, idas_ewtfn_wrapper, + nb::arg("ida_mem"), nb::arg("efun").none()); + + BIND_IDA_CALLBACK(IDASetNlsResFn, IDAResFn, resNLS, idas_nlsresfn_wrapper, + nb::arg("ida_mem"), nb::arg("res").none()); + + BIND_IDA_CALLBACK(IDASetJacFn, IDALsJacFn, lsjacfn, idas_lsjacfn_wrapper, + nb::arg("ida_mem"), nb::arg("jac").none()); + + BIND_IDA_CALLBACK2(IDASetPreconditioner, IDALsPrecSetupFn, lsprecsetupfn, + idas_lsprecsetupfn_wrapper, IDALsPrecSolveFn, lsprecsolvefn, + idas_lsprecsolvefn_wrapper, nb::arg("ida_mem"), + nb::arg("pset").none(), nb::arg("psolve").none()); + + BIND_IDA_CALLBACK2(IDASetJacTimes, IDALsJacTimesSetupFn, lsjactimessetupfn, + idas_lsjactimessetupfn_wrapper, IDALsJacTimesVecFn, + lsjactimesvecfn, idas_lsjactimesvecfn_wrapper, + nb::arg("ida_mem"), nb::arg("jtsetup").none(), + nb::arg("jtimes").none()); + + BIND_IDA_CALLBACK(IDASetJacTimesResFn, IDALsJacTimesVecFn, lsjacresfn, + idas_lsjacresfn_wrapper, nb::arg("ida_mem"), + nb::arg("jtimesResFn").none()); + + // + // Sensitivity and quadrature sensitivity bindings + // + + m.def("IDAQuadSensInit", + [](void* ida_mem, std::function resQS, + std::vector yQS0) + { + auto fn_table = get_idas_fn_table(ida_mem); + fn_table->resQS = nb::cast(resQS); + return IDAQuadSensInit(ida_mem, idas_resQS_wrapper, yQS0.data()); + }); + + m.def("IDASensInit", + [](void* ida_mem, int Ns, int ism, std::function resS, + std::vector yS0, std::vector ypS0) + { + auto fn_table = get_idas_fn_table(ida_mem); + fn_table->resS = nb::cast(resS); + return IDASensInit(ida_mem, Ns, ism, idas_resS_wrapper, yS0.data(), + ypS0.data()); + }); + + /// + // IDAS adjoint bindings + /// + + m.def("IDAInitB", + [](void* ida_mem, int which, + std::function> resB, + sunrealtype tB0, N_Vector yyB0, N_Vector ypB0) + { + int ida_status = IDAInitB(ida_mem, which, idas_resB_wrapper, tB0, + yyB0, ypB0); + + auto fn_table = idasa_user_supplied_fn_table_alloc(); + auto idab_mem = static_cast(IDAGetAdjIDABmem(ida_mem, which)); + idab_mem->python = fn_table; + + ida_status = IDASetUserDataB(ida_mem, which, idab_mem); + if (ida_status != IDA_SUCCESS) + { + free(fn_table); + throw sundials4py::error_returned( + "Failed to set user data in IDAS memory"); + } + + fn_table->resB = nb::cast(resB); + return ida_status; + }); + + m.def("IDAQuadInitB", + [](void* ida_mem, int which, + std::function> resQB, + N_Vector yQBO) + { + auto fn_table = get_idasa_fn_table(ida_mem, which); + fn_table->resQB = nb::cast(resQB); + return IDAQuadInitB(ida_mem, which, idas_resQB_wrapper, yQBO); + }); + + BIND_IDAB_CALLBACK(IDASetJacFnB, IDALsJacFnB, lsjacfnB, idas_lsjacfnB_wrapper, + nb::arg("ida_mem"), nb::arg("which"), + nb::arg("jacB").none()); + + BIND_IDAB_CALLBACK2(IDASetPreconditionerB, IDALsPrecSetupFnB, lsprecsetupfnB, + idas_lsprecsetupfnB_wrapper, IDALsPrecSolveFnB, + lsprecsolvefnB, idas_lsprecsolvefnB_wrapper, + nb::arg("ida_mem"), nb::arg("which"), + nb::arg("psetB").none(), nb::arg("psolveB").none()); + + BIND_IDAB_CALLBACK2(IDASetJacTimesB, IDALsJacTimesSetupFnB, lsjactimessetupfnB, + idas_lsjactimessetupfnB_wrapper, IDALsJacTimesVecFnB, + lsjactimesvecfnB, idas_lsjactimesvecfnB_wrapper, + nb::arg("ida_mem"), nb::arg("which"), + nb::arg("jsetupB").none(), nb::arg("jtimesB").none()); + + m.def("IDAInitBS", + [](void* ida_mem, int which, std::function resBS, + sunrealtype tB0, N_Vector yyB0, N_Vector ypB0) + { + int ida_status = IDAInitBS(ida_mem, which, ida_resBS_wrapper, tB0, + yyB0, ypB0); + + auto fn_table = idasa_user_supplied_fn_table_alloc(); + auto idab_mem = static_cast(IDAGetAdjIDABmem(ida_mem, which)); + idab_mem->python = fn_table; + + ida_status = IDASetUserDataB(ida_mem, which, idab_mem); + if (ida_status != IDA_SUCCESS) + { + free(fn_table); + throw sundials4py::error_returned( + "Failed to set user data in IDA memory"); + } + + fn_table->resBS = nb::cast(resBS); + return ida_status; + }); + + m.def("IDAQuadInitBS", + [](void* ida_mem, int which, std::function resQBS, + N_Vector yQBO) + { + auto fn_table = get_idasa_fn_table(ida_mem, which); + fn_table->resQBS = nb::cast(resQBS); + return IDAQuadInitBS(ida_mem, which, idas_resQBS_wrapper, yQBO); + }); + + BIND_IDAB_CALLBACK(IDASetJacFnBS, IDALsJacStdFnBS, lsjacfnBS, + idas_lsjacfnBS_wrapper, nb::arg("ida_mem"), + nb::arg("which"), nb::arg("jacBS").none()); + + BIND_IDAB_CALLBACK2(IDASetPreconditionerBS, IDALsPrecSetupStdFnBS, + lsprecsetupfnBS, idas_lsprecsetupfnBS_wrapper, + IDALsPrecSolveStdFnBS, lsprecsolvefnBS, + idas_lsprecsolvefnBS_wrapper, nb::arg("ida_mem"), + nb::arg("which"), nb::arg("psetBS").none(), + nb::arg("psolveBS").none()); + + BIND_IDAB_CALLBACK2(IDASetJacTimesBS, IDALsJacTimesSetupStdFnBS, + lsjactimessetupfnBS, idas_lsjactimessetupfnBS_wrapper, + IDALsJacTimesVecStdFnBS, lsjactimesvecfnBS, + idas_lsjactimesvecfnBS_wrapper, nb::arg("ida_mem"), + nb::arg("which"), nb::arg("jsetupBS").none(), + nb::arg("jtimesBS").none()); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/idas/idas_generated.hpp b/bindings/sundials4py/idas/idas_generated.hpp new file mode 100644 index 0000000000..ccd79f63f8 --- /dev/null +++ b/bindings/sundials4py/idas/idas_generated.hpp @@ -0,0 +1,1881 @@ +// #ifndef _IDAS_H +// +// #ifdef __cplusplus +// #endif +// +m.attr("IDA_NORMAL") = 1; +m.attr("IDA_ONE_STEP") = 2; +m.attr("IDA_YA_YDP_INIT") = 1; +m.attr("IDA_Y_INIT") = 2; +m.attr("IDA_SIMULTANEOUS") = 1; +m.attr("IDA_STAGGERED") = 2; +m.attr("IDA_CENTERED") = 1; +m.attr("IDA_FORWARD") = 2; +m.attr("IDA_HERMITE") = 1; +m.attr("IDA_POLYNOMIAL") = 2; +m.attr("IDA_SUCCESS") = 0; +m.attr("IDA_TSTOP_RETURN") = 1; +m.attr("IDA_ROOT_RETURN") = 2; +m.attr("IDA_WARNING") = 99; +m.attr("IDA_TOO_MUCH_WORK") = -1; +m.attr("IDA_TOO_MUCH_ACC") = -2; +m.attr("IDA_ERR_FAIL") = -3; +m.attr("IDA_CONV_FAIL") = -4; +m.attr("IDA_LINIT_FAIL") = -5; +m.attr("IDA_LSETUP_FAIL") = -6; +m.attr("IDA_LSOLVE_FAIL") = -7; +m.attr("IDA_RES_FAIL") = -8; +m.attr("IDA_REP_RES_ERR") = -9; +m.attr("IDA_RTFUNC_FAIL") = -10; +m.attr("IDA_CONSTR_FAIL") = -11; +m.attr("IDA_FIRST_RES_FAIL") = -12; +m.attr("IDA_LINESEARCH_FAIL") = -13; +m.attr("IDA_NO_RECOVERY") = -14; +m.attr("IDA_NLS_INIT_FAIL") = -15; +m.attr("IDA_NLS_SETUP_FAIL") = -16; +m.attr("IDA_NLS_FAIL") = -17; +m.attr("IDA_MEM_NULL") = -20; +m.attr("IDA_MEM_FAIL") = -21; +m.attr("IDA_ILL_INPUT") = -22; +m.attr("IDA_NO_MALLOC") = -23; +m.attr("IDA_BAD_EWT") = -24; +m.attr("IDA_BAD_K") = -25; +m.attr("IDA_BAD_T") = -26; +m.attr("IDA_BAD_DKY") = -27; +m.attr("IDA_VECTOROP_ERR") = -28; +m.attr("IDA_CONTEXT_ERR") = -29; +m.attr("IDA_NO_QUAD") = -30; +m.attr("IDA_QRHS_FAIL") = -31; +m.attr("IDA_FIRST_QRHS_ERR") = -32; +m.attr("IDA_REP_QRHS_ERR") = -33; +m.attr("IDA_NO_SENS") = -40; +m.attr("IDA_SRES_FAIL") = -41; +m.attr("IDA_REP_SRES_ERR") = -42; +m.attr("IDA_BAD_IS") = -43; +m.attr("IDA_NO_QUADSENS") = -50; +m.attr("IDA_QSRHS_FAIL") = -51; +m.attr("IDA_FIRST_QSRHS_ERR") = -52; +m.attr("IDA_REP_QSRHS_ERR") = -53; +m.attr("IDA_UNRECOGNIZED_ERROR") = -99; +m.attr("IDA_NO_ADJ") = -101; +m.attr("IDA_NO_FWD") = -102; +m.attr("IDA_NO_BCK") = -103; +m.attr("IDA_BAD_TB0") = -104; +m.attr("IDA_REIFWD_FAIL") = -105; +m.attr("IDA_FWD_FAIL") = -106; +m.attr("IDA_GETY_BADT") = -107; + +m.def("IDAReInit", IDAReInit, nb::arg("ida_mem"), nb::arg("t0"), nb::arg("yy0"), + nb::arg("yp0")); + +m.def("IDASStolerances", IDASStolerances, nb::arg("ida_mem"), nb::arg("reltol"), + nb::arg("abstol")); + +m.def("IDASVtolerances", IDASVtolerances, nb::arg("ida_mem"), nb::arg("reltol"), + nb::arg("abstol")); + +m.def("IDACalcIC", IDACalcIC, nb::arg("ida_mem"), nb::arg("icopt"), + nb::arg("tout1"), "Initial condition calculation function"); + +m.def("IDASetNonlinConvCoefIC", IDASetNonlinConvCoefIC, nb::arg("ida_mem"), + nb::arg("epiccon")); + +m.def("IDASetMaxNumStepsIC", IDASetMaxNumStepsIC, nb::arg("ida_mem"), + nb::arg("maxnh")); + +m.def("IDASetMaxNumJacsIC", IDASetMaxNumJacsIC, nb::arg("ida_mem"), + nb::arg("maxnj")); + +m.def("IDASetMaxNumItersIC", IDASetMaxNumItersIC, nb::arg("ida_mem"), + nb::arg("maxnit")); + +m.def("IDASetLineSearchOffIC", IDASetLineSearchOffIC, nb::arg("ida_mem"), + nb::arg("lsoff")); + +m.def("IDASetStepToleranceIC", IDASetStepToleranceIC, nb::arg("ida_mem"), + nb::arg("steptol")); + +m.def("IDASetMaxBacksIC", IDASetMaxBacksIC, nb::arg("ida_mem"), + nb::arg("maxbacks")); + +m.def("IDASetDeltaCjLSetup", IDASetDeltaCjLSetup, nb::arg("ida_max"), + nb::arg("dcj")); + +m.def("IDASetMaxOrd", IDASetMaxOrd, nb::arg("ida_mem"), nb::arg("maxord")); + +m.def("IDASetMaxNumSteps", IDASetMaxNumSteps, nb::arg("ida_mem"), + nb::arg("mxsteps")); + +m.def("IDASetInitStep", IDASetInitStep, nb::arg("ida_mem"), nb::arg("hin")); + +m.def("IDASetMaxStep", IDASetMaxStep, nb::arg("ida_mem"), nb::arg("hmax")); + +m.def("IDASetMinStep", IDASetMinStep, nb::arg("ida_mem"), nb::arg("hmin")); + +m.def("IDASetStopTime", IDASetStopTime, nb::arg("ida_mem"), nb::arg("tstop")); + +m.def("IDAClearStopTime", IDAClearStopTime, nb::arg("ida_mem")); + +m.def("IDASetMaxErrTestFails", IDASetMaxErrTestFails, nb::arg("ida_mem"), + nb::arg("maxnef")); + +m.def("IDASetSuppressAlg", IDASetSuppressAlg, nb::arg("ida_mem"), + nb::arg("suppressalg")); + +m.def("IDASetId", IDASetId, nb::arg("ida_mem"), nb::arg("id")); + +m.def("IDASetConstraints", IDASetConstraints, nb::arg("ida_mem"), + nb::arg("constraints")); + +m.def("IDASetEtaFixedStepBounds", IDASetEtaFixedStepBounds, nb::arg("ida_mem"), + nb::arg("eta_min_fx"), nb::arg("eta_max_fx")); + +m.def("IDASetEtaMin", IDASetEtaMin, nb::arg("ida_mem"), nb::arg("eta_min")); + +m.def("IDASetEtaMax", IDASetEtaMax, nb::arg("ida_mem"), nb::arg("eta_max")); + +m.def("IDASetEtaLow", IDASetEtaLow, nb::arg("ida_mem"), nb::arg("eta_low")); + +m.def("IDASetEtaMinErrFail", IDASetEtaMinErrFail, nb::arg("ida_mem"), + nb::arg("eta_min_ef")); + +m.def("IDASetEtaConvFail", IDASetEtaConvFail, nb::arg("ida_mem"), + nb::arg("eta_cf")); + +m.def("IDASetMaxConvFails", IDASetMaxConvFails, nb::arg("ida_mem"), + nb::arg("maxncf")); + +m.def("IDASetMaxNonlinIters", IDASetMaxNonlinIters, nb::arg("ida_mem"), + nb::arg("maxcor")); + +m.def("IDASetNonlinConvCoef", IDASetNonlinConvCoef, nb::arg("ida_mem"), + nb::arg("epcon")); + +m.def("IDASetNonlinearSolver", IDASetNonlinearSolver, nb::arg("ida_mem"), + nb::arg("NLS")); + +m.def( + "IDASetRootDirection", + [](void* ida_mem) -> std::tuple + { + auto IDASetRootDirection_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + int rootdir_adapt_modifiable; + + int r = IDASetRootDirection(ida_mem, &rootdir_adapt_modifiable); + return std::make_tuple(r, rootdir_adapt_modifiable); + }; + + return IDASetRootDirection_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def("IDASetNoInactiveRootWarn", IDASetNoInactiveRootWarn, nb::arg("ida_mem")); + +m.def( + "IDASolve", + [](void* ida_mem, sunrealtype tout, N_Vector yret, N_Vector ypret, + int itask) -> std::tuple + { + auto IDASolve_adapt_modifiable_immutable_to_return = + [](void* ida_mem, sunrealtype tout, N_Vector yret, N_Vector ypret, + int itask) -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + int r = IDASolve(ida_mem, tout, &tret_adapt_modifiable, yret, ypret, itask); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return IDASolve_adapt_modifiable_immutable_to_return(ida_mem, tout, yret, + ypret, itask); + }, + nb::arg("ida_mem"), nb::arg("tout"), nb::arg("yret"), nb::arg("ypret"), + nb::arg("itask"), "Solver function"); + +m.def("IDAComputeY", IDAComputeY, nb::arg("ida_mem"), nb::arg("ycor"), + nb::arg("y")); + +m.def("IDAComputeYp", IDAComputeYp, nb::arg("ida_mem"), nb::arg("ycor"), + nb::arg("yp")); + +m.def( + "IDAComputeYSens", + [](void* ida_mem, std::vector ycor_1d, + std::vector yyS_1d) -> int + { + auto IDAComputeYSens_adapt_arr_ptr_to_std_vector = + [](void* ida_mem, std::vector ycor_1d, + std::vector yyS_1d) -> int + { + N_Vector* ycor_1d_ptr = + reinterpret_cast(ycor_1d.empty() ? nullptr : ycor_1d.data()); + N_Vector* yyS_1d_ptr = + reinterpret_cast(yyS_1d.empty() ? nullptr : yyS_1d.data()); + + auto lambda_result = IDAComputeYSens(ida_mem, ycor_1d_ptr, yyS_1d_ptr); + return lambda_result; + }; + + return IDAComputeYSens_adapt_arr_ptr_to_std_vector(ida_mem, ycor_1d, yyS_1d); + }, + nb::arg("ida_mem"), nb::arg("ycor_1d"), nb::arg("yyS_1d")); + +m.def( + "IDAComputeYpSens", + [](void* ida_mem, std::vector ycor_1d, + std::vector ypS_1d) -> int + { + auto IDAComputeYpSens_adapt_arr_ptr_to_std_vector = + [](void* ida_mem, std::vector ycor_1d, + std::vector ypS_1d) -> int + { + N_Vector* ycor_1d_ptr = + reinterpret_cast(ycor_1d.empty() ? nullptr : ycor_1d.data()); + N_Vector* ypS_1d_ptr = + reinterpret_cast(ypS_1d.empty() ? nullptr : ypS_1d.data()); + + auto lambda_result = IDAComputeYpSens(ida_mem, ycor_1d_ptr, ypS_1d_ptr); + return lambda_result; + }; + + return IDAComputeYpSens_adapt_arr_ptr_to_std_vector(ida_mem, ycor_1d, ypS_1d); + }, + nb::arg("ida_mem"), nb::arg("ycor_1d"), nb::arg("ypS_1d")); + +m.def("IDAGetDky", IDAGetDky, nb::arg("ida_mem"), nb::arg("t"), nb::arg("k"), + nb::arg("dky"), "Dense output function"); + +m.def( + "IDAGetNumSteps", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumSteps_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nsteps_adapt_modifiable; + + int r = IDAGetNumSteps(ida_mem, &nsteps_adapt_modifiable); + return std::make_tuple(r, nsteps_adapt_modifiable); + }; + + return IDAGetNumSteps_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumResEvals", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumResEvals_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nrevals_adapt_modifiable; + + int r = IDAGetNumResEvals(ida_mem, &nrevals_adapt_modifiable); + return std::make_tuple(r, nrevals_adapt_modifiable); + }; + + return IDAGetNumResEvals_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumLinSolvSetups", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumLinSolvSetups_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nlinsetups_adapt_modifiable; + + int r = IDAGetNumLinSolvSetups(ida_mem, &nlinsetups_adapt_modifiable); + return std::make_tuple(r, nlinsetups_adapt_modifiable); + }; + + return IDAGetNumLinSolvSetups_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumErrTestFails", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumErrTestFails_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long netfails_adapt_modifiable; + + int r = IDAGetNumErrTestFails(ida_mem, &netfails_adapt_modifiable); + return std::make_tuple(r, netfails_adapt_modifiable); + }; + + return IDAGetNumErrTestFails_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumBacktrackOps", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumBacktrackOps_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nbacktr_adapt_modifiable; + + int r = IDAGetNumBacktrackOps(ida_mem, &nbacktr_adapt_modifiable); + return std::make_tuple(r, nbacktr_adapt_modifiable); + }; + + return IDAGetNumBacktrackOps_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def("IDAGetConsistentIC", IDAGetConsistentIC, nb::arg("ida_mem"), + nb::arg("yy0_mod"), nb::arg("yp0_mod")); + +m.def( + "IDAGetLastOrder", + [](void* ida_mem) -> std::tuple + { + auto IDAGetLastOrder_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + int klast_adapt_modifiable; + + int r = IDAGetLastOrder(ida_mem, &klast_adapt_modifiable); + return std::make_tuple(r, klast_adapt_modifiable); + }; + + return IDAGetLastOrder_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetCurrentOrder", + [](void* ida_mem) -> std::tuple + { + auto IDAGetCurrentOrder_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + int kcur_adapt_modifiable; + + int r = IDAGetCurrentOrder(ida_mem, &kcur_adapt_modifiable); + return std::make_tuple(r, kcur_adapt_modifiable); + }; + + return IDAGetCurrentOrder_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetCurrentCj", + [](void* ida_mem) -> std::tuple + { + auto IDAGetCurrentCj_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + sunrealtype cj_adapt_modifiable; + + int r = IDAGetCurrentCj(ida_mem, &cj_adapt_modifiable); + return std::make_tuple(r, cj_adapt_modifiable); + }; + + return IDAGetCurrentCj_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetCurrentY", + [](void* ida_mem) -> std::tuple + { + auto IDAGetCurrentY_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + N_Vector ycur_adapt_modifiable; + + int r = IDAGetCurrentY(ida_mem, &ycur_adapt_modifiable); + return std::make_tuple(r, ycur_adapt_modifiable); + }; + + return IDAGetCurrentY_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem"), "nb::rv_policy::reference", nb::rv_policy::reference); + +m.def( + "IDAGetCurrentYp", + [](void* ida_mem) -> std::tuple + { + auto IDAGetCurrentYp_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + N_Vector ypcur_adapt_modifiable; + + int r = IDAGetCurrentYp(ida_mem, &ypcur_adapt_modifiable); + return std::make_tuple(r, ypcur_adapt_modifiable); + }; + + return IDAGetCurrentYp_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem"), "nb::rv_policy::reference", nb::rv_policy::reference); + +m.def( + "IDAGetActualInitStep", + [](void* ida_mem) -> std::tuple + { + auto IDAGetActualInitStep_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + sunrealtype hinused_adapt_modifiable; + + int r = IDAGetActualInitStep(ida_mem, &hinused_adapt_modifiable); + return std::make_tuple(r, hinused_adapt_modifiable); + }; + + return IDAGetActualInitStep_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetLastStep", + [](void* ida_mem) -> std::tuple + { + auto IDAGetLastStep_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + sunrealtype hlast_adapt_modifiable; + + int r = IDAGetLastStep(ida_mem, &hlast_adapt_modifiable); + return std::make_tuple(r, hlast_adapt_modifiable); + }; + + return IDAGetLastStep_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetCurrentStep", + [](void* ida_mem) -> std::tuple + { + auto IDAGetCurrentStep_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + sunrealtype hcur_adapt_modifiable; + + int r = IDAGetCurrentStep(ida_mem, &hcur_adapt_modifiable); + return std::make_tuple(r, hcur_adapt_modifiable); + }; + + return IDAGetCurrentStep_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetCurrentTime", + [](void* ida_mem) -> std::tuple + { + auto IDAGetCurrentTime_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + sunrealtype tcur_adapt_modifiable; + + int r = IDAGetCurrentTime(ida_mem, &tcur_adapt_modifiable); + return std::make_tuple(r, tcur_adapt_modifiable); + }; + + return IDAGetCurrentTime_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetTolScaleFactor", + [](void* ida_mem) -> std::tuple + { + auto IDAGetTolScaleFactor_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + sunrealtype tolsfact_adapt_modifiable; + + int r = IDAGetTolScaleFactor(ida_mem, &tolsfact_adapt_modifiable); + return std::make_tuple(r, tolsfact_adapt_modifiable); + }; + + return IDAGetTolScaleFactor_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def("IDAGetErrWeights", IDAGetErrWeights, nb::arg("ida_mem"), + nb::arg("eweight")); + +m.def("IDAGetEstLocalErrors", IDAGetEstLocalErrors, nb::arg("ida_mem"), + nb::arg("ele")); + +m.def( + "IDAGetNumGEvals", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumGEvals_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long ngevals_adapt_modifiable; + + int r = IDAGetNumGEvals(ida_mem, &ngevals_adapt_modifiable); + return std::make_tuple(r, ngevals_adapt_modifiable); + }; + + return IDAGetNumGEvals_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetRootInfo", + [](void* ida_mem) -> std::tuple + { + auto IDAGetRootInfo_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + int rootsfound_adapt_modifiable; + + int r = IDAGetRootInfo(ida_mem, &rootsfound_adapt_modifiable); + return std::make_tuple(r, rootsfound_adapt_modifiable); + }; + + return IDAGetRootInfo_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetIntegratorStats", + [](void* ida_mem) -> std::tuple + { + auto IDAGetIntegratorStats_adapt_modifiable_immutable_to_return = + [](void* ida_mem) + -> std::tuple + { + long nsteps_adapt_modifiable; + long nrevals_adapt_modifiable; + long nlinsetups_adapt_modifiable; + long netfails_adapt_modifiable; + int qlast_adapt_modifiable; + int qcur_adapt_modifiable; + sunrealtype hinused_adapt_modifiable; + sunrealtype hlast_adapt_modifiable; + sunrealtype hcur_adapt_modifiable; + sunrealtype tcur_adapt_modifiable; + + int r = + IDAGetIntegratorStats(ida_mem, &nsteps_adapt_modifiable, + &nrevals_adapt_modifiable, + &nlinsetups_adapt_modifiable, + &netfails_adapt_modifiable, + &qlast_adapt_modifiable, &qcur_adapt_modifiable, + &hinused_adapt_modifiable, &hlast_adapt_modifiable, + &hcur_adapt_modifiable, &tcur_adapt_modifiable); + return std::make_tuple(r, nsteps_adapt_modifiable, nrevals_adapt_modifiable, + nlinsetups_adapt_modifiable, + netfails_adapt_modifiable, qlast_adapt_modifiable, + qcur_adapt_modifiable, hinused_adapt_modifiable, + hlast_adapt_modifiable, hcur_adapt_modifiable, + tcur_adapt_modifiable); + }; + + return IDAGetIntegratorStats_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumNonlinSolvIters", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumNonlinSolvIters_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nniters_adapt_modifiable; + + int r = IDAGetNumNonlinSolvIters(ida_mem, &nniters_adapt_modifiable); + return std::make_tuple(r, nniters_adapt_modifiable); + }; + + return IDAGetNumNonlinSolvIters_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumNonlinSolvConvFails", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumNonlinSolvConvFails_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nnfails_adapt_modifiable; + + int r = IDAGetNumNonlinSolvConvFails(ida_mem, &nnfails_adapt_modifiable); + return std::make_tuple(r, nnfails_adapt_modifiable); + }; + + return IDAGetNumNonlinSolvConvFails_adapt_modifiable_immutable_to_return( + ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNonlinSolvStats", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNonlinSolvStats_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nniters_adapt_modifiable; + long nnfails_adapt_modifiable; + + int r = IDAGetNonlinSolvStats(ida_mem, &nniters_adapt_modifiable, + &nnfails_adapt_modifiable); + return std::make_tuple(r, nniters_adapt_modifiable, + nnfails_adapt_modifiable); + }; + + return IDAGetNonlinSolvStats_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumStepSolveFails", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumStepSolveFails_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nncfails_adapt_modifiable; + + int r = IDAGetNumStepSolveFails(ida_mem, &nncfails_adapt_modifiable); + return std::make_tuple(r, nncfails_adapt_modifiable); + }; + + return IDAGetNumStepSolveFails_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def("IDAPrintAllStats", IDAPrintAllStats, nb::arg("ida_mem"), + nb::arg("outfile"), nb::arg("fmt")); + +m.def("IDAGetReturnFlagName", IDAGetReturnFlagName, nb::arg("flag")); + +m.def("IDAQuadReInit", IDAQuadReInit, nb::arg("ida_mem"), nb::arg("yQ0")); + +m.def("IDAQuadSStolerances", IDAQuadSStolerances, nb::arg("ida_mem"), + nb::arg("reltolQ"), nb::arg("abstolQ")); + +m.def("IDAQuadSVtolerances", IDAQuadSVtolerances, nb::arg("ida_mem"), + nb::arg("reltolQ"), nb::arg("abstolQ")); + +m.def("IDASetQuadErrCon", IDASetQuadErrCon, nb::arg("ida_mem"), + nb::arg("errconQ"), "Optional input specification functions"); + +m.def( + "IDAGetQuad", + [](void* ida_mem, N_Vector yQout) -> std::tuple + { + auto IDAGetQuad_adapt_modifiable_immutable_to_return = + [](void* ida_mem, N_Vector yQout) -> std::tuple + { + sunrealtype t_adapt_modifiable; + + int r = IDAGetQuad(ida_mem, &t_adapt_modifiable, yQout); + return std::make_tuple(r, t_adapt_modifiable); + }; + + return IDAGetQuad_adapt_modifiable_immutable_to_return(ida_mem, yQout); + }, + nb::arg("ida_mem"), nb::arg("yQout")); + +m.def("IDAGetQuadDky", IDAGetQuadDky, nb::arg("ida_mem"), nb::arg("t"), + nb::arg("k"), nb::arg("dky")); + +m.def( + "IDAGetQuadNumRhsEvals", + [](void* ida_mem) -> std::tuple + { + auto IDAGetQuadNumRhsEvals_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nrhsQevals_adapt_modifiable; + + int r = IDAGetQuadNumRhsEvals(ida_mem, &nrhsQevals_adapt_modifiable); + return std::make_tuple(r, nrhsQevals_adapt_modifiable); + }; + + return IDAGetQuadNumRhsEvals_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetQuadNumErrTestFails", + [](void* ida_mem) -> std::tuple + { + auto IDAGetQuadNumErrTestFails_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nQetfails_adapt_modifiable; + + int r = IDAGetQuadNumErrTestFails(ida_mem, &nQetfails_adapt_modifiable); + return std::make_tuple(r, nQetfails_adapt_modifiable); + }; + + return IDAGetQuadNumErrTestFails_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def("IDAGetQuadErrWeights", IDAGetQuadErrWeights, nb::arg("ida_mem"), + nb::arg("eQweight")); + +m.def( + "IDAGetQuadStats", + [](void* ida_mem) -> std::tuple + { + auto IDAGetQuadStats_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nrhsQevals_adapt_modifiable; + long nQetfails_adapt_modifiable; + + int r = IDAGetQuadStats(ida_mem, &nrhsQevals_adapt_modifiable, + &nQetfails_adapt_modifiable); + return std::make_tuple(r, nrhsQevals_adapt_modifiable, + nQetfails_adapt_modifiable); + }; + + return IDAGetQuadStats_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDASensReInit", + [](void* ida_mem, int ism, std::vector yS0_1d, + std::vector ypS0_1d) -> int + { + auto IDASensReInit_adapt_arr_ptr_to_std_vector = + [](void* ida_mem, int ism, std::vector yS0_1d, + std::vector ypS0_1d) -> int + { + N_Vector* yS0_1d_ptr = + reinterpret_cast(yS0_1d.empty() ? nullptr : yS0_1d.data()); + N_Vector* ypS0_1d_ptr = + reinterpret_cast(ypS0_1d.empty() ? nullptr : ypS0_1d.data()); + + auto lambda_result = IDASensReInit(ida_mem, ism, yS0_1d_ptr, ypS0_1d_ptr); + return lambda_result; + }; + + return IDASensReInit_adapt_arr_ptr_to_std_vector(ida_mem, ism, yS0_1d, + ypS0_1d); + }, + nb::arg("ida_mem"), nb::arg("ism"), nb::arg("yS0_1d"), nb::arg("ypS0_1d")); + +m.def( + "IDASensSStolerances", + [](void* ida_mem, sunrealtype reltolS, sundials4py::Array1d abstolS_1d) -> int + { + auto IDASensSStolerances_adapt_arr_ptr_to_std_vector = + [](void* ida_mem, sunrealtype reltolS, sundials4py::Array1d abstolS_1d) -> int + { + sunrealtype* abstolS_1d_ptr = + reinterpret_cast(abstolS_1d.data()); + + auto lambda_result = IDASensSStolerances(ida_mem, reltolS, abstolS_1d_ptr); + return lambda_result; + }; + + return IDASensSStolerances_adapt_arr_ptr_to_std_vector(ida_mem, reltolS, + abstolS_1d); + }, + nb::arg("ida_mem"), nb::arg("reltolS"), nb::arg("abstolS_1d")); + +m.def( + "IDASensSVtolerances", + [](void* ida_mem, sunrealtype reltolS, std::vector abstolS_1d) -> int + { + auto IDASensSVtolerances_adapt_arr_ptr_to_std_vector = + [](void* ida_mem, sunrealtype reltolS, + std::vector abstolS_1d) -> int + { + N_Vector* abstolS_1d_ptr = reinterpret_cast( + abstolS_1d.empty() ? nullptr : abstolS_1d.data()); + + auto lambda_result = IDASensSVtolerances(ida_mem, reltolS, abstolS_1d_ptr); + return lambda_result; + }; + + return IDASensSVtolerances_adapt_arr_ptr_to_std_vector(ida_mem, reltolS, + abstolS_1d); + }, + nb::arg("ida_mem"), nb::arg("reltolS"), nb::arg("abstolS_1d")); + +m.def("IDASensEEtolerances", IDASensEEtolerances, nb::arg("ida_mem")); + +m.def( + "IDAGetSensConsistentIC", + [](void* ida_mem, std::vector yyS0_1d, + std::vector ypS0_1d) -> int + { + auto IDAGetSensConsistentIC_adapt_arr_ptr_to_std_vector = + [](void* ida_mem, std::vector yyS0_1d, + std::vector ypS0_1d) -> int + { + N_Vector* yyS0_1d_ptr = + reinterpret_cast(yyS0_1d.empty() ? nullptr : yyS0_1d.data()); + N_Vector* ypS0_1d_ptr = + reinterpret_cast(ypS0_1d.empty() ? nullptr : ypS0_1d.data()); + + auto lambda_result = IDAGetSensConsistentIC(ida_mem, yyS0_1d_ptr, + ypS0_1d_ptr); + return lambda_result; + }; + + return IDAGetSensConsistentIC_adapt_arr_ptr_to_std_vector(ida_mem, yyS0_1d, + ypS0_1d); + }, + nb::arg("ida_mem"), nb::arg("yyS0_1d"), nb::arg("ypS0_1d"), + "Initial condition calculation function"); + +m.def("IDASetSensDQMethod", IDASetSensDQMethod, nb::arg("ida_mem"), + nb::arg("DQtype"), nb::arg("DQrhomax")); + +m.def("IDASetSensErrCon", IDASetSensErrCon, nb::arg("ida_mem"), + nb::arg("errconS")); + +m.def("IDASetSensMaxNonlinIters", IDASetSensMaxNonlinIters, nb::arg("ida_mem"), + nb::arg("maxcorS")); + +m.def( + "IDASetSensParams", + [](void* ida_mem, sundials4py::Array1d p_1d, sundials4py::Array1d pbar_1d, + std::vector plist_1d) -> int + { + auto IDASetSensParams_adapt_arr_ptr_to_std_vector = + [](void* ida_mem, sundials4py::Array1d p_1d, sundials4py::Array1d pbar_1d, + std::vector plist_1d) -> int + { + sunrealtype* p_1d_ptr = reinterpret_cast(p_1d.data()); + sunrealtype* pbar_1d_ptr = reinterpret_cast(pbar_1d.data()); + int* plist_1d_ptr = + reinterpret_cast(plist_1d.empty() ? nullptr : plist_1d.data()); + + auto lambda_result = IDASetSensParams(ida_mem, p_1d_ptr, pbar_1d_ptr, + plist_1d_ptr); + return lambda_result; + }; + + return IDASetSensParams_adapt_arr_ptr_to_std_vector(ida_mem, p_1d, pbar_1d, + plist_1d); + }, + nb::arg("ida_mem"), nb::arg("p_1d"), nb::arg("pbar_1d"), nb::arg("plist_1d")); + +m.def("IDASetNonlinearSolverSensSim", IDASetNonlinearSolverSensSim, + nb::arg("ida_mem"), nb::arg("NLS")); + +m.def("IDASetNonlinearSolverSensStg", IDASetNonlinearSolverSensStg, + nb::arg("ida_mem"), nb::arg("NLS")); + +m.def("IDASensToggleOff", IDASensToggleOff, nb::arg("ida_mem"), + "Enable/disable sensitivities"); + +m.def( + "IDAGetSens", + [](void* ida_mem, std::vector yySout_1d) -> std::tuple + { + auto IDAGetSens_adapt_arr_ptr_to_std_vector = + [](void* ida_mem, sunrealtype* tret, std::vector yySout_1d) -> int + { + N_Vector* yySout_1d_ptr = reinterpret_cast( + yySout_1d.empty() ? nullptr : yySout_1d.data()); + + auto lambda_result = IDAGetSens(ida_mem, tret, yySout_1d_ptr); + return lambda_result; + }; + auto IDAGetSens_adapt_modifiable_immutable_to_return = + [&IDAGetSens_adapt_arr_ptr_to_std_vector](void* ida_mem, + std::vector yySout_1d) + -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + int r = IDAGetSens_adapt_arr_ptr_to_std_vector(ida_mem, + &tret_adapt_modifiable, + yySout_1d); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return IDAGetSens_adapt_modifiable_immutable_to_return(ida_mem, yySout_1d); + }, + nb::arg("ida_mem"), nb::arg("yySout_1d")); + +m.def( + "IDAGetSens1", + [](void* ida_mem, int is, N_Vector yySret) -> std::tuple + { + auto IDAGetSens1_adapt_modifiable_immutable_to_return = + [](void* ida_mem, int is, N_Vector yySret) -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + int r = IDAGetSens1(ida_mem, &tret_adapt_modifiable, is, yySret); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return IDAGetSens1_adapt_modifiable_immutable_to_return(ida_mem, is, yySret); + }, + nb::arg("ida_mem"), nb::arg("is_"), nb::arg("yySret")); + +m.def( + "IDAGetSensDky", + [](void* ida_mem, sunrealtype t, int k, std::vector dkyS_1d) -> int + { + auto IDAGetSensDky_adapt_arr_ptr_to_std_vector = + [](void* ida_mem, sunrealtype t, int k, std::vector dkyS_1d) -> int + { + N_Vector* dkyS_1d_ptr = + reinterpret_cast(dkyS_1d.empty() ? nullptr : dkyS_1d.data()); + + auto lambda_result = IDAGetSensDky(ida_mem, t, k, dkyS_1d_ptr); + return lambda_result; + }; + + return IDAGetSensDky_adapt_arr_ptr_to_std_vector(ida_mem, t, k, dkyS_1d); + }, + nb::arg("ida_mem"), nb::arg("t"), nb::arg("k"), nb::arg("dkyS_1d")); + +m.def("IDAGetSensDky1", IDAGetSensDky1, nb::arg("ida_mem"), nb::arg("t"), + nb::arg("k"), nb::arg("is_"), nb::arg("dkyS")); + +m.def( + "IDAGetSensNumResEvals", + [](void* ida_mem) -> std::tuple + { + auto IDAGetSensNumResEvals_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nresSevals_adapt_modifiable; + + int r = IDAGetSensNumResEvals(ida_mem, &nresSevals_adapt_modifiable); + return std::make_tuple(r, nresSevals_adapt_modifiable); + }; + + return IDAGetSensNumResEvals_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumResEvalsSens", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumResEvalsSens_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nresevalsS_adapt_modifiable; + + int r = IDAGetNumResEvalsSens(ida_mem, &nresevalsS_adapt_modifiable); + return std::make_tuple(r, nresevalsS_adapt_modifiable); + }; + + return IDAGetNumResEvalsSens_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetSensNumErrTestFails", + [](void* ida_mem) -> std::tuple + { + auto IDAGetSensNumErrTestFails_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nSetfails_adapt_modifiable; + + int r = IDAGetSensNumErrTestFails(ida_mem, &nSetfails_adapt_modifiable); + return std::make_tuple(r, nSetfails_adapt_modifiable); + }; + + return IDAGetSensNumErrTestFails_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetSensNumLinSolvSetups", + [](void* ida_mem) -> std::tuple + { + auto IDAGetSensNumLinSolvSetups_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nlinsetupsS_adapt_modifiable; + + int r = IDAGetSensNumLinSolvSetups(ida_mem, &nlinsetupsS_adapt_modifiable); + return std::make_tuple(r, nlinsetupsS_adapt_modifiable); + }; + + return IDAGetSensNumLinSolvSetups_adapt_modifiable_immutable_to_return( + ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetSensErrWeights", + [](void* ida_mem, std::vector eSweight_1d) -> int + { + auto IDAGetSensErrWeights_adapt_arr_ptr_to_std_vector = + [](void* ida_mem, std::vector eSweight_1d) -> int + { + N_Vector* eSweight_1d_ptr = reinterpret_cast( + eSweight_1d.empty() ? nullptr : eSweight_1d.data()); + + auto lambda_result = IDAGetSensErrWeights(ida_mem, eSweight_1d_ptr); + return lambda_result; + }; + + return IDAGetSensErrWeights_adapt_arr_ptr_to_std_vector(ida_mem, eSweight_1d); + }, + nb::arg("ida_mem"), nb::arg("eSweight_1d")); + +m.def( + "IDAGetSensStats", + [](void* ida_mem) -> std::tuple + { + auto IDAGetSensStats_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nresSevals_adapt_modifiable; + long nresevalsS_adapt_modifiable; + long nSetfails_adapt_modifiable; + long nlinsetupsS_adapt_modifiable; + + int r = IDAGetSensStats(ida_mem, &nresSevals_adapt_modifiable, + &nresevalsS_adapt_modifiable, + &nSetfails_adapt_modifiable, + &nlinsetupsS_adapt_modifiable); + return std::make_tuple(r, nresSevals_adapt_modifiable, + nresevalsS_adapt_modifiable, + nSetfails_adapt_modifiable, + nlinsetupsS_adapt_modifiable); + }; + + return IDAGetSensStats_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetSensNumNonlinSolvIters", + [](void* ida_mem) -> std::tuple + { + auto IDAGetSensNumNonlinSolvIters_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nSniters_adapt_modifiable; + + int r = IDAGetSensNumNonlinSolvIters(ida_mem, &nSniters_adapt_modifiable); + return std::make_tuple(r, nSniters_adapt_modifiable); + }; + + return IDAGetSensNumNonlinSolvIters_adapt_modifiable_immutable_to_return( + ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetSensNumNonlinSolvConvFails", + [](void* ida_mem) -> std::tuple + { + auto IDAGetSensNumNonlinSolvConvFails_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nSnfails_adapt_modifiable; + + int r = IDAGetSensNumNonlinSolvConvFails(ida_mem, + &nSnfails_adapt_modifiable); + return std::make_tuple(r, nSnfails_adapt_modifiable); + }; + + return IDAGetSensNumNonlinSolvConvFails_adapt_modifiable_immutable_to_return( + ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetSensNonlinSolvStats", + [](void* ida_mem) -> std::tuple + { + auto IDAGetSensNonlinSolvStats_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nSniters_adapt_modifiable; + long nSnfails_adapt_modifiable; + + int r = IDAGetSensNonlinSolvStats(ida_mem, &nSniters_adapt_modifiable, + &nSnfails_adapt_modifiable); + return std::make_tuple(r, nSniters_adapt_modifiable, + nSnfails_adapt_modifiable); + }; + + return IDAGetSensNonlinSolvStats_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumStepSensSolveFails", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumStepSensSolveFails_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nSncfails_adapt_modifiable; + + int r = IDAGetNumStepSensSolveFails(ida_mem, &nSncfails_adapt_modifiable); + return std::make_tuple(r, nSncfails_adapt_modifiable); + }; + + return IDAGetNumStepSensSolveFails_adapt_modifiable_immutable_to_return( + ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAQuadSensReInit", + [](void* ida_mem, std::vector yQS0_1d) -> int + { + auto IDAQuadSensReInit_adapt_arr_ptr_to_std_vector = + [](void* ida_mem, std::vector yQS0_1d) -> int + { + N_Vector* yQS0_1d_ptr = + reinterpret_cast(yQS0_1d.empty() ? nullptr : yQS0_1d.data()); + + auto lambda_result = IDAQuadSensReInit(ida_mem, yQS0_1d_ptr); + return lambda_result; + }; + + return IDAQuadSensReInit_adapt_arr_ptr_to_std_vector(ida_mem, yQS0_1d); + }, + nb::arg("ida_mem"), nb::arg("yQS0_1d")); + +m.def( + "IDAQuadSensSStolerances", + [](void* ida_mem, sunrealtype reltolQS) -> std::tuple + { + auto IDAQuadSensSStolerances_adapt_modifiable_immutable_to_return = + [](void* ida_mem, sunrealtype reltolQS) -> std::tuple + { + sunrealtype abstolQS_adapt_modifiable; + + int r = IDAQuadSensSStolerances(ida_mem, reltolQS, + &abstolQS_adapt_modifiable); + return std::make_tuple(r, abstolQS_adapt_modifiable); + }; + + return IDAQuadSensSStolerances_adapt_modifiable_immutable_to_return(ida_mem, + reltolQS); + }, + nb::arg("ida_mem"), nb::arg("reltolQS")); + +m.def( + "IDAQuadSensSVtolerances", + [](void* ida_mem, sunrealtype reltolQS, std::vector abstolQS_1d) -> int + { + auto IDAQuadSensSVtolerances_adapt_arr_ptr_to_std_vector = + [](void* ida_mem, sunrealtype reltolQS, + std::vector abstolQS_1d) -> int + { + N_Vector* abstolQS_1d_ptr = reinterpret_cast( + abstolQS_1d.empty() ? nullptr : abstolQS_1d.data()); + + auto lambda_result = IDAQuadSensSVtolerances(ida_mem, reltolQS, + abstolQS_1d_ptr); + return lambda_result; + }; + + return IDAQuadSensSVtolerances_adapt_arr_ptr_to_std_vector(ida_mem, reltolQS, + abstolQS_1d); + }, + nb::arg("ida_mem"), nb::arg("reltolQS"), nb::arg("abstolQS_1d")); + +m.def("IDAQuadSensEEtolerances", IDAQuadSensEEtolerances, nb::arg("ida_mem")); + +m.def("IDASetQuadSensErrCon", IDASetQuadSensErrCon, nb::arg("ida_mem"), + nb::arg("errconQS"), "Optional input specification functions"); + +m.def( + "IDAGetQuadSens", + [](void* ida_mem, + std::vector yyQSout_1d) -> std::tuple + { + auto IDAGetQuadSens_adapt_arr_ptr_to_std_vector = + [](void* ida_mem, sunrealtype* tret, std::vector yyQSout_1d) -> int + { + N_Vector* yyQSout_1d_ptr = reinterpret_cast( + yyQSout_1d.empty() ? nullptr : yyQSout_1d.data()); + + auto lambda_result = IDAGetQuadSens(ida_mem, tret, yyQSout_1d_ptr); + return lambda_result; + }; + auto IDAGetQuadSens_adapt_modifiable_immutable_to_return = + [&IDAGetQuadSens_adapt_arr_ptr_to_std_vector](void* ida_mem, + std::vector yyQSout_1d) + -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + int r = IDAGetQuadSens_adapt_arr_ptr_to_std_vector(ida_mem, + &tret_adapt_modifiable, + yyQSout_1d); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return IDAGetQuadSens_adapt_modifiable_immutable_to_return(ida_mem, + yyQSout_1d); + }, + nb::arg("ida_mem"), nb::arg("yyQSout_1d")); + +m.def( + "IDAGetQuadSens1", + [](void* ida_mem, int is, N_Vector yyQSret) -> std::tuple + { + auto IDAGetQuadSens1_adapt_modifiable_immutable_to_return = + [](void* ida_mem, int is, N_Vector yyQSret) -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + int r = IDAGetQuadSens1(ida_mem, &tret_adapt_modifiable, is, yyQSret); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return IDAGetQuadSens1_adapt_modifiable_immutable_to_return(ida_mem, is, + yyQSret); + }, + nb::arg("ida_mem"), nb::arg("is_"), nb::arg("yyQSret")); + +m.def( + "IDAGetQuadSensDky", + [](void* ida_mem, sunrealtype t, int k, std::vector dkyQS_1d) -> int + { + auto IDAGetQuadSensDky_adapt_arr_ptr_to_std_vector = + [](void* ida_mem, sunrealtype t, int k, std::vector dkyQS_1d) -> int + { + N_Vector* dkyQS_1d_ptr = reinterpret_cast( + dkyQS_1d.empty() ? nullptr : dkyQS_1d.data()); + + auto lambda_result = IDAGetQuadSensDky(ida_mem, t, k, dkyQS_1d_ptr); + return lambda_result; + }; + + return IDAGetQuadSensDky_adapt_arr_ptr_to_std_vector(ida_mem, t, k, dkyQS_1d); + }, + nb::arg("ida_mem"), nb::arg("t"), nb::arg("k"), nb::arg("dkyQS_1d")); + +m.def("IDAGetQuadSensDky1", IDAGetQuadSensDky1, nb::arg("ida_mem"), + nb::arg("t"), nb::arg("k"), nb::arg("is_"), nb::arg("dkyQS")); + +m.def( + "IDAGetQuadSensNumRhsEvals", + [](void* ida_mem) -> std::tuple + { + auto IDAGetQuadSensNumRhsEvals_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nrhsQSevals_adapt_modifiable; + + int r = IDAGetQuadSensNumRhsEvals(ida_mem, &nrhsQSevals_adapt_modifiable); + return std::make_tuple(r, nrhsQSevals_adapt_modifiable); + }; + + return IDAGetQuadSensNumRhsEvals_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetQuadSensNumErrTestFails", + [](void* ida_mem) -> std::tuple + { + auto IDAGetQuadSensNumErrTestFails_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nQSetfails_adapt_modifiable; + + int r = IDAGetQuadSensNumErrTestFails(ida_mem, + &nQSetfails_adapt_modifiable); + return std::make_tuple(r, nQSetfails_adapt_modifiable); + }; + + return IDAGetQuadSensNumErrTestFails_adapt_modifiable_immutable_to_return( + ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetQuadSensErrWeights", + [](void* ida_mem, std::vector eQSweight_1d) -> int + { + auto IDAGetQuadSensErrWeights_adapt_arr_ptr_to_std_vector = + [](void* ida_mem, std::vector eQSweight_1d) -> int + { + N_Vector* eQSweight_1d_ptr = reinterpret_cast( + eQSweight_1d.empty() ? nullptr : eQSweight_1d.data()); + + auto lambda_result = IDAGetQuadSensErrWeights(ida_mem, eQSweight_1d_ptr); + return lambda_result; + }; + + return IDAGetQuadSensErrWeights_adapt_arr_ptr_to_std_vector(ida_mem, + eQSweight_1d); + }, + nb::arg("ida_mem"), nb::arg("eQSweight_1d")); + +m.def( + "IDAGetQuadSensStats", + [](void* ida_mem) -> std::tuple + { + auto IDAGetQuadSensStats_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nrhsQSevals_adapt_modifiable; + long nQSetfails_adapt_modifiable; + + int r = IDAGetQuadSensStats(ida_mem, &nrhsQSevals_adapt_modifiable, + &nQSetfails_adapt_modifiable); + return std::make_tuple(r, nrhsQSevals_adapt_modifiable, + nQSetfails_adapt_modifiable); + }; + + return IDAGetQuadSensStats_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def("IDAAdjInit", IDAAdjInit, nb::arg("ida_mem"), nb::arg("steps"), + nb::arg("interp")); + +m.def("IDAAdjReInit", IDAAdjReInit, nb::arg("ida_mem")); + +m.def( + "IDACreateB", + [](void* ida_mem) -> std::tuple + { + auto IDACreateB_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + int which_adapt_modifiable; + + int r = IDACreateB(ida_mem, &which_adapt_modifiable); + return std::make_tuple(r, which_adapt_modifiable); + }; + + return IDACreateB_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def("IDAReInitB", IDAReInitB, nb::arg("ida_mem"), nb::arg("which"), + nb::arg("tB0"), nb::arg("yyB0"), nb::arg("ypB0")); + +m.def("IDASStolerancesB", IDASStolerancesB, nb::arg("ida_mem"), + nb::arg("which"), nb::arg("relTolB"), nb::arg("absTolB")); + +m.def("IDASVtolerancesB", IDASVtolerancesB, nb::arg("ida_mem"), + nb::arg("which"), nb::arg("relTolB"), nb::arg("absTolB")); + +m.def("IDAQuadReInitB", IDAQuadReInitB, nb::arg("ida_mem"), nb::arg("which"), + nb::arg("yQB0")); + +m.def("IDAQuadSStolerancesB", IDAQuadSStolerancesB, nb::arg("ida_mem"), + nb::arg("which"), nb::arg("reltolQB"), nb::arg("abstolQB")); + +m.def("IDAQuadSVtolerancesB", IDAQuadSVtolerancesB, nb::arg("ida_mem"), + nb::arg("which"), nb::arg("reltolQB"), nb::arg("abstolQB")); + +m.def("IDACalcICB", IDACalcICB, nb::arg("ida_mem"), nb::arg("which"), + nb::arg("tout1"), nb::arg("yy0"), nb::arg("yp0")); + +m.def( + "IDACalcICBS", + [](void* ida_mem, int which, sunrealtype tout1, N_Vector yy0, N_Vector yp0, + std::vector yyS0_1d, std::vector ypS0_1d) -> int + { + auto IDACalcICBS_adapt_arr_ptr_to_std_vector = + [](void* ida_mem, int which, sunrealtype tout1, N_Vector yy0, N_Vector yp0, + std::vector yyS0_1d, std::vector ypS0_1d) -> int + { + N_Vector* yyS0_1d_ptr = + reinterpret_cast(yyS0_1d.empty() ? nullptr : yyS0_1d.data()); + N_Vector* ypS0_1d_ptr = + reinterpret_cast(ypS0_1d.empty() ? nullptr : ypS0_1d.data()); + + auto lambda_result = IDACalcICBS(ida_mem, which, tout1, yy0, yp0, + yyS0_1d_ptr, ypS0_1d_ptr); + return lambda_result; + }; + + return IDACalcICBS_adapt_arr_ptr_to_std_vector(ida_mem, which, tout1, yy0, + yp0, yyS0_1d, ypS0_1d); + }, + nb::arg("ida_mem"), nb::arg("which"), nb::arg("tout1"), nb::arg("yy0"), + nb::arg("yp0"), nb::arg("yyS0_1d"), nb::arg("ypS0_1d")); + +m.def( + "IDASolveF", + [](void* ida_mem, sunrealtype tout, N_Vector yret, N_Vector ypret, + int itask) -> std::tuple + { + auto IDASolveF_adapt_modifiable_immutable_to_return = + [](void* ida_mem, sunrealtype tout, N_Vector yret, N_Vector ypret, + int itask) -> std::tuple + { + sunrealtype tret_adapt_modifiable; + int ncheckPtr_adapt_modifiable; + + int r = IDASolveF(ida_mem, tout, &tret_adapt_modifiable, yret, ypret, + itask, &ncheckPtr_adapt_modifiable); + return std::make_tuple(r, tret_adapt_modifiable, + ncheckPtr_adapt_modifiable); + }; + + return IDASolveF_adapt_modifiable_immutable_to_return(ida_mem, tout, yret, + ypret, itask); + }, + nb::arg("ida_mem"), nb::arg("tout"), nb::arg("yret"), nb::arg("ypret"), + nb::arg("itask")); + +m.def("IDASolveB", IDASolveB, nb::arg("ida_mem"), nb::arg("tBout"), + nb::arg("itaskB")); + +m.def("IDAAdjSetNoSensi", IDAAdjSetNoSensi, nb::arg("ida_mem")); + +m.def("IDASetMaxOrdB", IDASetMaxOrdB, nb::arg("ida_mem"), nb::arg("which"), + nb::arg("maxordB")); + +m.def("IDASetMaxNumStepsB", IDASetMaxNumStepsB, nb::arg("ida_mem"), + nb::arg("which"), nb::arg("mxstepsB")); + +m.def("IDASetInitStepB", IDASetInitStepB, nb::arg("ida_mem"), nb::arg("which"), + nb::arg("hinB")); + +m.def("IDASetMaxStepB", IDASetMaxStepB, nb::arg("ida_mem"), nb::arg("which"), + nb::arg("hmaxB")); + +m.def("IDASetSuppressAlgB", IDASetSuppressAlgB, nb::arg("ida_mem"), + nb::arg("which"), nb::arg("suppressalgB")); + +m.def("IDASetIdB", IDASetIdB, nb::arg("ida_mem"), nb::arg("which"), + nb::arg("idB")); + +m.def("IDASetConstraintsB", IDASetConstraintsB, nb::arg("ida_mem"), + nb::arg("which"), nb::arg("constraintsB")); + +m.def("IDASetQuadErrConB", IDASetQuadErrConB, nb::arg("ida_mem"), + nb::arg("which"), nb::arg("errconQB")); + +m.def("IDASetNonlinearSolverB", IDASetNonlinearSolverB, nb::arg("ida_mem"), + nb::arg("which"), nb::arg("NLS")); + +m.def( + "IDAGetB", + [](void* ida_mem, int which, N_Vector yy, + N_Vector yp) -> std::tuple + { + auto IDAGetB_adapt_modifiable_immutable_to_return = + [](void* ida_mem, int which, N_Vector yy, + N_Vector yp) -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + int r = IDAGetB(ida_mem, which, &tret_adapt_modifiable, yy, yp); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return IDAGetB_adapt_modifiable_immutable_to_return(ida_mem, which, yy, yp); + }, + nb::arg("ida_mem"), nb::arg("which"), nb::arg("yy"), nb::arg("yp")); + +m.def( + "IDAGetQuadB", + [](void* ida_mem, int which, N_Vector qB) -> std::tuple + { + auto IDAGetQuadB_adapt_modifiable_immutable_to_return = + [](void* ida_mem, int which, N_Vector qB) -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + int r = IDAGetQuadB(ida_mem, which, &tret_adapt_modifiable, qB); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return IDAGetQuadB_adapt_modifiable_immutable_to_return(ida_mem, which, qB); + }, + nb::arg("ida_mem"), nb::arg("which"), nb::arg("qB")); + +m.def("IDAGetAdjIDABmem", IDAGetAdjIDABmem, nb::arg("ida_mem"), nb::arg("which")); + +m.def("IDAGetConsistentICB", IDAGetConsistentICB, nb::arg("ida_mem"), + nb::arg("which"), nb::arg("yyB0"), nb::arg("ypB0")); + +m.def("IDAGetAdjY", IDAGetAdjY, nb::arg("ida_mem"), nb::arg("t"), nb::arg("yy"), + nb::arg("yp")); + +m.def("IDAGetAdjCheckPointsInfo", IDAGetAdjCheckPointsInfo, nb::arg("ida_mem"), + nb::arg("ckpnt")); + +m.def( + "IDAGetAdjDataPointHermite", + [](void* ida_mem, int which, N_Vector yy, + N_Vector yd) -> std::tuple + { + auto IDAGetAdjDataPointHermite_adapt_modifiable_immutable_to_return = + [](void* ida_mem, int which, N_Vector yy, + N_Vector yd) -> std::tuple + { + sunrealtype t_adapt_modifiable; + + int r = IDAGetAdjDataPointHermite(ida_mem, which, &t_adapt_modifiable, yy, + yd); + return std::make_tuple(r, t_adapt_modifiable); + }; + + return IDAGetAdjDataPointHermite_adapt_modifiable_immutable_to_return(ida_mem, + which, + yy, yd); + }, + nb::arg("ida_mem"), nb::arg("which"), nb::arg("yy"), nb::arg("yd")); + +m.def( + "IDAGetAdjDataPointPolynomial", + [](void* ida_mem, int which, N_Vector y) -> std::tuple + { + auto IDAGetAdjDataPointPolynomial_adapt_modifiable_immutable_to_return = + [](void* ida_mem, int which, N_Vector y) -> std::tuple + { + sunrealtype t_adapt_modifiable; + int order_adapt_modifiable; + + int r = IDAGetAdjDataPointPolynomial(ida_mem, which, &t_adapt_modifiable, + &order_adapt_modifiable, y); + return std::make_tuple(r, t_adapt_modifiable, order_adapt_modifiable); + }; + + return IDAGetAdjDataPointPolynomial_adapt_modifiable_immutable_to_return(ida_mem, + which, + y); + }, + nb::arg("ida_mem"), nb::arg("which"), nb::arg("y")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// +// #ifndef _IDASLS_H +// +// #ifdef __cplusplus +// #endif +// +m.attr("IDALS_SUCCESS") = 0; +m.attr("IDALS_MEM_NULL") = -1; +m.attr("IDALS_LMEM_NULL") = -2; +m.attr("IDALS_ILL_INPUT") = -3; +m.attr("IDALS_MEM_FAIL") = -4; +m.attr("IDALS_PMEM_NULL") = -5; +m.attr("IDALS_JACFUNC_UNRECVR") = -6; +m.attr("IDALS_JACFUNC_RECVR") = -7; +m.attr("IDALS_SUNMAT_FAIL") = -8; +m.attr("IDALS_SUNLS_FAIL") = -9; +m.attr("IDALS_NO_ADJ") = -101; +m.attr("IDALS_LMEMB_NULL") = -102; + +m.def( + "IDASetLinearSolver", + [](void* ida_mem, SUNLinearSolver LS, + std::optional A = std::nullopt) -> int + { + auto IDASetLinearSolver_adapt_optional_arg_with_default_null = + [](void* ida_mem, SUNLinearSolver LS, + std::optional A = std::nullopt) -> int + { + SUNMatrix A_adapt_default_null = nullptr; + if (A.has_value()) A_adapt_default_null = A.value(); + + auto lambda_result = IDASetLinearSolver(ida_mem, LS, A_adapt_default_null); + return lambda_result; + }; + + return IDASetLinearSolver_adapt_optional_arg_with_default_null(ida_mem, LS, + A); + }, + nb::arg("ida_mem"), nb::arg("LS"), nb::arg("A").none() = nb::none()); + +m.def("IDASetEpsLin", IDASetEpsLin, nb::arg("ida_mem"), nb::arg("eplifac")); + +m.def("IDASetLSNormFactor", IDASetLSNormFactor, nb::arg("ida_mem"), + nb::arg("nrmfac")); + +m.def("IDASetLinearSolutionScaling", IDASetLinearSolutionScaling, + nb::arg("ida_mem"), nb::arg("onoff")); + +m.def("IDASetIncrementFactor", IDASetIncrementFactor, nb::arg("ida_mem"), + nb::arg("dqincfac")); + +m.def( + "IDAGetJac", + [](void* ida_mem) -> std::tuple + { + auto IDAGetJac_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + SUNMatrix J_adapt_modifiable; + + int r = IDAGetJac(ida_mem, &J_adapt_modifiable); + return std::make_tuple(r, J_adapt_modifiable); + }; + + return IDAGetJac_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem"), "nb::rv_policy::reference", nb::rv_policy::reference); + +m.def( + "IDAGetJacCj", + [](void* ida_mem) -> std::tuple + { + auto IDAGetJacCj_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + sunrealtype cj_J_adapt_modifiable; + + int r = IDAGetJacCj(ida_mem, &cj_J_adapt_modifiable); + return std::make_tuple(r, cj_J_adapt_modifiable); + }; + + return IDAGetJacCj_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetJacTime", + [](void* ida_mem) -> std::tuple + { + auto IDAGetJacTime_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + sunrealtype t_J_adapt_modifiable; + + int r = IDAGetJacTime(ida_mem, &t_J_adapt_modifiable); + return std::make_tuple(r, t_J_adapt_modifiable); + }; + + return IDAGetJacTime_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetJacNumSteps", + [](void* ida_mem) -> std::tuple + { + auto IDAGetJacNumSteps_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nst_J_adapt_modifiable; + + int r = IDAGetJacNumSteps(ida_mem, &nst_J_adapt_modifiable); + return std::make_tuple(r, nst_J_adapt_modifiable); + }; + + return IDAGetJacNumSteps_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumJacEvals", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumJacEvals_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long njevals_adapt_modifiable; + + int r = IDAGetNumJacEvals(ida_mem, &njevals_adapt_modifiable); + return std::make_tuple(r, njevals_adapt_modifiable); + }; + + return IDAGetNumJacEvals_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumPrecEvals", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumPrecEvals_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long npevals_adapt_modifiable; + + int r = IDAGetNumPrecEvals(ida_mem, &npevals_adapt_modifiable); + return std::make_tuple(r, npevals_adapt_modifiable); + }; + + return IDAGetNumPrecEvals_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumPrecSolves", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumPrecSolves_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long npsolves_adapt_modifiable; + + int r = IDAGetNumPrecSolves(ida_mem, &npsolves_adapt_modifiable); + return std::make_tuple(r, npsolves_adapt_modifiable); + }; + + return IDAGetNumPrecSolves_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumLinIters", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumLinIters_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nliters_adapt_modifiable; + + int r = IDAGetNumLinIters(ida_mem, &nliters_adapt_modifiable); + return std::make_tuple(r, nliters_adapt_modifiable); + }; + + return IDAGetNumLinIters_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumLinConvFails", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumLinConvFails_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nlcfails_adapt_modifiable; + + int r = IDAGetNumLinConvFails(ida_mem, &nlcfails_adapt_modifiable); + return std::make_tuple(r, nlcfails_adapt_modifiable); + }; + + return IDAGetNumLinConvFails_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumJTSetupEvals", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumJTSetupEvals_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long njtsetups_adapt_modifiable; + + int r = IDAGetNumJTSetupEvals(ida_mem, &njtsetups_adapt_modifiable); + return std::make_tuple(r, njtsetups_adapt_modifiable); + }; + + return IDAGetNumJTSetupEvals_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumJtimesEvals", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumJtimesEvals_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long njvevals_adapt_modifiable; + + int r = IDAGetNumJtimesEvals(ida_mem, &njvevals_adapt_modifiable); + return std::make_tuple(r, njvevals_adapt_modifiable); + }; + + return IDAGetNumJtimesEvals_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetNumLinResEvals", + [](void* ida_mem) -> std::tuple + { + auto IDAGetNumLinResEvals_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long nrevalsLS_adapt_modifiable; + + int r = IDAGetNumLinResEvals(ida_mem, &nrevalsLS_adapt_modifiable); + return std::make_tuple(r, nrevalsLS_adapt_modifiable); + }; + + return IDAGetNumLinResEvals_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def( + "IDAGetLastLinFlag", + [](void* ida_mem) -> std::tuple + { + auto IDAGetLastLinFlag_adapt_modifiable_immutable_to_return = + [](void* ida_mem) -> std::tuple + { + long flag_adapt_modifiable; + + int r = IDAGetLastLinFlag(ida_mem, &flag_adapt_modifiable); + return std::make_tuple(r, flag_adapt_modifiable); + }; + + return IDAGetLastLinFlag_adapt_modifiable_immutable_to_return(ida_mem); + }, + nb::arg("ida_mem")); + +m.def("IDAGetLinReturnFlagName", IDAGetLinReturnFlagName, nb::arg("flag")); + +m.def( + "IDASetLinearSolverB", + [](void* ida_mem, int which, SUNLinearSolver LS, + std::optional A = std::nullopt) -> int + { + auto IDASetLinearSolverB_adapt_optional_arg_with_default_null = + [](void* ida_mem, int which, SUNLinearSolver LS, + std::optional A = std::nullopt) -> int + { + SUNMatrix A_adapt_default_null = nullptr; + if (A.has_value()) A_adapt_default_null = A.value(); + + auto lambda_result = IDASetLinearSolverB(ida_mem, which, LS, + A_adapt_default_null); + return lambda_result; + }; + + return IDASetLinearSolverB_adapt_optional_arg_with_default_null(ida_mem, + which, LS, A); + }, + nb::arg("ida_mem"), nb::arg("which"), nb::arg("LS"), + nb::arg("A").none() = nb::none()); + +m.def("IDASetEpsLinB", IDASetEpsLinB, nb::arg("ida_mem"), nb::arg("which"), + nb::arg("eplifacB")); + +m.def("IDASetLSNormFactorB", IDASetLSNormFactorB, nb::arg("ida_mem"), + nb::arg("which"), nb::arg("nrmfacB")); + +m.def("IDASetLinearSolutionScalingB", IDASetLinearSolutionScalingB, + nb::arg("ida_mem"), nb::arg("which"), nb::arg("onoffB")); + +m.def("IDASetIncrementFactorB", IDASetIncrementFactorB, nb::arg("ida_mem"), + nb::arg("which"), nb::arg("dqincfacB")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/idas/idas_usersupplied.hpp b/bindings/sundials4py/idas/idas_usersupplied.hpp new file mode 100644 index 0000000000..6c71c1c1ab --- /dev/null +++ b/bindings/sundials4py/idas/idas_usersupplied.hpp @@ -0,0 +1,497 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#ifndef _SUNDIALS4PY_IDAS_USERSUPPLIED_HPP +#define _SUNDIALS4PY_IDAS_USERSUPPLIED_HPP + +#include +#include + +#include + +#include "sundials4py_helpers.hpp" + +/////////////////////////////////////////////////////////////////////////////// +// IDAS user-supplied function table +// Every integrator-level user-supplied function must be in this table. +// The user-supplied function table is passed to IDAS as user_data. +/////////////////////////////////////////////////////////////////////////////// + +struct idas_user_supplied_fn_table +{ + // user-supplied function pointers + nb::object res, rootfn, ewtn, rwtn, resNLS; + + // idas_ls user-supplied function pointers + nb::object lsjacfn, lsprecsetupfn, lsprecsolvefn, lsjactimessetupfn, + lsjactimesvecfn, lsjacresfn; + + // idas quadrature user-supplied function pointers + nanobind::object resQ, resQS; + + // idas FSA user-supplied function pointers + nanobind::object resS; +}; + +inline idas_user_supplied_fn_table* idas_user_supplied_fn_table_alloc() +{ + // We must use malloc since IDAFree calls free + auto fn_table = static_cast( + std::malloc(sizeof(idas_user_supplied_fn_table))); + + // Zero out the memory + std::memset(fn_table, 0, sizeof(idas_user_supplied_fn_table)); + + return fn_table; +} + +inline idas_user_supplied_fn_table* get_idas_fn_table(void* ida_mem) +{ + auto mem = static_cast(ida_mem); + auto fn_table = static_cast(mem->python); + if (!fn_table) + { + throw sundials4py::null_function_table( + "Failed to get Python function table from IDAS memory"); + } + return fn_table; +} + +/////////////////////////////////////////////////////////////////////////////// +// IDAS user-supplied functions +/////////////////////////////////////////////////////////////////////////////// + +template +inline int idas_res_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idas_user_supplied_fn_table, IDAMem, + 1>(&idas_user_supplied_fn_table::res, std::forward(args)...); +} + +using IDARootStdFn = int(sunrealtype t, N_Vector y, N_Vector yp, + sundials4py::Array1d gout, void* user_data); + +inline int idas_rootfn_wrapper(sunrealtype t, N_Vector y, N_Vector yp, + sunrealtype* gout_1d, void* user_data) +{ + auto ida_mem = static_cast(user_data); + auto fn_table = get_idas_fn_table(user_data); + auto fn = nb::cast>(fn_table->rootfn); + auto nrtfn = ida_mem->ida_nrtfn; + + sundials4py::Array1d gout(gout_1d, {static_cast(nrtfn)}, + nb::find(gout_1d)); + + return fn(t, y, yp, gout, nullptr); +} + +template +inline int idas_ewtfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idas_user_supplied_fn_table, IDAMem, + 1>(&idas_user_supplied_fn_table::ewtn, std::forward(args)...); +} + +template +inline int idas_nlsresfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idas_user_supplied_fn_table, IDAMem, + 1>(&idas_user_supplied_fn_table::resNLS, std::forward(args)...); +} + +template +inline int idas_lsjacfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idas_user_supplied_fn_table, IDAMem, + 4>(&idas_user_supplied_fn_table::lsjacfn, std::forward(args)...); +} + +template +inline int idas_lsprecsetupfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idas_user_supplied_fn_table, IDAMem, + 1>(&idas_user_supplied_fn_table::lsprecsetupfn, std::forward(args)...); +} + +template +inline int idas_lsprecsolvefn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idas_user_supplied_fn_table, IDAMem, + 1>(&idas_user_supplied_fn_table::lsprecsolvefn, std::forward(args)...); +} + +template +inline int idas_lsjactimessetupfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idas_user_supplied_fn_table, + IDAMem, 1>(&idas_user_supplied_fn_table::lsjactimessetupfn, + std::forward(args)...); +} + +template +inline int idas_lsjactimesvecfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idas_user_supplied_fn_table, + IDAMem, 3>(&idas_user_supplied_fn_table::lsjactimesvecfn, + std::forward(args)...); +} + +template +inline int idas_lsjacresfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idas_user_supplied_fn_table, IDAMem, + 1>(&idas_user_supplied_fn_table::lsjacresfn, std::forward(args)...); +} + +template +inline int idas_resQ_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idas_user_supplied_fn_table, IDAMem, + 1>(&idas_user_supplied_fn_table::resQ, std::forward(args)...); +} + +using IDAQuadSensRhsStdFn = int(int Ns, sunrealtype t, N_Vector yy, N_Vector yp, + std::vector yyS, + std::vector ypS, N_Vector rrQ, + std::vector rhsvalQS, void* user_data, + N_Vector tmp1, N_Vector tmp2, N_Vector tmp3); + +inline int idas_resQS_wrapper(int Ns, sunrealtype t, N_Vector yy, N_Vector yp, + N_Vector* yyS, N_Vector* ypS, N_Vector rrQ, + N_Vector* rhsvalQS, void* user_data, + N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) +{ + auto fn_table = static_cast(user_data); + auto fn = nb::cast>(fn_table->resQS); + + std::vector yyS_1d(yyS, yyS + Ns); + std::vector ypS_1d(ypS, ypS + Ns); + std::vector rhsvalQS_1d(rhsvalQS, rhsvalQS + Ns); + + return fn(Ns, t, yy, yp, yyS_1d, ypS_1d, rrQ, rhsvalQS_1d, nullptr, tmp1, + tmp2, tmp3); +} + +using IDASensResStdFn = int(int Ns, sunrealtype t, N_Vector yy, N_Vector yp, + N_Vector resval, std::vector yyS_1d, + std::vector ypS_1d, + std::vector resvalS_1d, void* user_data, + N_Vector tmp1, N_Vector tmp2, N_Vector tmp3); + +inline int idas_resS_wrapper(int Ns, sunrealtype t, N_Vector yy, N_Vector yp, + N_Vector resval, N_Vector* yyS_1d, N_Vector* ypS_1d, + N_Vector* resvalS_1d, void* user_data, + N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) +{ + auto fn_table = get_idas_fn_table(user_data); + auto fn = nb::cast>(fn_table->resS); + + std::vector yyS(yyS_1d, yyS_1d + Ns); + std::vector ypS(ypS_1d, ypS_1d + Ns); + std::vector resvalS(resvalS_1d, resvalS_1d + Ns); + + return fn(Ns, t, yy, yp, resval, yyS, ypS, resvalS, nullptr, tmp1, tmp2, tmp3); +} + +/////////////////////////////////////////////////////////////////////////////// +// IDAS Adjoint user-supplied functions +/////////////////////////////////////////////////////////////////////////////// + +struct idasa_user_supplied_fn_table +{ + // idas adjoint user-supplied function pointers + nb::object resB, resQB, resBS, resQBS; + + // idas_ls adjoint user-supplied function pointers + nb::object lsjacfnB, lsjacfnBS, lsprecsetupfnB, lsprecsetupfnBS, + lsprecsolvefnB, lsprecsolvefnBS, lsjactimessetupfnB, lsjactimessetupfnBS, + lsjactimesvecfnB, lsjactimesvecfnBS; +}; + +inline idasa_user_supplied_fn_table* idasa_user_supplied_fn_table_alloc() +{ + // We must use malloc since IDASFree calls free + auto fn_table = static_cast( + std::malloc(sizeof(idasa_user_supplied_fn_table))); + + // Zero out the memory + std::memset(fn_table, 0, sizeof(idasa_user_supplied_fn_table)); + + return fn_table; +} + +inline idasa_user_supplied_fn_table* get_idasa_fn_table(void* ida_mem) +{ + auto mem = static_cast(ida_mem); + auto fn_table = static_cast( + static_cast(ida_mem)->python); + if (!fn_table) + { + throw sundials4py::null_function_table( + "Failed to get Python function table from IDAS memory"); + } + return fn_table; +} + +inline idasa_user_supplied_fn_table* get_idasa_fn_table(void* ida_mem, int which) +{ + auto mem = static_cast(IDAGetAdjIDABmem(ida_mem, which)); + auto fn_table = static_cast(mem->python); + if (!fn_table) + { + throw sundials4py::null_function_table( + "Failed to get Python function table from IDAS memory"); + } + return fn_table; +} + +template +inline int idas_resB_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idasa_user_supplied_fn_table, IDAMem, + 1>(&idasa_user_supplied_fn_table::resB, std::forward(args)...); +} + +template +inline int idas_resQB_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idasa_user_supplied_fn_table, IDAMem, + 1>(&idasa_user_supplied_fn_table::resQB, std::forward(args)...); +} + +template +inline int idas_lsjacfnB_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idasa_user_supplied_fn_table, IDAMem, + 4>(&idasa_user_supplied_fn_table::lsjacfnB, std::forward(args)...); +} + +template +inline int idas_lsprecsetupfnB_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idasa_user_supplied_fn_table, + IDAMem, 1>(&idasa_user_supplied_fn_table::lsprecsetupfnB, + std::forward(args)...); +} + +template +inline int idas_lsprecsolvefnB_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idasa_user_supplied_fn_table, + IDAMem, 1>(&idasa_user_supplied_fn_table::lsprecsolvefnB, + std::forward(args)...); +} + +template +inline int idas_lsjactimessetupfnB_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idasa_user_supplied_fn_table, + IDAMem, 1>(&idasa_user_supplied_fn_table::lsjactimessetupfnB, + std::forward(args)...); +} + +template +inline int idas_lsjactimesvecfnB_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, idasa_user_supplied_fn_table, + IDAMem, 3>(&idasa_user_supplied_fn_table::lsjactimesvecfnB, + std::forward(args)...); +} + +using IDAResStdFnBS = int(sunrealtype t, N_Vector y, N_Vector yp, + std::vector yS_1d, + std::vector ypS_1d, N_Vector yB, + N_Vector ypB, N_Vector yBdot, void* user_dataB); + +inline int ida_resBS_wrapper(sunrealtype t, N_Vector y, N_Vector yp, + N_Vector* yS_1d, N_Vector* ypS_1d, N_Vector yB, + N_Vector ypB, N_Vector yBdot, void* user_dataB) +{ + auto ida_mem = static_cast(user_dataB); + auto fn_table = get_idasa_fn_table(user_dataB); + auto fn = nb::cast>(fn_table->resBS); + auto Ns = ida_mem->ida_Ns; + + std::vector yS(yS_1d, yS_1d + Ns); + std::vector ypS(ypS_1d, ypS_1d + Ns); + + return fn(t, y, yp, yS, ypS, yB, ypB, yBdot, nullptr); +} + +using IDAQuadRhsStdFnBS = int(sunrealtype t, N_Vector y, N_Vector yp, + std::vector yS_1d, + std::vector ypS_1d, N_Vector yB, + N_Vector ypB, N_Vector rhsvalBQS, void* user_dataB); + +inline int idas_resQBS_wrapper(sunrealtype t, N_Vector y, N_Vector yp, + N_Vector* yS_1d, N_Vector* ypS_1d, N_Vector yB, + N_Vector ypB, N_Vector rhsvalBQS, void* user_dataB) +{ + auto ida_mem = static_cast(user_dataB); + auto fn_table = static_cast(user_dataB); + auto fn = nb::cast>(fn_table->resQBS); + auto Ns = ida_mem->ida_Ns; + + std::vector yS(yS_1d, yS_1d + Ns); + std::vector ypS(ypS_1d, ypS_1d + Ns); + + return fn(t, y, yp, yS, ypS, yB, ypB, rhsvalBQS, nullptr); +} + +using IDALsJacStdFnBS = int(sunrealtype tt, sunrealtype c_jB, N_Vector yy, + N_Vector yp, std::vector yS_1d, + std::vector ypS_1d, N_Vector yyB, + N_Vector ypB, N_Vector rrB, SUNMatrix JacB, + void* user_dataB, N_Vector tmp1B, N_Vector tmp2B, + N_Vector tmp3B); + +inline int idas_lsjacfnBS_wrapper(sunrealtype tt, sunrealtype c_jB, N_Vector yy, + N_Vector yp, N_Vector* yS_1d, + N_Vector* ypS_1d, N_Vector yyB, N_Vector ypB, + N_Vector rrB, SUNMatrix JacB, void* user_dataB, + N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B) +{ + auto ida_mem = static_cast(user_dataB); + auto fn_table = get_idasa_fn_table(user_dataB); + auto fn = nb::cast>(fn_table->lsjacfnBS); + auto Ns = ida_mem->ida_Ns; + + std::vector yS(yS_1d, yS_1d + Ns); + std::vector ypS(ypS_1d, ypS_1d + Ns); + + return fn(tt, c_jB, yy, yp, yS, ypS, yyB, ypB, rrB, JacB, nullptr, tmp1B, + tmp2B, tmp3B); +} + +using IDALsPrecSetupStdFnBS = int(sunrealtype tt, N_Vector yy, N_Vector yp, + std::vector yyS_1d, + std::vector ypS_1d, N_Vector yyB, + N_Vector ypB, N_Vector rrB, sunrealtype c_jB, + void* user_dataB); + +inline int idas_lsprecsetupfnBS_wrapper(sunrealtype tt, N_Vector yy, N_Vector yp, + N_Vector* yyS_1d, N_Vector* ypS_1d, + N_Vector yyB, N_Vector ypB, N_Vector rrB, + sunrealtype c_jB, void* user_dataB) +{ + auto ida_mem = static_cast(user_dataB); + auto fn_table = get_idasa_fn_table(user_dataB); + auto fn = + nb::cast>(fn_table->lsprecsetupfnBS); + auto Ns = ida_mem->ida_Ns; + + std::vector yyS(yyS_1d, yyS_1d + Ns); + std::vector ypS(ypS_1d, ypS_1d + Ns); + + return fn(tt, yy, yp, yyS, ypS, yyB, ypB, rrB, c_jB, nullptr); +} + +using IDALsPrecSolveStdFnBS = int(sunrealtype tt, N_Vector yy, N_Vector yp, + std::vector yyS_1d, + std::vector ypS_1d, N_Vector yyB, + N_Vector ypB, N_Vector rrB, N_Vector rvecB, + N_Vector zvecB, sunrealtype c_jB, + sunrealtype deltaB, void* user_dataB); + +inline int idas_lsprecsolvefnBS_wrapper(sunrealtype tt, N_Vector yy, N_Vector yp, + N_Vector* yyS_1d, N_Vector* ypS_1d, + N_Vector yyB, N_Vector ypB, + N_Vector rrB, N_Vector rvecB, + N_Vector zvecB, sunrealtype c_jB, + sunrealtype deltaB, void* user_dataB) +{ + auto ida_mem = static_cast(user_dataB); + auto fn_table = get_idasa_fn_table(user_dataB); + auto fn = + nb::cast>(fn_table->lsprecsolvefnBS); + auto Ns = ida_mem->ida_Ns; + + std::vector yyS(yyS_1d, yyS_1d + Ns); + std::vector ypS(ypS_1d, ypS_1d + Ns); + + return fn(tt, yy, yp, yyS, ypS, yyB, ypB, rrB, rvecB, zvecB, c_jB, deltaB, + nullptr); +} + +using IDALsJacTimesSetupStdFnBS = int(sunrealtype t, N_Vector yy, N_Vector yp, + std::vector yyS_1d, + std::vector ypS_1d, + N_Vector yyB, N_Vector ypB, N_Vector rrB, + sunrealtype c_jB, void* user_dataB); + +inline int idas_lsjactimessetupfnBS_wrapper(sunrealtype t, N_Vector yy, + N_Vector yp, N_Vector* yyS_1d, + N_Vector* ypS_1d, N_Vector yyB, + N_Vector ypB, N_Vector rrB, + sunrealtype c_jB, void* user_dataB) +{ + auto ida_mem = static_cast(user_dataB); + auto fn_table = get_idasa_fn_table(user_dataB); + auto fn = nb::cast>( + fn_table->lsjactimessetupfnBS); + auto Ns = ida_mem->ida_Ns; + + std::vector yyS(yyS_1d, yyS_1d + Ns); + std::vector ypS(ypS_1d, ypS_1d + Ns); + + return fn(t, yy, yp, yyS, ypS, yyB, ypB, rrB, c_jB, nullptr); +} + +using IDALsJacTimesVecStdFnBS = int(sunrealtype t, N_Vector yy, N_Vector yp, + std::vector yyS_1d, + std::vector ypS_1d, N_Vector yyB, + N_Vector ypB, N_Vector rrB, N_Vector vB, + N_Vector JvB, sunrealtype c_jB, + void* user_dataB, N_Vector tmp1B, + N_Vector tmp2B); + +inline int idas_lsjactimesvecfnBS_wrapper( + sunrealtype t, N_Vector yy, N_Vector yp, N_Vector* yyS_1d, N_Vector* ypS_1d, + N_Vector yyB, N_Vector ypB, N_Vector rrB, N_Vector vB, N_Vector JvB, + sunrealtype c_jB, void* user_dataB, N_Vector tmp1B, N_Vector tmp2B) +{ + auto ida_mem = static_cast(user_dataB); + auto fn_table = get_idasa_fn_table(user_dataB); + auto fn = nb::cast>( + fn_table->lsjactimesvecfnBS); + auto Ns = ida_mem->ida_Ns; + + std::vector yyS(yyS_1d, yyS_1d + Ns); + std::vector ypS(ypS_1d, ypS_1d + Ns); + + return fn(t, yy, yp, yyS, ypS, yyB, ypB, rrB, vB, JvB, c_jB, nullptr, tmp1B, + tmp2B); +} + +#endif diff --git a/bindings/sundials4py/include/sundials4py.hpp b/bindings/sundials4py/include/sundials4py.hpp new file mode 100644 index 0000000000..2eb350be31 --- /dev/null +++ b/bindings/sundials4py/include/sundials4py.hpp @@ -0,0 +1,28 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "sundials4py_helpers.hpp" +#include "sundials4py_types.hpp" diff --git a/bindings/sundials4py/include/sundials4py_helpers.hpp b/bindings/sundials4py/include/sundials4py_helpers.hpp new file mode 100644 index 0000000000..9be3b3bb15 --- /dev/null +++ b/bindings/sundials4py/include/sundials4py_helpers.hpp @@ -0,0 +1,164 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#ifndef _SUNDIALS4PY_HELPERS_HPP +#define _SUNDIALS4PY_HELPERS_HPP + +#include "sundials4py.hpp" + +namespace nb = nanobind; + +namespace sundials4py { + +/// This function will call a user-supplied Python function through C++ side wrappers +/// \tparam FnType is the function signature, e.g., std::remove_pointer_t +/// \tparam FnTableType is the struct function table that holds the user-supplied Python functions as std::function +/// \tparam UserDataArg is the index of the void* user_data argument of the C function. We are counting from the last arg to the first arg, so if user_data is the last arg then this should be 1. +/// \tparam Args is the template parameter pack that contains all of the types of the function arguments to the C function +/// +/// \param fn_member is the name of the function in the FnTableType to call +/// \param args is the arguments to the C function, which will be forwarded to the user-supplied Python function, except user_data, which is intercepted and passed as a nullptr. +template +int user_supplied_fn_caller(nb::object FnTableType::*fn_member, Args... args) +{ + constexpr size_t N = sizeof...(Args); + constexpr int user_data_index = N - UserDataArg; + auto args_tuple = std::tuple(args...); + + // Extract user_data from the specified position + void* user_data = std::get(args_tuple); + + // Cast user_data to FnTableType* + auto fn_table = static_cast(user_data); + auto fn = nb::cast>(fn_table->*fn_member); + + // Pass nullptr as user_data since we do not want the user to mess with user_data (which holds our function table) + std::get(args_tuple) = nullptr; + return std::apply([&](auto&&... call_args) { return fn(call_args...); }, + args_tuple); +} + +/// This function will call a user-supplied Python function through C++ side wrappers +/// \tparam FnType is the function signature, e.g., std::remove_pointer_t +/// \tparam FnTableType is the struct function table that holds the user-supplied Python functions as std::function +/// \tparam MemType the type that user_data will be cast to +/// \tparam UserDataArg is the index of the void* user_data argument of the C function. We are counting from the last arg to the first arg, so if user_data is the last arg then this should be 1. +/// \tparam Args is the template parameter pack that contains all of the types of the function arguments to the C function +/// +/// \param fn_member is the name of the function in the FnTableType to call +/// \param args is the arguments to the C function, which will be forwarded to the user-supplied Python function, except user_data, which is intercepted and passed as a nullptr. +template +int user_supplied_fn_caller(nb::object FnTableType::*fn_member, Args... args) +{ + constexpr size_t N = sizeof...(Args); + constexpr int user_data_index = N - UserDataArg; + auto args_tuple = std::tuple(args...); + + // Extract user_data from the specified position + void* user_data = std::get(args_tuple); + + // Cast user_data to FnTableType* + auto mem = static_cast(user_data); + auto fn_table = static_cast(mem->python); + auto fn = nb::cast>(fn_table->*fn_member); + + // Pass nullptr as user_data since we do not want the user to mess with user_data (which holds our function table) + std::get(args_tuple) = nullptr; + return std::apply([&](auto&&... call_args) { return fn(call_args...); }, + args_tuple); +} + +/// This function will call a user-supplied Python function through C++ side wrappers +/// \tparam FnType is the function signature, e.g., std::remove_pointer_t +/// \tparam FnTableType is the struct function table that holds the user-supplied Python functions as std::function +/// \tparam Args is the template parameter pack that contains all of the types of the function arguments to the C function +/// +/// \param fn_member is the name of the function in the FnTableType to call +/// \param args is the arguments to the C function, which will be forwarded to the user-supplied Python function. +template +int user_supplied_fn_caller(nb::object FnTableType::*fn_member, void* user_data, + Args... args) +{ + auto args_tuple = std::tuple(args...); + + // Cast user_data to FnTableType* + auto fn_table = static_cast(user_data); + auto fn = nb::cast>(fn_table->*fn_member); + + return std::apply([&](auto&&... call_args) { return fn(call_args...); }, + args_tuple); +} + +/// This function will call a user-supplied Python function through C++ side wrappers +/// \tparam FnType is the function signature, e.g., std::remove_pointer_t +/// \tparam FnTableType is the struct function table that holds the user-supplied Python functions as std::function +/// \tparam Args is the template parameter pack that contains all of the types of the function arguments to the C function +/// +/// \param fn_member is the name of the function in the FnTableType to call +/// \param args is the arguments to the C function, which will be forwarded to the user-supplied Python function. +template +int user_supplied_fn_caller(nb::object FnTableType::*fn_member, Args... args) +{ + auto args_tuple = std::tuple(args...); + + // Cast object->python to FnTableType* + auto object = static_cast(std::get<0>(args_tuple)); + auto fn_table = static_cast(object->python); + auto fn = nb::cast>(fn_table->*fn_member); + + return std::apply([&](auto&&... call_args) { return fn(call_args...); }, + args_tuple); +} + +/// +/// \brief Helper struct to manage reference lifetimes for function return values in Python bindings. +/// +/// Enables the nb::keep_alive paradigm when the function returns a sequence where +/// elements of the sequence (e.g., a tuple) are Nurses. +/// +/// \tparam IP Index of the input argument who is kept alive by the nurses. +/// \tparam IN Indexes of the return values in the returned sequence who keep the patient alive. +// +/// See https://nanobind.readthedocs.io/en/latest/api_core.html#_CPPv4I0EN8nanobind11call_policyE. +/// +template +struct returns_references_to +{ + static void precall(PyObject**, size_t, nb::detail::cleanup_list*) {} + + template + static void postcall(PyObject** args, std::integral_constant, + nb::handle ret) + { + static_assert(IP > 0 && IP <= N, + "IP in returns_references_to must be in the " + "range [1, number of C++ function arguments]"); + + if (!nb::isinstance(ret)) + { + throw sundials4py::error_returned("return value should be a sequence"); + } + + // Directly apply keep_alive for each IN using a fold expression + (nb::detail::keep_alive(ret[IN].ptr(), args[IP - 1]), ...); + } +}; + +} // namespace sundials4py + +#endif diff --git a/bindings/sundials4py/include/sundials4py_types.hpp b/bindings/sundials4py/include/sundials4py_types.hpp new file mode 100644 index 0000000000..3921032e54 --- /dev/null +++ b/bindings/sundials4py/include/sundials4py_types.hpp @@ -0,0 +1,88 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#ifndef _SUNDIALS4PY_TYPES_HPP +#define _SUNDIALS4PY_TYPES_HPP + +#include + +#include + +#include "sundials4py.hpp" + +namespace nb = nanobind; + +namespace sundials4py { + +using Array1d = nb::ndarray, nb::c_contig>; + +class error_returned : public std::runtime_error +{ +public: + explicit error_returned(const char* message) + : std::runtime_error(base_message + message) + {} + + // Constructor that takes a std::string message + explicit error_returned(const std::string& message) + : std::runtime_error(base_message + message) + {} + +private: + inline static const std::string base_message = + "[sundials4py] a SUNDIALS function returned a code indicating an error, " + "details are given below:\n\t"; +}; + +class illegal_value : public std::runtime_error +{ +public: + explicit illegal_value(const char* message) + : std::runtime_error(base_message + message) + {} + + // Constructor that takes a std::string message + explicit illegal_value(const std::string& message) + : std::runtime_error(base_message + message) + {} + +private: + inline static const std::string base_message = + "[sundials4py] an illegal value was given, " + "details are given below:\n\t"; +}; + +class null_function_table : public std::runtime_error +{ +public: + explicit null_function_table(const char* message) + : std::runtime_error(base_message + message) + {} + + // Constructor that takes a std::string message + explicit null_function_table(const std::string& message) + : std::runtime_error(base_message + message) + {} + +private: + inline static const std::string base_message = + "[sundials4py] the python function table was null:\n\t"; +}; + +} // namespace sundials4py + +#endif diff --git a/bindings/sundials4py/kinsol/generate.yaml b/bindings/sundials4py/kinsol/generate.yaml new file mode 100644 index 0000000000..a42f5ed274 --- /dev/null +++ b/bindings/sundials4py/kinsol/generate.yaml @@ -0,0 +1,52 @@ + +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2002-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This YAML file is used to instruct the generate.py script on +# how to generate some of the nanobind code needed for the +# KINSOL module of sundials4py. +# ----------------------------------------------------------------- + +modules: + all: + fn_exclude_by_name__regex: + - "Free" # Free and destroy functions should not need to be called as all objects on the Python side are RAII objects + - "Destroy" + - "Space" # Space functions are deprecated, so dont expose them in Python + # Due to the need to convert between sys.argv and C argv, we need to do custom wrappers of these + - "SetOptions" + macro_define_include_by_name__regex: + - "^SUN_" + - "^KIN_" + - "^KINLS_" + kinsol: + path: kinsol/kinsol_generated.hpp + headers: + - ../../include/kinsol/kinsol.h + - ../../include/kinsol/kinsol_ls.h + # this option describes the functions which have optional pointer arguments, + # i.e., one where you could provide NULL + fn_params_optional_with_default_null: + "SetLinearSolver": + - "A" + fn_exclude_by_name__regex: + # We do custom handling of Create so we can wrap the void* in a KINView + - "^KINCreate$" + # we use user_data for sneaking in python contexts, users can instead capture their states in a class + - "^KINGetUserData$" + # generator cannot handle setting of function pointers, so we do something custom + - "KINInit.*" + - "KINSet.*Fn" + - "KINSet.*Preconditioner" + - "KINSetJacTimes.*" + - "KINSetSysFunc" \ No newline at end of file diff --git a/bindings/sundials4py/kinsol/kinsol.cpp b/bindings/sundials4py/kinsol/kinsol.cpp new file mode 100644 index 0000000000..0f85f2162d --- /dev/null +++ b/bindings/sundials4py/kinsol/kinsol.cpp @@ -0,0 +1,147 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include + +#include +#include +#include + +#include "kinsol/kinsol_impl.h" + +namespace nb = nanobind; +using namespace sundials::experimental; + +#include "kinsol_usersupplied.hpp" + +#define BIND_KINSOL_CALLBACK(NAME, FN_TYPE, MEMBER, WRAPPER, ...) \ + m.def( \ + #NAME, \ + [](void* kin_mem, std::function> fn) \ + { \ + auto fntable = get_kinsol_fn_table(kin_mem); \ + fntable->MEMBER = nb::cast(fn); \ + if (fn) { return NAME(kin_mem, WRAPPER); } \ + else { return NAME(kin_mem, nullptr); } \ + }, \ + __VA_ARGS__) + +#define BIND_KINSOL_CALLBACK2(NAME, FN_TYPE1, MEMBER1, WRAPPER1, FN_TYPE2, \ + MEMBER2, WRAPPER2, ...) \ + m.def( \ + #NAME, \ + [](void* kin_mem, std::function> fn1, \ + std::function> fn2) \ + { \ + auto fntable = get_kinsol_fn_table(kin_mem); \ + fntable->MEMBER1 = nb::cast(fn1); \ + fntable->MEMBER2 = nb::cast(fn2); \ + if (fn1) { return NAME(kin_mem, WRAPPER1, WRAPPER2); } \ + else { return NAME(kin_mem, nullptr, WRAPPER2); } \ + }, \ + __VA_ARGS__) + +namespace sundials4py { + +void bind_kinsol(nb::module_& m) +{ +#include "kinsol_generated.hpp" + + nb::class_(m, "KINView") + .def("get", nb::overload_cast<>(&KINView::get, nb::const_), + nb::rv_policy::reference); + + m.def( + "KINSetOptions", + [](void* kin_mem, const std::string& kinid, const std::string& file_name, + int argc, const std::vector& args) + { + std::vector argv; + + for (const auto& arg : args) + { + // We need a non-const char*, so we use data() and an explicit cast. + // This is safe as long as the underlying std::string is not modified. + argv.push_back(const_cast(arg.data())); + } + + return KINSetOptions(kin_mem, kinid.empty() ? nullptr : kinid.c_str(), + file_name.empty() ? nullptr : file_name.c_str(), + argc, argv.data()); + }, + nb::arg("kin_mem"), nb::arg("kinid"), nb::arg("file_name"), nb::arg("argc"), + nb::arg("args")); + + m.def( + "KINCreate", + [](SUNContext sunctx) + { return std::make_shared(KINCreate(sunctx)); }, + nb::arg("sunctx"), nb::keep_alive<0, 1>()); + + m.def("KINInit", + [](void* kin_mem, std::function> sysfn, + N_Vector tmpl) + { + int kin_status = KINInit(kin_mem, kinsol_sysfn_wrapper, tmpl); + + auto fn_table = kinsol_user_supplied_fn_table_alloc(); + auto kinsol_mem = static_cast(kin_mem); + kinsol_mem->python = fn_table; + kin_status = KINSetUserData(kin_mem, kin_mem); + if (kin_status != KIN_SUCCESS) + { + free(fn_table); + throw sundials4py::error_returned( + "Failed to set user data in KINSOL memory"); + } + + fn_table->sysfn = nb::cast(sysfn); + return kin_status; + }); + + BIND_KINSOL_CALLBACK(KINSetSysFunc, KINSysFn, sysfn, kinsol_sysfn_wrapper, + nb::arg("kin_mem"), nb::arg("sysfn")); + + BIND_KINSOL_CALLBACK(KINSetDampingFn, KINDampingStdFn, dampingfn, + kinsol_dampingfn_wrapper, nb::arg("kin_mem"), + nb::arg("damping_fn").none()); + + BIND_KINSOL_CALLBACK(KINSetDepthFn, KINDepthStdFn, depthfn, + kinsol_depthfn_wrapper, nb::arg("kin_mem"), + nb::arg("depth_fn").none()); + + BIND_KINSOL_CALLBACK2(KINSetPreconditioner, KINLsPrecSetupFn, lsprecsetupfn, + kinsol_lsprecsetupfn_wrapper, KINLsPrecSolveFn, + lsprecsolvefn, kinsol_lsprecsolvefn_wrapper, + nb::arg("kin_mem"), nb::arg("psetup").none(), + nb::arg("psolve").none()); + + BIND_KINSOL_CALLBACK(KINSetJacFn, KINSysFn, lsjacfn, kinsol_lsjacfn_wrapper, + nb::arg("kin_mem"), nb::arg("jac").none()); + + BIND_KINSOL_CALLBACK(KINSetJacTimesVecFn, KINLsJacTimesVecStdFn, + lsjactimesvecfn, kinsol_lsjactimesvecfn_wrapper, + nb::arg("kin_mem"), nb::arg("jtimes").none()); + + BIND_KINSOL_CALLBACK(KINSetJacTimesVecSysFn, KINSysFn, lsjtvsysfn, + kinsol_lsjtvsysfn_wrapper, nb::arg("kin_mem"), + nb::arg("jtvSysFn").none()); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/kinsol/kinsol_generated.hpp b/bindings/sundials4py/kinsol/kinsol_generated.hpp new file mode 100644 index 0000000000..151915d495 --- /dev/null +++ b/bindings/sundials4py/kinsol/kinsol_generated.hpp @@ -0,0 +1,431 @@ +// #ifndef _KINSOL_H +// +// #ifdef __cplusplus +// #endif +// +m.attr("KIN_SUCCESS") = 0; +m.attr("KIN_INITIAL_GUESS_OK") = 1; +m.attr("KIN_STEP_LT_STPTOL") = 2; +m.attr("KIN_WARNING") = 99; +m.attr("KIN_MEM_NULL") = -1; +m.attr("KIN_ILL_INPUT") = -2; +m.attr("KIN_NO_MALLOC") = -3; +m.attr("KIN_MEM_FAIL") = -4; +m.attr("KIN_LINESEARCH_NONCONV") = -5; +m.attr("KIN_MAXITER_REACHED") = -6; +m.attr("KIN_MXNEWT_5X_EXCEEDED") = -7; +m.attr("KIN_LINESEARCH_BCFAIL") = -8; +m.attr("KIN_LINSOLV_NO_RECOVERY") = -9; +m.attr("KIN_LINIT_FAIL") = -10; +m.attr("KIN_LSETUP_FAIL") = -11; +m.attr("KIN_LSOLVE_FAIL") = -12; +m.attr("KIN_SYSFUNC_FAIL") = -13; +m.attr("KIN_FIRST_SYSFUNC_ERR") = -14; +m.attr("KIN_REPTD_SYSFUNC_ERR") = -15; +m.attr("KIN_VECTOROP_ERR") = -16; +m.attr("KIN_CONTEXT_ERR") = -17; +m.attr("KIN_DAMPING_FN_ERR") = -18; +m.attr("KIN_DEPTH_FN_ERR") = -19; +m.attr("KIN_ORTH_MGS") = 0; +m.attr("KIN_ORTH_ICWY") = 1; +m.attr("KIN_ORTH_CGS2") = 2; +m.attr("KIN_ORTH_DCGS2") = 3; +m.attr("KIN_ETACHOICE1") = 1; +m.attr("KIN_ETACHOICE2") = 2; +m.attr("KIN_ETACONSTANT") = 3; +m.attr("KIN_NONE") = 0; +m.attr("KIN_LINESEARCH") = 1; +m.attr("KIN_PICARD") = 2; +m.attr("KIN_FP") = 3; + +m.def("KINSol", KINSol, nb::arg("kinmem"), nb::arg("uu"), nb::arg("strategy"), + nb::arg("u_scale"), nb::arg("f_scale"), "Solver function"); + +m.def("KINSetUserData", KINSetUserData, nb::arg("kinmem"), nb::arg("user_data")); + +m.def("KINSetDamping", KINSetDamping, nb::arg("kinmem"), nb::arg("beta")); + +m.def("KINSetMAA", KINSetMAA, nb::arg("kinmem"), nb::arg("maa")); + +m.def("KINSetOrthAA", KINSetOrthAA, nb::arg("kinmem"), nb::arg("orthaa")); + +m.def("KINSetDelayAA", KINSetDelayAA, nb::arg("kinmem"), nb::arg("delay")); + +m.def("KINSetDampingAA", KINSetDampingAA, nb::arg("kinmem"), nb::arg("beta")); + +m.def("KINSetReturnNewest", KINSetReturnNewest, nb::arg("kinmem"), + nb::arg("ret_newest")); + +m.def("KINSetNumMaxIters", KINSetNumMaxIters, nb::arg("kinmem"), + nb::arg("mxiter")); + +m.def("KINSetNoInitSetup", KINSetNoInitSetup, nb::arg("kinmem"), + nb::arg("noInitSetup")); + +m.def("KINSetNoResMon", KINSetNoResMon, nb::arg("kinmem"), + nb::arg("noNNIResMon")); + +m.def("KINSetMaxSetupCalls", KINSetMaxSetupCalls, nb::arg("kinmem"), + nb::arg("msbset")); + +m.def("KINSetMaxSubSetupCalls", KINSetMaxSubSetupCalls, nb::arg("kinmem"), + nb::arg("msbsetsub")); + +m.def("KINSetEtaForm", KINSetEtaForm, nb::arg("kinmem"), nb::arg("etachoice")); + +m.def("KINSetEtaConstValue", KINSetEtaConstValue, nb::arg("kinmem"), + nb::arg("eta")); + +m.def("KINSetEtaParams", KINSetEtaParams, nb::arg("kinmem"), nb::arg("egamma"), + nb::arg("ealpha")); + +m.def("KINSetResMonParams", KINSetResMonParams, nb::arg("kinmem"), + nb::arg("omegamin"), nb::arg("omegamax")); + +m.def("KINSetResMonConstValue", KINSetResMonConstValue, nb::arg("kinmem"), + nb::arg("omegaconst")); + +m.def("KINSetNoMinEps", KINSetNoMinEps, nb::arg("kinmem"), nb::arg("noMinEps")); + +m.def("KINSetMaxNewtonStep", KINSetMaxNewtonStep, nb::arg("kinmem"), + nb::arg("mxnewtstep")); + +m.def("KINSetMaxBetaFails", KINSetMaxBetaFails, nb::arg("kinmem"), + nb::arg("mxnbcf")); + +m.def("KINSetRelErrFunc", KINSetRelErrFunc, nb::arg("kinmem"), + nb::arg("relfunc")); + +m.def("KINSetFuncNormTol", KINSetFuncNormTol, nb::arg("kinmem"), + nb::arg("fnormtol")); + +m.def("KINSetScaledStepTol", KINSetScaledStepTol, nb::arg("kinmem"), + nb::arg("scsteptol")); + +m.def("KINSetConstraints", KINSetConstraints, nb::arg("kinmem"), + nb::arg("constraints")); + +m.def( + "KINGetNumNonlinSolvIters", + [](void* kinmem) -> std::tuple + { + auto KINGetNumNonlinSolvIters_adapt_modifiable_immutable_to_return = + [](void* kinmem) -> std::tuple + { + long nniters_adapt_modifiable; + + int r = KINGetNumNonlinSolvIters(kinmem, &nniters_adapt_modifiable); + return std::make_tuple(r, nniters_adapt_modifiable); + }; + + return KINGetNumNonlinSolvIters_adapt_modifiable_immutable_to_return(kinmem); + }, + nb::arg("kinmem")); + +m.def( + "KINGetNumFuncEvals", + [](void* kinmem) -> std::tuple + { + auto KINGetNumFuncEvals_adapt_modifiable_immutable_to_return = + [](void* kinmem) -> std::tuple + { + long nfevals_adapt_modifiable; + + int r = KINGetNumFuncEvals(kinmem, &nfevals_adapt_modifiable); + return std::make_tuple(r, nfevals_adapt_modifiable); + }; + + return KINGetNumFuncEvals_adapt_modifiable_immutable_to_return(kinmem); + }, + nb::arg("kinmem")); + +m.def( + "KINGetNumBetaCondFails", + [](void* kinmem) -> std::tuple + { + auto KINGetNumBetaCondFails_adapt_modifiable_immutable_to_return = + [](void* kinmem) -> std::tuple + { + long nbcfails_adapt_modifiable; + + int r = KINGetNumBetaCondFails(kinmem, &nbcfails_adapt_modifiable); + return std::make_tuple(r, nbcfails_adapt_modifiable); + }; + + return KINGetNumBetaCondFails_adapt_modifiable_immutable_to_return(kinmem); + }, + nb::arg("kinmem")); + +m.def( + "KINGetNumBacktrackOps", + [](void* kinmem) -> std::tuple + { + auto KINGetNumBacktrackOps_adapt_modifiable_immutable_to_return = + [](void* kinmem) -> std::tuple + { + long nbacktr_adapt_modifiable; + + int r = KINGetNumBacktrackOps(kinmem, &nbacktr_adapt_modifiable); + return std::make_tuple(r, nbacktr_adapt_modifiable); + }; + + return KINGetNumBacktrackOps_adapt_modifiable_immutable_to_return(kinmem); + }, + nb::arg("kinmem")); + +m.def( + "KINGetFuncNorm", + [](void* kinmem) -> std::tuple + { + auto KINGetFuncNorm_adapt_modifiable_immutable_to_return = + [](void* kinmem) -> std::tuple + { + sunrealtype fnorm_adapt_modifiable; + + int r = KINGetFuncNorm(kinmem, &fnorm_adapt_modifiable); + return std::make_tuple(r, fnorm_adapt_modifiable); + }; + + return KINGetFuncNorm_adapt_modifiable_immutable_to_return(kinmem); + }, + nb::arg("kinmem")); + +m.def( + "KINGetStepLength", + [](void* kinmem) -> std::tuple + { + auto KINGetStepLength_adapt_modifiable_immutable_to_return = + [](void* kinmem) -> std::tuple + { + sunrealtype steplength_adapt_modifiable; + + int r = KINGetStepLength(kinmem, &steplength_adapt_modifiable); + return std::make_tuple(r, steplength_adapt_modifiable); + }; + + return KINGetStepLength_adapt_modifiable_immutable_to_return(kinmem); + }, + nb::arg("kinmem")); + +m.def("KINPrintAllStats", KINPrintAllStats, nb::arg("kinmem"), + nb::arg("outfile"), nb::arg("fmt")); + +m.def("KINGetReturnFlagName", KINGetReturnFlagName, nb::arg("flag")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// +// #ifndef _KINLS_H +// +// #ifdef __cplusplus +// #endif +// +m.attr("KINLS_SUCCESS") = 0; +m.attr("KINLS_MEM_NULL") = -1; +m.attr("KINLS_LMEM_NULL") = -2; +m.attr("KINLS_ILL_INPUT") = -3; +m.attr("KINLS_MEM_FAIL") = -4; +m.attr("KINLS_PMEM_NULL") = -5; +m.attr("KINLS_JACFUNC_ERR") = -6; +m.attr("KINLS_SUNMAT_FAIL") = -7; +m.attr("KINLS_SUNLS_FAIL") = -8; + +m.def( + "KINSetLinearSolver", + [](void* kinmem, SUNLinearSolver LS, + std::optional A = std::nullopt) -> int + { + auto KINSetLinearSolver_adapt_optional_arg_with_default_null = + [](void* kinmem, SUNLinearSolver LS, + std::optional A = std::nullopt) -> int + { + SUNMatrix A_adapt_default_null = nullptr; + if (A.has_value()) A_adapt_default_null = A.value(); + + auto lambda_result = KINSetLinearSolver(kinmem, LS, A_adapt_default_null); + return lambda_result; + }; + + return KINSetLinearSolver_adapt_optional_arg_with_default_null(kinmem, LS, A); + }, + nb::arg("kinmem"), nb::arg("LS"), nb::arg("A").none() = nb::none()); + +m.def( + "KINGetJac", + [](void* kinmem) -> std::tuple + { + auto KINGetJac_adapt_modifiable_immutable_to_return = + [](void* kinmem) -> std::tuple + { + SUNMatrix J_adapt_modifiable; + + int r = KINGetJac(kinmem, &J_adapt_modifiable); + return std::make_tuple(r, J_adapt_modifiable); + }; + + return KINGetJac_adapt_modifiable_immutable_to_return(kinmem); + }, + nb::arg("kinmem"), "nb::rv_policy::reference", nb::rv_policy::reference); + +m.def( + "KINGetJacNumIters", + [](void* kinmem) -> std::tuple + { + auto KINGetJacNumIters_adapt_modifiable_immutable_to_return = + [](void* kinmem) -> std::tuple + { + long nni_J_adapt_modifiable; + + int r = KINGetJacNumIters(kinmem, &nni_J_adapt_modifiable); + return std::make_tuple(r, nni_J_adapt_modifiable); + }; + + return KINGetJacNumIters_adapt_modifiable_immutable_to_return(kinmem); + }, + nb::arg("kinmem")); + +m.def( + "KINGetNumJacEvals", + [](void* kinmem) -> std::tuple + { + auto KINGetNumJacEvals_adapt_modifiable_immutable_to_return = + [](void* kinmem) -> std::tuple + { + long njevals_adapt_modifiable; + + int r = KINGetNumJacEvals(kinmem, &njevals_adapt_modifiable); + return std::make_tuple(r, njevals_adapt_modifiable); + }; + + return KINGetNumJacEvals_adapt_modifiable_immutable_to_return(kinmem); + }, + nb::arg("kinmem")); + +m.def( + "KINGetNumLinFuncEvals", + [](void* kinmem) -> std::tuple + { + auto KINGetNumLinFuncEvals_adapt_modifiable_immutable_to_return = + [](void* kinmem) -> std::tuple + { + long nfevals_adapt_modifiable; + + int r = KINGetNumLinFuncEvals(kinmem, &nfevals_adapt_modifiable); + return std::make_tuple(r, nfevals_adapt_modifiable); + }; + + return KINGetNumLinFuncEvals_adapt_modifiable_immutable_to_return(kinmem); + }, + nb::arg("kinmem")); + +m.def( + "KINGetNumPrecEvals", + [](void* kinmem) -> std::tuple + { + auto KINGetNumPrecEvals_adapt_modifiable_immutable_to_return = + [](void* kinmem) -> std::tuple + { + long npevals_adapt_modifiable; + + int r = KINGetNumPrecEvals(kinmem, &npevals_adapt_modifiable); + return std::make_tuple(r, npevals_adapt_modifiable); + }; + + return KINGetNumPrecEvals_adapt_modifiable_immutable_to_return(kinmem); + }, + nb::arg("kinmem")); + +m.def( + "KINGetNumPrecSolves", + [](void* kinmem) -> std::tuple + { + auto KINGetNumPrecSolves_adapt_modifiable_immutable_to_return = + [](void* kinmem) -> std::tuple + { + long npsolves_adapt_modifiable; + + int r = KINGetNumPrecSolves(kinmem, &npsolves_adapt_modifiable); + return std::make_tuple(r, npsolves_adapt_modifiable); + }; + + return KINGetNumPrecSolves_adapt_modifiable_immutable_to_return(kinmem); + }, + nb::arg("kinmem")); + +m.def( + "KINGetNumLinIters", + [](void* kinmem) -> std::tuple + { + auto KINGetNumLinIters_adapt_modifiable_immutable_to_return = + [](void* kinmem) -> std::tuple + { + long nliters_adapt_modifiable; + + int r = KINGetNumLinIters(kinmem, &nliters_adapt_modifiable); + return std::make_tuple(r, nliters_adapt_modifiable); + }; + + return KINGetNumLinIters_adapt_modifiable_immutable_to_return(kinmem); + }, + nb::arg("kinmem")); + +m.def( + "KINGetNumLinConvFails", + [](void* kinmem) -> std::tuple + { + auto KINGetNumLinConvFails_adapt_modifiable_immutable_to_return = + [](void* kinmem) -> std::tuple + { + long nlcfails_adapt_modifiable; + + int r = KINGetNumLinConvFails(kinmem, &nlcfails_adapt_modifiable); + return std::make_tuple(r, nlcfails_adapt_modifiable); + }; + + return KINGetNumLinConvFails_adapt_modifiable_immutable_to_return(kinmem); + }, + nb::arg("kinmem")); + +m.def( + "KINGetNumJtimesEvals", + [](void* kinmem) -> std::tuple + { + auto KINGetNumJtimesEvals_adapt_modifiable_immutable_to_return = + [](void* kinmem) -> std::tuple + { + long njvevals_adapt_modifiable; + + int r = KINGetNumJtimesEvals(kinmem, &njvevals_adapt_modifiable); + return std::make_tuple(r, njvevals_adapt_modifiable); + }; + + return KINGetNumJtimesEvals_adapt_modifiable_immutable_to_return(kinmem); + }, + nb::arg("kinmem")); + +m.def( + "KINGetLastLinFlag", + [](void* kinmem) -> std::tuple + { + auto KINGetLastLinFlag_adapt_modifiable_immutable_to_return = + [](void* kinmem) -> std::tuple + { + long flag_adapt_modifiable; + + int r = KINGetLastLinFlag(kinmem, &flag_adapt_modifiable); + return std::make_tuple(r, flag_adapt_modifiable); + }; + + return KINGetLastLinFlag_adapt_modifiable_immutable_to_return(kinmem); + }, + nb::arg("kinmem")); + +m.def("KINGetLinReturnFlagName", KINGetLinReturnFlagName, nb::arg("flag")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/kinsol/kinsol_usersupplied.hpp b/bindings/sundials4py/kinsol/kinsol_usersupplied.hpp new file mode 100644 index 0000000000..23cda733e8 --- /dev/null +++ b/bindings/sundials4py/kinsol/kinsol_usersupplied.hpp @@ -0,0 +1,200 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#ifndef _SUNDIALS4PY_KINSOL_USERSUPPLIED_HPP +#define _SUNDIALS4PY_KINSOL_USERSUPPLIED_HPP + +#include +#include + +#include "sundials/sundials_types.h" +#include "sundials4py.hpp" + +#include +#include + +#include "kinsol_impl.h" + +#include + +#include "sundials4py_helpers.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +/////////////////////////////////////////////////////////////////////////////// +// KINSOL user-supplied function table +// Every package-level user-supplied function must be in this table. +// The user-supplied function table is passed to KINSOL as user_data. +/////////////////////////////////////////////////////////////////////////////// + +struct kinsol_user_supplied_fn_table +{ + // KINSOL user-supplied function pointers + nb::object sysfn, dampingfn, depthfn; + + // KINSOL LS user-supplied function pointers + nb::object lsjacfn, lsjactimesvecfn, lsjtvsysfn, lsprecsetupfn, lsprecsolvefn; +}; + +// Helper to extract KINMem and function table +inline kinsol_user_supplied_fn_table* get_kinsol_fn_table(void* kin_mem) +{ + auto mem = static_cast(kin_mem); + auto fntable = static_cast(mem->python); + if (!fntable) + { + throw sundials4py::null_function_table( + "Failed to get Python function table from KINSOL memory"); + } + return fntable; +} + +/////////////////////////////////////////////////////////////////////////////// +// KINSOL user-supplied functions +/////////////////////////////////////////////////////////////////////////////// + +inline kinsol_user_supplied_fn_table* kinsol_user_supplied_fn_table_alloc() +{ + // We must use malloc since KINFree calls free + auto fn_table = static_cast( + std::malloc(sizeof(kinsol_user_supplied_fn_table))); + + // Zero out the memory + std::memset(fn_table, 0, sizeof(kinsol_user_supplied_fn_table)); + + return fn_table; +} + +template +inline int kinsol_sysfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, kinsol_user_supplied_fn_table, KINMem, + 1>(&kinsol_user_supplied_fn_table::sysfn, std::forward(args)...); +} + +using KINDampingStdFn = std::tuple( + long int iter, N_Vector u_val, N_Vector g_val, sundials4py::Array1d qt_fn, + long int depth, void* user_data); + +inline int kinsol_dampingfn_wrapper(long int iter, N_Vector u_val, N_Vector g_val, + sunrealtype* qt_fn_1d, long int depth, + void* user_data, sunrealtype* damping_factor) +{ + auto fn_table = get_kinsol_fn_table(user_data); + auto fn = nb::cast>(fn_table->dampingfn); + + sundials4py::Array1d qt_fn(qt_fn_1d, {static_cast(depth)}, + nb::find(qt_fn_1d)); + + auto result = fn(iter, u_val, g_val, qt_fn, depth, nullptr); + + *damping_factor = std::get<1>(result); + + return std::get<0>(result); +} + +using KINDepthStdFn = std::tuple( + long int iter, N_Vector u_val, N_Vector g_val, N_Vector f_val, + std::vector df, sundials4py::Array1d R_mat, long int depth, + void* user_data, std::vector remove_indices); + +inline int kinsol_depthfn_wrapper(long int iter, N_Vector u_val, N_Vector g_val, + N_Vector f_val, N_Vector* df_1d, + sunrealtype* R_mat_1d, long int depth, + void* user_data, long int* new_depth, + sunbooleantype* remove_indices_1d) +{ + auto fn_table = get_kinsol_fn_table(user_data); + auto fn = nb::cast>(fn_table->depthfn); + + std::vector df(df_1d, df_1d + depth); + sundials4py::Array1d R_mat(R_mat_1d, + {static_cast(depth * depth)}, + nb::find(R_mat_1d)); + if (remove_indices_1d) + { + std::vector remove_indices(remove_indices_1d, + remove_indices_1d + depth); + auto result = fn(iter, u_val, g_val, f_val, df, R_mat, depth, nullptr, + remove_indices); + *new_depth = std::get<1>(result); + return std::get<0>(result); + } + + std::vector remove_indices(0); + auto result = fn(iter, u_val, g_val, f_val, df, R_mat, depth, nullptr, + remove_indices); + *new_depth = std::get<1>(result); + return std::get<0>(result); +} + +template +inline int kinsol_lsjacfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, kinsol_user_supplied_fn_table, KINMem, + 3>(&kinsol_user_supplied_fn_table::lsjacfn, std::forward(args)...); +} + +template +inline int kinsol_lsprecsetupfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, kinsol_user_supplied_fn_table, + KINMem, 1>(&kinsol_user_supplied_fn_table::lsprecsetupfn, + std::forward(args)...); +} + +template +inline int kinsol_lsprecsolvefn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, kinsol_user_supplied_fn_table, + KINMem, 1>(&kinsol_user_supplied_fn_table::lsprecsolvefn, + std::forward(args)...); +} + +using KINLsJacTimesVecStdFn = std::tuple(N_Vector v, + N_Vector Jv, + N_Vector u, + void* user_data); + +inline int kinsol_lsjactimesvecfn_wrapper(N_Vector v, N_Vector Jv, N_Vector u, + sunbooleantype* new_u, void* user_data) +{ + auto fn_table = get_kinsol_fn_table(user_data); + auto fn = + nb::cast>(fn_table->lsjactimesvecfn); + + auto result = fn(v, Jv, u, nullptr); + + *new_u = std::get<1>(result); + + return std::get<0>(result); +} + +template +inline int kinsol_lsjtvsysfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, kinsol_user_supplied_fn_table, KINMem, + 1>(&kinsol_user_supplied_fn_table::lsjtvsysfn, std::forward(args)...); +} + +#endif // _SUNDIALS4PY_KINSOL_USERSUPPLIED_HPP diff --git a/bindings/sundials4py/nvector/generate.yaml b/bindings/sundials4py/nvector/generate.yaml new file mode 100644 index 0000000000..291bb596a0 --- /dev/null +++ b/bindings/sundials4py/nvector/generate.yaml @@ -0,0 +1,103 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2002-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This YAML file is used to instruct the generate.py script on +# how to generate some of the nanobind code needed for the +# core module of sundials4py. +# ----------------------------------------------------------------- + +modules: + all: + macro_define_include_by_name__regex: + - "^SUN_" + fn_exclude_by_name__regex: + # Don't interface the implementation specific overrides of the generic routines + - "^N_VGetVectorID_.*" + - "^N_VCloneEmpty_.*" + - "^N_VClone_.*" + - "^N_VDestroy_.*" + - "^N_VAbs_.*" + - "^N_VAddConst_.*" + - "^N_VBufPack_.*" + - "^N_VBufSize_.*" + - "^N_VBufUnpack_.*" + - "^N_VCloneEmptyVectorArray_.*" + - "^N_VCloneVectorArray_.*" + - "^N_VCompare_.*" + - "^N_VConst_.*" + - "^N_VConstrMask_.*" + - "^N_VConstrMaskLocal_.*" + - "^N_VConstVectorArray_.*" + - "^N_VDestroyVectorArray_.*" + - "^N_VDiv_.*" + - "^N_VDotProd_.*" + - "^N_VDotProdLocal_.*" + - "^N_VDotProdMulti_.*" + - "^N_VDotProdMultiAllReduce_.*" + - "^N_VDotProdMultiLocal_.*" + - "^N_VGetArrayPointer_.*" + - "^N_VGetCommunicator_.*" + - "^N_VGetDeviceArrayPointer_.*" + - "^N_VGetLength_.*" + - "^N_VGetLocalLength_.*" + - "^N_VGetVecAtIndexVectorArray_.*" + - "^N_VInv_.*" + - "^N_VInvTest_.*" + - "^N_VInvTestLocal_.*" + - "^N_VL1Norm_.*" + - "^N_VL1NormLocal_.*" + - "^N_VLinearCombination_.*" + - "^N_VLinearCombinationVectorArray_.*" + - "^N_VLinearSum_.*" + - "^N_VLinearSumVectorArray_.*" + - "^N_VMaxNorm_.*" + - "^N_VMaxNormLocal_.*" + - "^N_VMin_.*" + - "^N_VMinLocal_.*" + - "^N_VMinQuotient_.*" + - "^N_VMinQuotientLocal_.*" + - "^N_VNewVectorArray_.*" + - "^N_VPrint_.*" + - "^N_VPrintFile_.*" + - "^N_VProd_.*" + - "^N_VScale_.*" + - "^N_VScaleAddMulti_.*" + - "^N_VScaleAddMultiVectorArray_.*" + - "^N_VScaleVectorArray_.*" + - "^N_VSetArrayPointer_.*" + - "^N_VSetVecAtIndexVectorArray_.*" + - "^N_VSpace_.*" + - "^N_VWL2Norm_.*" + - "^N_VWrmsNorm_.*" + - "^N_VWrmsNormMask_.*" + - "^N_VWrmsNormMaskVectorArray_.*" + - "^N_VWrmsNormVectorArray_.*" + - "^N_VWSqrSumLocal_.*" + - "^N_VWSqrSumMaskLocal_.*" + nvector_serial: + path: nvector/nvector_serial_generated.hpp + headers: + - ../../include/nvector/nvector_serial.h + nvector_manyvector: + path: nvector/nvector_manyvector_generated.hpp + headers: + - ../../include/nvector/nvector_manyvector.h + fn_exclude_by_name__regex: + # N_VNew_ManyVector takes an array of N_Vectors for the subvectors. We need to keep + # the subvectors alive as long as the ManyVector exists. Since we also need to keep + # the SUNContext alive, this case falls outside what our litgen extension adapt_sundials_types_returns.py + # can currently handle, so we manually bind to it. + - "^N_VNew_ManyVector$" + # Getting and setting array pointers requires us to do something custom + - "^N_VGetSubvectorArrayPointer_ManyVector$" + - "^N_VSetSubvectorArrayPointer_ManyVector$" diff --git a/bindings/sundials4py/nvector/nvector_manyvector.cpp b/bindings/sundials4py/nvector/nvector_manyvector.cpp new file mode 100644 index 0000000000..ffa1276e90 --- /dev/null +++ b/bindings/sundials4py/nvector/nvector_manyvector.cpp @@ -0,0 +1,31 @@ +#include "sundials4py.hpp" + +#include +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_nvector_manyvector(nb::module_& m) +{ +#include "nvector_manyvector_generated.hpp" + + m.def( + "N_VNew_ManyVector", + [](sunindextype num_subvectors, std::vector vec_array_1d, + SUNContext sunctx) -> std::shared_ptr> + { + N_Vector* vec_array_1d_ptr = reinterpret_cast( + vec_array_1d.empty() ? nullptr : vec_array_1d.data()); + return our_make_shared, N_VectorDeleter>( + N_VNew_ManyVector(num_subvectors, vec_array_1d_ptr, sunctx)); + }, + nb::arg("num_subvectors"), nb::arg("vec_array_1d"), nb::arg("sunctx"), + nb::keep_alive<0, 3>() /* keep the SUNContext alive as long as the N_Vector is */, + nb::keep_alive<0, 2>() /* keep the list, and thus the elements, alive as long as the N_Vector is */); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/nvector/nvector_manyvector_generated.hpp b/bindings/sundials4py/nvector/nvector_manyvector_generated.hpp new file mode 100644 index 0000000000..3e395fef6a --- /dev/null +++ b/bindings/sundials4py/nvector/nvector_manyvector_generated.hpp @@ -0,0 +1,55 @@ +// #ifndef _NVECTOR_MANY_VECTOR_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClass_N_VectorContent_ManyVector = + nb::class_<_N_VectorContent_ManyVector>(m, "_N_VectorContent_ManyVector", "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def("N_VGetSubvector_ManyVector", N_VGetSubvector_ManyVector, nb::arg("v"), + nb::arg("vec_num"), "nb::rv_policy::reference", nb::rv_policy::reference); + +m.def("N_VGetNumSubvectors_ManyVector", N_VGetNumSubvectors_ManyVector, + nb::arg("v")); + +m.def("N_VGetSubvectorLocalLength_ManyVector", + N_VGetSubvectorLocalLength_ManyVector, nb::arg("v"), nb::arg("vec_num")); + +m.def("N_VEnableFusedOps_ManyVector", N_VEnableFusedOps_ManyVector, + nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableLinearCombination_ManyVector", + N_VEnableLinearCombination_ManyVector, nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableScaleAddMulti_ManyVector", N_VEnableScaleAddMulti_ManyVector, + nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableDotProdMulti_ManyVector", N_VEnableDotProdMulti_ManyVector, + nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableLinearSumVectorArray_ManyVector", + N_VEnableLinearSumVectorArray_ManyVector, nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableScaleVectorArray_ManyVector", + N_VEnableScaleVectorArray_ManyVector, nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableConstVectorArray_ManyVector", + N_VEnableConstVectorArray_ManyVector, nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableWrmsNormVectorArray_ManyVector", + N_VEnableWrmsNormVectorArray_ManyVector, nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableWrmsNormMaskVectorArray_ManyVector", + N_VEnableWrmsNormMaskVectorArray_ManyVector, nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableDotProdMultiLocal_ManyVector", + N_VEnableDotProdMultiLocal_ManyVector, nb::arg("v"), nb::arg("tf")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/nvector/nvector_serial.cpp b/bindings/sundials4py/nvector/nvector_serial.cpp new file mode 100644 index 0000000000..2ec2d784f0 --- /dev/null +++ b/bindings/sundials4py/nvector/nvector_serial.cpp @@ -0,0 +1,18 @@ +#include "sundials4py.hpp" + +#include +#include + +#include "sundials/sundials_classview.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_nvector_serial(nb::module_& m) +{ +#include "nvector_serial_generated.hpp" +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/nvector/nvector_serial_generated.hpp b/bindings/sundials4py/nvector/nvector_serial_generated.hpp new file mode 100644 index 0000000000..c15e78fe2a --- /dev/null +++ b/bindings/sundials4py/nvector/nvector_serial_generated.hpp @@ -0,0 +1,123 @@ +// #ifndef _NVECTOR_SERIAL_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClass_N_VectorContent_Serial = + nb::class_<_N_VectorContent_Serial>(m, "_N_VectorContent_Serial", "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "N_VNewEmpty_Serial", + [](sunindextype vec_length, + SUNContext sunctx) -> std::shared_ptr> + { + auto N_VNewEmpty_Serial_adapt_return_type_to_shared_ptr = + [](sunindextype vec_length, + SUNContext sunctx) -> std::shared_ptr> + { + auto lambda_result = N_VNewEmpty_Serial(vec_length, sunctx); + + return our_make_shared, N_VectorDeleter>( + lambda_result); + }; + + return N_VNewEmpty_Serial_adapt_return_type_to_shared_ptr(vec_length, sunctx); + }, + nb::arg("vec_length"), nb::arg("sunctx"), "nb::keep_alive<0, 2>()", + nb::keep_alive<0, 2>()); + +m.def( + "N_VNew_Serial", + [](sunindextype vec_length, + SUNContext sunctx) -> std::shared_ptr> + { + auto N_VNew_Serial_adapt_return_type_to_shared_ptr = + [](sunindextype vec_length, + SUNContext sunctx) -> std::shared_ptr> + { + auto lambda_result = N_VNew_Serial(vec_length, sunctx); + + return our_make_shared, N_VectorDeleter>( + lambda_result); + }; + + return N_VNew_Serial_adapt_return_type_to_shared_ptr(vec_length, sunctx); + }, + nb::arg("vec_length"), nb::arg("sunctx"), "nb::keep_alive<0, 2>()", + nb::keep_alive<0, 2>()); + +m.def( + "N_VMake_Serial", + [](sunindextype vec_length, sundials4py::Array1d v_data_1d, + SUNContext sunctx) -> std::shared_ptr> + { + auto N_VMake_Serial_adapt_arr_ptr_to_std_vector = + [](sunindextype vec_length, sundials4py::Array1d v_data_1d, + SUNContext sunctx) -> N_Vector + { + sunrealtype* v_data_1d_ptr = + reinterpret_cast(v_data_1d.data()); + + auto lambda_result = N_VMake_Serial(vec_length, v_data_1d_ptr, sunctx); + return lambda_result; + }; + auto N_VMake_Serial_adapt_return_type_to_shared_ptr = + [&N_VMake_Serial_adapt_arr_ptr_to_std_vector](sunindextype vec_length, + sundials4py::Array1d v_data_1d, + SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = + N_VMake_Serial_adapt_arr_ptr_to_std_vector(vec_length, v_data_1d, sunctx); + + return our_make_shared, N_VectorDeleter>( + lambda_result); + }; + + return N_VMake_Serial_adapt_return_type_to_shared_ptr(vec_length, v_data_1d, + sunctx); + }, + nb::arg("vec_length"), nb::arg("v_data_1d"), nb::arg("sunctx"), + "nb::keep_alive<0, 3>()", nb::keep_alive<0, 3>()); + +m.def("N_VEnableFusedOps_Serial", N_VEnableFusedOps_Serial, nb::arg("v"), + nb::arg("tf")); + +m.def("N_VEnableLinearCombination_Serial", N_VEnableLinearCombination_Serial, + nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableScaleAddMulti_Serial", N_VEnableScaleAddMulti_Serial, + nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableDotProdMulti_Serial", N_VEnableDotProdMulti_Serial, + nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableLinearSumVectorArray_Serial", + N_VEnableLinearSumVectorArray_Serial, nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableScaleVectorArray_Serial", N_VEnableScaleVectorArray_Serial, + nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableConstVectorArray_Serial", N_VEnableConstVectorArray_Serial, + nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableWrmsNormVectorArray_Serial", + N_VEnableWrmsNormVectorArray_Serial, nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableWrmsNormMaskVectorArray_Serial", + N_VEnableWrmsNormMaskVectorArray_Serial, nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableScaleAddMultiVectorArray_Serial", + N_VEnableScaleAddMultiVectorArray_Serial, nb::arg("v"), nb::arg("tf")); + +m.def("N_VEnableLinearCombinationVectorArray_Serial", + N_VEnableLinearCombinationVectorArray_Serial, nb::arg("v"), nb::arg("tf")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sunadaptcontroller/generate.yaml b/bindings/sundials4py/sunadaptcontroller/generate.yaml new file mode 100644 index 0000000000..f785f09535 --- /dev/null +++ b/bindings/sundials4py/sunadaptcontroller/generate.yaml @@ -0,0 +1,47 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2002-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This YAML file is used to instruct the generate.py script on +# how to generate some of the nanobind code needed for the +# core module of sundials4py. +# ----------------------------------------------------------------- + +modules: + all: + macro_define_include_by_name__regex: + - "^SUN_" + fn_exclude_by_name__regex: + # Don't interface the implementation specific overrides of the generic routines + - "^SUNAdaptController_GetType.*" + - "^SUNAdaptController_EstimateStep.*" + - "^SUNAdaptController_EstimateStepTol.*" + - "^SUNAdaptController_Reset.*" + - "^SUNAdaptController_SetOptions.*" + - "^SUNAdaptController_SetDefaults.*" + - "^SUNAdaptController_Write.*" + - "^SUNAdaptController_SetErrorBias.*" + - "^SUNAdaptController_UpdateH.*" + - "^SUNAdaptController_UpdateMRIHTol.*" + - "^SUNAdaptController_Space.*" + sunadaptcontroller_imexgus: + path: sunadaptcontroller/sunadaptcontroller_imexgus_generated.hpp + headers: + - ../../include/sunadaptcontroller/sunadaptcontroller_imexgus.h + sunadaptcontroller_mrihtol: + path: sunadaptcontroller/sunadaptcontroller_mrihtol_generated.hpp + headers: + - ../../include/sunadaptcontroller/sunadaptcontroller_mrihtol.h + sunadaptcontroller_soderlind: + path: sunadaptcontroller/sunadaptcontroller_soderlind_generated.hpp + headers: + - ../../include/sunadaptcontroller/sunadaptcontroller_soderlind.h diff --git a/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_imexgus.cpp b/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_imexgus.cpp new file mode 100644 index 0000000000..c7f4b99c28 --- /dev/null +++ b/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_imexgus.cpp @@ -0,0 +1,34 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunadaptcontroller_imexgus(nb::module_& m) +{ +#include "sunadaptcontroller_imexgus_generated.hpp" +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_imexgus_generated.hpp b/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_imexgus_generated.hpp new file mode 100644 index 0000000000..633aba299a --- /dev/null +++ b/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_imexgus_generated.hpp @@ -0,0 +1,38 @@ +// #ifndef _SUNADAPTCONTROLLER_IMEXGUS_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClass_SUNAdaptControllerContent_ImExGus = + nb::class_<_SUNAdaptControllerContent_ImExGus>(m, "_SUNAdaptControllerContent_ImExGus", + "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNAdaptController_ImExGus", + [](SUNContext sunctx) -> std::shared_ptr> + { + auto SUNAdaptController_ImExGus_adapt_return_type_to_shared_ptr = + [](SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNAdaptController_ImExGus(sunctx); + + return our_make_shared, + SUNAdaptControllerDeleter>(lambda_result); + }; + + return SUNAdaptController_ImExGus_adapt_return_type_to_shared_ptr(sunctx); + }, + nb::arg("sunctx"), "nb::keep_alive<0, 1>()", nb::keep_alive<0, 1>()); + +m.def("SUNAdaptController_SetParams_ImExGus", + SUNAdaptController_SetParams_ImExGus, nb::arg("C"), nb::arg("k1e"), + nb::arg("k2e"), nb::arg("k1i"), nb::arg("k2i")); +// #ifdef __cplusplus +// +// #endif +// +// #endif diff --git a/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_mrihtol.cpp b/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_mrihtol.cpp new file mode 100644 index 0000000000..ddbde500ac --- /dev/null +++ b/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_mrihtol.cpp @@ -0,0 +1,34 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunadaptcontroller_mrihtol(nb::module_& m) +{ +#include "sunadaptcontroller_mrihtol_generated.hpp" +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_mrihtol_generated.hpp b/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_mrihtol_generated.hpp new file mode 100644 index 0000000000..0783fdd933 --- /dev/null +++ b/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_mrihtol_generated.hpp @@ -0,0 +1,83 @@ +// #ifndef _SUNADAPTCONTROLLER_MRIHTOL_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClassSUNAdaptControllerContent_MRIHTol_ = + nb::class_(m, "SUNAdaptControllerContent_MRIHTol_", + "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNAdaptController_MRIHTol", + [](SUNAdaptController HControl, SUNAdaptController TolControl, + SUNContext sunctx) -> std::shared_ptr> + { + auto SUNAdaptController_MRIHTol_adapt_return_type_to_shared_ptr = + [](SUNAdaptController HControl, SUNAdaptController TolControl, + SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNAdaptController_MRIHTol(HControl, TolControl, + sunctx); + + return our_make_shared, + SUNAdaptControllerDeleter>(lambda_result); + }; + + return SUNAdaptController_MRIHTol_adapt_return_type_to_shared_ptr(HControl, + TolControl, + sunctx); + }, + nb::arg("HControl"), nb::arg("TolControl"), nb::arg("sunctx"), + "nb::keep_alive<0, 3>()", nb::keep_alive<0, 3>()); + +m.def("SUNAdaptController_SetParams_MRIHTol", + SUNAdaptController_SetParams_MRIHTol, nb::arg("C"), + nb::arg("inner_max_relch"), nb::arg("inner_min_tolfac"), + nb::arg("inner_max_tolfac")); + +m.def( + "SUNAdaptController_GetSlowController_MRIHTol", + [](SUNAdaptController C) -> std::tuple + { + auto SUNAdaptController_GetSlowController_MRIHTol_adapt_modifiable_immutable_to_return = + [](SUNAdaptController C) -> std::tuple + { + SUNAdaptController Cslow_adapt_modifiable; + + SUNErrCode r = + SUNAdaptController_GetSlowController_MRIHTol(C, &Cslow_adapt_modifiable); + return std::make_tuple(r, Cslow_adapt_modifiable); + }; + + return SUNAdaptController_GetSlowController_MRIHTol_adapt_modifiable_immutable_to_return( + C); + }, + nb::arg("C"), "nb::rv_policy::reference", nb::rv_policy::reference); + +m.def( + "SUNAdaptController_GetFastController_MRIHTol", + [](SUNAdaptController C) -> std::tuple + { + auto SUNAdaptController_GetFastController_MRIHTol_adapt_modifiable_immutable_to_return = + [](SUNAdaptController C) -> std::tuple + { + SUNAdaptController Cfast_adapt_modifiable; + + SUNErrCode r = + SUNAdaptController_GetFastController_MRIHTol(C, &Cfast_adapt_modifiable); + return std::make_tuple(r, Cfast_adapt_modifiable); + }; + + return SUNAdaptController_GetFastController_MRIHTol_adapt_modifiable_immutable_to_return( + C); + }, + nb::arg("C"), "nb::rv_policy::reference", nb::rv_policy::reference); +// #ifdef __cplusplus +// +// #endif +// +// #endif diff --git a/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_soderlind.cpp b/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_soderlind.cpp new file mode 100644 index 0000000000..a1d2f9484d --- /dev/null +++ b/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_soderlind.cpp @@ -0,0 +1,34 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunadaptcontroller_soderlind(nb::module_& m) +{ +#include "sunadaptcontroller_soderlind_generated.hpp" +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_soderlind_generated.hpp b/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_soderlind_generated.hpp new file mode 100644 index 0000000000..af8206c319 --- /dev/null +++ b/bindings/sundials4py/sunadaptcontroller/sunadaptcontroller_soderlind_generated.hpp @@ -0,0 +1,215 @@ +// #ifndef _SUNADAPTCONTROLLER_SODERLIND_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClass_SUNAdaptControllerContent_Soderlind = + nb::class_<_SUNAdaptControllerContent_Soderlind>(m, "_SUNAdaptControllerContent_Soderlind", + "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNAdaptController_Soderlind", + [](SUNContext sunctx) -> std::shared_ptr> + { + auto SUNAdaptController_Soderlind_adapt_return_type_to_shared_ptr = + [](SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNAdaptController_Soderlind(sunctx); + + return our_make_shared, + SUNAdaptControllerDeleter>(lambda_result); + }; + + return SUNAdaptController_Soderlind_adapt_return_type_to_shared_ptr(sunctx); + }, + nb::arg("sunctx"), "nb::keep_alive<0, 1>()", nb::keep_alive<0, 1>()); + +m.def("SUNAdaptController_SetParams_Soderlind", + SUNAdaptController_SetParams_Soderlind, nb::arg("C"), nb::arg("k1"), + nb::arg("k2"), nb::arg("k3"), nb::arg("k4"), nb::arg("k5")); + +m.def( + "SUNAdaptController_PID", + [](SUNContext sunctx) -> std::shared_ptr> + { + auto SUNAdaptController_PID_adapt_return_type_to_shared_ptr = + [](SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNAdaptController_PID(sunctx); + + return our_make_shared, + SUNAdaptControllerDeleter>(lambda_result); + }; + + return SUNAdaptController_PID_adapt_return_type_to_shared_ptr(sunctx); + }, + nb::arg("sunctx"), "nb::keep_alive<0, 1>()", nb::keep_alive<0, 1>()); + +m.def("SUNAdaptController_SetParams_PID", SUNAdaptController_SetParams_PID, + nb::arg("C"), nb::arg("k1"), nb::arg("k2"), nb::arg("k3")); + +m.def( + "SUNAdaptController_PI", + [](SUNContext sunctx) -> std::shared_ptr> + { + auto SUNAdaptController_PI_adapt_return_type_to_shared_ptr = + [](SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNAdaptController_PI(sunctx); + + return our_make_shared, + SUNAdaptControllerDeleter>(lambda_result); + }; + + return SUNAdaptController_PI_adapt_return_type_to_shared_ptr(sunctx); + }, + nb::arg("sunctx"), "nb::keep_alive<0, 1>()", nb::keep_alive<0, 1>()); + +m.def("SUNAdaptController_SetParams_PI", SUNAdaptController_SetParams_PI, + nb::arg("C"), nb::arg("k1"), nb::arg("k2")); + +m.def( + "SUNAdaptController_I", + [](SUNContext sunctx) -> std::shared_ptr> + { + auto SUNAdaptController_I_adapt_return_type_to_shared_ptr = + [](SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNAdaptController_I(sunctx); + + return our_make_shared, + SUNAdaptControllerDeleter>(lambda_result); + }; + + return SUNAdaptController_I_adapt_return_type_to_shared_ptr(sunctx); + }, + nb::arg("sunctx"), "nb::keep_alive<0, 1>()", nb::keep_alive<0, 1>()); + +m.def("SUNAdaptController_SetParams_I", SUNAdaptController_SetParams_I, + nb::arg("C"), nb::arg("k1")); + +m.def( + "SUNAdaptController_ExpGus", + [](SUNContext sunctx) -> std::shared_ptr> + { + auto SUNAdaptController_ExpGus_adapt_return_type_to_shared_ptr = + [](SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNAdaptController_ExpGus(sunctx); + + return our_make_shared, + SUNAdaptControllerDeleter>(lambda_result); + }; + + return SUNAdaptController_ExpGus_adapt_return_type_to_shared_ptr(sunctx); + }, + nb::arg("sunctx"), "nb::keep_alive<0, 1>()", nb::keep_alive<0, 1>()); + +m.def("SUNAdaptController_SetParams_ExpGus", SUNAdaptController_SetParams_ExpGus, + nb::arg("C"), nb::arg("k1"), nb::arg("k2")); + +m.def( + "SUNAdaptController_ImpGus", + [](SUNContext sunctx) -> std::shared_ptr> + { + auto SUNAdaptController_ImpGus_adapt_return_type_to_shared_ptr = + [](SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNAdaptController_ImpGus(sunctx); + + return our_make_shared, + SUNAdaptControllerDeleter>(lambda_result); + }; + + return SUNAdaptController_ImpGus_adapt_return_type_to_shared_ptr(sunctx); + }, + nb::arg("sunctx"), "nb::keep_alive<0, 1>()", nb::keep_alive<0, 1>()); + +m.def("SUNAdaptController_SetParams_ImpGus", SUNAdaptController_SetParams_ImpGus, + nb::arg("C"), nb::arg("k1"), nb::arg("k2")); + +m.def( + "SUNAdaptController_H0211", + [](SUNContext sunctx) -> std::shared_ptr> + { + auto SUNAdaptController_H0211_adapt_return_type_to_shared_ptr = + [](SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNAdaptController_H0211(sunctx); + + return our_make_shared, + SUNAdaptControllerDeleter>(lambda_result); + }; + + return SUNAdaptController_H0211_adapt_return_type_to_shared_ptr(sunctx); + }, + nb::arg("sunctx"), "nb::keep_alive<0, 1>()", nb::keep_alive<0, 1>()); + +m.def( + "SUNAdaptController_H0321", + [](SUNContext sunctx) -> std::shared_ptr> + { + auto SUNAdaptController_H0321_adapt_return_type_to_shared_ptr = + [](SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNAdaptController_H0321(sunctx); + + return our_make_shared, + SUNAdaptControllerDeleter>(lambda_result); + }; + + return SUNAdaptController_H0321_adapt_return_type_to_shared_ptr(sunctx); + }, + nb::arg("sunctx"), "nb::keep_alive<0, 1>()", nb::keep_alive<0, 1>()); + +m.def( + "SUNAdaptController_H211", + [](SUNContext sunctx) -> std::shared_ptr> + { + auto SUNAdaptController_H211_adapt_return_type_to_shared_ptr = + [](SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNAdaptController_H211(sunctx); + + return our_make_shared, + SUNAdaptControllerDeleter>(lambda_result); + }; + + return SUNAdaptController_H211_adapt_return_type_to_shared_ptr(sunctx); + }, + nb::arg("sunctx"), "nb::keep_alive<0, 1>()", nb::keep_alive<0, 1>()); + +m.def( + "SUNAdaptController_H312", + [](SUNContext sunctx) -> std::shared_ptr> + { + auto SUNAdaptController_H312_adapt_return_type_to_shared_ptr = + [](SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNAdaptController_H312(sunctx); + + return our_make_shared, + SUNAdaptControllerDeleter>(lambda_result); + }; + + return SUNAdaptController_H312_adapt_return_type_to_shared_ptr(sunctx); + }, + nb::arg("sunctx"), "nb::keep_alive<0, 1>()", nb::keep_alive<0, 1>()); +// #ifdef __cplusplus +// +// #endif +// +// #endif diff --git a/bindings/sundials4py/sunadjointcheckpointscheme/generate.yaml b/bindings/sundials4py/sunadjointcheckpointscheme/generate.yaml new file mode 100644 index 0000000000..fe6e32ac43 --- /dev/null +++ b/bindings/sundials4py/sunadjointcheckpointscheme/generate.yaml @@ -0,0 +1,40 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2002-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This YAML file is used to instruct the generate.py script on +# how to generate some of the nanobind code needed for the +# core module of sundials4py. +# ----------------------------------------------------------------- + +modules: + all: + macro_define_include_by_name__regex: + - "^SUN_" + fn_exclude_by_name__regex: + # Don't interface the implementation specific overrides of the generic routines + - "^SUNAdaptController_GetType.*" + - "^SUNAdjointCheckpointScheme_NeedsSaving_.*" + - "^SUNAdjointCheckpointScheme_InsertVector_.*" + - "^SUNAdjointCheckpointScheme_LoadVector_.*" + - "^SUNAdjointCheckpointScheme_Destroy_.*" + - "^SUNAdjointCheckpointScheme_EnableDense_.*" + sunadjointcheckpointscheme_fixed: + path: sunadjointcheckpointscheme/sunadjointcheckpointscheme_fixed_generated.hpp + headers: + - ../../include/sunadjointcheckpointscheme/sunadjointcheckpointscheme_fixed.h + fn_exclude_by_name__regex: + # SUNAdjointCheckpointScheme_Create_Fixed takes both a SUNMemoryHelper and SUNContext + # which must be kept alive as long as the SUNAdjointCheckpointScheme object. + # Currently, our adapt_sundials_type_returns.py litgen extension cannot handle cases + # where there are two patients in the Nurse,Patient scheme. So we wrap this one manually. + - "^SUNAdjointCheckpointScheme_Create_Fixed$" diff --git a/bindings/sundials4py/sunadjointcheckpointscheme/sunadjointcheckpointscheme_fixed.cpp b/bindings/sundials4py/sunadjointcheckpointscheme/sunadjointcheckpointscheme_fixed.cpp new file mode 100644 index 0000000000..7395e36f43 --- /dev/null +++ b/bindings/sundials4py/sunadjointcheckpointscheme/sunadjointcheckpointscheme_fixed.cpp @@ -0,0 +1,63 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include +#include + +#include "sundials/sundials_adjointcheckpointscheme.h" +#include "sundials_adjointcheckpointscheme_impl.h" + +#include "sundials4py_helpers.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunadjointcheckpointscheme_fixed(nb::module_& m) +{ +#include "sunadjointcheckpointscheme_fixed_generated.hpp" + + m.def( + "SUNAdjointCheckpointScheme_Create_Fixed", + [](SUNDataIOMode io_mode, SUNMemoryHelper mem_helper, suncountertype interval, + suncountertype estimate, sunbooleantype keep, SUNContext sunctx) + -> std::tuple>> + { + SUNAdjointCheckpointScheme check_scheme; + SUNErrCode err = + SUNAdjointCheckpointScheme_Create_Fixed(io_mode, mem_helper, interval, + estimate, keep, sunctx, + &check_scheme); + return std::make_tuple(err, + our_make_shared< + std::remove_pointer_t, + SUNAdjointCheckpointSchemeDeleter>(check_scheme)); + }, + nb::arg("io_mode"), nb::arg("mem_helper"), nb::arg("interval"), + nb::arg("estimate"), nb::arg("keep"), nb::arg("sunctx"), + nb::call_policy>() /* keep SUNMemoryHelper alive as long as SUNAdjointCheckpointScheme is alive */, + nb::call_policy>() /* keep SUNContext alive as long as SUNAdjointCheckpointScheme is alive */); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sunadjointcheckpointscheme/sunadjointcheckpointscheme_fixed_generated.hpp b/bindings/sundials4py/sunadjointcheckpointscheme/sunadjointcheckpointscheme_fixed_generated.hpp new file mode 100644 index 0000000000..b4f8705998 --- /dev/null +++ b/bindings/sundials4py/sunadjointcheckpointscheme/sunadjointcheckpointscheme_fixed_generated.hpp @@ -0,0 +1,10 @@ +// #ifndef _SUNADJOINTCHECKPOINTSCHEME_FIXED_H +// +// #ifdef __cplusplus +// #endif +// +// #ifdef __cplusplus +// +// #endif +// +// #endif diff --git a/bindings/sundials4py/sundials/generate.yaml b/bindings/sundials4py/sundials/generate.yaml new file mode 100644 index 0000000000..3474e2127b --- /dev/null +++ b/bindings/sundials4py/sundials/generate.yaml @@ -0,0 +1,154 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2002-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This YAML file is used to instruct the generate.py script on +# how to generate some of the nanobind code needed for the +# core module of sundials4py. +# ----------------------------------------------------------------- + +modules: + all: + fn_exclude_by_name__regex: + - "Free" # Free and destroy functions should not need to be called as all objects on the Python side are RAII objects + - "Destroy" + - "NewEmpty" # No need to call NewEmpty functions in Python + - "CopyOps" # No need to call CopyOps functions in Python + - "Space" # Space functions are deprecated, so dont expose them in Python + # Due to the need to convert between sys.argv and C argv, we need to do custom wrappers of these + - "SetOptions" + macro_define_include_by_name__regex: + - "^SUN_" + sundials_adaptcontroller: + path: sundials/sundials_adaptcontroller_generated.hpp + headers: + - ../../include/sundials/sundials_adaptcontroller.h + sundials_adjointcheckpointscheme: + path: sundials/sundials_adjointcheckpointscheme_generated.hpp + headers: + - ../../include/sundials/sundials_adjointcheckpointscheme.h + fn_exclude_by_name__regex: + - "Set.*Fn" # nanobind cannot bind to functions which take a function pointer, so we do something custom + - "SetContent" # no need to set content from Python, so we don't interface it + - "GetContent" # content should not be accessible from Python, so we don't interface it + sundials_adjointstepper: + path: sundials/sundials_adjointstepper_generated.hpp + headers: + - ../../include/sundials/sundials_adjointstepper.h + sundials_context: + path: sundials/sundials_context_generated.hpp + headers: + - ../../include/sundials/sundials_context.h + fn_exclude_by_name__regex: + - "^SUNContext_PushErrHandler$" + - "^SUNContext_PopErrHandler$" + sundials_domeigestimator: + path: sundials/sundials_domeigestimator_generated.hpp + headers: + - ../../include/sundials/sundials_domeigestimator.h + fn_exclude_by_name__regex: + - "SUNDomEigEstimator_SetATimes" # nanobind cannot bind to functions which take a function pointer, so we do something custom + sundials_errors: + path: sundials/sundials_errors_generated.hpp + headers: + - ../../include/sundials/sundials_errors.h + enum_exclude_by_name__regex: + - "SUNErrCode_" # not needed in Python + sundials_linearsolver: + path: sundials/sundials_linearsolver_generated.hpp + headers: + - ../../include/sundials/sundials_iterative.h + - ../../include/sundials/sundials_linearsolver.h + fn_exclude_by_name__regex: + # SUNLinSolSolve has an optional argument, A, followed by non-optional. Litgen can't handle this yet, so we just manually wrap it. + - "^SUNLinSolSolve$" + # need to do custom handling of functions which take a function pointer + - "^SUNLinSolSetATimes$" + - "^SUNLinSolSetPreconditioner$" + fn_params_optional_with_default_null: + "SUNLinSolSetup": + - "A" + sundials_logger: + path: sundials/sundials_logger_generated.hpp + headers: + - ../../include/sundials/sundials_logger.h + sundials_matrix: + path: sundials/sundials_matrix_generated.hpp + headers: + - ../../include/sundials/sundials_matrix.h + sundials_memory: + path: sundials/sundials_memory_generated.hpp + headers: + - ../../include/sundials/sundials_memory.h + fn_exclude_by_name__regex: + # these functions are not needed in a Python code, so we don't interface them + - "Alias" + - "Wrap" + - "Alloc*" + - "Dealloc" + - "Copy*" + class_exclude_by_name__regex: + - "SUNMemory_" # not needed in Python + sundials_nonlinearsolver: + path: sundials/sundials_nonlinearsolver_generated.hpp + headers: + - ../../include/sundials/sundials_nonlinearsolver.h + fn_exclude_by_name__regex: + # nanobind cannot bind to functions with the nullable void* mem argument through std::optional + # (which is what litgen produces for these) so we have to instead bind to it manually + - "^SUNNonlinSolSetup$" + - "^SUNNonlinSolSolve$" + # nanobind cannot bind to functions which take a function pointer. + # Furthermore, several of the callback functions for the SUNNonlinearSolver module + # do not pass back user_data, instead they pass back the integrator memory. This makes + # it to where we cannot access our python function table stored in user_data in generic way + # (you'd have to call the integrator's GetUserData function, but we don't know which integrator is being used). + - "Set.*Fn" + sundials_nvector: + path: sundials/sundials_nvector_generated.hpp + headers: + - ../../include/sundials/sundials_nvector.h + fn_exclude_by_name__regex: + # we have to have custom handling for the following to get the proper dimensions set for the numpy array + - "^N_VGetArrayPointer$" + - "^N_VSetArrayPointer$" + - "^N_VGetDeviceArrayPointer$" + # these operations for supporting xbraid are not yet available via Python, we will have to evaluate if they are useful/needed + - "^N_VBufPack$" + - "^N_VBufUnpack$" + - "^N_VBufSize$" + # vector arrays do not need to be allocated through these in Python, so we do not interface them + - "^N_VNewVectorArray$" + - "^N_VCloneVectorArray$" + - "^N_VCloneEmptyVectorArray$" + - "^N_VGetVecAtIndexVectorArray$" + - "^N_VSetVecAtIndexVectorArray$" + # ** parameters are not yet supported by litgen yet, so we do something custom + - "^N_VLinearCombinationVectorArray$" + - "^N_VScaleAddMultiVectorArray$" + sundials_profiler: + path: sundials/sundials_profiler_generated.hpp + headers: + - ../../include/sundials/sundials_profiler.h + sundials_stepper: + path: sundials/sundials_stepper_generated.hpp + headers: + - ../../include/sundials/sundials_stepper.h + fn_exclude_by_name__regex: + - "Set.*Fn" # nanobind cannot bind to functions which take a function pointer, so we do something custom + # we dont allow access to content from Python + - "SetContent" + - "GetContent" + sundials_types: + path: sundials/sundials_types_generated.hpp + headers: + - ../../include/sundials/sundials_types.h diff --git a/bindings/sundials4py/sundials/sundials_adaptcontroller.cpp b/bindings/sundials4py/sundials/sundials_adaptcontroller.cpp new file mode 100644 index 0000000000..c0d691d0f7 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_adaptcontroller.cpp @@ -0,0 +1,62 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------- + * This file is the entrypoint for the Python binding code for the + * SUNDIALS SUNAdaptController class. It contains hand-written code for + * functions that require special treatment, and includes the generated + * code produced with the generate.py script. + * -----------------------------------------------------------------*/ + +#include "sundials/sundials_adaptcontroller.h" +#include "sundials4py.hpp" + +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunadaptcontroller(nb::module_& m) +{ +#include "sundials_adaptcontroller_generated.hpp" + + m.def( + "SUNAdaptController_SetOptions", + [](SUNAdaptController self, const std::string& id, + const std::string& file_name, int argc, + const std::vector& args) + { + std::vector argv; + + for (const auto& arg : args) + { + // We need a non-const char*, so we use data() and an explicit cast. + // This is safe as long as the underlying std::string is not modified. + argv.push_back(const_cast(arg.data())); + } + + return SUNAdaptController_SetOptions(self, + id.empty() ? nullptr : id.c_str(), + file_name.empty() ? nullptr + : file_name.c_str(), + argc, argv.data()); + }, + nb::arg("self"), nb::arg("id"), nb::arg("file_name"), nb::arg("argc"), + nb::arg("args")); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sundials/sundials_adaptcontroller_generated.hpp b/bindings/sundials4py/sundials/sundials_adaptcontroller_generated.hpp new file mode 100644 index 0000000000..ad1c4d9b92 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_adaptcontroller_generated.hpp @@ -0,0 +1,109 @@ +// #ifndef _SUNDIALS_ADAPTCONTROLLER_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyEnumSUNAdaptController_Type = + nb::enum_(m, "SUNAdaptController_Type", + nb::is_arithmetic(), "") + .value("SUN_ADAPTCONTROLLER_NONE", SUN_ADAPTCONTROLLER_NONE, "") + .value("SUN_ADAPTCONTROLLER_H", SUN_ADAPTCONTROLLER_H, "") + .value("SUN_ADAPTCONTROLLER_MRI_H_TOL", SUN_ADAPTCONTROLLER_MRI_H_TOL, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +auto pyClass_generic_SUNAdaptController_Ops = + nb::class_<_generic_SUNAdaptController_Ops>(m, + "_generic_SUNAdaptController_Ops", "Structure containing function pointers to controller operations") + .def(nb::init<>()) // implicit default constructor + ; + +auto pyClass_generic_SUNAdaptController = + nb::class_<_generic_SUNAdaptController>(m, + "_generic_SUNAdaptController", " A SUNAdaptController is a structure with an implementation-dependent\n 'content' field, and a pointer to a structure of\n operations corresponding to that implementation.") + .def(nb::init<>()) // implicit default constructor + ; + +m.def("SUNAdaptController_GetType", SUNAdaptController_GetType, nb::arg("C")); + +m.def( + "SUNAdaptController_EstimateStep", + [](SUNAdaptController C, sunrealtype h, int p, + sunrealtype dsm) -> std::tuple + { + auto SUNAdaptController_EstimateStep_adapt_modifiable_immutable_to_return = + [](SUNAdaptController C, sunrealtype h, int p, + sunrealtype dsm) -> std::tuple + { + sunrealtype hnew_adapt_modifiable; + + SUNErrCode r = SUNAdaptController_EstimateStep(C, h, p, dsm, + &hnew_adapt_modifiable); + return std::make_tuple(r, hnew_adapt_modifiable); + }; + + return SUNAdaptController_EstimateStep_adapt_modifiable_immutable_to_return(C, + h, + p, + dsm); + }, + nb::arg("C"), nb::arg("h"), nb::arg("p"), nb::arg("dsm")); + +m.def( + "SUNAdaptController_EstimateStepTol", + [](SUNAdaptController C, sunrealtype H, sunrealtype tolfac, int P, + sunrealtype DSM, + sunrealtype dsm) -> std::tuple + { + auto SUNAdaptController_EstimateStepTol_adapt_modifiable_immutable_to_return = + [](SUNAdaptController C, sunrealtype H, sunrealtype tolfac, int P, + sunrealtype DSM, + sunrealtype dsm) -> std::tuple + { + sunrealtype Hnew_adapt_modifiable; + sunrealtype tolfacnew_adapt_modifiable; + + SUNErrCode r = + SUNAdaptController_EstimateStepTol(C, H, tolfac, P, DSM, dsm, + &Hnew_adapt_modifiable, + &tolfacnew_adapt_modifiable); + return std::make_tuple(r, Hnew_adapt_modifiable, + tolfacnew_adapt_modifiable); + }; + + return SUNAdaptController_EstimateStepTol_adapt_modifiable_immutable_to_return(C, + H, + tolfac, + P, + DSM, + dsm); + }, + nb::arg("C"), nb::arg("H"), nb::arg("tolfac"), nb::arg("P"), nb::arg("DSM"), + nb::arg("dsm")); + +m.def("SUNAdaptController_Reset", SUNAdaptController_Reset, nb::arg("C")); + +m.def("SUNAdaptController_SetDefaults", SUNAdaptController_SetDefaults, + nb::arg("C")); + +m.def("SUNAdaptController_Write", SUNAdaptController_Write, nb::arg("C"), + nb::arg("fptr")); + +m.def("SUNAdaptController_SetErrorBias", SUNAdaptController_SetErrorBias, + nb::arg("C"), nb::arg("bias")); + +m.def("SUNAdaptController_UpdateH", SUNAdaptController_UpdateH, nb::arg("C"), + nb::arg("h"), nb::arg("dsm")); + +m.def("SUNAdaptController_UpdateMRIHTol", SUNAdaptController_UpdateMRIHTol, + nb::arg("C"), nb::arg("H"), nb::arg("tolfac"), nb::arg("DSM"), + nb::arg("dsm")); +// #ifdef __cplusplus +// +// #endif +// +// #endif diff --git a/bindings/sundials4py/sundials/sundials_adjointcheckpointscheme.cpp b/bindings/sundials4py/sundials/sundials_adjointcheckpointscheme.cpp new file mode 100644 index 0000000000..1994ebcf80 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_adjointcheckpointscheme.cpp @@ -0,0 +1,44 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------- + * This file is the entrypoint for the Python binding code for the + * SUNDIALS SUNAdjointCheckpointScheme class. It contains hand-written + * code for functions that require special treatment, and includes the + * generated code produced with the generate.py script. + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +#include "sundials_adjointcheckpointscheme_impl.h" + +namespace nb = nanobind; +using namespace sundials::experimental; + +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunadjointcheckpointscheme(nb::module_& m) +{ +#include "sundials_adjointcheckpointscheme_generated.hpp" + + nb::class_(m, "SUNAdjointCheckpointScheme_"); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sundials/sundials_adjointcheckpointscheme_generated.hpp b/bindings/sundials4py/sundials/sundials_adjointcheckpointscheme_generated.hpp new file mode 100644 index 0000000000..7471762754 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_adjointcheckpointscheme_generated.hpp @@ -0,0 +1,74 @@ +// #ifndef _SUNADJOINT_CHECKPOINTSCHEME_H +// +// #ifdef __cplusplus +// #endif +// + +m.def( + "SUNAdjointCheckpointScheme_NeedsSaving", + [](SUNAdjointCheckpointScheme check_scheme, suncountertype step_num, + suncountertype stage_num, + sunrealtype t) -> std::tuple + { + auto SUNAdjointCheckpointScheme_NeedsSaving_adapt_modifiable_immutable_to_return = + [](SUNAdjointCheckpointScheme check_scheme, suncountertype step_num, + suncountertype stage_num, + sunrealtype t) -> std::tuple + { + sunbooleantype yes_or_no_adapt_modifiable; + + SUNErrCode r = + SUNAdjointCheckpointScheme_NeedsSaving(check_scheme, step_num, stage_num, + t, &yes_or_no_adapt_modifiable); + return std::make_tuple(r, yes_or_no_adapt_modifiable); + }; + + return SUNAdjointCheckpointScheme_NeedsSaving_adapt_modifiable_immutable_to_return(check_scheme, + step_num, + stage_num, + t); + }, + nb::arg("check_scheme"), nb::arg("step_num"), nb::arg("stage_num"), + nb::arg("t")); + +m.def("SUNAdjointCheckpointScheme_InsertVector", + SUNAdjointCheckpointScheme_InsertVector, nb::arg("check_scheme"), + nb::arg("step_num"), nb::arg("stage_num"), nb::arg("t"), nb::arg("state")); + +m.def( + "SUNAdjointCheckpointScheme_LoadVector", + [](SUNAdjointCheckpointScheme check_scheme, suncountertype step_num, + suncountertype stage_num, + sunbooleantype peek) -> std::tuple + { + auto SUNAdjointCheckpointScheme_LoadVector_adapt_modifiable_immutable_to_return = + [](SUNAdjointCheckpointScheme check_scheme, suncountertype step_num, + suncountertype stage_num, + sunbooleantype peek) -> std::tuple + { + N_Vector out_adapt_modifiable; + sunrealtype tout_adapt_modifiable; + + SUNErrCode r = + SUNAdjointCheckpointScheme_LoadVector(check_scheme, step_num, stage_num, + peek, &out_adapt_modifiable, + &tout_adapt_modifiable); + return std::make_tuple(r, out_adapt_modifiable, tout_adapt_modifiable); + }; + + return SUNAdjointCheckpointScheme_LoadVector_adapt_modifiable_immutable_to_return(check_scheme, + step_num, + stage_num, + peek); + }, + nb::arg("check_scheme"), nb::arg("step_num"), nb::arg("stage_num"), + nb::arg("peek"), "nb::rv_policy::reference", nb::rv_policy::reference); + +m.def("SUNAdjointCheckpointScheme_EnableDense", + SUNAdjointCheckpointScheme_EnableDense, nb::arg("check_scheme"), + nb::arg("on_or_off")); +// #ifdef __cplusplus +// +// #endif +// +// #endif diff --git a/bindings/sundials4py/sundials/sundials_adjointstepper.cpp b/bindings/sundials4py/sundials/sundials_adjointstepper.cpp new file mode 100644 index 0000000000..00a9522093 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_adjointstepper.cpp @@ -0,0 +1,44 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------- + * This file is the entrypoint for the Python binding code for the + * SUNDIALS SUNAdjointStepper class. It contains hand-written code for + * functions that require special treatment, and includes the generated + * code produced with the generate.py script. + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +#include "sundials_adjointcheckpointscheme_impl.h" +#include "sundials_adjointstepper_impl.h" +#include "sundials_stepper_impl.h" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunadjointstepper(nb::module_& m) +{ +#include "sundials_adjointstepper_generated.hpp" + + nb::class_(m, "SUNAdjointStepper_"); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sundials/sundials_adjointstepper_generated.hpp b/bindings/sundials4py/sundials/sundials_adjointstepper_generated.hpp new file mode 100644 index 0000000000..2effeb7c66 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_adjointstepper_generated.hpp @@ -0,0 +1,179 @@ +// #ifndef _SUNADJOINT_STEPPER_H +// +// #ifdef __cplusplus +// +// #endif +// + +m.def( + "SUNAdjointStepper_Create", + [](SUNStepper fwd_sunstepper, sunbooleantype own_fwd, SUNStepper adj_sunstepper, + sunbooleantype own_adj, suncountertype final_step_idx, sunrealtype tf, + N_Vector sf, SUNAdjointCheckpointScheme checkpoint_scheme, SUNContext sunctx) + -> std::tuple>> + { + auto SUNAdjointStepper_Create_adapt_modifiable_immutable_to_return = + [](SUNStepper fwd_sunstepper, sunbooleantype own_fwd, + SUNStepper adj_sunstepper, sunbooleantype own_adj, + suncountertype final_step_idx, sunrealtype tf, N_Vector sf, + SUNAdjointCheckpointScheme checkpoint_scheme, + SUNContext sunctx) -> std::tuple + { + SUNAdjointStepper adj_stepper_adapt_modifiable; + + SUNErrCode r = SUNAdjointStepper_Create(fwd_sunstepper, own_fwd, + adj_sunstepper, own_adj, + final_step_idx, tf, sf, + checkpoint_scheme, sunctx, + &adj_stepper_adapt_modifiable); + return std::make_tuple(r, adj_stepper_adapt_modifiable); + }; + auto SUNAdjointStepper_Create_adapt_return_type_to_shared_ptr = + [&SUNAdjointStepper_Create_adapt_modifiable_immutable_to_return](SUNStepper fwd_sunstepper, + sunbooleantype + own_fwd, + SUNStepper adj_sunstepper, + sunbooleantype + own_adj, + suncountertype + final_step_idx, + sunrealtype tf, + N_Vector sf, + SUNAdjointCheckpointScheme + checkpoint_scheme, + SUNContext sunctx) + -> std::tuple>> + { + auto lambda_result = + SUNAdjointStepper_Create_adapt_modifiable_immutable_to_return(fwd_sunstepper, + own_fwd, + adj_sunstepper, + own_adj, + final_step_idx, + tf, sf, + checkpoint_scheme, + sunctx); + + return std::make_tuple(std::get<0>(lambda_result), + our_make_shared, + SUNAdjointStepperDeleter>( + std::get<1>(lambda_result))); + }; + + return SUNAdjointStepper_Create_adapt_return_type_to_shared_ptr(fwd_sunstepper, + own_fwd, + adj_sunstepper, + own_adj, + final_step_idx, + tf, sf, + checkpoint_scheme, + sunctx); + }, + nb::arg("fwd_sunstepper"), nb::arg("own_fwd"), nb::arg("adj_sunstepper"), + nb::arg("own_adj"), nb::arg("final_step_idx"), nb::arg("tf"), nb::arg("sf"), + nb::arg("checkpoint_scheme"), nb::arg("sunctx"), + "nb::call_policy>()", + nb::rv_policy::reference, + nb::call_policy>()); + +m.def("SUNAdjointStepper_ReInit", SUNAdjointStepper_ReInit, nb::arg("adj"), + nb::arg("t0"), nb::arg("y0"), nb::arg("tf"), nb::arg("sf")); + +m.def( + "SUNAdjointStepper_Evolve", + [](SUNAdjointStepper adj_stepper, sunrealtype tout, + N_Vector sens) -> std::tuple + { + auto SUNAdjointStepper_Evolve_adapt_modifiable_immutable_to_return = + [](SUNAdjointStepper adj_stepper, sunrealtype tout, + N_Vector sens) -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + SUNErrCode r = SUNAdjointStepper_Evolve(adj_stepper, tout, sens, + &tret_adapt_modifiable); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return SUNAdjointStepper_Evolve_adapt_modifiable_immutable_to_return(adj_stepper, + tout, + sens); + }, + nb::arg("adj_stepper"), nb::arg("tout"), nb::arg("sens")); + +m.def( + "SUNAdjointStepper_OneStep", + [](SUNAdjointStepper adj_stepper, sunrealtype tout, + N_Vector sens) -> std::tuple + { + auto SUNAdjointStepper_OneStep_adapt_modifiable_immutable_to_return = + [](SUNAdjointStepper adj_stepper, sunrealtype tout, + N_Vector sens) -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + SUNErrCode r = SUNAdjointStepper_OneStep(adj_stepper, tout, sens, + &tret_adapt_modifiable); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return SUNAdjointStepper_OneStep_adapt_modifiable_immutable_to_return(adj_stepper, + tout, + sens); + }, + nb::arg("adj_stepper"), nb::arg("tout"), nb::arg("sens")); + +m.def("SUNAdjointStepper_RecomputeFwd", SUNAdjointStepper_RecomputeFwd, + nb::arg("adj_stepper"), nb::arg("start_idx"), nb::arg("t0"), + nb::arg("y0"), nb::arg("tf")); + +m.def("SUNAdjointStepper_SetUserData", SUNAdjointStepper_SetUserData, + nb::arg("param_0"), nb::arg("user_data")); + +m.def( + "SUNAdjointStepper_GetNumSteps", + [](SUNAdjointStepper adj_stepper) -> std::tuple + { + auto SUNAdjointStepper_GetNumSteps_adapt_modifiable_immutable_to_return = + [](SUNAdjointStepper adj_stepper) -> std::tuple + { + suncountertype num_steps_adapt_modifiable; + + SUNErrCode r = SUNAdjointStepper_GetNumSteps(adj_stepper, + &num_steps_adapt_modifiable); + return std::make_tuple(r, num_steps_adapt_modifiable); + }; + + return SUNAdjointStepper_GetNumSteps_adapt_modifiable_immutable_to_return( + adj_stepper); + }, + nb::arg("adj_stepper")); + +m.def( + "SUNAdjointStepper_GetNumRecompute", + [](SUNAdjointStepper adj_stepper) -> std::tuple + { + auto SUNAdjointStepper_GetNumRecompute_adapt_modifiable_immutable_to_return = + [](SUNAdjointStepper adj_stepper) -> std::tuple + { + suncountertype num_recompute_adapt_modifiable; + + SUNErrCode r = + SUNAdjointStepper_GetNumRecompute(adj_stepper, + &num_recompute_adapt_modifiable); + return std::make_tuple(r, num_recompute_adapt_modifiable); + }; + + return SUNAdjointStepper_GetNumRecompute_adapt_modifiable_immutable_to_return( + adj_stepper); + }, + nb::arg("adj_stepper")); + +m.def("SUNAdjointStepper_PrintAllStats", SUNAdjointStepper_PrintAllStats, + nb::arg("adj_stepper"), nb::arg("outfile"), nb::arg("fmt")); +// #ifdef __cplusplus +// +// #endif +// +// #endif diff --git a/bindings/sundials4py/sundials/sundials_context.cpp b/bindings/sundials4py/sundials/sundials_context.cpp new file mode 100644 index 0000000000..744060af14 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_context.cpp @@ -0,0 +1,110 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------- + * This file is the entrypoint for the Python binding code for the + * SUNDIALS SUNContext class. It contains hand-written code for + * functions that require special treatment, and includes the generated + * code produced with the generate.py script. + * -----------------------------------------------------------------*/ + +#include "sundials/sundials_errors.h" +#include "sundials4py.hpp" + +#include +#include +#include + +#include "sundials/sundials_types.h" +#include "sundials_logger_impl.h" +#include "sundials_profiler_impl.h" + +namespace nb = nanobind; +using namespace sundials::experimental; +using sundials::SUNContextDeleter; + +#include "sundials_context_usersupplied.hpp" + +namespace sundials4py { + +using namespace sundials::experimental; + +void bind_suncontext(nb::module_& m) +{ +#include "sundials_context_generated.hpp" + + nb::class_(m, "SUNContext_"); + + // Note: only one error handler can be pushed from python + m.def("SUNContext_PushErrHandler", + [](SUNContext sunctx, + std::function> err_fn) + { + if (!sunctx->python) + { + sunctx->python = SUNContextFunctionTable_Alloc(); + + // Only push the wrapper the first time this is called + SUNErrCode status = + SUNContext_PushErrHandler(sunctx, suncontext_errhandler_wrapper, + sunctx->python); + if (status) + { + throw sundials4py::error_returned( + "SUNContext_PushErrHandler returned an error"); + } + } + + auto fn_table = static_cast(sunctx->python); + + fn_table->err_handlers.push_back(nb::cast(err_fn)); + + return SUN_SUCCESS; + }); + + m.def("SUNContext_PopErrHandler", + [](SUNContext sunctx) -> SUNErrCode + { + if (!sunctx->python) { return SUN_SUCCESS; } + + auto fn_table = static_cast(sunctx->python); + + if (fn_table->err_handlers.size() > 0) + { + // pop the python functions off the interface layer stack + fn_table->err_handlers.pop_back(); + } + + if (fn_table->err_handlers.size() == 0) + { + // now we can pop the suncontext_errhandler_wrapper off the C side stack + return SUNContext_PopErrHandler(sunctx); + } + + return SUN_SUCCESS; + }); + + m.def( + "SUNContext_TestErrHandler", + [](SUNContext sunctx) + { + SUNHandleErrWithMsg(__LINE__, __func__, __FILE__, + "create an error to test the error handlers", + SUN_ERR_ARG_CORRUPT, sunctx); + }, + "This function is for testing purposes and should not be called."); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sundials/sundials_context_generated.hpp b/bindings/sundials4py/sundials/sundials_context_generated.hpp new file mode 100644 index 0000000000..3ecf466a5b --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_context_generated.hpp @@ -0,0 +1,88 @@ +// #ifndef _SUNDIALS_CONTEXT_H +// +// #ifdef __cplusplus +// #endif +// + +m.def( + "SUNContext_Create", + [](SUNComm comm) + -> std::tuple>> + { + auto SUNContext_Create_adapt_modifiable_immutable_to_return = + [](SUNComm comm) -> std::tuple + { + SUNContext sunctx_out_adapt_modifiable; + + SUNErrCode r = SUNContext_Create(comm, &sunctx_out_adapt_modifiable); + return std::make_tuple(r, sunctx_out_adapt_modifiable); + }; + auto SUNContext_Create_adapt_return_type_to_shared_ptr = + [&SUNContext_Create_adapt_modifiable_immutable_to_return](SUNComm comm) + -> std::tuple>> + { + auto lambda_result = + SUNContext_Create_adapt_modifiable_immutable_to_return(comm); + + return std::make_tuple(std::get<0>(lambda_result), + our_make_shared, + SUNContextDeleter>( + std::get<1>(lambda_result))); + }; + + return SUNContext_Create_adapt_return_type_to_shared_ptr(comm); + }, + nb::arg("comm"), nb::rv_policy::reference); + +m.def("SUNContext_GetLastError", SUNContext_GetLastError, nb::arg("sunctx")); + +m.def("SUNContext_PeekLastError", SUNContext_PeekLastError, nb::arg("sunctx")); + +m.def("SUNContext_ClearErrHandlers", SUNContext_ClearErrHandlers, + nb::arg("sunctx")); + +m.def( + "SUNContext_GetProfiler", + [](SUNContext sunctx) -> std::tuple + { + auto SUNContext_GetProfiler_adapt_modifiable_immutable_to_return = + [](SUNContext sunctx) -> std::tuple + { + SUNProfiler profiler_adapt_modifiable; + + SUNErrCode r = SUNContext_GetProfiler(sunctx, &profiler_adapt_modifiable); + return std::make_tuple(r, profiler_adapt_modifiable); + }; + + return SUNContext_GetProfiler_adapt_modifiable_immutable_to_return(sunctx); + }, + nb::arg("sunctx"), "nb::rv_policy::reference", nb::rv_policy::reference); + +m.def("SUNContext_SetProfiler", SUNContext_SetProfiler, nb::arg("sunctx"), + nb::arg("profiler")); + +m.def( + "SUNContext_GetLogger", + [](SUNContext sunctx) -> std::tuple + { + auto SUNContext_GetLogger_adapt_modifiable_immutable_to_return = + [](SUNContext sunctx) -> std::tuple + { + SUNLogger logger_adapt_modifiable; + + SUNErrCode r = SUNContext_GetLogger(sunctx, &logger_adapt_modifiable); + return std::make_tuple(r, logger_adapt_modifiable); + }; + + return SUNContext_GetLogger_adapt_modifiable_immutable_to_return(sunctx); + }, + nb::arg("sunctx"), "nb::rv_policy::reference", nb::rv_policy::reference); + +m.def("SUNContext_SetLogger", SUNContext_SetLogger, nb::arg("sunctx"), + nb::arg("logger")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sundials/sundials_context_usersupplied.hpp b/bindings/sundials4py/sundials/sundials_context_usersupplied.hpp new file mode 100644 index 0000000000..f271e7144c --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_context_usersupplied.hpp @@ -0,0 +1,59 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#ifndef _SUNDIALS4PY_SUNCONTEXT_USERSUPPLIED_HPP +#define _SUNDIALS4PY_SUNCONTEXT_USERSUPPLIED_HPP + +#include +#include +#include +#include + +#include +#include "sundials4py.hpp" +#include "sundials4py_helpers.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +// Function table for user-supplied error handler for SUNContext +struct SUNContextFunctionTable +{ + std::vector err_handlers; +}; + +inline SUNContextFunctionTable* SUNContextFunctionTable_Alloc() +{ + auto fn_table = static_cast( + std::malloc(sizeof(SUNContextFunctionTable))); + std::memset(fn_table, 0, sizeof(SUNContextFunctionTable)); + return fn_table; +} + +inline void suncontext_errhandler_wrapper(int line, const char* func, + const char* file, const char* msg, + SUNErrCode err_code, + void* err_user_data, SUNContext sunctx) +{ + auto fn_table = static_cast(err_user_data); + for (int i = fn_table->err_handlers.size() - 1; i >= 0; i--) + { + fn_table->err_handlers[i](line, func, file, msg, err_code, nullptr, sunctx); + } +} + +#endif \ No newline at end of file diff --git a/bindings/sundials4py/sundials/sundials_core.cpp b/bindings/sundials4py/sundials/sundials_core.cpp new file mode 100644 index 0000000000..2b92bea7c1 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_core.cpp @@ -0,0 +1,100 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------- + * This file defines the sundials4py.core module. + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_nvector(nb::module_& m); +void bind_sunadaptcontroller(nb::module_& m); +void bind_sunadjointcheckpointscheme(nb::module_& m); +void bind_sunadjointstepper(nb::module_& m); +void bind_suncontext(nb::module_& m); +void bind_sundomeigestimator(nb::module_& m); +void bind_sunlinearsolver(nb::module_& m); +void bind_sunlogger(nb::module_& m); +void bind_sunmatrix(nb::module_& m); +void bind_sunmemory(nb::module_& m); +void bind_sunnonlinearsolver(nb::module_& m); +void bind_sunprofiler(nb::module_& m); +void bind_sunstepper(nb::module_& m); + +void bind_core(nb::module_& m) +{ +#include "sundials_errors.hpp" +#include "sundials_types_generated.hpp" + + // handle opening and closing C files + nb::class_(m, "FILE"); + m.def("SUNFileOpen", + [](const char* filename, const char* modes) + { + FILE* tmp = nullptr; + std::shared_ptr fp; + SUNErrCode status = SUNFileOpen(filename, modes, &tmp); + if (status) { fp = nullptr; } + else { fp = std::shared_ptr(tmp, std::fclose); } + return std::make_tuple(status, fp); + }); + + bind_nvector(m); + bind_sunadaptcontroller(m); + bind_sunadjointcheckpointscheme(m); + bind_sunadjointstepper(m); + bind_suncontext(m); + bind_sundomeigestimator(m); + bind_sunlinearsolver(m); + bind_sunlogger(m); + bind_sunmatrix(m); + bind_sunmemory(m); + bind_sunnonlinearsolver(m); + bind_sunprofiler(m); + bind_sunstepper(m); + + // + // Expose sunrealtye and sunindextype as the corresponding numpy types + // + + nb::object np = nb::module_::import_("numpy"); +#if defined(SUNDIALS_SINGLE_PRECISION) + m.attr("sunrealtype") = np.attr("float32"); +#elif defined(SUNDIALS_DOUBLE_PRECISION) + m.attr("sunrealtype") = np.attr("float64"); +#elif defined(SUNDIALS_EXTENDED_PRECISION) + m.attr("sunrealtype") = np.attr("longdouble"); +#else +#error Unknown sunrealtype, email sundials-users@llnl.gov +#endif + +#if defined(SUNDIALS_INT64_T) + m.attr("sunindextype") = np.attr("int64"); +#elif defined(SUNDIALS_INT32_T) + m.attr("sunindextype") = np.attr("int32"); +#else +#error Unknown sunindextype, email sundials-users@llnl.gov +#endif +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sundials/sundials_domeigestimator.cpp b/bindings/sundials4py/sundials/sundials_domeigestimator.cpp new file mode 100644 index 0000000000..08c8e38592 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_domeigestimator.cpp @@ -0,0 +1,90 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------- + * This file is the entrypoint for the Python binding code for the + * SUNDIALS N_Vector class. It contains hand-written code for functions + * that require special treatment, and includes the generated code + * produced with the generate.py script. + * -----------------------------------------------------------------*/ + +#include +#include + +#include "sundials4py.hpp" + +#include +#include + +#include "sundials_domeigestimator_usersupplied.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sundomeigestimator(nb::module_& m) +{ +#include "sundials_domeigestimator_generated.hpp" + + m.def( + "SUNDomEigEstimator_SetOptions", + [](SUNDomEigEstimator self, const std::string& id, + const std::string& file_name, int argc, + const std::vector& args) + { + std::vector argv; + + for (const auto& arg : args) + { + // We need a non-const char*, so we use data() and an explicit cast. + // This is safe as long as the underlying std::string is not modified. + argv.push_back(const_cast(arg.data())); + } + + return SUNDomEigEstimator_SetOptions(self, + id.empty() ? nullptr : id.c_str(), + file_name.empty() ? nullptr + : file_name.c_str(), + argc, argv.data()); + }, + nb::arg("self"), nb::arg("id"), nb::arg("file_name"), nb::arg("argc"), + nb::arg("args")); + + m.def( + "SUNDomEigEstimator_SetATimes", + [](SUNDomEigEstimator dee, + std::function> ATimes) -> SUNErrCode + { + if (!dee->python) + { + dee->python = SUNDomEigEstimatorFunctionTable_Alloc(); + } + + auto fntable = static_cast(dee->python); + + fntable->atimes = nb::cast(ATimes); + + if (ATimes) + { + return SUNDomEigEstimator_SetATimes(dee, fntable, + sundomeigestimator_atimes_wrapper); + } + else { return SUNDomEigEstimator_SetATimes(dee, fntable, nullptr); } + }, + nb::arg("DEE"), nb::arg("ATimes").none()); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sundials/sundials_domeigestimator_generated.hpp b/bindings/sundials4py/sundials/sundials_domeigestimator_generated.hpp new file mode 100644 index 0000000000..c39c1575db --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_domeigestimator_generated.hpp @@ -0,0 +1,117 @@ +// #ifndef _SUNDOMEIGEST_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClassSUNDomEigEstimator_Ops_ = + nb::class_(m, + "SUNDomEigEstimator_Ops_", "Structure containing function pointers to estimator operations") + .def(nb::init<>()) // implicit default constructor + ; + +auto pyClassSUNDomEigEstimator_ = + nb::class_(m, + "SUNDomEigEstimator_", " An estimator is a structure with an implementation-dependent\n 'content' field, and a pointer to a structure of estimator\n operations corresponding to that implementation.") + .def(nb::init<>()) // implicit default constructor + ; + +m.def("SUNDomEigEstimator_SetMaxIters", SUNDomEigEstimator_SetMaxIters, + nb::arg("DEE"), nb::arg("max_iters")); + +m.def("SUNDomEigEstimator_SetNumPreprocessIters", + SUNDomEigEstimator_SetNumPreprocessIters, nb::arg("DEE"), + nb::arg("num_iters")); + +m.def("SUNDomEigEstimator_SetRelTol", SUNDomEigEstimator_SetRelTol, + nb::arg("DEE"), nb::arg("tol")); + +m.def("SUNDomEigEstimator_SetInitialGuess", SUNDomEigEstimator_SetInitialGuess, + nb::arg("DEE"), nb::arg("q")); + +m.def("SUNDomEigEstimator_Initialize", SUNDomEigEstimator_Initialize, + nb::arg("DEE")); + +m.def( + "SUNDomEigEstimator_Estimate", + [](SUNDomEigEstimator DEE) -> std::tuple + { + auto SUNDomEigEstimator_Estimate_adapt_modifiable_immutable_to_return = + [](SUNDomEigEstimator DEE) -> std::tuple + { + sunrealtype lambdaR_adapt_modifiable; + sunrealtype lambdaI_adapt_modifiable; + + SUNErrCode r = SUNDomEigEstimator_Estimate(DEE, &lambdaR_adapt_modifiable, + &lambdaI_adapt_modifiable); + return std::make_tuple(r, lambdaR_adapt_modifiable, + lambdaI_adapt_modifiable); + }; + + return SUNDomEigEstimator_Estimate_adapt_modifiable_immutable_to_return(DEE); + }, + nb::arg("DEE")); + +m.def( + "SUNDomEigEstimator_GetRes", + [](SUNDomEigEstimator DEE) -> std::tuple + { + auto SUNDomEigEstimator_GetRes_adapt_modifiable_immutable_to_return = + [](SUNDomEigEstimator DEE) -> std::tuple + { + sunrealtype res_adapt_modifiable; + + SUNErrCode r = SUNDomEigEstimator_GetRes(DEE, &res_adapt_modifiable); + return std::make_tuple(r, res_adapt_modifiable); + }; + + return SUNDomEigEstimator_GetRes_adapt_modifiable_immutable_to_return(DEE); + }, + nb::arg("DEE")); + +m.def( + "SUNDomEigEstimator_GetNumIters", + [](SUNDomEigEstimator DEE) -> std::tuple + { + auto SUNDomEigEstimator_GetNumIters_adapt_modifiable_immutable_to_return = + [](SUNDomEigEstimator DEE) -> std::tuple + { + long num_iters_adapt_modifiable; + + SUNErrCode r = + SUNDomEigEstimator_GetNumIters(DEE, &num_iters_adapt_modifiable); + return std::make_tuple(r, num_iters_adapt_modifiable); + }; + + return SUNDomEigEstimator_GetNumIters_adapt_modifiable_immutable_to_return( + DEE); + }, + nb::arg("DEE")); + +m.def( + "SUNDomEigEstimator_GetNumATimesCalls", + [](SUNDomEigEstimator DEE) -> std::tuple + { + auto SUNDomEigEstimator_GetNumATimesCalls_adapt_modifiable_immutable_to_return = + [](SUNDomEigEstimator DEE) -> std::tuple + { + long num_ATimes_adapt_modifiable; + + SUNErrCode r = + SUNDomEigEstimator_GetNumATimesCalls(DEE, &num_ATimes_adapt_modifiable); + return std::make_tuple(r, num_ATimes_adapt_modifiable); + }; + + return SUNDomEigEstimator_GetNumATimesCalls_adapt_modifiable_immutable_to_return( + DEE); + }, + nb::arg("DEE")); + +m.def("SUNDomEigEstimator_Write", SUNDomEigEstimator_Write, nb::arg("DEE"), + nb::arg("outfile")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sundials/sundials_domeigestimator_usersupplied.hpp b/bindings/sundials4py/sundials/sundials_domeigestimator_usersupplied.hpp new file mode 100644 index 0000000000..417028985c --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_domeigestimator_usersupplied.hpp @@ -0,0 +1,58 @@ +/*------------------------------------------------------------------------------ + * Programmer(s): Cody J. Balos @ LLNL + *------------------------------------------------------------------------------ + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + *----------------------------------------------------------------------------*/ + +#ifndef _SUNDIALS4PY_SUNDOMEIGESTIMATOR_USERSUPPLIED_HPP +#define _SUNDIALS4PY_SUNDOMEIGESTIMATOR_USERSUPPLIED_HPP + +#include +#include + +#include "sundials4py.hpp" + +#include + +#include "sundials4py_helpers.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +struct SUNDomEigEstimatorFunctionTable +{ + nb::object atimes; +}; + +inline SUNDomEigEstimatorFunctionTable* SUNDomEigEstimatorFunctionTable_Alloc() +{ + // We must use malloc since ARKodeFree calls free + auto fn_table = static_cast( + std::malloc(sizeof(SUNDomEigEstimatorFunctionTable))); + + // Zero out the memory + std::memset(fn_table, 0, sizeof(SUNDomEigEstimatorFunctionTable)); + + return fn_table; +} + +template +SUNErrCode sundomeigestimator_atimes_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, SUNDomEigEstimatorFunctionTable, + 3>(&SUNDomEigEstimatorFunctionTable::atimes, std::forward(args)...); +} + +#endif \ No newline at end of file diff --git a/bindings/sundials4py/sundials/sundials_errors.hpp b/bindings/sundials4py/sundials/sundials_errors.hpp new file mode 100644 index 0000000000..4e8d19629d --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_errors.hpp @@ -0,0 +1,34 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------- + * This file contains hand-written bindings for sundials_errors.h + * -----------------------------------------------------------------*/ + +#include + +#include "sundials_errors_generated.hpp" + +/* Expand SUN_ERR_CODE_LIST to enum */ +#define SUN_EXPAND_TO_NB_BINDING(name, description) \ + .value(#name, name, description) + +auto pyEnumSUNErrCode_ = nb::enum_(m, "SUNErrCode", + nb::is_arithmetic(), "") + .value("SUN_ERR_MINIMUM", SUN_ERR_MINIMUM, "") + SUN_ERR_CODE_LIST(SUN_EXPAND_TO_NB_BINDING) + .value("SUN_ERR_MAXIMUM", SUN_ERR_MAXIMUM, "") + .value("SUN_SUCCESS", SUN_SUCCESS, "") + .export_values(); diff --git a/bindings/sundials4py/sundials/sundials_errors_generated.hpp b/bindings/sundials4py/sundials/sundials_errors_generated.hpp new file mode 100644 index 0000000000..4602aba2fd --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_errors_generated.hpp @@ -0,0 +1,19 @@ +// #ifndef _SUNDIALS_ERRORS_H +// +// #ifdef __cplusplus +// #endif +// + +m.def("SUNLogErrHandlerFn", SUNLogErrHandlerFn, nb::arg("line"), + nb::arg("func"), nb::arg("file"), nb::arg("msg"), nb::arg("err_code"), + nb::arg("err_user_data"), nb::arg("sunctx")); + +m.def("SUNAbortErrHandlerFn", SUNAbortErrHandlerFn, nb::arg("line"), + nb::arg("func"), nb::arg("file"), nb::arg("msg"), nb::arg("err_code"), + nb::arg("err_user_data"), nb::arg("sunctx")); + +m.def("SUNGetErrMsg", SUNGetErrMsg, nb::arg("code")); +// #ifdef __cplusplus +// #endif +// +// #endif diff --git a/bindings/sundials4py/sundials/sundials_linearsolver.cpp b/bindings/sundials4py/sundials/sundials_linearsolver.cpp new file mode 100644 index 0000000000..9064e7a38b --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_linearsolver.cpp @@ -0,0 +1,108 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------- + * This file is the entrypoint for the Python binding code for the + * SUNDIALS SUNLinearSolver class. It contains hand-written code for + * functions that require special treatment, and includes the generated + * code produced with the generate.py script. + * -----------------------------------------------------------------*/ + +#include "sundials/sundials_linearsolver.h" +#include "sundials/sundials_iterative.h" +#include "sundials4py.hpp" + +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +#include "sundials_linearsolver_usersupplied.hpp" + +namespace sundials4py { + +void bind_sunlinearsolver(nb::module_& m) +{ +#include "sundials_linearsolver_generated.hpp" + + m.def( + "SUNLinSolSetOptions", + [](SUNLinearSolver self, const std::string& id, const std::string& file_name, + int argc, const std::vector& args) + { + std::vector argv; + + for (const auto& arg : args) + { + // We need a non-const char*, so we use data() and an explicit cast. + // This is safe as long as the underlying std::string is not modified. + argv.push_back(const_cast(arg.data())); + } + + return SUNLinSolSetOptions(self, id.empty() ? nullptr : id.c_str(), + file_name.empty() ? nullptr : file_name.c_str(), + argc, argv.data()); + }, + nb::arg("self"), nb::arg("id"), nb::arg("file_name"), nb::arg("argc"), + nb::arg("args")); + + m.def("SUNLinSolSolve", SUNLinSolSolve, nb::arg("S"), nb::arg("A").none(), + nb::arg("x"), nb::arg("b"), nb::arg("tol")); + + m.def( + "SUNLinSolSetATimes", + [](SUNLinearSolver LS, + std::function> ATimesFn) -> SUNErrCode + { + if (!LS->python) { LS->python = SUNLinearSolverFunctionTable_Alloc(); } + auto fn_table = static_cast(LS->python); + fn_table->ATimesFn = nb::cast(ATimesFn); + if (ATimesFn) + { + return SUNLinSolSetATimes(LS, LS->python, + sunlinearsolver_atimesfn_wrapper); + } + else { return SUNLinSolSetATimes(LS, nullptr, nullptr); } + }, + nb::arg("LS"), nb::arg("ATimes").none()); + + m.def( + "SUNLinSolSetPreconditioner", + [](SUNLinearSolver LS, + std::function> PSetupFn, + std::function> PSolveFn) -> SUNErrCode + { + if (!LS->python) { LS->python = SUNLinearSolverFunctionTable_Alloc(); } + auto fn_table = static_cast(LS->python); + fn_table->PSetupFn = nb::cast(PSetupFn); + fn_table->PSolveFn = nb::cast(PSolveFn); + if (!PSetupFn && PSolveFn) + { + return SUNLinSolSetPreconditioner(LS, LS->python, nullptr, + sunlinearsolver_psolvefn_wrapper); + } + else if (PSetupFn && PSolveFn) + { + return SUNLinSolSetPreconditioner(LS, LS->python, + sunlinearsolver_psetupfn_wrapper, + sunlinearsolver_psolvefn_wrapper); + } + else { return SUNLinSolSetPreconditioner(LS, nullptr, nullptr, nullptr); } + }, + nb::arg("LS"), nb::arg("PSetupFn").none(), nb::arg("PSolveFn")); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sundials/sundials_linearsolver_generated.hpp b/bindings/sundials4py/sundials/sundials_linearsolver_generated.hpp new file mode 100644 index 0000000000..b99e1baeed --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_linearsolver_generated.hpp @@ -0,0 +1,398 @@ +// #ifndef _SUNDIALS_ITERATIVE_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyEnumSUNPrecType = nb::enum_(m, "SUNPrecType", + nb::is_arithmetic(), "") + .value("SUN_PREC_NONE", SUN_PREC_NONE, "") + .value("SUN_PREC_LEFT", SUN_PREC_LEFT, "") + .value("SUN_PREC_RIGHT", SUN_PREC_RIGHT, "") + .value("SUN_PREC_BOTH", SUN_PREC_BOTH, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +auto pyEnumSUNGramSchmidtType = + nb::enum_(m, "SUNGramSchmidtType", nb::is_arithmetic(), "") + .value("SUN_MODIFIED_GS", SUN_MODIFIED_GS, "") + .value("SUN_CLASSICAL_GS", SUN_CLASSICAL_GS, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +m.def( + "SUNModifiedGS", + [](std::vector v_1d, sundials4py::Array1d h_2d, int k, + int p) -> std::tuple + { + auto SUNModifiedGS_adapt_arr_ptr_to_std_vector = + [](std::vector v_1d, sundials4py::Array1d h_2d, int k, int p, + sunrealtype* new_vk_norm) -> SUNErrCode + { + N_Vector* v_1d_ptr = + reinterpret_cast(v_1d.empty() ? nullptr : v_1d.data()); + sunrealtype** h_2d_ptr = reinterpret_cast(h_2d.data()); + + auto lambda_result = SUNModifiedGS(v_1d_ptr, h_2d_ptr, k, p, new_vk_norm); + return lambda_result; + }; + auto SUNModifiedGS_adapt_modifiable_immutable_to_return = + [&SUNModifiedGS_adapt_arr_ptr_to_std_vector](std::vector v_1d, + sundials4py::Array1d h_2d, + int k, int p) + -> std::tuple + { + sunrealtype new_vk_norm_adapt_modifiable; + + SUNErrCode r = + SUNModifiedGS_adapt_arr_ptr_to_std_vector(v_1d, h_2d, k, p, + &new_vk_norm_adapt_modifiable); + return std::make_tuple(r, new_vk_norm_adapt_modifiable); + }; + + return SUNModifiedGS_adapt_modifiable_immutable_to_return(v_1d, h_2d, k, p); + }, + nb::arg("v_1d"), nb::arg("h_2d"), nb::arg("k"), nb::arg("p")); + +m.def( + "SUNClassicalGS", + [](std::vector v_1d, sundials4py::Array1d h_2d, int k, int p, + sundials4py::Array1d stemp_1d, + std::vector vtemp_1d) -> std::tuple + { + auto SUNClassicalGS_adapt_arr_ptr_to_std_vector = + [](std::vector v_1d, sundials4py::Array1d h_2d, int k, int p, + sunrealtype* new_vk_norm, sundials4py::Array1d stemp_1d, + std::vector vtemp_1d) -> SUNErrCode + { + N_Vector* v_1d_ptr = + reinterpret_cast(v_1d.empty() ? nullptr : v_1d.data()); + sunrealtype** h_2d_ptr = reinterpret_cast(h_2d.data()); + sunrealtype* stemp_1d_ptr = reinterpret_cast(stemp_1d.data()); + N_Vector* vtemp_1d_ptr = reinterpret_cast( + vtemp_1d.empty() ? nullptr : vtemp_1d.data()); + + auto lambda_result = SUNClassicalGS(v_1d_ptr, h_2d_ptr, k, p, new_vk_norm, + stemp_1d_ptr, vtemp_1d_ptr); + return lambda_result; + }; + auto SUNClassicalGS_adapt_modifiable_immutable_to_return = + [&SUNClassicalGS_adapt_arr_ptr_to_std_vector](std::vector v_1d, + sundials4py::Array1d h_2d, + int k, int p, + sundials4py::Array1d stemp_1d, + std::vector vtemp_1d) + -> std::tuple + { + sunrealtype new_vk_norm_adapt_modifiable; + + SUNErrCode r = + SUNClassicalGS_adapt_arr_ptr_to_std_vector(v_1d, h_2d, k, p, + &new_vk_norm_adapt_modifiable, + stemp_1d, vtemp_1d); + return std::make_tuple(r, new_vk_norm_adapt_modifiable); + }; + + return SUNClassicalGS_adapt_modifiable_immutable_to_return(v_1d, h_2d, k, p, + stemp_1d, + vtemp_1d); + }, + nb::arg("v_1d"), nb::arg("h_2d"), nb::arg("k"), nb::arg("p"), + nb::arg("stemp_1d"), nb::arg("vtemp_1d")); + +m.def( + "SUNQRfact", + [](int n, sundials4py::Array1d h_2d, sundials4py::Array1d q_1d, int job) -> int + { + auto SUNQRfact_adapt_arr_ptr_to_std_vector = + [](int n, sundials4py::Array1d h_2d, sundials4py::Array1d q_1d, + int job) -> int + { + sunrealtype** h_2d_ptr = reinterpret_cast(h_2d.data()); + sunrealtype* q_1d_ptr = reinterpret_cast(q_1d.data()); + + auto lambda_result = SUNQRfact(n, h_2d_ptr, q_1d_ptr, job); + return lambda_result; + }; + + return SUNQRfact_adapt_arr_ptr_to_std_vector(n, h_2d, q_1d, job); + }, + nb::arg("n"), nb::arg("h_2d"), nb::arg("q_1d"), nb::arg("job")); + +m.def( + "SUNQRsol", + [](int n, sundials4py::Array1d h_2d, sundials4py::Array1d q_1d, + sundials4py::Array1d b_1d) -> int + { + auto SUNQRsol_adapt_arr_ptr_to_std_vector = + [](int n, sundials4py::Array1d h_2d, sundials4py::Array1d q_1d, + sundials4py::Array1d b_1d) -> int + { + sunrealtype** h_2d_ptr = reinterpret_cast(h_2d.data()); + sunrealtype* q_1d_ptr = reinterpret_cast(q_1d.data()); + sunrealtype* b_1d_ptr = reinterpret_cast(b_1d.data()); + + auto lambda_result = SUNQRsol(n, h_2d_ptr, q_1d_ptr, b_1d_ptr); + return lambda_result; + }; + + return SUNQRsol_adapt_arr_ptr_to_std_vector(n, h_2d, q_1d, b_1d); + }, + nb::arg("n"), nb::arg("h_2d"), nb::arg("q_1d"), nb::arg("b_1d")); + +m.def( + "SUNQRAdd_MGS", + [](std::vector Q_1d, sundials4py::Array1d R_1d, N_Vector df, int m, + int mMax, void* QRdata) -> SUNErrCode + { + auto SUNQRAdd_MGS_adapt_arr_ptr_to_std_vector = + [](std::vector Q_1d, sundials4py::Array1d R_1d, N_Vector df, + int m, int mMax, void* QRdata) -> SUNErrCode + { + N_Vector* Q_1d_ptr = + reinterpret_cast(Q_1d.empty() ? nullptr : Q_1d.data()); + sunrealtype* R_1d_ptr = reinterpret_cast(R_1d.data()); + + auto lambda_result = SUNQRAdd_MGS(Q_1d_ptr, R_1d_ptr, df, m, mMax, QRdata); + return lambda_result; + }; + + return SUNQRAdd_MGS_adapt_arr_ptr_to_std_vector(Q_1d, R_1d, df, m, mMax, + QRdata); + }, + nb::arg("Q_1d"), nb::arg("R_1d"), nb::arg("df"), nb::arg("m"), + nb::arg("mMax"), nb::arg("QRdata")); + +m.def( + "SUNQRAdd_ICWY", + [](std::vector Q_1d, sundials4py::Array1d R_1d, N_Vector df, int m, + int mMax, void* QRdata) -> SUNErrCode + { + auto SUNQRAdd_ICWY_adapt_arr_ptr_to_std_vector = + [](std::vector Q_1d, sundials4py::Array1d R_1d, N_Vector df, + int m, int mMax, void* QRdata) -> SUNErrCode + { + N_Vector* Q_1d_ptr = + reinterpret_cast(Q_1d.empty() ? nullptr : Q_1d.data()); + sunrealtype* R_1d_ptr = reinterpret_cast(R_1d.data()); + + auto lambda_result = SUNQRAdd_ICWY(Q_1d_ptr, R_1d_ptr, df, m, mMax, QRdata); + return lambda_result; + }; + + return SUNQRAdd_ICWY_adapt_arr_ptr_to_std_vector(Q_1d, R_1d, df, m, mMax, + QRdata); + }, + nb::arg("Q_1d"), nb::arg("R_1d"), nb::arg("df"), nb::arg("m"), + nb::arg("mMax"), nb::arg("QRdata")); + +m.def( + "SUNQRAdd_ICWY_SB", + [](std::vector Q_1d, sundials4py::Array1d R_1d, N_Vector df, int m, + int mMax, void* QRdata) -> SUNErrCode + { + auto SUNQRAdd_ICWY_SB_adapt_arr_ptr_to_std_vector = + [](std::vector Q_1d, sundials4py::Array1d R_1d, N_Vector df, + int m, int mMax, void* QRdata) -> SUNErrCode + { + N_Vector* Q_1d_ptr = + reinterpret_cast(Q_1d.empty() ? nullptr : Q_1d.data()); + sunrealtype* R_1d_ptr = reinterpret_cast(R_1d.data()); + + auto lambda_result = SUNQRAdd_ICWY_SB(Q_1d_ptr, R_1d_ptr, df, m, mMax, + QRdata); + return lambda_result; + }; + + return SUNQRAdd_ICWY_SB_adapt_arr_ptr_to_std_vector(Q_1d, R_1d, df, m, mMax, + QRdata); + }, + nb::arg("Q_1d"), nb::arg("R_1d"), nb::arg("df"), nb::arg("m"), + nb::arg("mMax"), nb::arg("QRdata")); + +m.def( + "SUNQRAdd_CGS2", + [](std::vector Q_1d, sundials4py::Array1d R_1d, N_Vector df, int m, + int mMax, void* QRdata) -> SUNErrCode + { + auto SUNQRAdd_CGS2_adapt_arr_ptr_to_std_vector = + [](std::vector Q_1d, sundials4py::Array1d R_1d, N_Vector df, + int m, int mMax, void* QRdata) -> SUNErrCode + { + N_Vector* Q_1d_ptr = + reinterpret_cast(Q_1d.empty() ? nullptr : Q_1d.data()); + sunrealtype* R_1d_ptr = reinterpret_cast(R_1d.data()); + + auto lambda_result = SUNQRAdd_CGS2(Q_1d_ptr, R_1d_ptr, df, m, mMax, QRdata); + return lambda_result; + }; + + return SUNQRAdd_CGS2_adapt_arr_ptr_to_std_vector(Q_1d, R_1d, df, m, mMax, + QRdata); + }, + nb::arg("Q_1d"), nb::arg("R_1d"), nb::arg("df"), nb::arg("m"), + nb::arg("mMax"), nb::arg("QRdata")); + +m.def( + "SUNQRAdd_DCGS2", + [](std::vector Q_1d, sundials4py::Array1d R_1d, N_Vector df, int m, + int mMax, void* QRdata) -> SUNErrCode + { + auto SUNQRAdd_DCGS2_adapt_arr_ptr_to_std_vector = + [](std::vector Q_1d, sundials4py::Array1d R_1d, N_Vector df, + int m, int mMax, void* QRdata) -> SUNErrCode + { + N_Vector* Q_1d_ptr = + reinterpret_cast(Q_1d.empty() ? nullptr : Q_1d.data()); + sunrealtype* R_1d_ptr = reinterpret_cast(R_1d.data()); + + auto lambda_result = SUNQRAdd_DCGS2(Q_1d_ptr, R_1d_ptr, df, m, mMax, + QRdata); + return lambda_result; + }; + + return SUNQRAdd_DCGS2_adapt_arr_ptr_to_std_vector(Q_1d, R_1d, df, m, mMax, + QRdata); + }, + nb::arg("Q_1d"), nb::arg("R_1d"), nb::arg("df"), nb::arg("m"), + nb::arg("mMax"), nb::arg("QRdata")); + +m.def( + "SUNQRAdd_DCGS2_SB", + [](std::vector Q_1d, sundials4py::Array1d R_1d, N_Vector df, int m, + int mMax, void* QRdata) -> SUNErrCode + { + auto SUNQRAdd_DCGS2_SB_adapt_arr_ptr_to_std_vector = + [](std::vector Q_1d, sundials4py::Array1d R_1d, N_Vector df, + int m, int mMax, void* QRdata) -> SUNErrCode + { + N_Vector* Q_1d_ptr = + reinterpret_cast(Q_1d.empty() ? nullptr : Q_1d.data()); + sunrealtype* R_1d_ptr = reinterpret_cast(R_1d.data()); + + auto lambda_result = SUNQRAdd_DCGS2_SB(Q_1d_ptr, R_1d_ptr, df, m, mMax, + QRdata); + return lambda_result; + }; + + return SUNQRAdd_DCGS2_SB_adapt_arr_ptr_to_std_vector(Q_1d, R_1d, df, m, + mMax, QRdata); + }, + nb::arg("Q_1d"), nb::arg("R_1d"), nb::arg("df"), nb::arg("m"), + nb::arg("mMax"), nb::arg("QRdata")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// +// #ifndef _SUNLINEARSOLVER_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyEnumSUNLinearSolver_Type = + nb::enum_(m, "SUNLinearSolver_Type", + nb::is_arithmetic(), "") + .value("SUNLINEARSOLVER_DIRECT", SUNLINEARSOLVER_DIRECT, "") + .value("SUNLINEARSOLVER_ITERATIVE", SUNLINEARSOLVER_ITERATIVE, "") + .value("SUNLINEARSOLVER_MATRIX_ITERATIVE", SUNLINEARSOLVER_MATRIX_ITERATIVE, + "") + .value("SUNLINEARSOLVER_MATRIX_EMBEDDED", SUNLINEARSOLVER_MATRIX_EMBEDDED, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +auto pyEnumSUNLinearSolver_ID = + nb::enum_(m, "SUNLinearSolver_ID", nb::is_arithmetic(), "") + .value("SUNLINEARSOLVER_BAND", SUNLINEARSOLVER_BAND, "") + .value("SUNLINEARSOLVER_DENSE", SUNLINEARSOLVER_DENSE, "") + .value("SUNLINEARSOLVER_KLU", SUNLINEARSOLVER_KLU, "") + .value("SUNLINEARSOLVER_LAPACKBAND", SUNLINEARSOLVER_LAPACKBAND, "") + .value("SUNLINEARSOLVER_LAPACKDENSE", SUNLINEARSOLVER_LAPACKDENSE, "") + .value("SUNLINEARSOLVER_PCG", SUNLINEARSOLVER_PCG, "") + .value("SUNLINEARSOLVER_SPBCGS", SUNLINEARSOLVER_SPBCGS, "") + .value("SUNLINEARSOLVER_SPFGMR", SUNLINEARSOLVER_SPFGMR, "") + .value("SUNLINEARSOLVER_SPGMR", SUNLINEARSOLVER_SPGMR, "") + .value("SUNLINEARSOLVER_SPTFQMR", SUNLINEARSOLVER_SPTFQMR, "") + .value("SUNLINEARSOLVER_SUPERLUDIST", SUNLINEARSOLVER_SUPERLUDIST, "") + .value("SUNLINEARSOLVER_SUPERLUMT", SUNLINEARSOLVER_SUPERLUMT, "") + .value("SUNLINEARSOLVER_CUSOLVERSP_BATCHQR", + SUNLINEARSOLVER_CUSOLVERSP_BATCHQR, "") + .value("SUNLINEARSOLVER_MAGMADENSE", SUNLINEARSOLVER_MAGMADENSE, "") + .value("SUNLINEARSOLVER_ONEMKLDENSE", SUNLINEARSOLVER_ONEMKLDENSE, "") + .value("SUNLINEARSOLVER_GINKGO", SUNLINEARSOLVER_GINKGO, "") + .value("SUNLINEARSOLVER_GINKGOBATCH", SUNLINEARSOLVER_GINKGOBATCH, "") + .value("SUNLINEARSOLVER_KOKKOSDENSE", SUNLINEARSOLVER_KOKKOSDENSE, "") + .value("SUNLINEARSOLVER_CUSTOM", SUNLINEARSOLVER_CUSTOM, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +auto pyClass_generic_SUNLinearSolver_Ops = + nb::class_<_generic_SUNLinearSolver_Ops>(m, + "_generic_SUNLinearSolver_Ops", "Structure containing function pointers to linear solver operations") + .def(nb::init<>()) // implicit default constructor + ; + +auto pyClass_generic_SUNLinearSolver = + nb::class_<_generic_SUNLinearSolver>(m, + "_generic_SUNLinearSolver", " A linear solver is a structure with an implementation-dependent\n 'content' field, and a pointer to a structure of linear solver\n operations corresponding to that implementation.") + .def(nb::init<>()) // implicit default constructor + ; + +m.def("SUNLinSolGetType", SUNLinSolGetType, nb::arg("S")); + +m.def("SUNLinSolGetID", SUNLinSolGetID, nb::arg("S")); + +m.def("SUNLinSolSetScalingVectors", SUNLinSolSetScalingVectors, nb::arg("S"), + nb::arg("s1"), nb::arg("s2")); + +m.def("SUNLinSolSetZeroGuess", SUNLinSolSetZeroGuess, nb::arg("S"), + nb::arg("onoff")); + +m.def("SUNLinSolInitialize", SUNLinSolInitialize, nb::arg("S")); + +m.def( + "SUNLinSolSetup", + [](SUNLinearSolver S, std::optional A = std::nullopt) -> int + { + auto SUNLinSolSetup_adapt_optional_arg_with_default_null = + [](SUNLinearSolver S, std::optional A = std::nullopt) -> int + { + SUNMatrix A_adapt_default_null = nullptr; + if (A.has_value()) A_adapt_default_null = A.value(); + + auto lambda_result = SUNLinSolSetup(S, A_adapt_default_null); + return lambda_result; + }; + + return SUNLinSolSetup_adapt_optional_arg_with_default_null(S, A); + }, + nb::arg("S"), nb::arg("A").none() = nb::none()); + +m.def("SUNLinSolNumIters", SUNLinSolNumIters, nb::arg("S")); + +m.def("SUNLinSolResNorm", SUNLinSolResNorm, nb::arg("S")); + +m.def("SUNLinSolResid", SUNLinSolResid, nb::arg("S"), + "nb::rv_policy::reference", nb::rv_policy::reference); + +m.def("SUNLinSolLastFlag", SUNLinSolLastFlag, nb::arg("S")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sundials/sundials_linearsolver_usersupplied.hpp b/bindings/sundials4py/sundials/sundials_linearsolver_usersupplied.hpp new file mode 100644 index 0000000000..97b4c624b7 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_linearsolver_usersupplied.hpp @@ -0,0 +1,72 @@ +/* ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#ifndef _SUNDIALS4PY_LINEARSOLVER_USERSUPPLIED_HPP +#define _SUNDIALS4PY_LINEARSOLVER_USERSUPPLIED_HPP + +#include +#include + +#include "sundials/sundials_iterative.h" +#include "sundials4py.hpp" + +#include +#include + +#include "sundials4py_helpers.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +struct SUNLinearSolverFunctionTable +{ + nb::object ATimesFn; + nb::object PSetupFn; + nb::object PSolveFn; +}; + +inline SUNLinearSolverFunctionTable* SUNLinearSolverFunctionTable_Alloc() +{ + auto fn_table = static_cast( + std::malloc(sizeof(SUNLinearSolverFunctionTable))); + std::memset(fn_table, 0, sizeof(SUNLinearSolverFunctionTable)); + return fn_table; +} + +template +inline int sunlinearsolver_atimesfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, SUNLinearSolverFunctionTable, + 3>(&SUNLinearSolverFunctionTable::ATimesFn, std::forward(args)...); +} + +template +inline int sunlinearsolver_psetupfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, SUNLinearSolverFunctionTable, + 1>(&SUNLinearSolverFunctionTable::PSetupFn, std::forward(args)...); +} + +template +inline int sunlinearsolver_psolvefn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, SUNLinearSolverFunctionTable, + 5>(&SUNLinearSolverFunctionTable::PSolveFn, std::forward(args)...); +} + +#endif // _SUNDIALS4PY_LINEARSOLVER_USERSUPPLIED_HPP diff --git a/bindings/sundials4py/sundials/sundials_logger.cpp b/bindings/sundials4py/sundials/sundials_logger.cpp new file mode 100644 index 0000000000..7567461596 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_logger.cpp @@ -0,0 +1,41 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------- + * This file is the entrypoint for the Python binding code for the + * SUNDIALS SUNLogger class. It contains hand-written code for + * functions that require special treatment, and includes the generated + * code produced with the generate.py script. + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +#include "sundials_logger_impl.h" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunlogger(nb::module_& m) +{ +#include "sundials_logger_generated.hpp" + nb::class_(m, "SUNLogger_"); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sundials/sundials_logger_generated.hpp b/bindings/sundials4py/sundials/sundials_logger_generated.hpp new file mode 100644 index 0000000000..155a875cc5 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_logger_generated.hpp @@ -0,0 +1,137 @@ +// #ifndef _SUNDIALS_LOGGER_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyEnumSUNLogLevel = + nb::enum_(m, "SUNLogLevel", nb::is_arithmetic(), "") + .value("SUN_LOGLEVEL_ALL", SUN_LOGLEVEL_ALL, "") + .value("SUN_LOGLEVEL_NONE", SUN_LOGLEVEL_NONE, "") + .value("SUN_LOGLEVEL_ERROR", SUN_LOGLEVEL_ERROR, "") + .value("SUN_LOGLEVEL_WARNING", SUN_LOGLEVEL_WARNING, "") + .value("SUN_LOGLEVEL_INFO", SUN_LOGLEVEL_INFO, "") + .value("SUN_LOGLEVEL_DEBUG", SUN_LOGLEVEL_DEBUG, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +m.def( + "SUNLogger_Create", + [](SUNComm comm, int output_rank) + -> std::tuple>> + { + auto SUNLogger_Create_adapt_modifiable_immutable_to_return = + [](SUNComm comm, int output_rank) -> std::tuple + { + SUNLogger logger_adapt_modifiable; + + SUNErrCode r = SUNLogger_Create(comm, output_rank, + &logger_adapt_modifiable); + return std::make_tuple(r, logger_adapt_modifiable); + }; + auto SUNLogger_Create_adapt_return_type_to_shared_ptr = + [&SUNLogger_Create_adapt_modifiable_immutable_to_return](SUNComm comm, + int output_rank) + -> std::tuple>> + { + auto lambda_result = + SUNLogger_Create_adapt_modifiable_immutable_to_return(comm, output_rank); + + return std::make_tuple(std::get<0>(lambda_result), + our_make_shared, + SUNLoggerDeleter>( + std::get<1>(lambda_result))); + }; + + return SUNLogger_Create_adapt_return_type_to_shared_ptr(comm, output_rank); + }, + nb::arg("comm"), nb::arg("output_rank"), nb::rv_policy::reference); + +m.def( + "SUNLogger_CreateFromEnv", + [](SUNComm comm) + -> std::tuple>> + { + auto SUNLogger_CreateFromEnv_adapt_modifiable_immutable_to_return = + [](SUNComm comm) -> std::tuple + { + SUNLogger logger_adapt_modifiable; + + SUNErrCode r = SUNLogger_CreateFromEnv(comm, &logger_adapt_modifiable); + return std::make_tuple(r, logger_adapt_modifiable); + }; + auto SUNLogger_CreateFromEnv_adapt_return_type_to_shared_ptr = + [&SUNLogger_CreateFromEnv_adapt_modifiable_immutable_to_return](SUNComm comm) + -> std::tuple>> + { + auto lambda_result = + SUNLogger_CreateFromEnv_adapt_modifiable_immutable_to_return(comm); + + return std::make_tuple(std::get<0>(lambda_result), + our_make_shared, + SUNLoggerDeleter>( + std::get<1>(lambda_result))); + }; + + return SUNLogger_CreateFromEnv_adapt_return_type_to_shared_ptr(comm); + }, + nb::arg("comm"), nb::rv_policy::reference); + +m.def("SUNLogger_SetErrorFilename", SUNLogger_SetErrorFilename, + nb::arg("logger"), nb::arg("error_filename")); + +m.def("SUNLogger_SetWarningFilename", SUNLogger_SetWarningFilename, + nb::arg("logger"), nb::arg("warning_filename")); + +m.def("SUNLogger_SetDebugFilename", SUNLogger_SetDebugFilename, + nb::arg("logger"), nb::arg("debug_filename")); + +m.def("SUNLogger_SetInfoFilename", SUNLogger_SetInfoFilename, nb::arg("logger"), + nb::arg("info_filename")); + +m.def( + "SUNLogger_QueueMsg", + [](SUNLogger logger, SUNLogLevel lvl, const char* scope, const char* label, + const char* msg_txt) -> SUNErrCode + { + auto SUNLogger_QueueMsg_adapt_variadic_format = + [](SUNLogger logger, SUNLogLevel lvl, const char* scope, + const char* label, const char* msg_txt) -> SUNErrCode + { + auto lambda_result = SUNLogger_QueueMsg(logger, lvl, scope, label, "%s", + msg_txt); + return lambda_result; + }; + + return SUNLogger_QueueMsg_adapt_variadic_format(logger, lvl, scope, label, + msg_txt); + }, + nb::arg("logger"), nb::arg("lvl"), nb::arg("scope"), nb::arg("label"), + nb::arg("msg_txt")); + +m.def("SUNLogger_Flush", SUNLogger_Flush, nb::arg("logger"), nb::arg("lvl")); + +m.def( + "SUNLogger_GetOutputRank", + [](SUNLogger logger) -> std::tuple + { + auto SUNLogger_GetOutputRank_adapt_modifiable_immutable_to_return = + [](SUNLogger logger) -> std::tuple + { + int output_rank_adapt_modifiable; + + SUNErrCode r = SUNLogger_GetOutputRank(logger, + &output_rank_adapt_modifiable); + return std::make_tuple(r, output_rank_adapt_modifiable); + }; + + return SUNLogger_GetOutputRank_adapt_modifiable_immutable_to_return(logger); + }, + nb::arg("logger")); +// #ifdef __cplusplus +// #endif +// +// #endif diff --git a/bindings/sundials4py/sundials/sundials_matrix.cpp b/bindings/sundials4py/sundials/sundials_matrix.cpp new file mode 100644 index 0000000000..a8349b4a8b --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_matrix.cpp @@ -0,0 +1,37 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------- + * This file is the entrypoint for the Python binding code for the + * SUNDIALS SUNMatrix class. It contains hand-written code for functions + * that require special treatment, and includes the generated code + * produced with the generate.py script. + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunmatrix(nb::module_& m) +{ +#include "sundials_matrix_generated.hpp" +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sundials/sundials_matrix_generated.hpp b/bindings/sundials4py/sundials/sundials_matrix_generated.hpp new file mode 100644 index 0000000000..672fac632d --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_matrix_generated.hpp @@ -0,0 +1,75 @@ +// #ifndef _SUNMATRIX_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyEnumSUNMatrix_ID = + nb::enum_(m, "SUNMatrix_ID", nb::is_arithmetic(), "") + .value("SUNMATRIX_DENSE", SUNMATRIX_DENSE, "") + .value("SUNMATRIX_MAGMADENSE", SUNMATRIX_MAGMADENSE, "") + .value("SUNMATRIX_ONEMKLDENSE", SUNMATRIX_ONEMKLDENSE, "") + .value("SUNMATRIX_BAND", SUNMATRIX_BAND, "") + .value("SUNMATRIX_SPARSE", SUNMATRIX_SPARSE, "") + .value("SUNMATRIX_SLUNRLOC", SUNMATRIX_SLUNRLOC, "") + .value("SUNMATRIX_CUSPARSE", SUNMATRIX_CUSPARSE, "") + .value("SUNMATRIX_GINKGO", SUNMATRIX_GINKGO, "") + .value("SUNMATRIX_GINKGOBATCH", SUNMATRIX_GINKGOBATCH, "") + .value("SUNMATRIX_KOKKOSDENSE", SUNMATRIX_KOKKOSDENSE, "") + .value("SUNMATRIX_CUSTOM", SUNMATRIX_CUSTOM, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +auto pyClass_generic_SUNMatrix_Ops = + nb::class_<_generic_SUNMatrix_Ops>(m, + "_generic_SUNMatrix_Ops", "Structure containing function pointers to matrix operations") + .def(nb::init<>()) // implicit default constructor + ; + +auto pyClass_generic_SUNMatrix = + nb::class_<_generic_SUNMatrix>(m, + "_generic_SUNMatrix", " A matrix is a structure with an implementation-dependent\n 'content' field, and a pointer to a structure of matrix\n operations corresponding to that implementation.") + .def(nb::init<>()) // implicit default constructor + ; + +m.def("SUNMatGetID", SUNMatGetID, nb::arg("A")); + +m.def( + "SUNMatClone", + [](SUNMatrix A) -> std::shared_ptr> + { + auto SUNMatClone_adapt_return_type_to_shared_ptr = + [](SUNMatrix A) -> std::shared_ptr> + { + auto lambda_result = SUNMatClone(A); + + return our_make_shared, SUNMatrixDeleter>( + lambda_result); + }; + + return SUNMatClone_adapt_return_type_to_shared_ptr(A); + }, + nb::arg("A")); + +m.def("SUNMatZero", SUNMatZero, nb::arg("A")); + +m.def("SUNMatCopy", SUNMatCopy, nb::arg("A"), nb::arg("B")); + +m.def("SUNMatScaleAdd", SUNMatScaleAdd, nb::arg("c"), nb::arg("A"), nb::arg("B")); + +m.def("SUNMatScaleAddI", SUNMatScaleAddI, nb::arg("c"), nb::arg("A")); + +m.def("SUNMatMatvecSetup", SUNMatMatvecSetup, nb::arg("A")); + +m.def("SUNMatMatvec", SUNMatMatvec, nb::arg("A"), nb::arg("x"), nb::arg("y")); + +m.def("SUNMatHermitianTransposeVec", SUNMatHermitianTransposeVec, nb::arg("A"), + nb::arg("x"), nb::arg("y")); +// #ifdef __cplusplus +// +// #endif +// +// #endif diff --git a/bindings/sundials4py/sundials/sundials_memory.cpp b/bindings/sundials4py/sundials/sundials_memory.cpp new file mode 100644 index 0000000000..fbb3b1fc06 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_memory.cpp @@ -0,0 +1,37 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------- + * This file is the entrypoint for the Python binding code for the + * SUNDIALS SUNMemoryHelper class. It contains hand-written code for + * functions that require special treatment, and includes the + * generated code produced with the generate.py script. + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunmemory(nb::module_& m) +{ +#include "sundials_memory_generated.hpp" +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sundials/sundials_memory_generated.hpp b/bindings/sundials4py/sundials/sundials_memory_generated.hpp new file mode 100644 index 0000000000..5678423e1b --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_memory_generated.hpp @@ -0,0 +1,62 @@ +// #ifndef _SUNDIALS_MEMORY_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyEnumSUNMemoryType = nb::enum_(m, "SUNMemoryType", + nb::is_arithmetic(), "") + .value("SUNMEMTYPE_HOST", SUNMEMTYPE_HOST, + "pageable memory accessible on the host") + .value("SUNMEMTYPE_PINNED", SUNMEMTYPE_PINNED, + "page-locked memory accessible on the host") + .value("SUNMEMTYPE_DEVICE", SUNMEMTYPE_DEVICE, + "memory accessible from the device") + .value("SUNMEMTYPE_UVM", SUNMEMTYPE_UVM, + "memory accessible from the host or device") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +auto pyClassSUNMemoryHelper_ = + nb::class_(m, "SUNMemoryHelper_", "") + .def(nb::init<>()) // implicit default constructor + ; + +auto pyClassSUNMemoryHelper_Ops_ = + nb::class_(m, "SUNMemoryHelper_Ops_", "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNMemoryHelper_Clone", + [](SUNMemoryHelper param_0) + -> std::shared_ptr> + { + auto SUNMemoryHelper_Clone_adapt_return_type_to_shared_ptr = + [](SUNMemoryHelper param_0) + -> std::shared_ptr> + { + auto lambda_result = SUNMemoryHelper_Clone(param_0); + + return our_make_shared, + SUNMemoryHelperDeleter>(lambda_result); + }; + + return SUNMemoryHelper_Clone_adapt_return_type_to_shared_ptr(param_0); + }, + nb::arg("param_0")); + +m.def("SUNMemoryHelper_SetDefaultQueue", SUNMemoryHelper_SetDefaultQueue, + nb::arg("param_0"), nb::arg("queue")); + +m.def("SUNMemoryHelper_ImplementsRequiredOps", + SUNMemoryHelper_ImplementsRequiredOps, nb::arg("param_0")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sundials/sundials_nonlinearsolver.cpp b/bindings/sundials4py/sundials/sundials_nonlinearsolver.cpp new file mode 100644 index 0000000000..45d669a235 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_nonlinearsolver.cpp @@ -0,0 +1,167 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------- + * This file is the entrypoint for the Python binding code for the + * SUNDIALS SUNNonlinearSolver class. It contains hand-written code + * for functions that require special treatment, and includes the + * generated code produced with the generate.py script. + * -----------------------------------------------------------------*/ + +#include "sundials/sundials_nonlinearsolver.h" +#include "sundials4py.hpp" + +#include + +#include "sundials_nonlinearsolver_usersupplied.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunnonlinearsolver(nb::module_& m) +{ +#include "sundials_nonlinearsolver_generated.hpp" + + m.def( + "SUNNonlinSolSetOptions", + [](SUNNonlinearSolver self, const std::string& id, + const std::string& file_name, int argc, + const std::vector& args) + { + std::vector argv; + + for (const auto& arg : args) + { + // We need a non-const char*, so we use data() and an explicit cast. + // This is safe as long as the underlying std::string is not modified. + argv.push_back(const_cast(arg.data())); + } + + return SUNNonlinSolSetOptions(self, id.empty() ? nullptr : id.c_str(), + file_name.empty() ? nullptr + : file_name.c_str(), + argc, argv.data()); + }, + nb::arg("self"), nb::arg("id"), nb::arg("file_name"), nb::arg("argc"), + nb::arg("args")); + + m.def( + "SUNNonlinSolSetup", + [](SUNNonlinearSolver NLS, N_Vector y) + { + if (!NLS->python) + { + NLS->python = SUNNonlinearSolverFunctionTable_Alloc(); + } + return SUNNonlinSolSetup(NLS, y, NLS->python); + }, + nb::arg("NLS"), nb::arg("y")); + + m.def( + "SUNNonlinSolSolve", + [](SUNNonlinearSolver NLS, N_Vector y0, N_Vector y, N_Vector w, + sunrealtype tol, sunbooleantype callLSetup) + { + if (!NLS->python) + { + NLS->python = SUNNonlinearSolverFunctionTable_Alloc(); + } + return SUNNonlinSolSolve(NLS, y0, y, w, tol, callLSetup, NLS->python); + }, + nb::arg("NLS"), nb::arg("y0"), nb::arg("y"), nb::arg("w"), nb::arg("tol"), + nb::arg("callLSetup")); + + m.def( + "SUNNonlinSolSetSysFn", + [](SUNNonlinearSolver NLS, + std::function> SysFn) -> SUNErrCode + { + if (!NLS->python) + { + NLS->python = SUNNonlinearSolverFunctionTable_Alloc(); + } + auto fntable = static_cast(NLS->python); + fntable->sysfn = nb::cast(SysFn); + if (SysFn) + { + return SUNNonlinSolSetSysFn(NLS, sunnonlinearsolver_sysfn_wrapper); + } + else { return SUNNonlinSolSetSysFn(NLS, nullptr); } + }, + nb::arg("NLS"), nb::arg("SysFn").none()); + + m.def( + "SUNNonlinSolSetLSetupFn", + [](SUNNonlinearSolver NLS, + std::function SetupFn) -> SUNErrCode + { + if (!NLS->python) + { + NLS->python = SUNNonlinearSolverFunctionTable_Alloc(); + } + auto fntable = static_cast(NLS->python); + fntable->lsetupfn = nb::cast(SetupFn); + if (SetupFn) + { + return SUNNonlinSolSetLSetupFn(NLS, sunnonlinearsolver_lsetupfn_wrapper); + } + else { return SUNNonlinSolSetLSetupFn(NLS, nullptr); } + }, + nb::arg("NLS"), nb::arg("SetupFn").none()); + + m.def( + "SUNNonlinSolSetLSolveFn", + [](SUNNonlinearSolver NLS, + std::function> SolveFn) -> SUNErrCode + { + if (!NLS->python) + { + NLS->python = SUNNonlinearSolverFunctionTable_Alloc(); + } + auto fntable = static_cast(NLS->python); + fntable->lsolvefn = nb::cast(SolveFn); + if (SolveFn) + { + return SUNNonlinSolSetLSolveFn(NLS, sunnonlinearsolver_lsolvefn_wrapper); + } + else { return SUNNonlinSolSetLSolveFn(NLS, nullptr); } + }, + nb::arg("NLS"), nb::arg("SolveFn").none()); + + m.def( + "SUNNonlinSolSetConvTestFn", + [](SUNNonlinearSolver NLS, + std::function> CTestFn) -> SUNErrCode + { + if (!NLS->python) + { + NLS->python = SUNNonlinearSolverFunctionTable_Alloc(); + } + auto fntable = static_cast(NLS->python); + fntable->convtestfn = nb::cast(CTestFn); + if (CTestFn) + { + return SUNNonlinSolSetConvTestFn(NLS, + sunnonlinearsolver_convtestfn_wrapper, + NLS->python); + } + else { return SUNNonlinSolSetConvTestFn(NLS, nullptr, nullptr); } + }, + nb::arg("NLS"), nb::arg("CTestFn").none()); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sundials/sundials_nonlinearsolver_generated.hpp b/bindings/sundials4py/sundials/sundials_nonlinearsolver_generated.hpp new file mode 100644 index 0000000000..a92b81a617 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_nonlinearsolver_generated.hpp @@ -0,0 +1,97 @@ +// #ifndef _SUNNONLINEARSOLVER_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyEnumSUNNonlinearSolver_Type = + nb::enum_(m, "SUNNonlinearSolver_Type", + nb::is_arithmetic(), "") + .value("SUNNONLINEARSOLVER_ROOTFIND", SUNNONLINEARSOLVER_ROOTFIND, "") + .value("SUNNONLINEARSOLVER_FIXEDPOINT", SUNNONLINEARSOLVER_FIXEDPOINT, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +auto pyClass_generic_SUNNonlinearSolver_Ops = + nb::class_< + _generic_SUNNonlinearSolver_Ops>(m, "_generic_SUNNonlinearSolver_Ops", + "Structure containing function pointers " + "to nonlinear solver operations") + .def(nb::init<>()) // implicit default constructor + ; + +auto pyClass_generic_SUNNonlinearSolver = + nb::class_<_generic_SUNNonlinearSolver>(m, + "_generic_SUNNonlinearSolver", " A nonlinear solver is a structure with an implementation-dependent 'content'\n field, and a pointer to a structure of solver nonlinear solver operations\n corresponding to that implementation.") + .def(nb::init<>()) // implicit default constructor + ; + +m.def("SUNNonlinSolGetType", SUNNonlinSolGetType, nb::arg("NLS")); + +m.def("SUNNonlinSolInitialize", SUNNonlinSolInitialize, nb::arg("NLS")); + +m.def("SUNNonlinSolSetMaxIters", SUNNonlinSolSetMaxIters, nb::arg("NLS"), + nb::arg("maxiters")); + +m.def( + "SUNNonlinSolGetNumIters", + [](SUNNonlinearSolver NLS) -> std::tuple + { + auto SUNNonlinSolGetNumIters_adapt_modifiable_immutable_to_return = + [](SUNNonlinearSolver NLS) -> std::tuple + { + long niters_adapt_modifiable; + + SUNErrCode r = SUNNonlinSolGetNumIters(NLS, &niters_adapt_modifiable); + return std::make_tuple(r, niters_adapt_modifiable); + }; + + return SUNNonlinSolGetNumIters_adapt_modifiable_immutable_to_return(NLS); + }, + nb::arg("NLS")); + +m.def( + "SUNNonlinSolGetCurIter", + [](SUNNonlinearSolver NLS) -> std::tuple + { + auto SUNNonlinSolGetCurIter_adapt_modifiable_immutable_to_return = + [](SUNNonlinearSolver NLS) -> std::tuple + { + int iter_adapt_modifiable; + + SUNErrCode r = SUNNonlinSolGetCurIter(NLS, &iter_adapt_modifiable); + return std::make_tuple(r, iter_adapt_modifiable); + }; + + return SUNNonlinSolGetCurIter_adapt_modifiable_immutable_to_return(NLS); + }, + nb::arg("NLS")); + +m.def( + "SUNNonlinSolGetNumConvFails", + [](SUNNonlinearSolver NLS) -> std::tuple + { + auto SUNNonlinSolGetNumConvFails_adapt_modifiable_immutable_to_return = + [](SUNNonlinearSolver NLS) -> std::tuple + { + long nconvfails_adapt_modifiable; + + SUNErrCode r = SUNNonlinSolGetNumConvFails(NLS, + &nconvfails_adapt_modifiable); + return std::make_tuple(r, nconvfails_adapt_modifiable); + }; + + return SUNNonlinSolGetNumConvFails_adapt_modifiable_immutable_to_return(NLS); + }, + nb::arg("NLS")); +m.attr("SUN_NLS_CONTINUE") = +901; +m.attr("SUN_NLS_CONV_RECVR") = +902; +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sundials/sundials_nonlinearsolver_usersupplied.hpp b/bindings/sundials4py/sundials/sundials_nonlinearsolver_usersupplied.hpp new file mode 100644 index 0000000000..26d5a69a3f --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_nonlinearsolver_usersupplied.hpp @@ -0,0 +1,88 @@ +/* ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#ifndef _SUNDIALS4PY_NONLINEARSOLVER_USERSUPPLIED_HPP +#define _SUNDIALS4PY_NONLINEARSOLVER_USERSUPPLIED_HPP + +#include +#include + +#include "sundials4py.hpp" + +#include +#include + +#include "sundials4py_helpers.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +struct SUNNonlinearSolverFunctionTable +{ + nb::object sysfn; + nb::object lsetupfn; + nb::object lsolvefn; + nb::object convtestfn; +}; + +inline SUNNonlinearSolverFunctionTable* SUNNonlinearSolverFunctionTable_Alloc() +{ + auto fn_table = static_cast( + std::malloc(sizeof(SUNNonlinearSolverFunctionTable))); + std::memset(fn_table, 0, sizeof(SUNNonlinearSolverFunctionTable)); + return fn_table; +} + +template +inline int sunnonlinearsolver_sysfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, SUNNonlinearSolverFunctionTable, + 1>(&SUNNonlinearSolverFunctionTable::sysfn, std::forward(args)...); +} + +using SUNNonlinSolLSetupStdFn = std::tuple(sunbooleantype jbad, + void* mem); + +inline int sunnonlinearsolver_lsetupfn_wrapper(sunbooleantype jbad, + sunbooleantype* jcur, void* mem) +{ + auto fn_table = static_cast(mem); + auto fn = nb::cast>(fn_table->lsetupfn); + + auto result = fn(jbad, nullptr); + + *jcur = std::get<1>(result); + + return std::get<0>(result); +} + +template +inline int sunnonlinearsolver_lsolvefn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, SUNNonlinearSolverFunctionTable, + 1>(&SUNNonlinearSolverFunctionTable::lsolvefn, std::forward(args)...); +} + +template +inline int sunnonlinearsolver_convtestfn_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, SUNNonlinearSolverFunctionTable, + 1>(&SUNNonlinearSolverFunctionTable::convtestfn, std::forward(args)...); +} + +#endif // _SUNDIALS4PY_NONLINEARSOLVER_USERSUPPLIED_HPP diff --git a/bindings/sundials4py/sundials/sundials_nvector.cpp b/bindings/sundials4py/sundials/sundials_nvector.cpp new file mode 100644 index 0000000000..f120354d73 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_nvector.cpp @@ -0,0 +1,127 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------- + * This file is the entrypoint for the Python binding code for the + * SUNDIALS N_Vector class. It contains hand-written code for functions + * that require special treatment, and includes the generated code + * produced with the generate.py script. + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_nvector(nb::module_& m) +{ +#include "sundials_nvector_generated.hpp" + + m.def( + "N_VGetArrayPointer", + [](N_Vector v) + { + auto ptr = N_VGetArrayPointer(v); + if (!ptr) + { + throw sundials4py::error_returned("Failed to get array pointer"); + } + auto owner = nb::find(v); + size_t shape[1]{static_cast(N_VGetLength(v))}; + return sundials4py::Array1d(ptr, 1, shape, owner); + }, + nb::rv_policy::reference); + + m.def( + "N_VGetDeviceArrayPointer", + [](N_Vector v) + { + auto ptr = N_VGetDeviceArrayPointer(v); + if (!ptr) + { + throw sundials4py::error_returned("Failed to get array pointer"); + } + auto owner = nb::find(v); + size_t shape[1]{static_cast(N_VGetLength(v))}; + return sundials4py::Array1d(ptr, 1, shape, owner); + }, + nb::rv_policy::reference); + + m.def("N_VSetArrayPointer", + [](sundials4py::Array1d arr, N_Vector v) + { + if (arr.shape(0) != N_VGetLength(v)) + { + throw sundials4py::error_returned( + "Array shape does not match vector length"); + } + N_VSetArrayPointer(arr.data(), v); + }); + + m.def( + "N_VScaleAddMultiVectorArray", + [](int nvec, int nsum, sundials4py::Array1d c_1d, + std::vector X_1d, std::vector> Y_2d, + std::vector> Z_2d) -> SUNErrCode + { + sunrealtype* c_1d_ptr = reinterpret_cast(c_1d.data()); + N_Vector* X_1d_ptr = + reinterpret_cast(X_1d.empty() ? nullptr : X_1d.data()); + + // Convert Y_2d and Z_2d to N_Vector** + std::vector Y_2d_ptrs, Z_2d_ptrs; + for (auto& row : Y_2d) { Y_2d_ptrs.push_back(row.data()); } + for (auto& row : Z_2d) { Z_2d_ptrs.push_back(row.data()); } + + N_Vector** Y_2d_ptr = Y_2d_ptrs.data(); + N_Vector** Z_2d_ptr = Z_2d_ptrs.data(); + + auto lambda_result = N_VScaleAddMultiVectorArray(nvec, nsum, c_1d_ptr, + X_1d_ptr, Y_2d_ptr, + Z_2d_ptr); + return lambda_result; + }, + nb::arg("nvec"), nb::arg("nsum"), nb::arg("c_1d"), nb::arg("X_1d"), + nb::arg("Y_2d"), nb::arg("Z_2d")); + + m.def( + "N_VLinearCombinationVectorArray", + [](int nvec, int nsum, sundials4py::Array1d c_1d, + std::vector> X_2d, + std::vector Z_1d) -> SUNErrCode + { + sunrealtype* c_1d_ptr = reinterpret_cast(c_1d.data()); + + // Convert X_2d to N_Vector** + std::vector X_2d_ptrs; + for (auto& row : X_2d) { X_2d_ptrs.push_back(row.data()); } + N_Vector** X_2d_ptr = X_2d_ptrs.data(); + + N_Vector* Z_1d_ptr = + reinterpret_cast(Z_1d.empty() ? nullptr : Z_1d.data()); + + auto lambda_result = N_VLinearCombinationVectorArray(nvec, nsum, c_1d_ptr, + X_2d_ptr, Z_1d_ptr); + return lambda_result; + }, + nb::arg("nvec"), nb::arg("nsum"), nb::arg("c_1d"), nb::arg("X_2d"), + nb::arg("Z_1d")); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sundials/sundials_nvector_generated.hpp b/bindings/sundials4py/sundials/sundials_nvector_generated.hpp new file mode 100644 index 0000000000..1b21253d7a --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_nvector_generated.hpp @@ -0,0 +1,389 @@ +// #ifndef _NVECTOR_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyEnumN_Vector_ID = + nb::enum_(m, "N_Vector_ID", nb::is_arithmetic(), "") + .value("SUNDIALS_NVEC_SERIAL", SUNDIALS_NVEC_SERIAL, "") + .value("SUNDIALS_NVEC_PARALLEL", SUNDIALS_NVEC_PARALLEL, "") + .value("SUNDIALS_NVEC_OPENMP", SUNDIALS_NVEC_OPENMP, "") + .value("SUNDIALS_NVEC_PTHREADS", SUNDIALS_NVEC_PTHREADS, "") + .value("SUNDIALS_NVEC_PARHYP", SUNDIALS_NVEC_PARHYP, "") + .value("SUNDIALS_NVEC_PETSC", SUNDIALS_NVEC_PETSC, "") + .value("SUNDIALS_NVEC_CUDA", SUNDIALS_NVEC_CUDA, "") + .value("SUNDIALS_NVEC_HIP", SUNDIALS_NVEC_HIP, "") + .value("SUNDIALS_NVEC_SYCL", SUNDIALS_NVEC_SYCL, "") + .value("SUNDIALS_NVEC_RAJA", SUNDIALS_NVEC_RAJA, "") + .value("SUNDIALS_NVEC_KOKKOS", SUNDIALS_NVEC_KOKKOS, "") + .value("SUNDIALS_NVEC_OPENMPDEV", SUNDIALS_NVEC_OPENMPDEV, "") + .value("SUNDIALS_NVEC_TRILINOS", SUNDIALS_NVEC_TRILINOS, "") + .value("SUNDIALS_NVEC_MANYVECTOR", SUNDIALS_NVEC_MANYVECTOR, "") + .value("SUNDIALS_NVEC_MPIMANYVECTOR", SUNDIALS_NVEC_MPIMANYVECTOR, "") + .value("SUNDIALS_NVEC_MPIPLUSX", SUNDIALS_NVEC_MPIPLUSX, "") + .value("SUNDIALS_NVEC_CUSTOM", SUNDIALS_NVEC_CUSTOM, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +auto pyClass_generic_N_Vector_Ops = + nb::class_<_generic_N_Vector_Ops>(m, + "_generic_N_Vector_Ops", "Structure containing function pointers to vector operations") + .def(nb::init<>()) // implicit default constructor + ; + +auto pyClass_generic_N_Vector = + nb::class_<_generic_N_Vector>(m, + "_generic_N_Vector", " A vector is a structure with an implementation-dependent\n 'content' field, and a pointer to a structure of vector\n operations corresponding to that implementation.") + .def(nb::init<>()) // implicit default constructor + ; + +m.def("N_VGetVectorID", N_VGetVectorID, nb::arg("w")); + +m.def( + "N_VClone", + [](N_Vector w) -> std::shared_ptr> + { + auto N_VClone_adapt_return_type_to_shared_ptr = + [](N_Vector w) -> std::shared_ptr> + { + auto lambda_result = N_VClone(w); + + return our_make_shared, N_VectorDeleter>( + lambda_result); + }; + + return N_VClone_adapt_return_type_to_shared_ptr(w); + }, + nb::arg("w")); + +m.def( + "N_VCloneEmpty", + [](N_Vector w) -> std::shared_ptr> + { + auto N_VCloneEmpty_adapt_return_type_to_shared_ptr = + [](N_Vector w) -> std::shared_ptr> + { + auto lambda_result = N_VCloneEmpty(w); + + return our_make_shared, N_VectorDeleter>( + lambda_result); + }; + + return N_VCloneEmpty_adapt_return_type_to_shared_ptr(w); + }, + nb::arg("w")); + +m.def("N_VGetCommunicator", N_VGetCommunicator, nb::arg("v")); + +m.def("N_VGetLength", N_VGetLength, nb::arg("v")); + +m.def("N_VGetLocalLength", N_VGetLocalLength, nb::arg("v")); + +m.def("N_VLinearSum", N_VLinearSum, nb::arg("a"), nb::arg("x"), nb::arg("b"), + nb::arg("y"), nb::arg("z")); + +m.def("N_VConst", N_VConst, nb::arg("c"), nb::arg("z")); + +m.def("N_VProd", N_VProd, nb::arg("x"), nb::arg("y"), nb::arg("z")); + +m.def("N_VDiv", N_VDiv, nb::arg("x"), nb::arg("y"), nb::arg("z")); + +m.def("N_VScale", N_VScale, nb::arg("c"), nb::arg("x"), nb::arg("z")); + +m.def("N_VAbs", N_VAbs, nb::arg("x"), nb::arg("z")); + +m.def("N_VInv", N_VInv, nb::arg("x"), nb::arg("z")); + +m.def("N_VAddConst", N_VAddConst, nb::arg("x"), nb::arg("b"), nb::arg("z")); + +m.def("N_VDotProd", N_VDotProd, nb::arg("x"), nb::arg("y")); + +m.def("N_VMaxNorm", N_VMaxNorm, nb::arg("x")); + +m.def("N_VWrmsNorm", N_VWrmsNorm, nb::arg("x"), nb::arg("w")); + +m.def("N_VWrmsNormMask", N_VWrmsNormMask, nb::arg("x"), nb::arg("w"), + nb::arg("id")); + +m.def("N_VMin", N_VMin, nb::arg("x")); + +m.def("N_VWL2Norm", N_VWL2Norm, nb::arg("x"), nb::arg("w")); + +m.def("N_VL1Norm", N_VL1Norm, nb::arg("x")); + +m.def("N_VCompare", N_VCompare, nb::arg("c"), nb::arg("x"), nb::arg("z")); + +m.def("N_VInvTest", N_VInvTest, nb::arg("x"), nb::arg("z")); + +m.def("N_VConstrMask", N_VConstrMask, nb::arg("c"), nb::arg("x"), nb::arg("m")); + +m.def("N_VMinQuotient", N_VMinQuotient, nb::arg("num"), nb::arg("denom")); + +m.def( + "N_VLinearCombination", + [](int nvec, sundials4py::Array1d c_1d, std::vector X_1d, + N_Vector z) -> SUNErrCode + { + auto N_VLinearCombination_adapt_arr_ptr_to_std_vector = + [](int nvec, sundials4py::Array1d c_1d, std::vector X_1d, + N_Vector z) -> SUNErrCode + { + sunrealtype* c_1d_ptr = reinterpret_cast(c_1d.data()); + N_Vector* X_1d_ptr = + reinterpret_cast(X_1d.empty() ? nullptr : X_1d.data()); + + auto lambda_result = N_VLinearCombination(nvec, c_1d_ptr, X_1d_ptr, z); + return lambda_result; + }; + + return N_VLinearCombination_adapt_arr_ptr_to_std_vector(nvec, c_1d, X_1d, z); + }, + nb::arg("nvec"), nb::arg("c_1d"), nb::arg("X_1d"), nb::arg("z")); + +m.def( + "N_VScaleAddMulti", + [](int nvec, sundials4py::Array1d a_1d, N_Vector x, + std::vector Y_1d, std::vector Z_1d) -> SUNErrCode + { + auto N_VScaleAddMulti_adapt_arr_ptr_to_std_vector = + [](int nvec, sundials4py::Array1d a_1d, N_Vector x, + std::vector Y_1d, std::vector Z_1d) -> SUNErrCode + { + sunrealtype* a_1d_ptr = reinterpret_cast(a_1d.data()); + N_Vector* Y_1d_ptr = + reinterpret_cast(Y_1d.empty() ? nullptr : Y_1d.data()); + N_Vector* Z_1d_ptr = + reinterpret_cast(Z_1d.empty() ? nullptr : Z_1d.data()); + + auto lambda_result = N_VScaleAddMulti(nvec, a_1d_ptr, x, Y_1d_ptr, + Z_1d_ptr); + return lambda_result; + }; + + return N_VScaleAddMulti_adapt_arr_ptr_to_std_vector(nvec, a_1d, x, Y_1d, + Z_1d); + }, + nb::arg("nvec"), nb::arg("a_1d"), nb::arg("x"), nb::arg("Y_1d"), + nb::arg("Z_1d")); + +m.def( + "N_VDotProdMulti", + [](int nvec, N_Vector x, std::vector Y_1d, + sundials4py::Array1d dotprods_1d) -> SUNErrCode + { + auto N_VDotProdMulti_adapt_arr_ptr_to_std_vector = + [](int nvec, N_Vector x, std::vector Y_1d, + sundials4py::Array1d dotprods_1d) -> SUNErrCode + { + N_Vector* Y_1d_ptr = + reinterpret_cast(Y_1d.empty() ? nullptr : Y_1d.data()); + sunrealtype* dotprods_1d_ptr = + reinterpret_cast(dotprods_1d.data()); + + auto lambda_result = N_VDotProdMulti(nvec, x, Y_1d_ptr, dotprods_1d_ptr); + return lambda_result; + }; + + return N_VDotProdMulti_adapt_arr_ptr_to_std_vector(nvec, x, Y_1d, + dotprods_1d); + }, + nb::arg("nvec"), nb::arg("x"), nb::arg("Y_1d"), nb::arg("dotprods_1d")); + +m.def( + "N_VLinearSumVectorArray", + [](int nvec, sunrealtype a, std::vector X_1d, sunrealtype b, + std::vector Y_1d, std::vector Z_1d) -> SUNErrCode + { + auto N_VLinearSumVectorArray_adapt_arr_ptr_to_std_vector = + [](int nvec, sunrealtype a, std::vector X_1d, sunrealtype b, + std::vector Y_1d, std::vector Z_1d) -> SUNErrCode + { + N_Vector* X_1d_ptr = + reinterpret_cast(X_1d.empty() ? nullptr : X_1d.data()); + N_Vector* Y_1d_ptr = + reinterpret_cast(Y_1d.empty() ? nullptr : Y_1d.data()); + N_Vector* Z_1d_ptr = + reinterpret_cast(Z_1d.empty() ? nullptr : Z_1d.data()); + + auto lambda_result = N_VLinearSumVectorArray(nvec, a, X_1d_ptr, b, + Y_1d_ptr, Z_1d_ptr); + return lambda_result; + }; + + return N_VLinearSumVectorArray_adapt_arr_ptr_to_std_vector(nvec, a, X_1d, b, + Y_1d, Z_1d); + }, + nb::arg("nvec"), nb::arg("a"), nb::arg("X_1d"), nb::arg("b"), nb::arg("Y_1d"), + nb::arg("Z_1d")); + +m.def( + "N_VScaleVectorArray", + [](int nvec, sundials4py::Array1d c_1d, std::vector X_1d, + std::vector Z_1d) -> SUNErrCode + { + auto N_VScaleVectorArray_adapt_arr_ptr_to_std_vector = + [](int nvec, sundials4py::Array1d c_1d, std::vector X_1d, + std::vector Z_1d) -> SUNErrCode + { + sunrealtype* c_1d_ptr = reinterpret_cast(c_1d.data()); + N_Vector* X_1d_ptr = + reinterpret_cast(X_1d.empty() ? nullptr : X_1d.data()); + N_Vector* Z_1d_ptr = + reinterpret_cast(Z_1d.empty() ? nullptr : Z_1d.data()); + + auto lambda_result = N_VScaleVectorArray(nvec, c_1d_ptr, X_1d_ptr, + Z_1d_ptr); + return lambda_result; + }; + + return N_VScaleVectorArray_adapt_arr_ptr_to_std_vector(nvec, c_1d, X_1d, + Z_1d); + }, + nb::arg("nvec"), nb::arg("c_1d"), nb::arg("X_1d"), nb::arg("Z_1d")); + +m.def( + "N_VConstVectorArray", + [](int nvec, sunrealtype c, std::vector Z_1d) -> SUNErrCode + { + auto N_VConstVectorArray_adapt_arr_ptr_to_std_vector = + [](int nvec, sunrealtype c, std::vector Z_1d) -> SUNErrCode + { + N_Vector* Z_1d_ptr = + reinterpret_cast(Z_1d.empty() ? nullptr : Z_1d.data()); + + auto lambda_result = N_VConstVectorArray(nvec, c, Z_1d_ptr); + return lambda_result; + }; + + return N_VConstVectorArray_adapt_arr_ptr_to_std_vector(nvec, c, Z_1d); + }, + nb::arg("nvec"), nb::arg("c"), nb::arg("Z_1d")); + +m.def( + "N_VWrmsNormVectorArray", + [](int nvec, std::vector X_1d, std::vector W_1d, + sundials4py::Array1d nrm_1d) -> SUNErrCode + { + auto N_VWrmsNormVectorArray_adapt_arr_ptr_to_std_vector = + [](int nvec, std::vector X_1d, std::vector W_1d, + sundials4py::Array1d nrm_1d) -> SUNErrCode + { + N_Vector* X_1d_ptr = + reinterpret_cast(X_1d.empty() ? nullptr : X_1d.data()); + N_Vector* W_1d_ptr = + reinterpret_cast(W_1d.empty() ? nullptr : W_1d.data()); + sunrealtype* nrm_1d_ptr = reinterpret_cast(nrm_1d.data()); + + auto lambda_result = N_VWrmsNormVectorArray(nvec, X_1d_ptr, W_1d_ptr, + nrm_1d_ptr); + return lambda_result; + }; + + return N_VWrmsNormVectorArray_adapt_arr_ptr_to_std_vector(nvec, X_1d, W_1d, + nrm_1d); + }, + nb::arg("nvec"), nb::arg("X_1d"), nb::arg("W_1d"), nb::arg("nrm_1d")); + +m.def( + "N_VWrmsNormMaskVectorArray", + [](int nvec, std::vector X_1d, std::vector W_1d, + N_Vector id, sundials4py::Array1d nrm_1d) -> SUNErrCode + { + auto N_VWrmsNormMaskVectorArray_adapt_arr_ptr_to_std_vector = + [](int nvec, std::vector X_1d, std::vector W_1d, + N_Vector id, sundials4py::Array1d nrm_1d) -> SUNErrCode + { + N_Vector* X_1d_ptr = + reinterpret_cast(X_1d.empty() ? nullptr : X_1d.data()); + N_Vector* W_1d_ptr = + reinterpret_cast(W_1d.empty() ? nullptr : W_1d.data()); + sunrealtype* nrm_1d_ptr = reinterpret_cast(nrm_1d.data()); + + auto lambda_result = N_VWrmsNormMaskVectorArray(nvec, X_1d_ptr, W_1d_ptr, + id, nrm_1d_ptr); + return lambda_result; + }; + + return N_VWrmsNormMaskVectorArray_adapt_arr_ptr_to_std_vector(nvec, X_1d, + W_1d, id, + nrm_1d); + }, + nb::arg("nvec"), nb::arg("X_1d"), nb::arg("W_1d"), nb::arg("id"), + nb::arg("nrm_1d")); + +m.def("N_VDotProdLocal", N_VDotProdLocal, nb::arg("x"), nb::arg("y")); + +m.def("N_VMaxNormLocal", N_VMaxNormLocal, nb::arg("x")); + +m.def("N_VMinLocal", N_VMinLocal, nb::arg("x")); + +m.def("N_VL1NormLocal", N_VL1NormLocal, nb::arg("x")); + +m.def("N_VWSqrSumLocal", N_VWSqrSumLocal, nb::arg("x"), nb::arg("w")); + +m.def("N_VWSqrSumMaskLocal", N_VWSqrSumMaskLocal, nb::arg("x"), nb::arg("w"), + nb::arg("id")); + +m.def("N_VInvTestLocal", N_VInvTestLocal, nb::arg("x"), nb::arg("z")); + +m.def("N_VConstrMaskLocal", N_VConstrMaskLocal, nb::arg("c"), nb::arg("x"), + nb::arg("m")); + +m.def("N_VMinQuotientLocal", N_VMinQuotientLocal, nb::arg("num"), + nb::arg("denom")); + +m.def( + "N_VDotProdMultiLocal", + [](int nvec, N_Vector x, std::vector Y_1d, + sundials4py::Array1d dotprods_1d) -> SUNErrCode + { + auto N_VDotProdMultiLocal_adapt_arr_ptr_to_std_vector = + [](int nvec, N_Vector x, std::vector Y_1d, + sundials4py::Array1d dotprods_1d) -> SUNErrCode + { + N_Vector* Y_1d_ptr = + reinterpret_cast(Y_1d.empty() ? nullptr : Y_1d.data()); + sunrealtype* dotprods_1d_ptr = + reinterpret_cast(dotprods_1d.data()); + + auto lambda_result = N_VDotProdMultiLocal(nvec, x, Y_1d_ptr, + dotprods_1d_ptr); + return lambda_result; + }; + + return N_VDotProdMultiLocal_adapt_arr_ptr_to_std_vector(nvec, x, Y_1d, + dotprods_1d); + }, + nb::arg("nvec"), nb::arg("x"), nb::arg("Y_1d"), nb::arg("dotprods_1d")); + +m.def( + "N_VDotProdMultiAllReduce", + [](int nvec_total, N_Vector x, sundials4py::Array1d sum_1d) -> SUNErrCode + { + auto N_VDotProdMultiAllReduce_adapt_arr_ptr_to_std_vector = + [](int nvec_total, N_Vector x, sundials4py::Array1d sum_1d) -> SUNErrCode + { + sunrealtype* sum_1d_ptr = reinterpret_cast(sum_1d.data()); + + auto lambda_result = N_VDotProdMultiAllReduce(nvec_total, x, sum_1d_ptr); + return lambda_result; + }; + + return N_VDotProdMultiAllReduce_adapt_arr_ptr_to_std_vector(nvec_total, x, + sum_1d); + }, + nb::arg("nvec_total"), nb::arg("x"), nb::arg("sum_1d")); + +m.def("N_VPrint", N_VPrint, nb::arg("v")); + +m.def("N_VPrintFile", N_VPrintFile, nb::arg("v"), nb::arg("outfile")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sundials/sundials_profiler.cpp b/bindings/sundials4py/sundials/sundials_profiler.cpp new file mode 100644 index 0000000000..e2d8d83909 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_profiler.cpp @@ -0,0 +1,42 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------- + * This file is the entrypoint for the Python binding code for the + * SUNDIALS SUNProfiler class. It contains hand-written code for + * functions that require special treatment, and includes the generated + * code produced with the generate.py script. + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include + +#include "sundials/sundials_types.h" +#include "sundials_profiler_impl.h" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunprofiler(nb::module_& m) +{ +#include "sundials_profiler_generated.hpp" + + nb::class_(m, "SUNProfiler_"); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sundials/sundials_profiler_generated.hpp b/bindings/sundials4py/sundials/sundials_profiler_generated.hpp new file mode 100644 index 0000000000..fc62ba7923 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_profiler_generated.hpp @@ -0,0 +1,85 @@ +// #ifndef _SUNDIALS_PROFILER_H +// +// #ifdef __cplusplus +// #endif +// + +m.def( + "SUNProfiler_Create", + [](SUNComm comm, const char* title) + -> std::tuple>> + { + auto SUNProfiler_Create_adapt_modifiable_immutable_to_return = + [](SUNComm comm, const char* title) -> std::tuple + { + SUNProfiler p_adapt_modifiable; + + SUNErrCode r = SUNProfiler_Create(comm, title, &p_adapt_modifiable); + return std::make_tuple(r, p_adapt_modifiable); + }; + auto SUNProfiler_Create_adapt_return_type_to_shared_ptr = + [&SUNProfiler_Create_adapt_modifiable_immutable_to_return](SUNComm comm, + const char* title) + -> std::tuple>> + { + auto lambda_result = + SUNProfiler_Create_adapt_modifiable_immutable_to_return(comm, title); + + return std::make_tuple(std::get<0>(lambda_result), + our_make_shared, + SUNProfilerDeleter>( + std::get<1>(lambda_result))); + }; + + return SUNProfiler_Create_adapt_return_type_to_shared_ptr(comm, title); + }, + nb::arg("comm"), nb::arg("title"), nb::rv_policy::reference); + +m.def("SUNProfiler_Begin", SUNProfiler_Begin, nb::arg("p"), nb::arg("name")); + +m.def("SUNProfiler_End", SUNProfiler_End, nb::arg("p"), nb::arg("name")); + +m.def( + "SUNProfiler_GetTimerResolution", + [](SUNProfiler p) -> std::tuple + { + auto SUNProfiler_GetTimerResolution_adapt_modifiable_immutable_to_return = + [](SUNProfiler p) -> std::tuple + { + double resolution_adapt_modifiable; + + SUNErrCode r = + SUNProfiler_GetTimerResolution(p, &resolution_adapt_modifiable); + return std::make_tuple(r, resolution_adapt_modifiable); + }; + + return SUNProfiler_GetTimerResolution_adapt_modifiable_immutable_to_return(p); + }, + nb::arg("p")); + +m.def( + "SUNProfiler_GetElapsedTime", + [](SUNProfiler p, const char* name) -> std::tuple + { + auto SUNProfiler_GetElapsedTime_adapt_modifiable_immutable_to_return = + [](SUNProfiler p, const char* name) -> std::tuple + { + double time_adapt_modifiable; + + SUNErrCode r = SUNProfiler_GetElapsedTime(p, name, &time_adapt_modifiable); + return std::make_tuple(r, time_adapt_modifiable); + }; + + return SUNProfiler_GetElapsedTime_adapt_modifiable_immutable_to_return(p, + name); + }, + nb::arg("p"), nb::arg("name")); + +m.def("SUNProfiler_Print", SUNProfiler_Print, nb::arg("p"), nb::arg("fp")); + +m.def("SUNProfiler_Reset", SUNProfiler_Reset, nb::arg("p")); +// #ifdef __cplusplus +// +// #endif +// +// #endif diff --git a/bindings/sundials4py/sundials/sundials_stepper.cpp b/bindings/sundials4py/sundials/sundials_stepper.cpp new file mode 100644 index 0000000000..20f3d4725e --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_stepper.cpp @@ -0,0 +1,236 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------- + * This file is the entrypoint for the Python binding code for the + * SUNDIALS SUNStepper class. It contains hand-written code + * for functions that require special treatment, and includes the + * generated code produced with the generate.py script. + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +#include "sundials_stepper_impl.h" +#include "sundials_stepper_usersupplied.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunstepper(nb::module_& m) +{ +#include "sundials_stepper_generated.hpp" + + nb::class_(m, "SUNStepper_"); + + m.def( + "SUNStepper_SetEvolveFn", + [](SUNStepper stepper, + std::function> fn) -> SUNErrCode + { + if (!stepper->python) + { + stepper->python = SUNStepperFunctionTable_Alloc(); + } + auto fntable = static_cast(stepper->python); + fntable->evolve = nb::cast(fn); + if (fn) + { + return SUNStepper_SetEvolveFn(stepper, sunstepper_evolve_wrapper); + } + else { return SUNStepper_SetEvolveFn(stepper, nullptr); } + }, + nb::arg("stepper"), nb::arg("fn").none()); + + m.def( + "SUNStepper_SetOneStepFn", + [](SUNStepper stepper, + std::function> fn) -> SUNErrCode + { + if (!stepper->python) + { + stepper->python = SUNStepperFunctionTable_Alloc(); + } + auto fntable = static_cast(stepper->python); + fntable->one_step = nb::cast(fn); + if (fn) + { + return SUNStepper_SetOneStepFn(stepper, sunstepper_one_step_wrapper); + } + else { return SUNStepper_SetOneStepFn(stepper, nullptr); } + }, + nb::arg("stepper"), nb::arg("fn").none()); + + m.def( + "SUNStepper_SetFullRhsFn", + [](SUNStepper stepper, + std::function> fn) -> SUNErrCode + { + if (!stepper->python) + { + stepper->python = SUNStepperFunctionTable_Alloc(); + } + auto fntable = static_cast(stepper->python); + fntable->full_rhs = nb::cast(fn); + if (fn) + { + return SUNStepper_SetFullRhsFn(stepper, sunstepper_full_rhs_wrapper); + } + else { return SUNStepper_SetFullRhsFn(stepper, nullptr); } + }, + nb::arg("stepper"), nb::arg("fn").none()); + + m.def( + "SUNStepper_SetReInitFn", + [](SUNStepper stepper, + std::function> fn) -> SUNErrCode + { + if (!stepper->python) + { + stepper->python = SUNStepperFunctionTable_Alloc(); + } + auto fntable = static_cast(stepper->python); + fntable->reinit = nb::cast(fn); + if (fn) + { + return SUNStepper_SetReInitFn(stepper, sunstepper_reinit_wrapper); + } + else { return SUNStepper_SetReInitFn(stepper, nullptr); } + }, + nb::arg("stepper"), nb::arg("fn").none()); + + m.def( + "SUNStepper_SetResetFn", + [](SUNStepper stepper, + std::function> fn) -> SUNErrCode + { + if (!stepper->python) + { + stepper->python = SUNStepperFunctionTable_Alloc(); + } + auto fntable = static_cast(stepper->python); + fntable->reset = nb::cast(fn); + if (fn) + { + return SUNStepper_SetResetFn(stepper, sunstepper_reset_wrapper); + } + else { return SUNStepper_SetResetFn(stepper, nullptr); } + }, + nb::arg("stepper"), nb::arg("fn").none()); + + m.def( + "SUNStepper_SetResetCheckpointIndexFn", + [](SUNStepper stepper, + std::function> fn) + -> SUNErrCode + { + if (!stepper->python) + { + stepper->python = SUNStepperFunctionTable_Alloc(); + } + auto fntable = static_cast(stepper->python); + fntable->reset_ckpt_idx = nb::cast(fn); + if (fn) + { + return SUNStepper_SetResetCheckpointIndexFn(stepper, + sunstepper_reset_ckpt_idx_wrapper); + } + else { return SUNStepper_SetResetCheckpointIndexFn(stepper, nullptr); } + }, + nb::arg("stepper"), nb::arg("fn").none()); + + m.def( + "SUNStepper_SetStopTimeFn", + [](SUNStepper stepper, + std::function> fn) -> SUNErrCode + { + if (!stepper->python) + { + stepper->python = SUNStepperFunctionTable_Alloc(); + } + auto fntable = static_cast(stepper->python); + fntable->set_stop_time = nb::cast(fn); + if (fn) + { + return SUNStepper_SetStopTimeFn(stepper, + sunstepper_set_stop_time_wrapper); + } + else { return SUNStepper_SetStopTimeFn(stepper, nullptr); } + }, + nb::arg("stepper"), nb::arg("fn").none()); + + m.def( + "SUNStepper_SetStepDirectionFn", + [](SUNStepper stepper, + std::function> fn) -> SUNErrCode + { + if (!stepper->python) + { + stepper->python = SUNStepperFunctionTable_Alloc(); + } + auto fntable = static_cast(stepper->python); + fntable->set_step_direction = nb::cast(fn); + if (fn) + { + return SUNStepper_SetStepDirectionFn(stepper, + sunstepper_set_step_direction_wrapper); + } + else { return SUNStepper_SetStepDirectionFn(stepper, nullptr); } + }, + nb::arg("stepper"), nb::arg("fn").none()); + + m.def( + "SUNStepper_SetForcingFn", + [](SUNStepper stepper, std::function fn) -> SUNErrCode + { + if (!stepper->python) + { + stepper->python = SUNStepperFunctionTable_Alloc(); + } + auto fntable = static_cast(stepper->python); + fntable->set_forcing = nb::cast(fn); + if (fn) + { + return SUNStepper_SetForcingFn(stepper, sunstepper_set_forcing_wrapper); + } + else { return SUNStepper_SetForcingFn(stepper, nullptr); } + }, + nb::arg("stepper"), nb::arg("fn").none()); + + m.def( + "SUNStepper_SetGetNumStepsFn", + [](SUNStepper stepper, std::function fn) -> SUNErrCode + { + if (!stepper->python) + { + stepper->python = SUNStepperFunctionTable_Alloc(); + } + auto fntable = static_cast(stepper->python); + fntable->get_num_steps = nb::cast(fn); + if (fn) + { + return SUNStepper_SetGetNumStepsFn(stepper, + sunstepper_get_num_steps_wrapper); + } + else { return SUNStepper_SetGetNumStepsFn(stepper, nullptr); } + }, + nb::arg("stepper"), nb::arg("fn").none()); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sundials/sundials_stepper_generated.hpp b/bindings/sundials4py/sundials/sundials_stepper_generated.hpp new file mode 100644 index 0000000000..721bdda75d --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_stepper_generated.hpp @@ -0,0 +1,176 @@ +// #ifndef _SUNDIALS_STEPPER_H +// +// #ifdef __cplusplus +// +// #endif +// + +auto pyEnumSUNFullRhsMode = nb::enum_(m, "SUNFullRhsMode", + nb::is_arithmetic(), "") + .value("SUN_FULLRHS_START", SUN_FULLRHS_START, "") + .value("SUN_FULLRHS_END", SUN_FULLRHS_END, "") + .value("SUN_FULLRHS_OTHER", SUN_FULLRHS_OTHER, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// + +m.def( + "SUNStepper_Create", + [](SUNContext sunctx) + -> std::tuple>> + { + auto SUNStepper_Create_adapt_modifiable_immutable_to_return = + [](SUNContext sunctx) -> std::tuple + { + SUNStepper stepper_adapt_modifiable; + + SUNErrCode r = SUNStepper_Create(sunctx, &stepper_adapt_modifiable); + return std::make_tuple(r, stepper_adapt_modifiable); + }; + auto SUNStepper_Create_adapt_return_type_to_shared_ptr = + [&SUNStepper_Create_adapt_modifiable_immutable_to_return](SUNContext sunctx) + -> std::tuple>> + { + auto lambda_result = + SUNStepper_Create_adapt_modifiable_immutable_to_return(sunctx); + + return std::make_tuple(std::get<0>(lambda_result), + our_make_shared, + SUNStepperDeleter>( + std::get<1>(lambda_result))); + }; + + return SUNStepper_Create_adapt_return_type_to_shared_ptr(sunctx); + }, + nb::arg("sunctx"), + "nb::call_policy>()", + nb::rv_policy::reference, + nb::call_policy>()); + +m.def( + "SUNStepper_Evolve", + [](SUNStepper stepper, sunrealtype tout, + N_Vector vret) -> std::tuple + { + auto SUNStepper_Evolve_adapt_modifiable_immutable_to_return = + [](SUNStepper stepper, sunrealtype tout, + N_Vector vret) -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + SUNErrCode r = SUNStepper_Evolve(stepper, tout, vret, + &tret_adapt_modifiable); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return SUNStepper_Evolve_adapt_modifiable_immutable_to_return(stepper, tout, + vret); + }, + nb::arg("stepper"), nb::arg("tout"), nb::arg("vret")); + +m.def( + "SUNStepper_OneStep", + [](SUNStepper stepper, sunrealtype tout, + N_Vector vret) -> std::tuple + { + auto SUNStepper_OneStep_adapt_modifiable_immutable_to_return = + [](SUNStepper stepper, sunrealtype tout, + N_Vector vret) -> std::tuple + { + sunrealtype tret_adapt_modifiable; + + SUNErrCode r = SUNStepper_OneStep(stepper, tout, vret, + &tret_adapt_modifiable); + return std::make_tuple(r, tret_adapt_modifiable); + }; + + return SUNStepper_OneStep_adapt_modifiable_immutable_to_return(stepper, + tout, vret); + }, + nb::arg("stepper"), nb::arg("tout"), nb::arg("vret")); + +m.def("SUNStepper_FullRhs", SUNStepper_FullRhs, nb::arg("stepper"), + nb::arg("t"), nb::arg("v"), nb::arg("f"), nb::arg("mode")); + +m.def("SUNStepper_ReInit", SUNStepper_ReInit, nb::arg("stepper"), nb::arg("t0"), + nb::arg("v0")); + +m.def("SUNStepper_Reset", SUNStepper_Reset, nb::arg("stepper"), nb::arg("tR"), + nb::arg("vR")); + +m.def("SUNStepper_ResetCheckpointIndex", SUNStepper_ResetCheckpointIndex, + nb::arg("stepper"), nb::arg("ckptIdxR")); + +m.def("SUNStepper_SetStopTime", SUNStepper_SetStopTime, nb::arg("stepper"), + nb::arg("tstop")); + +m.def("SUNStepper_SetStepDirection", SUNStepper_SetStepDirection, + nb::arg("stepper"), nb::arg("stepdir")); + +m.def( + "SUNStepper_SetForcing", + [](SUNStepper stepper, sunrealtype tshift, sunrealtype tscale, + std::vector forcing_1d, int nforcing) -> SUNErrCode + { + auto SUNStepper_SetForcing_adapt_arr_ptr_to_std_vector = + [](SUNStepper stepper, sunrealtype tshift, sunrealtype tscale, + std::vector forcing_1d, int nforcing) -> SUNErrCode + { + N_Vector* forcing_1d_ptr = reinterpret_cast( + forcing_1d.empty() ? nullptr : forcing_1d.data()); + + auto lambda_result = SUNStepper_SetForcing(stepper, tshift, tscale, + forcing_1d_ptr, nforcing); + return lambda_result; + }; + + return SUNStepper_SetForcing_adapt_arr_ptr_to_std_vector(stepper, tshift, + tscale, forcing_1d, + nforcing); + }, + nb::arg("stepper"), nb::arg("tshift"), nb::arg("tscale"), + nb::arg("forcing_1d"), nb::arg("nforcing")); + +m.def("SUNStepper_SetLastFlag", SUNStepper_SetLastFlag, nb::arg("stepper"), + nb::arg("last_flag")); + +m.def( + "SUNStepper_GetLastFlag", + [](SUNStepper stepper) -> std::tuple + { + auto SUNStepper_GetLastFlag_adapt_modifiable_immutable_to_return = + [](SUNStepper stepper) -> std::tuple + { + int last_flag_adapt_modifiable; + + SUNErrCode r = SUNStepper_GetLastFlag(stepper, &last_flag_adapt_modifiable); + return std::make_tuple(r, last_flag_adapt_modifiable); + }; + + return SUNStepper_GetLastFlag_adapt_modifiable_immutable_to_return(stepper); + }, + nb::arg("stepper")); + +m.def( + "SUNStepper_GetNumSteps", + [](SUNStepper stepper) -> std::tuple + { + auto SUNStepper_GetNumSteps_adapt_modifiable_immutable_to_return = + [](SUNStepper stepper) -> std::tuple + { + suncountertype nst_adapt_modifiable; + + SUNErrCode r = SUNStepper_GetNumSteps(stepper, &nst_adapt_modifiable); + return std::make_tuple(r, nst_adapt_modifiable); + }; + + return SUNStepper_GetNumSteps_adapt_modifiable_immutable_to_return(stepper); + }, + nb::arg("stepper")); +// #ifdef __cplusplus +// +// #endif +// +// #endif diff --git a/bindings/sundials4py/sundials/sundials_stepper_usersupplied.hpp b/bindings/sundials4py/sundials/sundials_stepper_usersupplied.hpp new file mode 100644 index 0000000000..a93300eb7a --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_stepper_usersupplied.hpp @@ -0,0 +1,158 @@ +/* ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#ifndef _SUNDIALS4PY_STEPPER_USERSUPPLIED_HPP +#define _SUNDIALS4PY_STEPPER_USERSUPPLIED_HPP + +#include +#include +#include "sundials4py.hpp" + +#include + +// If helpers are available, include them +#include "sundials4py_helpers.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +struct SUNStepperFunctionTable +{ + nb::object evolve; + nb::object one_step; + nb::object full_rhs; + nb::object reinit; + nb::object reset; + nb::object reset_ckpt_idx; + nb::object set_stop_time; + nb::object set_step_direction; + nb::object set_forcing; + nb::object get_num_steps; +}; + +inline SUNStepperFunctionTable* SUNStepperFunctionTable_Alloc() +{ + auto fn_table = static_cast( + std::malloc(sizeof(SUNStepperFunctionTable))); + std::memset(fn_table, 0, sizeof(SUNStepperFunctionTable)); + return fn_table; +} + +template +inline SUNErrCode sunstepper_evolve_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, SUNStepperFunctionTable, + SUNStepper>(&SUNStepperFunctionTable::evolve, std::forward(args)...); +} + +template +inline SUNErrCode sunstepper_one_step_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, SUNStepperFunctionTable, + SUNStepper>(&SUNStepperFunctionTable::one_step, std::forward(args)...); +} + +template +inline SUNErrCode sunstepper_full_rhs_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, SUNStepperFunctionTable, + SUNStepper>(&SUNStepperFunctionTable::full_rhs, std::forward(args)...); +} + +template +inline SUNErrCode sunstepper_reinit_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, SUNStepperFunctionTable, + SUNStepper>(&SUNStepperFunctionTable::reinit, std::forward(args)...); +} + +template +inline SUNErrCode sunstepper_reset_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, SUNStepperFunctionTable, + SUNStepper>(&SUNStepperFunctionTable::reset, std::forward(args)...); +} + +template +inline SUNErrCode sunstepper_reset_ckpt_idx_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, + SUNStepperFunctionTable, SUNStepper>(&SUNStepperFunctionTable::reset_ckpt_idx, + std::forward(args)...); +} + +template +inline SUNErrCode sunstepper_set_stop_time_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, SUNStepperFunctionTable, + SUNStepper>(&SUNStepperFunctionTable::set_stop_time, + std::forward(args)...); +} + +template +inline SUNErrCode sunstepper_set_step_direction_wrapper(Args... args) +{ + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, SUNStepperFunctionTable, + SUNStepper>(&SUNStepperFunctionTable::set_step_direction, + std::forward(args)...); +} + +using SUNStepperSetForcingStdFn = SUNErrCode(SUNStepper stepper, + sunrealtype tshift, + sunrealtype tscale, + std::vector forcing, + int nforcing); + +inline SUNErrCode sunstepper_set_forcing_wrapper(SUNStepper stepper, + sunrealtype tshift, + sunrealtype tscale, + N_Vector* forcing_1d, + int nforcing) +{ + auto fn_table = static_cast(stepper->python); + auto fn = + nb::cast>(fn_table->set_forcing); + + std::vector forcing(forcing_1d, forcing_1d + nforcing); + + return fn(stepper, tshift, tscale, forcing, nforcing); +} + +using SUNStepperGetNumStepsStdFn = + std::tuple(SUNStepper); + +inline SUNErrCode sunstepper_get_num_steps_wrapper(SUNStepper stepper, + suncountertype* num_steps) +{ + auto fn_table = static_cast(stepper->python); + auto fn = + nb::cast>(fn_table->get_num_steps); + + auto result = fn(stepper); + + *num_steps = std::get<1>(result); + + return std::get<0>(result); +} + +#endif // _SUNDIALS4PY_STEPPER_USERSUPPLIED_HPP diff --git a/bindings/sundials4py/sundials/sundials_types_generated.hpp b/bindings/sundials4py/sundials/sundials_types_generated.hpp new file mode 100644 index 0000000000..46a5f6a0e9 --- /dev/null +++ b/bindings/sundials4py/sundials/sundials_types_generated.hpp @@ -0,0 +1,40 @@ +// #ifndef _SUNDIALS_TYPES_H +// +// #ifdef __cplusplus +// #endif +// +// #ifdef SWIG +// +// #else +// +// #endif +// + +auto pyEnumSUNOutputFormat = + nb::enum_(m, "SUNOutputFormat", nb::is_arithmetic(), "") + .value("SUN_OUTPUTFORMAT_TABLE", SUN_OUTPUTFORMAT_TABLE, "") + .value("SUN_OUTPUTFORMAT_CSV", SUN_OUTPUTFORMAT_CSV, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// +// #ifndef SWIG +// +m.attr("SUN_COMM_NULL") = 0; +// #endif +// +// #ifdef __cplusplus +// +// #endif +// + +auto pyEnumSUNDataIOMode = + nb::enum_(m, "SUNDataIOMode", nb::is_arithmetic(), "") + .value("SUNDATAIOMODE_INMEM", SUNDATAIOMODE_INMEM, "") + .export_values(); +// #ifndef SWIG +// +// #endif +// +// #endif diff --git a/bindings/sundials4py/sundials4py-generate b/bindings/sundials4py/sundials4py-generate new file mode 160000 index 0000000000..b685905c8d --- /dev/null +++ b/bindings/sundials4py/sundials4py-generate @@ -0,0 +1 @@ +Subproject commit b685905c8d808a8f0233f2ade0cff45062e1360b diff --git a/bindings/sundials4py/sundials4py.cpp b/bindings/sundials4py/sundials4py.cpp new file mode 100644 index 0000000000..5d912539f5 --- /dev/null +++ b/bindings/sundials4py/sundials4py.cpp @@ -0,0 +1,132 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------- + * This file defines the sundials4py Python module and includes all + * of the submodule pieces. + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +namespace nb = nanobind; + +namespace sundials4py { + +// +// Forward declarations of all of the binding functions +// + +void bind_core(nb::module_& m); + +void bind_arkode(nb::module_& m); +void bind_cvodes(nb::module_& m); +void bind_idas(nb::module_& m); +void bind_kinsol(nb::module_& m); + +void bind_nvector_serial(nb::module_& m); +void bind_nvector_manyvector(nb::module_& m); + +void bind_sumemoryhelper_sys(nb::module_& m); + +void bind_sunadaptcontroller_imexgus(nb::module_& m); +void bind_sunadaptcontroller_mrihtol(nb::module_& m); +void bind_sunadaptcontroller_soderlind(nb::module_& m); + +void bind_sunadjointcheckpointscheme_fixed(nb::module_& m); + +void bind_sundomeigest_power(nb::module_& m); + +void bind_sunlinsol_band(nb::module_& m); +void bind_sunlinsol_dense(nb::module_& m); +void bind_sunlinsol_pcg(nb::module_& m); +void bind_sunlinsol_spbcgs(nb::module_& m); +void bind_sunlinsol_spfgmr(nb::module_& m); +void bind_sunlinsol_spgmr(nb::module_& m); +void bind_sunlinsol_sptfqmr(nb::module_& m); + +void bind_sunmatrix_band(nb::module_& m); +void bind_sunmatrix_dense(nb::module_& m); +void bind_sunmatrix_sparse(nb::module_& m); + +void bind_sunnonlinsol_fixedpoint(nb::module_& m); +void bind_sunnonlinsol_newton(nb::module_& m); + +} // namespace sundials4py + +// +// Define main module, sundials4py, and all of its submodules +// + +NB_MODULE(sundials4py, m) +{ +#ifdef NDEBUG + // The nanobind leak warnings can be quite noisy due to leaks within Python itself, so we disable them for Release builds. + nb::set_leak_warnings(false); +#endif + + nb::module_ core_m = m.def_submodule("core", "A submodule of 'sundials4py'"); + sundials4py::bind_core(core_m); + + // + // Create submodules for each package + // + + nb::module_ arkode_m = m.def_submodule("arkode", + "A submodule of 'sundials4py'"); + sundials4py::bind_arkode(arkode_m); + + nb::module_ cvodes_m = m.def_submodule("cvodes", + "A submodule of 'sundials4py'"); + sundials4py::bind_cvodes(cvodes_m); + + nb::module_ idas_m = m.def_submodule("idas", "A submodule of 'sundials4py'"); + sundials4py::bind_idas(idas_m); + + nb::module_ kinsol_m = m.def_submodule("kinsol", + "A submodule of 'sundials4py'"); + sundials4py::bind_kinsol(kinsol_m); + + // + // Bind all implementation modules directly to core_m + // + + sundials4py::bind_nvector_serial(core_m); + sundials4py::bind_nvector_manyvector(core_m); + + sundials4py::bind_sunadaptcontroller_imexgus(core_m); + sundials4py::bind_sunadaptcontroller_mrihtol(core_m); + sundials4py::bind_sunadaptcontroller_soderlind(core_m); + + sundials4py::bind_sunadjointcheckpointscheme_fixed(core_m); + + sundials4py::bind_sundomeigest_power(core_m); + + sundials4py::bind_sunlinsol_band(core_m); + sundials4py::bind_sunlinsol_dense(core_m); + sundials4py::bind_sunlinsol_pcg(core_m); + sundials4py::bind_sunlinsol_spbcgs(core_m); + sundials4py::bind_sunlinsol_spfgmr(core_m); + sundials4py::bind_sunlinsol_spgmr(core_m); + sundials4py::bind_sunlinsol_sptfqmr(core_m); + + sundials4py::bind_sunmatrix_band(core_m); + sundials4py::bind_sunmatrix_dense(core_m); + sundials4py::bind_sunmatrix_sparse(core_m); + + sundials4py::bind_sumemoryhelper_sys(core_m); + + sundials4py::bind_sunnonlinsol_fixedpoint(core_m); + sundials4py::bind_sunnonlinsol_newton(core_m); +} \ No newline at end of file diff --git a/bindings/sundials4py/sundomeigest/generate.yaml b/bindings/sundials4py/sundomeigest/generate.yaml new file mode 100644 index 0000000000..c82a4ce8e1 --- /dev/null +++ b/bindings/sundials4py/sundomeigest/generate.yaml @@ -0,0 +1,41 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2002-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This YAML file is used to instruct the generate.py script on +# how to generate some of the nanobind code needed for the +# core module of sundials4py. +# ----------------------------------------------------------------- + +modules: + all: + macro_define_include_by_name__regex: + - "^SUN_" + fn_exclude_by_name__regex: + # Don't interface the implementation specific overrides of the generic routines + - "^SUNDomEigEstimator_SetATimes_.*" + - "^SUNDomEigEstimator_SetOptions_.*" + - "^SUNDomEigEstimator_SetMaxIters_.*" + - "^SUNDomEigEstimator_SetNumPreprocessIters_.*" + - "^SUNDomEigEstimator_SetRelTol_.*" + - "^SUNDomEigEstimator_SetInitialGuess_.*" + - "^SUNDomEigEstimator_Initialize_.*" + - "^SUNDomEigEstimator_Estimate_.*" + - "^SUNDomEigEstimator_GetRes_.*" + - "^SUNDomEigEstimator_GetNumIters_.*" + - "^SUNDomEigEstimator_GetNumATimesCalls_.*" + - "^SUNDomEigEstimator_Write_.*" + - "^SUNDomEigEstimator_Destroy_.*" + sundomeigest_power: + path: sundomeigest/sundomeigest_power_generated.hpp + headers: + - ../../include/sundomeigest/sundomeigest_power.h diff --git a/bindings/sundials4py/sundomeigest/sundomeigest_power.cpp b/bindings/sundials4py/sundomeigest/sundomeigest_power.cpp new file mode 100644 index 0000000000..266a8b0cec --- /dev/null +++ b/bindings/sundials4py/sundomeigest/sundomeigest_power.cpp @@ -0,0 +1,34 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include +#include +#include + +#include "sundials4py.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sundomeigest_power(nb::module_& m) +{ +#include "sundomeigest_power_generated.hpp" +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sundomeigest/sundomeigest_power_generated.hpp b/bindings/sundials4py/sundomeigest/sundomeigest_power_generated.hpp new file mode 100644 index 0000000000..996531f676 --- /dev/null +++ b/bindings/sundials4py/sundomeigest/sundomeigest_power_generated.hpp @@ -0,0 +1,40 @@ +// #ifndef _SUNDOMEIGEST_POWER_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClassSUNDomEigEstimatorContent_Power_ = + nb::class_(m, "SUNDomEigEstimatorContent_Power_", + "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNDomEigEstimator_Power", + [](N_Vector q, long max_iters, sunrealtype rel_tol, SUNContext sunctx) + -> std::shared_ptr> + { + auto SUNDomEigEstimator_Power_adapt_return_type_to_shared_ptr = + [](N_Vector q, long max_iters, sunrealtype rel_tol, SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNDomEigEstimator_Power(q, max_iters, rel_tol, + sunctx); + + return our_make_shared, + SUNDomEigEstimatorDeleter>(lambda_result); + }; + + return SUNDomEigEstimator_Power_adapt_return_type_to_shared_ptr(q, max_iters, + rel_tol, + sunctx); + }, + nb::arg("q"), nb::arg("max_iters"), nb::arg("rel_tol"), nb::arg("sunctx"), + "nb::keep_alive<0, 4>()", nb::keep_alive<0, 4>()); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sunlinsol/generate.yaml b/bindings/sundials4py/sunlinsol/generate.yaml new file mode 100644 index 0000000000..142cd7ee25 --- /dev/null +++ b/bindings/sundials4py/sunlinsol/generate.yaml @@ -0,0 +1,68 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2002-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This YAML file is used to instruct the generate.py script on +# how to generate some of the nanobind code needed for the +# core module of sundials4py. +# ----------------------------------------------------------------- + +modules: + all: + macro_define_include_by_name__regex: + - "^SUN_" + fn_exclude_by_name__regex: + # Don't interface the implementation specific overrides of the generic routines + - "^SUNLinSolGetType_.*" + - "^SUNLinSolGetID_.*" + - "^SUNLinSolSetATimes_.*" + - "^SUNLinSolSetPreconditioner_.*" + - "^SUNLinSolSetScalingVectors_.*" + - "^SUNLinSolSetOptions_.*" + - "^SUNLinSolSetZeroGuess_.*" + - "^SUNLinSolInitialize_.*" + - "^SUNLinSolSetup_.*" + - "^SUNLinSolSolve_.*" + - "^SUNLinSolNumIters_.*" + - "^SUNLinSolResNorm_.*" + - "^SUNLinSolResid_.*" + - "^SUNLinSolLastFlag_.*" + - "^SUNLinSolSpace_.*" + - "^SUNLinSolFree_.*" + sunlinsol_band: + path: sunlinsol/sunlinsol_band_generated.hpp + headers: + - ../../include/sunlinsol/sunlinsol_band.h + sunlinsol_dense: + path: sunlinsol/sunlinsol_dense_generated.hpp + headers: + - ../../include/sunlinsol/sunlinsol_dense.h + sunlinsol_spbcgs: + path: sunlinsol/sunlinsol_spbcgs_generated.hpp + headers: + - ../../include/sunlinsol/sunlinsol_spbcgs.h + sunlinsol_spfgmr: + path: sunlinsol/sunlinsol_spfgmr_generated.hpp + headers: + - ../../include/sunlinsol/sunlinsol_spfgmr.h + sunlinsol_spgmr: + path: sunlinsol/sunlinsol_spgmr_generated.hpp + headers: + - ../../include/sunlinsol/sunlinsol_spgmr.h + sunlinsol_sptfqmr: + path: sunlinsol/sunlinsol_sptfqmr_generated.hpp + headers: + - ../../include/sunlinsol/sunlinsol_sptfqmr.h + sunlinsol_pcg: + path: sunlinsol/sunlinsol_pcg_generated.hpp + headers: + - ../../include/sunlinsol/sunlinsol_pcg.h diff --git a/bindings/sundials4py/sunlinsol/sunlinsol_band.cpp b/bindings/sundials4py/sunlinsol/sunlinsol_band.cpp new file mode 100644 index 0000000000..da34c33afc --- /dev/null +++ b/bindings/sundials4py/sunlinsol/sunlinsol_band.cpp @@ -0,0 +1,33 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunlinsol_band(nb::module_& m) +{ +#include "sunlinsol_band_generated.hpp" +} + +} // namespace sundials4py \ No newline at end of file diff --git a/bindings/sundials4py/sunlinsol/sunlinsol_band_generated.hpp b/bindings/sundials4py/sunlinsol/sunlinsol_band_generated.hpp new file mode 100644 index 0000000000..baf315d191 --- /dev/null +++ b/bindings/sundials4py/sunlinsol/sunlinsol_band_generated.hpp @@ -0,0 +1,36 @@ +// #ifndef _SUNLINSOL_BAND_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClass_SUNLinearSolverContent_Band = + nb::class_<_SUNLinearSolverContent_Band>(m, "_SUNLinearSolverContent_Band", "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNLinSol_Band", + [](N_Vector y, SUNMatrix A, + SUNContext sunctx) -> std::shared_ptr> + { + auto SUNLinSol_Band_adapt_return_type_to_shared_ptr = + [](N_Vector y, SUNMatrix A, SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNLinSol_Band(y, A, sunctx); + + return our_make_shared, + SUNLinearSolverDeleter>(lambda_result); + }; + + return SUNLinSol_Band_adapt_return_type_to_shared_ptr(y, A, sunctx); + }, + nb::arg("y"), nb::arg("A"), nb::arg("sunctx"), "nb::keep_alive<0, 3>()", + nb::keep_alive<0, 3>()); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sunlinsol/sunlinsol_dense.cpp b/bindings/sundials4py/sunlinsol/sunlinsol_dense.cpp new file mode 100644 index 0000000000..04c5840866 --- /dev/null +++ b/bindings/sundials4py/sunlinsol/sunlinsol_dense.cpp @@ -0,0 +1,33 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunlinsol_dense(nb::module_& m) +{ +#include "sunlinsol_dense_generated.hpp" +} + +} // namespace sundials4py \ No newline at end of file diff --git a/bindings/sundials4py/sunlinsol/sunlinsol_dense_generated.hpp b/bindings/sundials4py/sunlinsol/sunlinsol_dense_generated.hpp new file mode 100644 index 0000000000..4b7767eda4 --- /dev/null +++ b/bindings/sundials4py/sunlinsol/sunlinsol_dense_generated.hpp @@ -0,0 +1,36 @@ +// #ifndef _SUNLINSOL_DENSE_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClass_SUNLinearSolverContent_Dense = + nb::class_<_SUNLinearSolverContent_Dense>(m, "_SUNLinearSolverContent_Dense", "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNLinSol_Dense", + [](N_Vector y, SUNMatrix A, + SUNContext sunctx) -> std::shared_ptr> + { + auto SUNLinSol_Dense_adapt_return_type_to_shared_ptr = + [](N_Vector y, SUNMatrix A, SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNLinSol_Dense(y, A, sunctx); + + return our_make_shared, + SUNLinearSolverDeleter>(lambda_result); + }; + + return SUNLinSol_Dense_adapt_return_type_to_shared_ptr(y, A, sunctx); + }, + nb::arg("y"), nb::arg("A"), nb::arg("sunctx"), "nb::keep_alive<0, 3>()", + nb::keep_alive<0, 3>()); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sunlinsol/sunlinsol_pcg.cpp b/bindings/sundials4py/sunlinsol/sunlinsol_pcg.cpp new file mode 100644 index 0000000000..09b692a5a4 --- /dev/null +++ b/bindings/sundials4py/sunlinsol/sunlinsol_pcg.cpp @@ -0,0 +1,33 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunlinsol_pcg(nb::module_& m) +{ +#include "sunlinsol_pcg_generated.hpp" +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sunlinsol/sunlinsol_pcg_generated.hpp b/bindings/sundials4py/sunlinsol/sunlinsol_pcg_generated.hpp new file mode 100644 index 0000000000..24c0e89af6 --- /dev/null +++ b/bindings/sundials4py/sunlinsol/sunlinsol_pcg_generated.hpp @@ -0,0 +1,43 @@ +// #ifndef _SUNLINSOL_PCG_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClass_SUNLinearSolverContent_PCG = + nb::class_<_SUNLinearSolverContent_PCG>(m, "_SUNLinearSolverContent_PCG", "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNLinSol_PCG", + [](N_Vector y, int pretype, int maxl, + SUNContext sunctx) -> std::shared_ptr> + { + auto SUNLinSol_PCG_adapt_return_type_to_shared_ptr = + [](N_Vector y, int pretype, int maxl, SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNLinSol_PCG(y, pretype, maxl, sunctx); + + return our_make_shared, + SUNLinearSolverDeleter>(lambda_result); + }; + + return SUNLinSol_PCG_adapt_return_type_to_shared_ptr(y, pretype, maxl, + sunctx); + }, + nb::arg("y"), nb::arg("pretype"), nb::arg("maxl"), nb::arg("sunctx"), + "nb::keep_alive<0, 4>()", nb::keep_alive<0, 4>()); + +m.def("SUNLinSol_PCGSetPrecType", SUNLinSol_PCGSetPrecType, nb::arg("S"), + nb::arg("pretype")); + +m.def("SUNLinSol_PCGSetMaxl", SUNLinSol_PCGSetMaxl, nb::arg("S"), + nb::arg("maxl")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sunlinsol/sunlinsol_spbcgs.cpp b/bindings/sundials4py/sunlinsol/sunlinsol_spbcgs.cpp new file mode 100644 index 0000000000..b24861d1ba --- /dev/null +++ b/bindings/sundials4py/sunlinsol/sunlinsol_spbcgs.cpp @@ -0,0 +1,33 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunlinsol_spbcgs(nb::module_& m) +{ +#include "sunlinsol_spbcgs_generated.hpp" +} + +} // namespace sundials4py \ No newline at end of file diff --git a/bindings/sundials4py/sunlinsol/sunlinsol_spbcgs_generated.hpp b/bindings/sundials4py/sunlinsol/sunlinsol_spbcgs_generated.hpp new file mode 100644 index 0000000000..fe7052bbb0 --- /dev/null +++ b/bindings/sundials4py/sunlinsol/sunlinsol_spbcgs_generated.hpp @@ -0,0 +1,44 @@ +// #ifndef _SUNLINSOL_SPBCGS_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClass_SUNLinearSolverContent_SPBCGS = + nb::class_<_SUNLinearSolverContent_SPBCGS>(m, "_SUNLinearSolverContent_SPBCGS", + "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNLinSol_SPBCGS", + [](N_Vector y, int pretype, int maxl, + SUNContext sunctx) -> std::shared_ptr> + { + auto SUNLinSol_SPBCGS_adapt_return_type_to_shared_ptr = + [](N_Vector y, int pretype, int maxl, SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNLinSol_SPBCGS(y, pretype, maxl, sunctx); + + return our_make_shared, + SUNLinearSolverDeleter>(lambda_result); + }; + + return SUNLinSol_SPBCGS_adapt_return_type_to_shared_ptr(y, pretype, maxl, + sunctx); + }, + nb::arg("y"), nb::arg("pretype"), nb::arg("maxl"), nb::arg("sunctx"), + "nb::keep_alive<0, 4>()", nb::keep_alive<0, 4>()); + +m.def("SUNLinSol_SPBCGSSetPrecType", SUNLinSol_SPBCGSSetPrecType, nb::arg("S"), + nb::arg("pretype")); + +m.def("SUNLinSol_SPBCGSSetMaxl", SUNLinSol_SPBCGSSetMaxl, nb::arg("S"), + nb::arg("maxl")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sunlinsol/sunlinsol_spfgmr.cpp b/bindings/sundials4py/sunlinsol/sunlinsol_spfgmr.cpp new file mode 100644 index 0000000000..3de2f2d06d --- /dev/null +++ b/bindings/sundials4py/sunlinsol/sunlinsol_spfgmr.cpp @@ -0,0 +1,33 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunlinsol_spfgmr(nb::module_& m) +{ +#include "sunlinsol_spfgmr_generated.hpp" +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sunlinsol/sunlinsol_spfgmr_generated.hpp b/bindings/sundials4py/sunlinsol/sunlinsol_spfgmr_generated.hpp new file mode 100644 index 0000000000..1607f5395e --- /dev/null +++ b/bindings/sundials4py/sunlinsol/sunlinsol_spfgmr_generated.hpp @@ -0,0 +1,47 @@ +// #ifndef _SUNLINSOL_SPFGMR_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClass_SUNLinearSolverContent_SPFGMR = + nb::class_<_SUNLinearSolverContent_SPFGMR>(m, "_SUNLinearSolverContent_SPFGMR", + "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNLinSol_SPFGMR", + [](N_Vector y, int pretype, int maxl, + SUNContext sunctx) -> std::shared_ptr> + { + auto SUNLinSol_SPFGMR_adapt_return_type_to_shared_ptr = + [](N_Vector y, int pretype, int maxl, SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNLinSol_SPFGMR(y, pretype, maxl, sunctx); + + return our_make_shared, + SUNLinearSolverDeleter>(lambda_result); + }; + + return SUNLinSol_SPFGMR_adapt_return_type_to_shared_ptr(y, pretype, maxl, + sunctx); + }, + nb::arg("y"), nb::arg("pretype"), nb::arg("maxl"), nb::arg("sunctx"), + "nb::keep_alive<0, 4>()", nb::keep_alive<0, 4>()); + +m.def("SUNLinSol_SPFGMRSetPrecType", SUNLinSol_SPFGMRSetPrecType, nb::arg("S"), + nb::arg("pretype")); + +m.def("SUNLinSol_SPFGMRSetGSType", SUNLinSol_SPFGMRSetGSType, nb::arg("S"), + nb::arg("gstype")); + +m.def("SUNLinSol_SPFGMRSetMaxRestarts", SUNLinSol_SPFGMRSetMaxRestarts, + nb::arg("S"), nb::arg("maxrs")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sunlinsol/sunlinsol_spgmr.cpp b/bindings/sundials4py/sunlinsol/sunlinsol_spgmr.cpp new file mode 100644 index 0000000000..f83e05d931 --- /dev/null +++ b/bindings/sundials4py/sunlinsol/sunlinsol_spgmr.cpp @@ -0,0 +1,33 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunlinsol_spgmr(nb::module_& m) +{ +#include "sunlinsol_spgmr_generated.hpp" +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sunlinsol/sunlinsol_spgmr_generated.hpp b/bindings/sundials4py/sunlinsol/sunlinsol_spgmr_generated.hpp new file mode 100644 index 0000000000..2ffd3a7725 --- /dev/null +++ b/bindings/sundials4py/sunlinsol/sunlinsol_spgmr_generated.hpp @@ -0,0 +1,46 @@ +// #ifndef _SUNLINSOL_SPGMR_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClass_SUNLinearSolverContent_SPGMR = + nb::class_<_SUNLinearSolverContent_SPGMR>(m, "_SUNLinearSolverContent_SPGMR", "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNLinSol_SPGMR", + [](N_Vector y, int pretype, int maxl, + SUNContext sunctx) -> std::shared_ptr> + { + auto SUNLinSol_SPGMR_adapt_return_type_to_shared_ptr = + [](N_Vector y, int pretype, int maxl, SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNLinSol_SPGMR(y, pretype, maxl, sunctx); + + return our_make_shared, + SUNLinearSolverDeleter>(lambda_result); + }; + + return SUNLinSol_SPGMR_adapt_return_type_to_shared_ptr(y, pretype, maxl, + sunctx); + }, + nb::arg("y"), nb::arg("pretype"), nb::arg("maxl"), nb::arg("sunctx"), + "nb::keep_alive<0, 4>()", nb::keep_alive<0, 4>()); + +m.def("SUNLinSol_SPGMRSetPrecType", SUNLinSol_SPGMRSetPrecType, nb::arg("S"), + nb::arg("pretype")); + +m.def("SUNLinSol_SPGMRSetGSType", SUNLinSol_SPGMRSetGSType, nb::arg("S"), + nb::arg("gstype")); + +m.def("SUNLinSol_SPGMRSetMaxRestarts", SUNLinSol_SPGMRSetMaxRestarts, + nb::arg("S"), nb::arg("maxrs")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sunlinsol/sunlinsol_sptfqmr.cpp b/bindings/sundials4py/sunlinsol/sunlinsol_sptfqmr.cpp new file mode 100644 index 0000000000..c4147dfd9c --- /dev/null +++ b/bindings/sundials4py/sunlinsol/sunlinsol_sptfqmr.cpp @@ -0,0 +1,33 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunlinsol_sptfqmr(nb::module_& m) +{ +#include "sunlinsol_sptfqmr_generated.hpp" +} + +} // namespace sundials4py \ No newline at end of file diff --git a/bindings/sundials4py/sunlinsol/sunlinsol_sptfqmr_generated.hpp b/bindings/sundials4py/sunlinsol/sunlinsol_sptfqmr_generated.hpp new file mode 100644 index 0000000000..a867cd03c2 --- /dev/null +++ b/bindings/sundials4py/sunlinsol/sunlinsol_sptfqmr_generated.hpp @@ -0,0 +1,45 @@ +// #ifndef _SUNLINSOL_SPTFQMR_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClass_SUNLinearSolverContent_SPTFQMR = + nb::class_<_SUNLinearSolverContent_SPTFQMR>(m, + "_SUNLinearSolverContent_SPTFQMR", + "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNLinSol_SPTFQMR", + [](N_Vector y, int pretype, int maxl, + SUNContext sunctx) -> std::shared_ptr> + { + auto SUNLinSol_SPTFQMR_adapt_return_type_to_shared_ptr = + [](N_Vector y, int pretype, int maxl, SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNLinSol_SPTFQMR(y, pretype, maxl, sunctx); + + return our_make_shared, + SUNLinearSolverDeleter>(lambda_result); + }; + + return SUNLinSol_SPTFQMR_adapt_return_type_to_shared_ptr(y, pretype, maxl, + sunctx); + }, + nb::arg("y"), nb::arg("pretype"), nb::arg("maxl"), nb::arg("sunctx"), + "nb::keep_alive<0, 4>()", nb::keep_alive<0, 4>()); + +m.def("SUNLinSol_SPTFQMRSetPrecType", SUNLinSol_SPTFQMRSetPrecType, + nb::arg("S"), nb::arg("pretype")); + +m.def("SUNLinSol_SPTFQMRSetMaxl", SUNLinSol_SPTFQMRSetMaxl, nb::arg("S"), + nb::arg("maxl")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sunmatrix/generate.yaml b/bindings/sundials4py/sunmatrix/generate.yaml new file mode 100644 index 0000000000..b24dbe2bf2 --- /dev/null +++ b/bindings/sundials4py/sunmatrix/generate.yaml @@ -0,0 +1,54 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2002-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This YAML file is used to instruct the generate.py script on +# how to generate some of the nanobind code needed for the +# core module of sundials4py. +# ----------------------------------------------------------------- + +modules: + all: + macro_define_include_by_name__regex: + - "^SUN_" + fn_exclude_by_name__regex: + # Don't interface the implementation specific overrides of the generic routines + - "^SUNMat.*" + sunmatrix_band: + path: sunmatrix/sunmatrix_band_generated.hpp + headers: + - ../../include/sunmatrix/sunmatrix_band.h + fn_exclude_by_name__regex: + # We need to wrap the Data function ourselves to ensure translation to a numpy array + - "^SUNBandMatrix_Data$" + # We don't interface these function in Python. Instead users can index the numpy array returned by Data. + - "^SUNBandMatrix_Cols$" + - "^SUNBandMatrix_Column$" + sunmatrix_dense: + path: sunmatrix/sunmatrix_dense_generated.hpp + headers: + - ../../include/sunmatrix/sunmatrix_dense.h + fn_exclude_by_name__regex: + # We need to wrap the Data function ourselves to ensure translation to a numpy array + - "^SUNDenseMatrix_Data$" + # We don't interface these functions to Python. Instead users can index the numpy array returned by Data. + - "^SUNDenseMatrix_Cols$" + - "^SUNDenseMatrix_Column$" + sunmatrix_sparse: + path: sunmatrix/sunmatrix_sparse_generated.hpp + headers: + - ../../include/sunmatrix/sunmatrix_sparse.h + fn_exclude_by_name__regex: + # We need to wrap the Data function ourselves to ensure translation to a numpy array + - "^SUNSparseMatrix_Data$" + - "^SUNSparseMatrix_IndexValues$" + - "^SUNSparseMatrix_IndexPointers$" diff --git a/bindings/sundials4py/sunmatrix/sunmatrix_band.cpp b/bindings/sundials4py/sunmatrix/sunmatrix_band.cpp new file mode 100644 index 0000000000..df107c1cd5 --- /dev/null +++ b/bindings/sundials4py/sunmatrix/sunmatrix_band.cpp @@ -0,0 +1,47 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunmatrix_band(nb::module_& m) +{ +#include "sunmatrix_band_generated.hpp" + + m.def( + "SUNBandMatrix_Data", + [](SUNMatrix A) + { + auto ldata = static_cast(SUNBandMatrix_LData(A)); + auto owner = nb::find(A); + auto ptr = SUNBandMatrix_Data(A); + // SUNBandMatrix_Data returns data that cannot be directly indexed as a 2-dimensional numpy array + return nb::ndarray, nb::c_contig>(ptr, + {ldata}, + owner); + }, + nb::arg("A"), nb::rv_policy::reference); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sunmatrix/sunmatrix_band_generated.hpp b/bindings/sundials4py/sunmatrix/sunmatrix_band_generated.hpp new file mode 100644 index 0000000000..e00a2d1771 --- /dev/null +++ b/bindings/sundials4py/sunmatrix/sunmatrix_band_generated.hpp @@ -0,0 +1,75 @@ +// #ifndef _SUNMATRIX_BAND_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClass_SUNMatrixContent_Band = + nb::class_<_SUNMatrixContent_Band>(m, "_SUNMatrixContent_Band", "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNBandMatrix", + [](sunindextype N, sunindextype mu, sunindextype ml, + SUNContext sunctx) -> std::shared_ptr> + { + auto SUNBandMatrix_adapt_return_type_to_shared_ptr = + [](sunindextype N, sunindextype mu, sunindextype ml, + SUNContext sunctx) -> std::shared_ptr> + { + auto lambda_result = SUNBandMatrix(N, mu, ml, sunctx); + + return our_make_shared, SUNMatrixDeleter>( + lambda_result); + }; + + return SUNBandMatrix_adapt_return_type_to_shared_ptr(N, mu, ml, sunctx); + }, + nb::arg("N"), nb::arg("mu"), nb::arg("ml"), nb::arg("sunctx"), + "nb::keep_alive<0, 4>()", nb::keep_alive<0, 4>()); + +m.def( + "SUNBandMatrixStorage", + [](sunindextype N, sunindextype mu, sunindextype ml, sunindextype smu, + SUNContext sunctx) -> std::shared_ptr> + { + auto SUNBandMatrixStorage_adapt_return_type_to_shared_ptr = + [](sunindextype N, sunindextype mu, sunindextype ml, sunindextype smu, + SUNContext sunctx) -> std::shared_ptr> + { + auto lambda_result = SUNBandMatrixStorage(N, mu, ml, smu, sunctx); + + return our_make_shared, SUNMatrixDeleter>( + lambda_result); + }; + + return SUNBandMatrixStorage_adapt_return_type_to_shared_ptr(N, mu, ml, smu, + sunctx); + }, + nb::arg("N"), nb::arg("mu"), nb::arg("ml"), nb::arg("smu"), nb::arg("sunctx"), + "nb::keep_alive<0, 5>()", nb::keep_alive<0, 5>()); + +m.def("SUNBandMatrix_Print", SUNBandMatrix_Print, nb::arg("A"), + nb::arg("outfile")); + +m.def("SUNBandMatrix_Rows", SUNBandMatrix_Rows, nb::arg("A")); + +m.def("SUNBandMatrix_Columns", SUNBandMatrix_Columns, nb::arg("A")); + +m.def("SUNBandMatrix_LowerBandwidth", SUNBandMatrix_LowerBandwidth, nb::arg("A")); + +m.def("SUNBandMatrix_UpperBandwidth", SUNBandMatrix_UpperBandwidth, nb::arg("A")); + +m.def("SUNBandMatrix_StoredUpperBandwidth", SUNBandMatrix_StoredUpperBandwidth, + nb::arg("A")); + +m.def("SUNBandMatrix_LDim", SUNBandMatrix_LDim, nb::arg("A")); + +m.def("SUNBandMatrix_LData", SUNBandMatrix_LData, nb::arg("A")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sunmatrix/sunmatrix_dense.cpp b/bindings/sundials4py/sunmatrix/sunmatrix_dense.cpp new file mode 100644 index 0000000000..0f31ec286d --- /dev/null +++ b/bindings/sundials4py/sunmatrix/sunmatrix_dense.cpp @@ -0,0 +1,47 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunmatrix_dense(nb::module_& m) +{ +#include "sunmatrix_dense_generated.hpp" + + m.def( + "SUNDenseMatrix_Data", + [](SUNMatrix A) + { + auto rows = static_cast(SUNDenseMatrix_Rows(A)); + auto cols = static_cast(SUNDenseMatrix_Columns(A)); + auto owner = nb::find(A); + auto ptr = SUNDenseMatrix_Data(A); + // SUNDenseMatrix_Data returns a column-major ordered array (i.e., Fortran style) + return nb::ndarray, + nb::f_contig>(ptr, {rows, cols}, owner); + }, + nb::arg("A"), nb::rv_policy::reference); +} + +} // namespace sundials4py \ No newline at end of file diff --git a/bindings/sundials4py/sunmatrix/sunmatrix_dense_generated.hpp b/bindings/sundials4py/sunmatrix/sunmatrix_dense_generated.hpp new file mode 100644 index 0000000000..9b5b6c9eb5 --- /dev/null +++ b/bindings/sundials4py/sunmatrix/sunmatrix_dense_generated.hpp @@ -0,0 +1,45 @@ +// #ifndef _SUNMATRIX_DENSE_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClass_SUNMatrixContent_Dense = + nb::class_<_SUNMatrixContent_Dense>(m, "_SUNMatrixContent_Dense", "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNDenseMatrix", + [](sunindextype M, sunindextype N, + SUNContext sunctx) -> std::shared_ptr> + { + auto SUNDenseMatrix_adapt_return_type_to_shared_ptr = + [](sunindextype M, sunindextype N, + SUNContext sunctx) -> std::shared_ptr> + { + auto lambda_result = SUNDenseMatrix(M, N, sunctx); + + return our_make_shared, SUNMatrixDeleter>( + lambda_result); + }; + + return SUNDenseMatrix_adapt_return_type_to_shared_ptr(M, N, sunctx); + }, + nb::arg("M"), nb::arg("N"), nb::arg("sunctx"), "nb::keep_alive<0, 3>()", + nb::keep_alive<0, 3>()); + +m.def("SUNDenseMatrix_Print", SUNDenseMatrix_Print, nb::arg("A"), + nb::arg("outfile")); + +m.def("SUNDenseMatrix_Rows", SUNDenseMatrix_Rows, nb::arg("A")); + +m.def("SUNDenseMatrix_Columns", SUNDenseMatrix_Columns, nb::arg("A")); + +m.def("SUNDenseMatrix_LData", SUNDenseMatrix_LData, nb::arg("A")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sunmatrix/sunmatrix_sparse.cpp b/bindings/sundials4py/sunmatrix/sunmatrix_sparse.cpp new file mode 100644 index 0000000000..1749d7bdf3 --- /dev/null +++ b/bindings/sundials4py/sunmatrix/sunmatrix_sparse.cpp @@ -0,0 +1,73 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunmatrix_sparse(nb::module_& m) +{ +#include "sunmatrix_sparse_generated.hpp" + + m.def( + "SUNSparseMatrix_Data", + [](SUNMatrix A) + { + auto nnz = static_cast(SUNSparseMatrix_NNZ(A)); + auto owner = nb::find(A); + auto ptr = SUNSparseMatrix_Data(A); + // SUNSparseMatrix_Data returns data that cannot be directly indexed as a 2-dimensional numpy array + return nb::ndarray, nb::c_contig>(ptr, + {nnz}, + owner); + }, + nb::arg("A"), nb::rv_policy::reference); + + m.def( + "SUNSparseMatrix_IndexValues", + [](SUNMatrix A) + { + auto nnz = static_cast(SUNSparseMatrix_NNZ(A)); + auto owner = nb::find(A); + auto ptr = SUNSparseMatrix_IndexValues(A); + return nb::ndarray, nb::c_contig>(ptr, + {nnz}, + owner); + }, + nb::arg("A"), nb::rv_policy::reference); + + m.def( + "SUNSparseMatrix_IndexPointers", + [](SUNMatrix A) + { + auto nnz = static_cast(SUNSparseMatrix_NP(A) + 1); + auto owner = nb::find(A); + auto ptr = SUNSparseMatrix_IndexPointers(A); + return nb::ndarray, nb::c_contig>(ptr, + {nnz}, + owner); + }, + nb::arg("A"), nb::rv_policy::reference); +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sunmatrix/sunmatrix_sparse_generated.hpp b/bindings/sundials4py/sunmatrix/sunmatrix_sparse_generated.hpp new file mode 100644 index 0000000000..87973e4021 --- /dev/null +++ b/bindings/sundials4py/sunmatrix/sunmatrix_sparse_generated.hpp @@ -0,0 +1,159 @@ +// #ifndef _SUNMATRIX_SPARSE_H +// +// #ifdef __cplusplus +// #endif +// +m.attr("SUN_CSC_MAT") = 0; +m.attr("SUN_CSR_MAT") = 1; + +auto pyClass_SUNMatrixContent_Sparse = + nb::class_<_SUNMatrixContent_Sparse>(m, "_SUNMatrixContent_Sparse", "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNSparseMatrix", + [](sunindextype M, sunindextype N, sunindextype NNZ, int sparsetype, + SUNContext sunctx) -> std::shared_ptr> + { + auto SUNSparseMatrix_adapt_return_type_to_shared_ptr = + [](sunindextype M, sunindextype N, sunindextype NNZ, int sparsetype, + SUNContext sunctx) -> std::shared_ptr> + { + auto lambda_result = SUNSparseMatrix(M, N, NNZ, sparsetype, sunctx); + + return our_make_shared, SUNMatrixDeleter>( + lambda_result); + }; + + return SUNSparseMatrix_adapt_return_type_to_shared_ptr(M, N, NNZ, + sparsetype, sunctx); + }, + nb::arg("M"), nb::arg("N"), nb::arg("NNZ"), nb::arg("sparsetype"), + nb::arg("sunctx"), "nb::keep_alive<0, 5>()", nb::keep_alive<0, 5>()); + +m.def( + "SUNSparseFromDenseMatrix", + [](SUNMatrix A, sunrealtype droptol, + int sparsetype) -> std::shared_ptr> + { + auto SUNSparseFromDenseMatrix_adapt_return_type_to_shared_ptr = + [](SUNMatrix A, sunrealtype droptol, + int sparsetype) -> std::shared_ptr> + { + auto lambda_result = SUNSparseFromDenseMatrix(A, droptol, sparsetype); + + return our_make_shared, SUNMatrixDeleter>( + lambda_result); + }; + + return SUNSparseFromDenseMatrix_adapt_return_type_to_shared_ptr(A, droptol, + sparsetype); + }, + nb::arg("A"), nb::arg("droptol"), nb::arg("sparsetype")); + +m.def( + "SUNSparseFromBandMatrix", + [](SUNMatrix A, sunrealtype droptol, + int sparsetype) -> std::shared_ptr> + { + auto SUNSparseFromBandMatrix_adapt_return_type_to_shared_ptr = + [](SUNMatrix A, sunrealtype droptol, + int sparsetype) -> std::shared_ptr> + { + auto lambda_result = SUNSparseFromBandMatrix(A, droptol, sparsetype); + + return our_make_shared, SUNMatrixDeleter>( + lambda_result); + }; + + return SUNSparseFromBandMatrix_adapt_return_type_to_shared_ptr(A, droptol, + sparsetype); + }, + nb::arg("A"), nb::arg("droptol"), nb::arg("sparsetype")); + +m.def( + "SUNSparseMatrix_ToCSR", + [](const SUNMatrix A) + -> std::tuple>> + { + auto SUNSparseMatrix_ToCSR_adapt_modifiable_immutable_to_return = + [](const SUNMatrix A) -> std::tuple + { + SUNMatrix Bout_adapt_modifiable; + + SUNErrCode r = SUNSparseMatrix_ToCSR(A, &Bout_adapt_modifiable); + return std::make_tuple(r, Bout_adapt_modifiable); + }; + auto SUNSparseMatrix_ToCSR_adapt_return_type_to_shared_ptr = + [&SUNSparseMatrix_ToCSR_adapt_modifiable_immutable_to_return]( + const SUNMatrix A) + -> std::tuple>> + { + auto lambda_result = + SUNSparseMatrix_ToCSR_adapt_modifiable_immutable_to_return(A); + + return std::make_tuple(std::get<0>(lambda_result), + our_make_shared, + SUNMatrixDeleter>( + std::get<1>(lambda_result))); + }; + + return SUNSparseMatrix_ToCSR_adapt_return_type_to_shared_ptr(A); + }, + nb::arg("A"), nb::rv_policy::reference); + +m.def( + "SUNSparseMatrix_ToCSC", + [](const SUNMatrix A) + -> std::tuple>> + { + auto SUNSparseMatrix_ToCSC_adapt_modifiable_immutable_to_return = + [](const SUNMatrix A) -> std::tuple + { + SUNMatrix Bout_adapt_modifiable; + + SUNErrCode r = SUNSparseMatrix_ToCSC(A, &Bout_adapt_modifiable); + return std::make_tuple(r, Bout_adapt_modifiable); + }; + auto SUNSparseMatrix_ToCSC_adapt_return_type_to_shared_ptr = + [&SUNSparseMatrix_ToCSC_adapt_modifiable_immutable_to_return]( + const SUNMatrix A) + -> std::tuple>> + { + auto lambda_result = + SUNSparseMatrix_ToCSC_adapt_modifiable_immutable_to_return(A); + + return std::make_tuple(std::get<0>(lambda_result), + our_make_shared, + SUNMatrixDeleter>( + std::get<1>(lambda_result))); + }; + + return SUNSparseMatrix_ToCSC_adapt_return_type_to_shared_ptr(A); + }, + nb::arg("A"), nb::rv_policy::reference); + +m.def("SUNSparseMatrix_Realloc", SUNSparseMatrix_Realloc, nb::arg("A")); + +m.def("SUNSparseMatrix_Reallocate", SUNSparseMatrix_Reallocate, nb::arg("A"), + nb::arg("NNZ")); + +m.def("SUNSparseMatrix_Print", SUNSparseMatrix_Print, nb::arg("A"), + nb::arg("outfile")); + +m.def("SUNSparseMatrix_Rows", SUNSparseMatrix_Rows, nb::arg("A")); + +m.def("SUNSparseMatrix_Columns", SUNSparseMatrix_Columns, nb::arg("A")); + +m.def("SUNSparseMatrix_NNZ", SUNSparseMatrix_NNZ, nb::arg("A")); + +m.def("SUNSparseMatrix_NP", SUNSparseMatrix_NP, nb::arg("A")); + +m.def("SUNSparseMatrix_SparseType", SUNSparseMatrix_SparseType, nb::arg("A")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sunmemory/generate.yaml b/bindings/sundials4py/sunmemory/generate.yaml new file mode 100644 index 0000000000..3642870461 --- /dev/null +++ b/bindings/sundials4py/sunmemory/generate.yaml @@ -0,0 +1,37 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2002-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This YAML file is used to instruct the generate.py script on +# how to generate some of the nanobind code needed for the +# core module of sundials4py. +# ----------------------------------------------------------------- + +modules: + all: + macro_define_include_by_name__regex: + - "^SUN_" + fn_exclude_by_name__regex: + # Don't interface the implementation specific overrides of the generic routines + - "^SUNMemoryHelper_Alloc_.*" + - "^SUNMemoryHelper_AllocStrided_.*" + - "^SUNMemoryHelper_Dealloc_.*" + - "^SUNMemoryHelper_Copy_.*" + - "^SUNMemoryHelper_CopyAsync_.*" + - "^SUNMemoryHelper_GetAllocStats_.*" + - "^SUNMemoryHelper_Clone_.*" + - "^SUNMemoryHelper_Destroy_.*" + - "^SUNMemoryHelper_SetDefaultQueue_.*" + sunmemory_system: + path: sunmemory/sunmemory_system_generated.hpp + headers: + - ../../include/sunmemory/sunmemory_system.h diff --git a/bindings/sundials4py/sunmemory/sunmemory_system.cpp b/bindings/sundials4py/sunmemory/sunmemory_system.cpp new file mode 100644 index 0000000000..f8cf4a2908 --- /dev/null +++ b/bindings/sundials4py/sunmemory/sunmemory_system.cpp @@ -0,0 +1,33 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include "sundials4py.hpp" + +#include +#include + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sumemoryhelper_sys(nb::module_& m) +{ +#include "sunmemory_system_generated.hpp" +} + +} // namespace sundials4py \ No newline at end of file diff --git a/bindings/sundials4py/sunmemory/sunmemory_system_generated.hpp b/bindings/sundials4py/sunmemory/sunmemory_system_generated.hpp new file mode 100644 index 0000000000..479a74698e --- /dev/null +++ b/bindings/sundials4py/sunmemory/sunmemory_system_generated.hpp @@ -0,0 +1,28 @@ +// #ifndef _SUNDIALS_SYSMEMORY_H +// +// #ifdef __cplusplus +// #endif +// + +m.def( + "SUNMemoryHelper_Sys", + [](SUNContext sunctx) -> std::shared_ptr> + { + auto SUNMemoryHelper_Sys_adapt_return_type_to_shared_ptr = [](SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNMemoryHelper_Sys(sunctx); + + return our_make_shared, + SUNMemoryHelperDeleter>(lambda_result); + }; + + return SUNMemoryHelper_Sys_adapt_return_type_to_shared_ptr(sunctx); + }, + nb::arg("sunctx"), "nb::keep_alive<0, 1>()", nb::keep_alive<0, 1>()); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sunnonlinsol/generate.yaml b/bindings/sundials4py/sunnonlinsol/generate.yaml new file mode 100644 index 0000000000..080d466167 --- /dev/null +++ b/bindings/sundials4py/sunnonlinsol/generate.yaml @@ -0,0 +1,52 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2002-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# This YAML file is used to instruct the generate.py script on +# how to generate some of the nanobind code needed for the +# core module of sundials4py. +# ----------------------------------------------------------------- + +modules: + all: + macro_define_include_by_name__regex: + - "^SUN_" + fn_exclude_by_name__regex: + # Don't interface the implementation specific overrides of the generic routines + - "^SUNNonlinSolGetType_.*" + - "^SUNNonlinSolInitialize_.*" + - "^SUNNonlinSolSetup_.*" + - "^SUNNonlinSolSolve_.*" + - "^SUNNonlinSolFree_.*" + - "^SUNNonlinSolSetSysFn_.*" + - "^SUNNonlinSolSetLSetupFn_.*" + - "^SUNNonlinSolSetLSolveFn_.*" + - "^SUNNonlinSolSetConvTestFn_.*" + - "^SUNNonlinSolSetOptions_.*" + - "^SUNNonlinSolSetMaxIters_.*" + - "^SUNNonlinSolGetNumIters_.*" + - "^SUNNonlinSolGetCurIter_.*" + - "^SUNNonlinSolGetNumConvFails_.*" + sunnonlinsol_fixedpoint: + path: sunnonlinsol/sunnonlinsol_fixedpoint_generated.hpp + headers: + - ../../include/sunnonlinsol/sunnonlinsol_fixedpoint.h + fn_exclude_by_name__regex: + # We do not support getting a function pointer back to Python + - "^SUNNonlinSolGetSysFn_FixedPoint$" + sunnonlinsol_newton: + path: sunnonlinsol/sunnonlinsol_newton_generated.hpp + headers: + - ../../include/sunnonlinsol/sunnonlinsol_newton.h + fn_exclude_by_name__regex: + # We do not support getting a function pointer back to Python + - "^SUNNonlinSolGetSysFn_Newton$" diff --git a/bindings/sundials4py/sunnonlinsol/sunnonlinsol_fixedpoint.cpp b/bindings/sundials4py/sunnonlinsol/sunnonlinsol_fixedpoint.cpp new file mode 100644 index 0000000000..7207a25d86 --- /dev/null +++ b/bindings/sundials4py/sunnonlinsol/sunnonlinsol_fixedpoint.cpp @@ -0,0 +1,32 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include +#include +#include "sundials4py.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunnonlinsol_fixedpoint(nb::module_& m) +{ +#include "sunnonlinsol_fixedpoint_generated.hpp" +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sunnonlinsol/sunnonlinsol_fixedpoint_generated.hpp b/bindings/sundials4py/sunnonlinsol/sunnonlinsol_fixedpoint_generated.hpp new file mode 100644 index 0000000000..62e3ba0b2e --- /dev/null +++ b/bindings/sundials4py/sunnonlinsol/sunnonlinsol_fixedpoint_generated.hpp @@ -0,0 +1,61 @@ +// #ifndef _SUNNONLINSOL_FIXEDPOINT_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClass_SUNNonlinearSolverContent_FixedPoint = + nb::class_<_SUNNonlinearSolverContent_FixedPoint>(m, "_SUNNonlinearSolverContent_FixedPoint", + "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNNonlinSol_FixedPoint", + [](N_Vector y, int m, SUNContext sunctx) + -> std::shared_ptr> + { + auto SUNNonlinSol_FixedPoint_adapt_return_type_to_shared_ptr = + [](N_Vector y, int m, SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNNonlinSol_FixedPoint(y, m, sunctx); + + return our_make_shared, + SUNNonlinearSolverDeleter>(lambda_result); + }; + + return SUNNonlinSol_FixedPoint_adapt_return_type_to_shared_ptr(y, m, sunctx); + }, + nb::arg("y"), nb::arg("m"), nb::arg("sunctx"), "nb::keep_alive<0, 3>()", + nb::keep_alive<0, 3>()); + +m.def( + "SUNNonlinSol_FixedPointSens", + [](int count, N_Vector y, int m, SUNContext sunctx) + -> std::shared_ptr> + { + auto SUNNonlinSol_FixedPointSens_adapt_return_type_to_shared_ptr = + [](int count, N_Vector y, int m, SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNNonlinSol_FixedPointSens(count, y, m, sunctx); + + return our_make_shared, + SUNNonlinearSolverDeleter>(lambda_result); + }; + + return SUNNonlinSol_FixedPointSens_adapt_return_type_to_shared_ptr(count, y, + m, sunctx); + }, + nb::arg("count"), nb::arg("y"), nb::arg("m"), nb::arg("sunctx"), + "nb::keep_alive<0, 4>()", nb::keep_alive<0, 4>()); + +m.def("SUNNonlinSolSetDamping_FixedPoint", SUNNonlinSolSetDamping_FixedPoint, + nb::arg("NLS"), nb::arg("beta")); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/sunnonlinsol/sunnonlinsol_newton.cpp b/bindings/sundials4py/sunnonlinsol/sunnonlinsol_newton.cpp new file mode 100644 index 0000000000..4228aa0895 --- /dev/null +++ b/bindings/sundials4py/sunnonlinsol/sunnonlinsol_newton.cpp @@ -0,0 +1,33 @@ +/* ----------------------------------------------------------------- + * Programmer(s): Cody J. Balos @ LLNL + * ----------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2025-2025, Lawrence Livermore National Security, + * University of Maryland Baltimore County, and the SUNDIALS contributors. + * Copyright (c) 2013-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * Copyright (c) 2002-2013, Lawrence Livermore National Security. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * -----------------------------------------------------------------*/ + +#include +#include + +#include "sundials4py.hpp" + +namespace nb = nanobind; +using namespace sundials::experimental; + +namespace sundials4py { + +void bind_sunnonlinsol_newton(nb::module_& m) +{ +#include "sunnonlinsol_newton_generated.hpp" +} + +} // namespace sundials4py diff --git a/bindings/sundials4py/sunnonlinsol/sunnonlinsol_newton_generated.hpp b/bindings/sundials4py/sunnonlinsol/sunnonlinsol_newton_generated.hpp new file mode 100644 index 0000000000..8c968395b8 --- /dev/null +++ b/bindings/sundials4py/sunnonlinsol/sunnonlinsol_newton_generated.hpp @@ -0,0 +1,58 @@ +// #ifndef _SUNNONLINSOL_NEWTON_H +// +// #ifdef __cplusplus +// #endif +// + +auto pyClass_SUNNonlinearSolverContent_Newton = + nb::class_<_SUNNonlinearSolverContent_Newton>(m, "_SUNNonlinearSolverContent_Newton", + "") + .def(nb::init<>()) // implicit default constructor + ; + +m.def( + "SUNNonlinSol_Newton", + [](N_Vector y, SUNContext sunctx) + -> std::shared_ptr> + { + auto SUNNonlinSol_Newton_adapt_return_type_to_shared_ptr = + [](N_Vector y, SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNNonlinSol_Newton(y, sunctx); + + return our_make_shared, + SUNNonlinearSolverDeleter>(lambda_result); + }; + + return SUNNonlinSol_Newton_adapt_return_type_to_shared_ptr(y, sunctx); + }, + nb::arg("y"), nb::arg("sunctx"), "nb::keep_alive<0, 2>()", + nb::keep_alive<0, 2>()); + +m.def( + "SUNNonlinSol_NewtonSens", + [](int count, N_Vector y, SUNContext sunctx) + -> std::shared_ptr> + { + auto SUNNonlinSol_NewtonSens_adapt_return_type_to_shared_ptr = + [](int count, N_Vector y, SUNContext sunctx) + -> std::shared_ptr> + { + auto lambda_result = SUNNonlinSol_NewtonSens(count, y, sunctx); + + return our_make_shared, + SUNNonlinearSolverDeleter>(lambda_result); + }; + + return SUNNonlinSol_NewtonSens_adapt_return_type_to_shared_ptr(count, y, + sunctx); + }, + nb::arg("count"), nb::arg("y"), nb::arg("sunctx"), "nb::keep_alive<0, 3>()", + nb::keep_alive<0, 3>()); +// #ifdef __cplusplus +// +// #endif +// +// #endif +// diff --git a/bindings/sundials4py/test/fixtures.py b/bindings/sundials4py/test/fixtures.py new file mode 100644 index 0000000000..d9be35f9cb --- /dev/null +++ b/bindings/sundials4py/test/fixtures.py @@ -0,0 +1,85 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import pytest +from numpy import sqrt, finfo +from sundials4py.core import * + +SUNREALTYPE_RTOL = sqrt(finfo(sunrealtype).eps) +SUNREALTYPE_ATOL = sqrt(finfo(sunrealtype).eps) + + +@pytest.fixture +def sunctx(): + status, sunctx = SUNContext_Create(SUN_COMM_NULL) + yield sunctx + + +@pytest.fixture +def nvec(sunctx): + nvec = N_VNew_Serial(10, sunctx) + yield nvec + + +@pytest.fixture +def sunstepper(sunctx): + status, stepper = SUNStepper_Create(sunctx) + + # Dummy callback implementations + def evolve_fn(stepper, tout, vret, tret): + return 0 + + def one_step_fn(stepper, tout, vret, tret): + return 0 + + def full_rhs_fn(stepper, t, v, f, mode): + return 0 + + def reinit_fn(stepper, t, y): + return 0 + + def reset_fn(stepper, t, y): + return 0 + + def reset_ckpt_idx_fn(stepper, idx): + return 0 + + def stop_time_fn(stepper, tstop): + return 0 + + def step_direction_fn(stepper, direction): + return 0 + + def forcing_fn(stepper, tshift, tscale, forcing, nforcing): + return 0 + + def get_num_steps_fn(stepper): + return 0, 0 + + # Set all function pointers + SUNStepper_SetEvolveFn(stepper, evolve_fn) + SUNStepper_SetOneStepFn(stepper, one_step_fn) + SUNStepper_SetFullRhsFn(stepper, full_rhs_fn) + SUNStepper_SetReInitFn(stepper, reinit_fn) + SUNStepper_SetResetFn(stepper, reset_fn) + SUNStepper_SetResetCheckpointIndexFn(stepper, reset_ckpt_idx_fn) + SUNStepper_SetStopTimeFn(stepper, stop_time_fn) + SUNStepper_SetStepDirectionFn(stepper, step_direction_fn) + SUNStepper_SetForcingFn(stepper, forcing_fn) + SUNStepper_SetGetNumStepsFn(stepper, get_num_steps_fn) + + yield stepper diff --git a/bindings/sundials4py/test/problems/__init__.py b/bindings/sundials4py/test/problems/__init__.py new file mode 100644 index 0000000000..c86da1af18 --- /dev/null +++ b/bindings/sundials4py/test/problems/__init__.py @@ -0,0 +1,2 @@ +from .analytic import * +from .harmonic_oscillator import * diff --git a/bindings/sundials4py/test/problems/analytic.py b/bindings/sundials4py/test/problems/analytic.py new file mode 100644 index 0000000000..8d1ce84fd1 --- /dev/null +++ b/bindings/sundials4py/test/problems/analytic.py @@ -0,0 +1,252 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import numpy as np +from sundials4py.core import * +from .problem import ODEProblem + + +class AnalyticODE(ODEProblem): + """ + * The following is a simple example problem with analytical + * solution, + * dy/dt = lambda*y + 1/(1+t^2) - lambda*atan(t) + * for t in the interval [0.0, 10.0], with initial condition: y=0. + * + * The stiffness of the problem is directly proportional to the + * value of "lambda". The value of lambda should be negative to + * result in a well-posed ODE; for values with magnitude larger + * than 100 the problem becomes quite stiff. + """ + + def __init__(self, lamb=-10.0): + self.lamb = lamb + self.inner_stepper = None + + def f(self, t, yvec, ydotvec): + y = N_VGetArrayPointer(yvec) + ydot = N_VGetArrayPointer(ydotvec) + ydot[0] = self.lamb * y[0] + 1.0 / (1.0 + t * t) - self.lamb * np.arctan(t) + return 0 + + def dom_eig(self, t, yvec, fnvec, tempv1, tempv2, tempv3): + lamdbaR = self.lamb + lamdbaI = 0.0 + return 0, lamdbaR, lamdbaI + + def solution(self, y0vec, yvec, t): + y = N_VGetArrayPointer(yvec) + y[0] = np.atan(t) + return 0 + + def set_init_cond(self, y0vec): + y0 = N_VGetArrayPointer(y0vec) + y0[0] = 0.0 + return 0 + + +class AnalyticMultiscaleODE(ODEProblem): + """ + * We consider the initial value problem + * y' + lambda*y = y^2, y(0) = 1 + * proposed in + * + * Estep, D., et al. "An a posteriori–a priori analysis of multiscale operator + * splitting." SIAM Journal on Numerical Analysis 46.3 (2008): 1116-1146. + * + * The parameter lambda is positive, t is in [0, 1], and the exact solution is + * + * y(t) = lambda*y / (y(0) - (y(0) - lambda)*exp(lambda*t)) + """ + + T0 = 0.0 + TF = 1.0 + + def __init__(self, lamb=2.0): + self.lamb = lamb + + def f_linear(self, t, yvec, ydotvec): + y = N_VGetArrayPointer(yvec) + ydot = N_VGetArrayPointer(ydotvec) + ydot[:] = -self.lamb * y + return 0 + + def f_nonlinear(self, t, yvec, ydotvec): + y = N_VGetArrayPointer(yvec) + ydot = N_VGetArrayPointer(ydotvec) + ydot[:] = y * y + return 0 + + def solution(self, y0vec, yvec, tf): + y0 = N_VGetArrayPointer(y0vec)[0] + y = N_VGetArrayPointer(yvec) + y[0] = self.lamb * y0 / (y0 - (y0 - self.lamb) * np.exp(self.lamb * tf)) + return 0 + + def set_init_cond(self, y0vec): + y0 = N_VGetArrayPointer(y0vec) + y0[0] = 1.0 + return 0 + + +class AnalyticDAE: + """ + * The following is a simple example problem with analytical + * solution adapted from example 10.2 of Ascher & Petzold, "Computer + * Methods for Ordinary Differential Equations and + * Differential-Algebraic Equations," SIAM, 1998, page 267: + * x1'(t) = (1-alpha)/(t-2)*x1 - x1 + (alpha-1)*x2 + 2*exp(t) + * 0 = (t+2)*x1 - (t+2)*exp(t) + * for t in the interval [0.0, 1.0], with initial condition: + * x1(0) = 1 and x2(0) = -1/2. + * The problem has true solution + * x1(t) = exp(t) and x2(t) = exp(t)/(t-2) + """ + + T0 = 0.0 + TF = 1.0 + + def __init__(self, alpha=10.0): + self.alpha = alpha + + def res(self, t, yyvec, ypvec, resvec): + yy = N_VGetArrayPointer(yyvec) + yp = N_VGetArrayPointer(ypvec) + res = N_VGetArrayPointer(resvec) + alpha = self.alpha + + # System residual function: + # 0 = (1-alpha)/(t-2)*x1 - x1 + (alpha-1)*x2 + 2*exp(t) - x1'(t) + # 0 = (t+2)*x1 - (t+2)*exp(t) + res[0] = ( + (1.0 - alpha) / (t - 2.0) * yy[0] + - yy[0] + + (alpha - 1.0) * yy[1] + + 2.0 * np.exp(t) + - yp[0] + ) + res[1] = (t + 2.0) * yy[0] - (t + 2.0) * np.exp(t) + return 0 + + def solution(self, yyvec, ypvec, t): + yy = N_VGetArrayPointer(yyvec) + yp = N_VGetArrayPointer(ypvec) + yy[0] = np.exp(t) + yy[1] = np.exp(t) / (t - 2.0) + yp[0] = np.exp(t) + yp[1] = np.exp(t) / (t - 2.0) - np.exp(t) / (t - 2.0) / (t - 2.0) + return 0 + + def set_init_cond(self, yyvec, ypvec, t0): + return self.solution(yyvec, ypvec, t0) + + def psolve(self, t, yyvec, ypvec, rrvec, rvec, zvec, cj, delta): + """ + Exact solution as preconditioner + P = df/dy + cj*df/dyp + => + P = [ - cj - (alpha - 1)/(t - 2) - 1, alpha - 1] + [ t + 2, 0] + + z = P^{-1} r + */ + """ + yy = N_VGetArrayPointer(yyvec) + yp = N_VGetArrayPointer(ypvec) + r = N_VGetArrayPointer(rvec) + z = N_VGetArrayPointer(zvec) + alpha = self.alpha + a11 = -cj - (alpha - 1.0) / (t - 2.0) - 1.0 + a12 = alpha - 1.0 + a21 = t + 2.0 + z[0] = r[1] / a21 + z[1] = -(a11 * r[1] - a21 * r[0]) / (a12 * a21) + return 0 + + +class AnalyticNonlinearSys: + """ + * This implements the nonlinear system + * + * 3x - cos((y-1)z) - 1/2 = 0 + * x^2 - 81(y-0.9)^2 + sin(z) + 1.06 = 0 + * exp(-x(y-1)) + 20z + (10 pi - 3)/3 = 0 + * + * The nonlinear fixed point function is + * + * g1(x,y,z) = 1/3 cos((y-1)z) + 1/6 + * g2(x,y,z) = 1/9 sqrt(x^2 + sin(z) + 1.06) + 0.9 + * g3(x,y,z) = -1/20 exp(-x(y-1)) - (10 pi - 3) / 60 + * + * Corrector form g(x,y,z): + * + * g1(x,y,z) = 1/3 cos((y-1)yz) + 1/6 - x0 + * g2(x,y,z) = 1/9 sqrt(x^2 + sin(z) + 1.06) + 0.9 - y0 + * g3(x,y,z) = -1/20 exp(-x(y-1)) - (10 pi - 3) / 60 - z0 + * + * This system has the analytic solution x = 1/2, y = 1, z = -pi/6. + """ + + NEQ = 3 + + def __init__(self, u0vec=None): + self.u0vec = u0vec + + # CJB: __enter__ and __exit__ are defined so that this class can be + # use with python "with" statements. This is a workaround for the following scenario: + # u0 = N_VNew_Serial(...) + # problem = AnalyticNonlinearSys(u0) + # def g_fn(self, u, g, _): + # return problem.fixed_point_fn(u, g) + # Without using a "with" block, this code will cause nanobind to complain about reference + # leaks because `problem`, which holds a reference to u0, seems to not be garbage collected + # until the nanobind shutdown callback. + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.u0vec = None + + def fixed_point_fn(self, uvec, gvec): + u = N_VGetArrayPointer(uvec) + g = N_VGetArrayPointer(gvec) + x, y, z = u[0], u[1], u[2] + + g[0] = (1.0 / 3.0) * np.cos((y - 1.0) * z) + (1.0 / 6.0) + g[1] = (1.0 / 9.0) * np.sqrt(x * x + np.sin(z) + 1.06) + 0.9 + g[2] = -(1.0 / 20.0) * np.exp(-x * (y - 1.0)) - (10.0 * np.pi - 3.0) / 60.0 + + return 0 + + def corrector_fp_fn(self, uvec, gvec): + self.fixed_point_fn(uvec, gvec) + N_VLinearSum(1.0, gvec, -1.0, self.u0vec, gvec) + return 0 + + def conv_test(self, nls, yvec, delvec, tol, ewtvec): + delnrm = N_VMaxNorm(delvec) + if delnrm <= tol: + return SUN_SUCCESS + else: + return SUN_NLS_CONTINUE + + def solution(self, uvec): + u = N_VGetArrayPointer(uvec) + u[0] = 0.5 + u[1] = 1.0 + u[2] = -np.pi / 6.0 + return 0 diff --git a/bindings/sundials4py/test/problems/harmonic_oscillator.py b/bindings/sundials4py/test/problems/harmonic_oscillator.py new file mode 100644 index 0000000000..7d9a00a0e4 --- /dev/null +++ b/bindings/sundials4py/test/problems/harmonic_oscillator.py @@ -0,0 +1,51 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import numpy as np +from sundials4py.core import * +from .problem import ODEProblem + + +class HarmonicOscillatorODE(ODEProblem): + def __init__(self, A=10.0, phi=0.0, omega=1.0): + self.A = A + self.phi = phi + self.omega = omega + + def xdot(self, t, yvec, ydotvec): + y = N_VGetArrayPointer(yvec) + ydot = N_VGetArrayPointer(ydotvec) + ydot[0] = y[1] + return 0 + + def vdot(self, t, yvec, ydotvec): + y = N_VGetArrayPointer(yvec) + ydot = N_VGetArrayPointer(ydotvec) + ydot[1] = -self.omega * self.omega * y[0] + return 0 + + def set_init_cond(self, yvec): + y = N_VGetArrayPointer(yvec) + y[0] = self.A * np.cos(self.phi) + y[1] = -self.A * self.omega * np.sin(self.phi) + + def solution(self, y0vec, yvec, t): + y0 = N_VGetArrayPointer(y0vec) + y = N_VGetArrayPointer(yvec) + y[0] = self.A * np.cos(self.omega * t + self.phi) + y[1] = -self.A * self.omega * np.sin(self.omega * t + self.phi) + return 0 diff --git a/bindings/sundials4py/test/problems/problem.py b/bindings/sundials4py/test/problems/problem.py new file mode 100644 index 0000000000..4f57fe8eba --- /dev/null +++ b/bindings/sundials4py/test/problems/problem.py @@ -0,0 +1,27 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# Base classes for problems +# ----------------------------------------------------------------- + + +class ODEProblem: + + def set_init_cond(self, y0vec): + raise NotImplementedError("Subclasses must implement the set_init_cond method.") + + def solution(self, y0vec, yvec, t): + raise NotImplementedError("Subclasses must implement the set_init_cond method.") diff --git a/bindings/sundials4py/test/sunmatrix/test_sunmatrix_band.py b/bindings/sundials4py/test/sunmatrix/test_sunmatrix_band.py new file mode 100644 index 0000000000..5582223b12 --- /dev/null +++ b/bindings/sundials4py/test/sunmatrix/test_sunmatrix_band.py @@ -0,0 +1,119 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import pytest +import numpy as np +from fixtures import * +from sundials4py.core import * + + +def test_create_band_matrix(sunctx): + rows, mu, ml = 4, 1, 1 + A = SUNBandMatrix(rows, mu, ml, sunctx) + assert A is not None + assert SUNMatGetID(A) == SUNMATRIX_BAND + # Ensure the shape is being translated correctly + dataA = SUNBandMatrix_Data(A) + ldata = SUNBandMatrix_LData(A) + assert dataA.shape[0] == ldata + + +def test_clone_matrix(sunctx): + rows, mu, ml = 4, 1, 1 + A = SUNBandMatrix(rows, mu, ml, sunctx) + B = SUNMatClone(A) + assert B is not None + + +def test_zero_matrix(sunctx): + rows, mu, ml = 4, 1, 1 + A = SUNBandMatrix(rows, mu, ml, sunctx) + ret = SUNMatZero(A) + assert ret == 0 + + +def test_copy_matrix(sunctx): + rows, mu, ml = 4, 1, 1 + A = SUNBandMatrix(rows, mu, ml, sunctx) + B = SUNBandMatrix(rows, mu, ml, sunctx) + smu = SUNBandMatrix_StoredUpperBandwidth(A) + dataA = SUNBandMatrix_Data(A) + dataA[smu - mu] = 1.0 + ret = SUNMatCopy(A, B) + assert ret == 0 + dataB = SUNBandMatrix_Data(B) + assert dataB[smu - mu] == 1.0 + + +def test_scale_add_matrix(sunctx): + rows, mu, ml = 4, 1, 1 + A = SUNBandMatrix(rows, mu, ml, sunctx) + B = SUNBandMatrix(rows, mu, ml, sunctx) + smu = SUNBandMatrix_StoredUpperBandwidth(A) + dataA = SUNBandMatrix_Data(A) + dataB = SUNBandMatrix_Data(B) + dataA[smu - mu : smu + ml] = 1.0 # column 0 set to 1.0 + dataB[smu - mu : smu + ml] = 2.0 + ret = SUNMatScaleAdd(3.0, A, B) + assert ret == 0 + # A should now be 3*A + B = 3*1 + 2 = 5 + assert np.allclose(dataA[smu - mu : smu + ml], 5.0) + + +def test_scale_add_identity(sunctx): + rows, mu, ml = 4, 1, 1 + A = SUNBandMatrix(rows, mu, ml, sunctx) + ldim = SUNBandMatrix_LDim(A) + smu = SUNBandMatrix_StoredUpperBandwidth(A) + dataA = SUNBandMatrix_Data(A) + ret = SUNMatScaleAddI(0.0, A) + assert ret == 0 + # A should now be I + diag = np.array([dataA[smu + i * ldim] for i in range(rows)], dtype=sunrealtype) + assert np.allclose(diag, 1.0) + + +def test_matvec(sunctx): + rows, mu, ml = 4, 1, 1 + A = SUNBandMatrix(rows, mu, ml, sunctx) + x = N_VNew_Serial(rows, sunctx) + y = N_VNew_Serial(rows, sunctx) + + N_VConst(1.0, x) + + # Fill band matrix data for a simple 4x4 banded matrix + # [3 2 0 0] + # [1 3 2 0] + # [0 1 3 2] + # [0 0 1 3] + dataA = SUNBandMatrix_Data(A) + ldim = SUNBandMatrix_LDim(A) + smu = SUNBandMatrix_StoredUpperBandwidth(A) + for j in range(rows): + # Diagonal + dataA[smu + j * ldim] = 3.0 + # Lower diagonal + if j > 0: + dataA[smu - 1 + j * ldim] = 2.0 + # Upper diagonal + if j < rows - 1: + dataA[smu + 1 + j * ldim] = 1.0 + + ret = SUNMatMatvec(A, x, y) + assert ret == 0 + + assert np.allclose(N_VGetArrayPointer(y), [5.0, 6.0, 6.0, 4.0]) diff --git a/bindings/sundials4py/test/sunmatrix/test_sunmatrix_dense.py b/bindings/sundials4py/test/sunmatrix/test_sunmatrix_dense.py new file mode 100644 index 0000000000..3ce7bbb6fe --- /dev/null +++ b/bindings/sundials4py/test/sunmatrix/test_sunmatrix_dense.py @@ -0,0 +1,92 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import pytest +import numpy as np +from fixtures import * +from sundials4py.core import * + + +def test_create_dense_matrix(sunctx): + rows, cols = 3, 2 + A = SUNDenseMatrix(rows, cols, sunctx) + assert A is not None + assert SUNMatGetID(A) == SUNMATRIX_DENSE + # Ensure the shape is being translated correctly + dataA_shape = np.shape(SUNDenseMatrix_Data(A)) + assert dataA_shape[0] == rows and dataA_shape[1] == cols + + +def test_clone_matrix(sunctx): + A = SUNDenseMatrix(2, 2, sunctx) + B = SUNMatClone(A) + assert B is not None + + +def test_zero_matrix(sunctx): + A = SUNDenseMatrix(2, 2, sunctx) + ret = SUNMatZero(A) + assert ret == 0 + + +def test_copy_matrix(sunctx): + A = SUNDenseMatrix(2, 2, sunctx) + B = SUNDenseMatrix(2, 2, sunctx) + # Set some values in A + dataA = SUNDenseMatrix_Data(A) + dataA[0, 0] = 1.0 + ret = SUNMatCopy(A, B) + assert ret == 0 + dataB = SUNDenseMatrix_Data(B) + assert dataB[0, 0] == 1.0 + + +def test_scale_add_matrix(sunctx): + A = SUNDenseMatrix(2, 2, sunctx) + B = SUNDenseMatrix(2, 2, sunctx) + dataA = SUNDenseMatrix_Data(A) + dataB = SUNDenseMatrix_Data(B) + dataA[:, :] = 1.0 + dataB[:, :] = 2.0 + ret = SUNMatScaleAdd(3.0, A, B) + assert ret == 0 + # A should now be 3*A + B = 3*1 + 2 = 5 + assert np.allclose(dataA, 5.0) + + +def test_scale_add_identity(sunctx): + A = SUNDenseMatrix(2, 2, sunctx) + dataA = SUNDenseMatrix_Data(A) + dataA[:, :] = 2.0 + ret = SUNMatScaleAddI(3.0, A) + assert ret == 0 + # A should now be 3*A + I + expected = np.eye(2) + 6.0 + assert np.allclose(dataA, expected) + + +def test_matvec(sunctx, nvec): + A = SUNDenseMatrix(2, 2, sunctx) + dataA = SUNDenseMatrix_Data(A) + dataA[:, :] = [[1.0, 2.0], [3.0, 4.0]] + x = N_VNew_Serial(2, sunctx) + y = N_VNew_Serial(2, sunctx) + N_VConst(1.0, x) + ret = SUNMatMatvec(A, x, y) + assert ret == 0 + arr = N_VGetArrayPointer(y) + assert np.allclose(arr, [3.0, 7.0]) diff --git a/bindings/sundials4py/test/sunmatrix/test_sunmatrix_sparse.py b/bindings/sundials4py/test/sunmatrix/test_sunmatrix_sparse.py new file mode 100644 index 0000000000..8906d2dd02 --- /dev/null +++ b/bindings/sundials4py/test/sunmatrix/test_sunmatrix_sparse.py @@ -0,0 +1,129 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import pytest +import numpy as np +from fixtures import * +from sundials4py.core import * + + +def test_create_sparse_matrix(sunctx): + rows, cols, nnz = 3, 3, 4 + A = SUNSparseMatrix(rows, cols, nnz, SUN_CSR_MAT, sunctx) + assert A is not None + assert SUNMatGetID(A) == SUNMATRIX_SPARSE + # Check shape of data + dataA = SUNSparseMatrix_Data(A) + assert dataA.shape[0] == nnz + + +def test_clone_matrix(sunctx): + rows, cols, nnz = 3, 3, 4 + A = SUNSparseMatrix(rows, cols, nnz, SUN_CSR_MAT, sunctx) + B = SUNMatClone(A) + assert B is not None + + +def test_zero_matrix(sunctx): + rows, cols, nnz = 3, 3, 4 + A = SUNSparseMatrix(rows, cols, nnz, SUN_CSR_MAT, sunctx) + ret = SUNMatZero(A) + assert ret == 0 + + +def test_copy_matrix(sunctx): + rows, cols, nnz = 3, 3, 4 + A = SUNSparseMatrix(rows, cols, nnz, SUN_CSR_MAT, sunctx) + B = SUNSparseMatrix(rows, cols, nnz, SUN_CSR_MAT, sunctx) + dataA = SUNSparseMatrix_Data(A) + idx_vals = SUNSparseMatrix_IndexValues(A) + idx_ptrs = SUNSparseMatrix_IndexPointers(A) + # CSR: row 0: col 0 (1.0), row 1: col 1 (1.0), row 2: col 2 (1.0), row 2: col 0 (2.0) + dataA[:] = [1.0, 1.0, 1.0, 2.0] + idx_vals[:] = [0, 1, 2, 0] + idx_ptrs[:] = [0, 1, 2, 4] + ret = SUNMatCopy(A, B) + assert ret == 0 + dataB = SUNSparseMatrix_Data(B) + idx_valsB = SUNSparseMatrix_IndexValues(B) + idx_ptrsB = SUNSparseMatrix_IndexPointers(B) + assert np.allclose(dataB, dataA) + assert np.allclose(idx_valsB, idx_vals) + assert np.allclose(idx_ptrsB, idx_ptrs) + + +def test_scale_add_matrix(sunctx): + rows, cols, nnz = 3, 3, 4 + A = SUNSparseMatrix(rows, cols, nnz, SUN_CSR_MAT, sunctx) + B = SUNSparseMatrix(rows, cols, nnz, SUN_CSR_MAT, sunctx) + dataA = SUNSparseMatrix_Data(A) + dataB = SUNSparseMatrix_Data(B) + idx_vals = SUNSparseMatrix_IndexValues(A) + idx_ptrs = SUNSparseMatrix_IndexPointers(A) + idx_valsB = SUNSparseMatrix_IndexValues(B) + idx_ptrsB = SUNSparseMatrix_IndexPointers(B) + # CSR: row 0: col 0 (1.0), row 1: col 1 (1.0), row 2: col 2 (1.0), row 2: col 0 (2.0) + dataA[:] = [1.0, 1.0, 1.0, 2.0] + dataB[:] = [2.0, 2.0, 2.0, 4.0] + idx_vals[:] = [0, 1, 2, 0] + idx_ptrs[:] = [0, 1, 2, 4] + idx_valsB[:] = [0, 1, 2, 0] + idx_ptrsB[:] = [0, 1, 2, 4] + ret = SUNMatScaleAdd(3.0, A, B) + assert ret == 0 + # 3*A + B = [3+2, 3+2, 3+2, 6+4] = [5, 5, 5, 10] + assert np.allclose(dataA, [5.0, 5.0, 5.0, 10.0]) + + +def test_scale_add_identity(sunctx): + rows, cols, nnz = 3, 3, 4 + A = SUNSparseMatrix(rows, cols, nnz, SUN_CSR_MAT, sunctx) + dataA = SUNSparseMatrix_Data(A) + idx_vals = SUNSparseMatrix_IndexValues(A) + idx_ptrs = SUNSparseMatrix_IndexPointers(A) + # CSR: row 0: col 0 (2.0), row 1: col 1 (2.0), row 2: col 2 (2.0), row 2: col 0 (2.0) + dataA[:] = [2.0, 2.0, 2.0, 2.0] + idx_vals[:] = [0, 1, 2, 0] + idx_ptrs[:] = [0, 1, 2, 4] + ret = SUNMatScaleAddI(3.0, A) + assert ret == 0 + # Diagonal elements should be 3*2+1=7, off-diagonal unchanged + # So dataA = [7.0, 7.0, 7.0, 2.0] (assuming diagonal at idx_vals[0:3]) + # But since idx_vals = [0,1,2,0], diagonal is at positions 0,1,2 + assert np.allclose(dataA[:3], [7.0, 7.0, 7.0]) + + +def test_matvec(sunctx): + rows, cols, nnz = 3, 3, 4 + A = SUNSparseMatrix(rows, cols, nnz, SUN_CSR_MAT, sunctx) + x = N_VNew_Serial(cols, sunctx) + y = N_VNew_Serial(rows, sunctx) + N_VConst(1.0, x) + # Fill a simple sparse matrix: 3x3 identity with one extra off-diagonal + dataA = SUNSparseMatrix_Data(A) + idx_vals = SUNSparseMatrix_IndexValues(A) + idx_ptrs = SUNSparseMatrix_IndexPointers(A) + # CSR: row 0: col 0 (1.0), row 1: col 1 (1.0), row 2: col 2 (1.0), row 2: col 0 (2.0) + dataA[:] = [1.0, 1.0, 1.0, 2.0] + idx_vals[:] = [0, 1, 2, 0] + idx_ptrs[:] = [0, 1, 2, 4] + ret = SUNMatMatvec(A, x, y) + assert ret == 0 + # y[0] = 1.0*x[0] = 1.0 + # y[1] = 1.0*x[1] = 1.0 + # y[2] = 1.0*x[2] + 2.0*x[0] = 1.0 + 2.0 = 3.0 + assert np.allclose(N_VGetArrayPointer(y), [1.0, 1.0, 3.0]) diff --git a/bindings/sundials4py/test/test_arkstep.py b/bindings/sundials4py/test/test_arkstep.py new file mode 100644 index 0000000000..bb27ba2d42 --- /dev/null +++ b/bindings/sundials4py/test/test_arkstep.py @@ -0,0 +1,124 @@ +#!/bin/python +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import pytest +import numpy as np +from fixtures import * +from sundials4py.core import * +from sundials4py.arkode import * +from problems import AnalyticODE, AnalyticMultiscaleODE + + +@pytest.mark.skipif( + sunrealtype == np.float32, reason="Test not supported for sunrealtype=np.float32" +) +def test_explicit(sunctx): + y = N_VNew_Serial(1, sunctx) + + ode_problem = AnalyticODE() + + ode_problem.set_init_cond(y) + + ark = ARKStepCreate(lambda t, y, ydot, _: ode_problem.f(t, y, ydot), None, 0, y, sunctx) + + status = ARKodeSStolerances(ark.get(), SUNREALTYPE_RTOL, SUNREALTYPE_ATOL) + assert status == ARK_SUCCESS + + nrtfn = 2 + + def rootfn(t, y, gout, _): + # just a smoke test of the root finding callback + gout[:] = 1.0 + assert len(gout) == nrtfn + return 0 + + status = ARKodeRootInit(ark.get(), nrtfn, rootfn) + assert status == ARK_SUCCESS + + tout = 10.0 + status, tret = ARKodeEvolve(ark.get(), tout, y, ARK_NORMAL) + assert status == ARK_SUCCESS + + sol = N_VClone(y) + ode_problem.solution(y, sol, tret) + + assert np.allclose(N_VGetArrayPointer(sol), N_VGetArrayPointer(y), atol=1e-2) + + +@pytest.mark.skipif( + sunrealtype == np.float32, reason="Test not supported for sunrealtype=np.float32" +) +def test_implicit(sunctx): + y = N_VNew_Serial(1, sunctx) + ls = SUNLinSol_SPGMR(y, 0, 0, sunctx) + + ode_problem = AnalyticODE() + + ode_problem.set_init_cond(y) + + ark = ARKStepCreate(None, lambda t, y, ydot, _: ode_problem.f(t, y, ydot), 0, y, sunctx) + + status = ARKodeSStolerances(ark.get(), SUNREALTYPE_RTOL, SUNREALTYPE_ATOL) + assert status == ARK_SUCCESS + + status = ARKodeSetLinearSolver(ark.get(), ls, None) + assert status == ARK_SUCCESS + + tout = 10.0 + status, tret = ARKodeEvolve(ark.get(), tout, y, ARK_NORMAL) + assert status == ARK_SUCCESS + + sol = N_VClone(y) + ode_problem.solution(y, sol, tret) + + assert np.allclose(N_VGetArrayPointer(sol), N_VGetArrayPointer(y), atol=1e-2) + + +@pytest.mark.skipif( + sunrealtype == np.float32, reason="Test not supported for sunrealtype=np.float32" +) +def test_imex(sunctx): + y = N_VNew_Serial(1, sunctx) + ls = SUNLinSol_SPGMR(y, 0, 0, sunctx) + + ode_problem = AnalyticMultiscaleODE() + + ode_problem.set_init_cond(y) + + ark = ARKStepCreate( + lambda t, y, ydot, _: ode_problem.f_nonlinear(t, y, ydot), + lambda t, y, ydot, _: ode_problem.f_linear(t, y, ydot), + 0, + y, + sunctx, + ) + + status = ARKodeSStolerances(ark.get(), SUNREALTYPE_RTOL, SUNREALTYPE_ATOL) + assert status == ARK_SUCCESS + + status = ARKodeSetLinearSolver(ark.get(), ls, None) + assert status == ARK_SUCCESS + + tout = 10.0 + status, tret = ARKodeEvolve(ark.get(), tout, y, ARK_NORMAL) + assert status == 0 + + sol = N_VClone(y) + ode_problem.solution(y, sol, tret) + + assert np.allclose(N_VGetArrayPointer(sol), N_VGetArrayPointer(y), atol=1e-2) diff --git a/bindings/sundials4py/test/test_cvodes.py b/bindings/sundials4py/test/test_cvodes.py new file mode 100644 index 0000000000..3d2cc611c9 --- /dev/null +++ b/bindings/sundials4py/test/test_cvodes.py @@ -0,0 +1,367 @@ +#!/bin/python +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import pytest +import numpy as np +from fixtures import * +from sundials4py.core import * +from sundials4py.cvodes import * +from problems import AnalyticODE + + +@pytest.mark.skipif( + sunrealtype == np.float32, reason="Test not supported for sunrealtype=np.float32" +) +def test_cvodes_ivp(sunctx): + NEQ = 1 + y = N_VNew_Serial(NEQ, sunctx) + ls = SUNLinSol_SPGMR(y, 0, 0, sunctx) + + ode_problem = AnalyticODE() + ode_problem.set_init_cond(y) + + cvode = CVodeCreate(CV_BDF, sunctx) + + status = CVodeInit(cvode.get(), lambda t, y, ydot, _: ode_problem.f(t, y, ydot), 0, y) + assert status == CV_SUCCESS + + status = CVodeSStolerances(cvode.get(), SUNREALTYPE_RTOL, SUNREALTYPE_ATOL) + assert status == CV_SUCCESS + + status = CVodeSetLinearSolver(cvode.get(), ls, None) + assert status == CV_SUCCESS + + nrtfn = 2 + + def rootfn(t, y, gout, _): + # just a smoke test of the root finding callback + gout[:] = 1.0 + assert len(gout) == nrtfn + return 0 + + status = CVodeRootInit(cvode.get(), nrtfn, rootfn) + assert status == CV_SUCCESS + + tout = 10.0 + status, tret = CVode(cvode.get(), tout, y, CV_NORMAL) + assert status == CV_SUCCESS + + sol = N_VClone(y) + ode_problem.solution(y, sol, tret) + assert np.allclose(N_VGetArrayPointer(sol), N_VGetArrayPointer(y), atol=1e-2) + + status, num_steps = CVodeGetNumSteps(cvode.get()) + assert status == CV_SUCCESS + assert num_steps > 0 + + +@pytest.mark.skipif( + sunrealtype == np.float32, reason="Test not supported for sunrealtype=np.float32" +) +def test_cvodes_fsa(sunctx): + # Forward Sensitivity Analysis (FSA) with respect to initial condition + NEQ = 1 + y = N_VNew_Serial(NEQ, sunctx) + ls = SUNLinSol_SPGMR(y, 0, 0, sunctx) + + ode_problem = AnalyticODE() + ode_problem.set_init_cond(y) + + cvode = CVodeCreate(CV_BDF, sunctx) + + # This problem requires tighter tolerances in order to get the forward + # sensitivity to converge to the expected solution within a reasonable tolerance + atol = 1e-10 + rtol = 1e-10 + + status = CVodeInit(cvode.get(), lambda t, y, ydot, _: ode_problem.f(t, y, ydot), 0, y) + assert status == CV_SUCCESS + + status = CVodeSStolerances(cvode.get(), rtol, atol) + assert status == CV_SUCCESS + + status = CVodeSetLinearSolver(cvode.get(), ls, None) + assert status == CV_SUCCESS + + # Sensitivity setup + Ns = 1 # number of sensitivities (wrt y0) + ism = CV_SIMULTANEOUS + yS0 = [N_VClone(y)] + N_VConst(1.0, yS0[0]) + + def fS(Ns, t, y, ydot, yS, ySdot, _, tmp1, tmp2): + # Sensitivity RHS: df/dy * yS + df/dp (here, p = y0, so df/dp = 0) + yarr = N_VGetArrayPointer(y) + ySarr = N_VGetArrayPointer(yS[0]) + ySdotarr = N_VGetArrayPointer(ySdot[0]) + # df/dy = lambda + lamb = ode_problem.lamb + ySdotarr[0] = lamb * ySarr[0] + return 0 + + status = CVodeSensInit(cvode.get(), Ns, ism, fS, yS0) + assert status == CV_SUCCESS + + status = CVodeSensSStolerances(cvode.get(), rtol, np.array([atol], dtype=sunrealtype)) + assert status == CV_SUCCESS + + status = CVodeSetMaxNumSteps(cvode.get(), 100000) + assert status == CV_SUCCESS + + tout = 10.0 + ySout = [N_VClone(y) for _ in range(Ns)] + status, tret = CVode(cvode.get(), tout, y, CV_NORMAL) + assert status == CV_SUCCESS + + # Check IVP solution + sol = N_VClone(y) + ode_problem.solution(y, sol, tret) + assert np.allclose(N_VGetArrayPointer(sol), N_VGetArrayPointer(y), atol=1e-2) + + # Get sensitivities + status, tret_sens = CVodeGetSens(cvode.get(), ySout) + assert status == CV_SUCCESS + + lamb = ode_problem.lamb + expected = np.exp(lamb * tret) + sens_val = N_VGetArrayPointer(ySout[0])[0] + assert np.allclose(sens_val, expected, atol=1e-2) + + +@pytest.mark.skipif( + sunrealtype == np.float32, reason="Test not supported for sunrealtype=np.float32" +) +def test_cvodes_adjoint(sunctx): + # Adjoint Sensitivity Analysis (ASA) for the same ODE problem as FSA + y = N_VNew_Serial(1, sunctx) + ls = SUNLinSol_SPGMR(y, 0, 0, sunctx) + + ode_problem = AnalyticODE() + ode_problem.set_init_cond(y) + + cvode = CVodeCreate(CV_BDF, sunctx) + + status = CVodeInit(cvode.get(), lambda t, y, ydot, _: ode_problem.f(t, y, ydot), 0, y) + assert status == CV_SUCCESS + + status = CVodeSStolerances(cvode.get(), SUNREALTYPE_RTOL, SUNREALTYPE_ATOL) + assert status == CV_SUCCESS + + status = CVodeSetLinearSolver(cvode.get(), ls, None) + assert status == CV_SUCCESS + + # Adjoint setup + steps = 1000 + interp = CV_HERMITE + status = CVodeAdjInit(cvode.get(), steps, interp) + assert status == CV_SUCCESS + + # Forward solve + tout = 10.0 + status, tret, ncheck = CVodeF(cvode.get(), tout, y, CV_NORMAL) + assert status == CV_SUCCESS + + # Define a simple functional: g = y(T) + # The gradient dg/dy0 = dy(T)/dy0, which is exp(lambda*T) + + # Backward problem setup + yB = N_VNew_Serial(1, sunctx) + N_VConst(1.0, yB) + lsB = SUNLinSol_SPGMR(yB, 0, 0, sunctx) + + status, whichB = CVodeCreateB(cvode.get(), CV_BDF) + assert status == CV_SUCCESS + + def fB(t, y, yB, yBdot, _): + # Adjoint RHS: -df/dy^T * lambdaB + yarr = N_VGetArrayPointer(y) + yBarr = N_VGetArrayPointer(yB) + yBdotarr = N_VGetArrayPointer(yBdot) + lamb = ode_problem.lamb + yBdotarr[0] = -lamb * yBarr[0] + return 0 + + tB0 = tret + status = CVodeInitB(cvode.get(), whichB, fB, tB0, yB) + assert status == CV_SUCCESS + + status = CVodeSStolerancesB(cvode.get(), whichB, SUNREALTYPE_RTOL, SUNREALTYPE_ATOL) + assert status == CV_SUCCESS + + status = CVodeSetLinearSolverB(cvode.get(), whichB, lsB, None) + assert status == CV_SUCCESS + + # Set terminal condition for lambda(T) = dG/dy(T) + yBarr = N_VGetArrayPointer(yB) + yBarr[0] = 1.0 + + # Integrate backward + tBout = 0.0 + status = CVodeB(cvode.get(), tBout, CV_NORMAL) + assert status >= CV_SUCCESS + + # Get lambda(0) = dg/dy0 + yB0 = N_VNew_Serial(1, sunctx) + status, tBret = CVodeGetB(cvode.get(), whichB, yB0) + assert status == CV_SUCCESS + + # Analytical result: exp(lambda * T) + lamb = ode_problem.lamb + expected = np.exp(lamb * tret) + sens_val = N_VGetArrayPointer(yB0)[0] + assert np.allclose(sens_val, expected, atol=1e-2) + + +@pytest.mark.skipif( + sunrealtype == np.float32, reason="Test not supported for sunrealtype=np.float32" +) +def test_cvodes_adjoint_quad(sunctx): + # Adjoint Sensitivity Analysis (ASA) for the same ODE problem as before but add quadratures + NEQ = 1 + + y = N_VNew_Serial(NEQ, sunctx) + ls = SUNLinSol_SPGMR(y, 0, 0, sunctx) + + ode_problem = AnalyticODE() + ode_problem.set_init_cond(y) + + # This problem requires tighter tolerances in order to get the forward + # sensitivity to converge to the expected solution within a reasonable tolerance + atol = 1e-10 + rtol = 1e-10 + + cvode = CVodeCreate(CV_BDF, sunctx) + + status = CVodeInit(cvode.get(), lambda t, y, ydot, _: ode_problem.f(t, y, ydot), 0, y) + assert status == CV_SUCCESS + + status = CVodeSStolerances(cvode.get(), rtol, atol) + assert status == CV_SUCCESS + + status = CVodeSetLinearSolver(cvode.get(), ls, None) + assert status == CV_SUCCESS + + # Forward sensitivity setup + Ns = 1 # number of sensitivities (wrt y0) + ism = CV_SIMULTANEOUS + yS0 = [N_VClone(y)] + N_VConst(1.0, yS0[0]) + + def fS(Ns, t, y, ydot, yS, ySdot, _, tmp1, tmp2): + # Sensitivity RHS: df/dy * yS + df/dp (here, p = y0, so df/dp = 0) + yarr = N_VGetArrayPointer(y) + ySarr = N_VGetArrayPointer(yS[0]) + ySdotarr = N_VGetArrayPointer(ySdot[0]) + # df/dy = lambda + lamb = ode_problem.lamb + ySdotarr[0] = lamb * ySarr[0] + return 0 + + status = CVodeSensInit(cvode.get(), Ns, ism, fS, yS0) + assert status == CV_SUCCESS + + status = CVodeSensSStolerances(cvode.get(), rtol, np.array([atol], dtype=sunrealtype)) + assert status == CV_SUCCESS + + # Forward quadrature setup + def fQ(t, y, qdot, _): + # Smoke test + qdotarr = N_VGetArrayPointer(qdot) + qdotarr[:] = 0 + return 0 + + Nq = 2 + yQ = N_VNew_Serial(Nq, sunctx) + N_VConst(0.0, yQ) + + status = CVodeQuadInit(cvode.get(), fQ, yQ) + assert status == CV_SUCCESS + + status = CVodeQuadSStolerances(cvode.get(), rtol, atol) + assert status == CV_SUCCESS + + # Adjoint setup + steps = 1000 + interp = CV_HERMITE + status = CVodeAdjInit(cvode.get(), steps, interp) + assert status == CV_SUCCESS + + # Forward solve + tout = 10.0 + status, tret, ncheck = CVodeF(cvode.get(), tout, y, CV_NORMAL) + assert status == CV_SUCCESS + + # Define a simple functional: g = y(T) + # The gradient dg/dy0 = dy(T)/dy0, which is exp(lambda*T) + + # Backward problem setup + Ns = 1 + yB = N_VNew_Serial(NEQ, sunctx) + yQB = N_VNew_Serial(Ns, sunctx) + N_VConst(1.0, yB) + lsB = SUNLinSol_SPGMR(yB, 0, 0, sunctx) + + status, whichB = CVodeCreateB(cvode.get(), CV_BDF) + assert status == CV_SUCCESS + + def fBS(t, y, yS, yB, yBdot, _): + # Adjoint RHS: -df/dy^T * lambdaB + yarr = N_VGetArrayPointer(y) + yBarr = N_VGetArrayPointer(yB) + yBdotarr = N_VGetArrayPointer(yBdot) + lamb = ode_problem.lamb + yBdotarr[0] = -lamb * yBarr[0] + return 0 + + def fQB(t, y, yS, yB, qBdot, _): + # Smoke test + assert len(yS) == Ns + return 0 + + tB0 = tret + status = CVodeInitBS(cvode.get(), whichB, fBS, tB0, yB) + assert status == CV_SUCCESS + + # Setup backward quadratures + status = CVodeQuadInitBS(cvode.get(), whichB, fQB, yQB) + + status = CVodeSStolerancesB(cvode.get(), whichB, rtol, atol) + assert status == CV_SUCCESS + + status = CVodeSetLinearSolverB(cvode.get(), whichB, lsB, None) + assert status == CV_SUCCESS + + # Set terminal condition for lambda(T) = dG/dy(T) + yBarr = N_VGetArrayPointer(yB) + yBarr[0] = 1.0 + + # Integrate backward + tBout = 0.0 + status = CVodeB(cvode.get(), tBout, CV_NORMAL) + assert status >= CV_SUCCESS + + # Get lambda(0) = dg/dy0 + yB0 = N_VNew_Serial(NEQ, sunctx) + status, tBret = CVodeGetB(cvode.get(), whichB, yB0) + assert status == CV_SUCCESS + + # Analytical result: exp(lambda * T) + lamb = ode_problem.lamb + expected = np.exp(lamb * tret) + sens_val = N_VGetArrayPointer(yB0)[0] + assert np.allclose(sens_val, expected, atol=1e-2) diff --git a/bindings/sundials4py/test/test_erkstep.py b/bindings/sundials4py/test/test_erkstep.py new file mode 100644 index 0000000000..d68f016b7d --- /dev/null +++ b/bindings/sundials4py/test/test_erkstep.py @@ -0,0 +1,48 @@ +#!/bin/python +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import pytest +import numpy as np +from fixtures import * +from sundials4py.core import * +from sundials4py.arkode import * +from problems import AnalyticODE + + +@pytest.mark.skipif( + sunrealtype == np.float32, reason="Test not supported for sunrealtype=np.float32" +) +def test_erkstep(sunctx): + y = N_VNew_Serial(1, sunctx) + ode_problem = AnalyticODE() + ode_problem.set_init_cond(y) + + def rhs(t, y, ydot, _): + return ode_problem.f(t, y, ydot) + + erk = ERKStepCreate(rhs, 0, y, sunctx) + status = ARKodeSStolerances(erk.get(), SUNREALTYPE_RTOL, SUNREALTYPE_ATOL) + assert status == ARK_SUCCESS + + tout = 10.0 + status, tret = ARKodeEvolve(erk.get(), tout, y, ARK_NORMAL) + assert status == ARK_SUCCESS + + sol = N_VClone(y) + ode_problem.solution(y, sol, tret) + assert np.allclose(N_VGetArrayPointer(sol), N_VGetArrayPointer(y), atol=1e-2) diff --git a/bindings/sundials4py/test/test_forcingstep.py b/bindings/sundials4py/test/test_forcingstep.py new file mode 100644 index 0000000000..29f3fcf355 --- /dev/null +++ b/bindings/sundials4py/test/test_forcingstep.py @@ -0,0 +1,65 @@ +#!/bin/python +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + + +import pytest +import numpy as np +from fixtures import * +from sundials4py.core import * +from sundials4py.arkode import * +from problems import AnalyticMultiscaleODE + + +def test_forcingstep(sunctx): + ode_problem = AnalyticMultiscaleODE() + t0, tf = AnalyticMultiscaleODE.T0, 0.01 + + def f_linear(t, y, ydot, _): + return ode_problem.f_linear(t, y, ydot) + + def f_nonlinear(t, y, ydot, _): + return ode_problem.f_nonlinear(t, y, ydot) + + y = N_VNew_Serial(1, sunctx) + y0 = N_VClone(y) + ode_problem.set_init_cond(y) + ode_problem.set_init_cond(y0) + + linear_ark = ERKStepCreate(f_linear, t0, y, sunctx) + status = ARKodeSetFixedStep(linear_ark.get(), 5e-3) + assert status == 0 + + nonlinear_ark = ARKStepCreate(f_nonlinear, None, t0, y, sunctx) + status = ARKodeSetFixedStep(nonlinear_ark.get(), 1e-3) + assert status == 0 + + status, linear_stepper = ARKodeCreateSUNStepper(linear_ark.get()) + status, nonlinear_stepper = ARKodeCreateSUNStepper(nonlinear_ark.get()) + + ark = ForcingStepCreate(linear_stepper, nonlinear_stepper, t0, y, sunctx) + + status = ARKodeSetFixedStep(ark.get(), 1e-2) + assert status == 0 + + tout = tf + status, tret = ARKodeEvolve(ark.get(), tout, y, ARK_NORMAL) + assert status == 0 + + sol = N_VClone(y) + ode_problem.solution(y0, sol, tf) + assert np.allclose(N_VGetArrayPointer(sol), N_VGetArrayPointer(y), atol=1e-2) diff --git a/bindings/sundials4py/test/test_idas.py b/bindings/sundials4py/test/test_idas.py new file mode 100644 index 0000000000..e8ee9b2e8c --- /dev/null +++ b/bindings/sundials4py/test/test_idas.py @@ -0,0 +1,225 @@ +#!/bin/python +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import pytest +import numpy as np +from sundials4py.core import * +from sundials4py.idas import * +from problems import AnalyticDAE +from fixtures import * + + +@pytest.mark.skipif( + sunrealtype == np.float32, reason="Test not supported for sunrealtype=np.float32" +) +def test_idas_ivp(sunctx): + ode_problem = AnalyticDAE() + + ida = IDACreate(sunctx) + yy = N_VNew_Serial(2, sunctx) + yp = N_VNew_Serial(2, sunctx) + + # y and y' initial conditions + ode_problem.set_init_cond(yy, yp, ode_problem.T0) + + ls = SUNLinSol_SPGMR(yy, SUN_PREC_LEFT, 0, sunctx) + + def resfn(t, yy, yp, rr, _): + return ode_problem.res(t, yy, yp, rr) + + def psolve(t, yy, yp, rr, r, z, cj, delta, _): + return ode_problem.psolve(t, yy, yp, rr, r, z, cj, delta) + + status = IDAInit(ida.get(), resfn, 0.0, yy, yp) + assert status == IDA_SUCCESS + + status = IDASStolerances(ida.get(), 1e-4, 1e-4) + assert status == IDA_SUCCESS + + status = IDASetLinearSolver(ida.get(), ls, None) + assert status == IDA_SUCCESS + + status = IDASetPreconditioner(ida.get(), None, psolve) + assert status == IDA_SUCCESS + + nrtfn = 2 + + def rootfn(t, yy, yp, gout, _): + # just a smoke test of the root finding callback + gout[:] = 1.0 + assert len(gout) == nrtfn + return 0 + + status = IDARootInit(ida.get(), nrtfn, rootfn) + assert status == IDA_SUCCESS + + tout = ode_problem.TF + status, tret = IDASolve(ida.get(), tout, yy, yp, IDA_NORMAL) + assert status == IDA_SUCCESS + + status, num_steps = IDAGetNumSteps(ida.get()) + assert status == IDA_SUCCESS + print("Number of steps: ", num_steps) + + sol_yy = N_VClone(yy) + sol_yp = N_VClone(yp) + + ode_problem.solution(sol_yy, sol_yp, tret) + assert np.allclose(N_VGetArrayPointer(sol_yy), N_VGetArrayPointer(yy), rtol=1e-2) + assert np.allclose(N_VGetArrayPointer(sol_yp), N_VGetArrayPointer(yp), rtol=1e-2) + + +def test_idas_fsa(sunctx): + ode_problem = AnalyticDAE() + + ida = IDACreate(sunctx) + yy = N_VNew_Serial(2, sunctx) + yp = N_VNew_Serial(2, sunctx) + + # y and y' initial conditions + ode_problem.set_init_cond(yy, yp, ode_problem.T0) + + ls = SUNLinSol_SPGMR(yy, SUN_PREC_LEFT, 0, sunctx) + + def resfn(t, yy, yp, rr, _): + return ode_problem.res(t, yy, yp, rr) + + def psolve(t, yy, yp, rr, r, z, cj, delta, _): + return ode_problem.psolve(t, yy, yp, rr, r, z, cj, delta) + + status = IDAInit(ida.get(), resfn, 0.0, yy, yp) + assert status == IDA_SUCCESS + + status = IDASStolerances(ida.get(), 1e-4, 1e-4) + assert status == IDA_SUCCESS + + status = IDASetLinearSolver(ida.get(), ls, None) + assert status == IDA_SUCCESS + + status = IDASetPreconditioner(ida.get(), None, psolve) + assert status == IDA_SUCCESS + + # Sensitivity setup + Ns = 1 # number of sensitivities (wrt yy0) + ism = IDA_SIMULTANEOUS + yyS0 = [N_VClone(yy) for _ in range(Ns)] + ypS0 = [N_VClone(yp) for _ in range(Ns)] + for v in yyS0: + N_VConst(0.0, v) + for v in ypS0: + N_VConst(0.0, v) + + def resS(Ns, t, yy, yp, resval, yS, ypS, resvalS, _, tmp1, tmp2, tmp3): + # Sensitivity residuals: d(res)/d(yy) * yyS + d(res)/d(yp) * ypS + # For smoke test, just zero out + for i in range(Ns): + N_VConst(0.0, resvalS[i]) + return 0 + + status = IDASensInit(ida.get(), Ns, ism, resS, yyS0, ypS0) + assert status == IDA_SUCCESS + + status = IDASensSStolerances(ida.get(), 1e-4, np.array([1e-4], dtype=sunrealtype)) + assert status == IDA_SUCCESS + + tout = ode_problem.TF + yySout = [N_VClone(yy) for _ in range(Ns)] + ypSout = [N_VClone(yp) for _ in range(Ns)] + status, tret = IDASolve(ida.get(), tout, yy, yp, IDA_NORMAL) + assert status == IDA_SUCCESS + + # Get sensitivities (smoke test: just check call and shape) + status, tret_sens = IDAGetSens(ida.get(), yySout) + assert status == IDA_SUCCESS + assert len(yySout) == Ns + + +def test_idas_adjoint(sunctx): + ode_problem = AnalyticDAE() + + ida = IDACreate(sunctx) + yy = N_VNew_Serial(2, sunctx) + yp = N_VNew_Serial(2, sunctx) + + # y and y' initial conditions + ode_problem.set_init_cond(yy, yp, ode_problem.T0) + + ls = SUNLinSol_SPGMR(yy, SUN_PREC_LEFT, 0, sunctx) + + def resfn(t, yy, yp, rr, _): + return ode_problem.res(t, yy, yp, rr) + + def psolve(t, yy, yp, rr, r, z, cj, delta, _): + return ode_problem.psolve(t, yy, yp, rr, r, z, cj, delta) + + status = IDAInit(ida.get(), resfn, 0.0, yy, yp) + assert status == IDA_SUCCESS + + status = IDASStolerances(ida.get(), 1e-4, 1e-4) + assert status == IDA_SUCCESS + + status = IDASetLinearSolver(ida.get(), ls, None) + assert status == IDA_SUCCESS + + status = IDASetPreconditioner(ida.get(), None, psolve) + assert status == IDA_SUCCESS + + # Adjoint (backward) problem setup + steps = 5 + interp = IDA_HERMITE + status = IDAAdjInit(ida.get(), steps, interp) + assert status == IDA_SUCCESS + + # Integrate forward + tout = ode_problem.TF + status, tret, ncheck = IDASolveF(ida.get(), tout, yy, yp, IDA_NORMAL) + assert status == IDA_SUCCESS + + # Create backward problem + yyB = N_VClone(yy) + ypB = N_VClone(yp) + N_VConst(0.0, yyB) + N_VConst(0.0, ypB) + lsB = SUNLinSol_SPGMR(yyB, 0, 0, sunctx) + + status, whichB = IDACreateB(ida.get()) + assert status == IDA_SUCCESS + + def resB(t, y, yp, yB, ypB, resvalB, _): + N_VConst(0.0, resvalB) + return 0 + + status = IDAInitB(ida.get(), whichB, resB, tout, yyB, ypB) + assert status == IDA_SUCCESS + + status = IDASStolerancesB(ida.get(), whichB, 1e-4, 1e-4) + assert status == IDA_SUCCESS + + status = IDASetLinearSolverB(ida.get(), whichB, lsB, None) + assert status == IDA_SUCCESS + + # Integrate backward + tB0 = ode_problem.T0 + status = IDASolveB(ida.get(), tB0, IDA_NORMAL) + assert status >= IDA_SUCCESS + + # Get sensitivities + yB0 = N_VClone(yyB) + ypB0 = N_VClone(ypB) + status, yyB0 = IDAGetB(ida.get(), whichB, yB0, ypB0) + assert status == IDA_SUCCESS diff --git a/bindings/sundials4py/test/test_kinsol.py b/bindings/sundials4py/test/test_kinsol.py new file mode 100644 index 0000000000..cf2deda98b --- /dev/null +++ b/bindings/sundials4py/test/test_kinsol.py @@ -0,0 +1,75 @@ +#!/bin/python +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import pytest +import numpy as np +from fixtures import * +from sundials4py.core import * +from sundials4py.kinsol import * +from problems import AnalyticNonlinearSys + + +def test_kinsol(sunctx): + NEQ = 3 + m_aa = 2 + tol = 1e-4 + kin_view = KINCreate(sunctx) + problem = AnalyticNonlinearSys(None) + u = N_VNew_Serial(NEQ, sunctx) + + def fp_function(u, g, _): + return problem.fixed_point_fn(u, g) + + def depth_fn(iter, u_val, g_val, f_val, df, R_mat, depth, _, remove_indices): + if iter < 2: + new_depth = 1 + else: + new_depth = depth + return 0, new_depth + + def damping_fn(iter, u_val, g_val, qt_fn, depth, _): + # smoke test of damping + damping_factor = 1.0 + return 0, damping_factor + + kin_status = KINSetMAA(kin_view.get(), m_aa) + assert kin_status == KIN_SUCCESS + kin_status = KINInit(kin_view.get(), fp_function, u) + assert kin_status == KIN_SUCCESS + kin_status = KINSetFuncNormTol(kin_view.get(), tol) + assert kin_status == KIN_SUCCESS + kin_status = KINSetDepthFn(kin_view.get(), depth_fn) + assert kin_status == KIN_SUCCESS + kin_status = KINSetDampingFn(kin_view.get(), damping_fn) + assert kin_status == KIN_SUCCESS + + # initial guess + u_data = N_VGetArrayPointer(u) + u_data[:] = [0.1, 0.1, -0.1] + + # no scaling used + scale = N_VNew_Serial(NEQ, sunctx) + N_VConst(1.0, scale) + + kin_status = KINSol(kin_view.get(), u, KIN_FP, scale, scale) + assert kin_status == KIN_SUCCESS + + u_expected = N_VNew_Serial(NEQ, sunctx) + u_expected_data = N_VGetArrayPointer(u_expected) + problem.solution(u_expected) + assert np.allclose(u_data, u_expected_data, atol=1e-6) diff --git a/bindings/sundials4py/test/test_lsrkstep.py b/bindings/sundials4py/test/test_lsrkstep.py new file mode 100644 index 0000000000..8e8973868f --- /dev/null +++ b/bindings/sundials4py/test/test_lsrkstep.py @@ -0,0 +1,56 @@ +#!/bin/python +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import pytest +import numpy as np +from fixtures import * +from sundials4py.core import * +from sundials4py.arkode import * +from problems import AnalyticODE + + +@pytest.mark.skipif( + sunrealtype == np.float32, reason="Test not supported for sunrealtype=np.float32" +) +def test_lsrkstep(sunctx): + y = N_VNew_Serial(1, sunctx) + ode_problem = AnalyticODE() + ode_problem.set_init_cond(y) + + def rhs(t, y, ydot, _): + return ode_problem.f(t, y, ydot) + + def dom_eig(t, yvec, fnvec, _, tempv1, tempv2, tempv3): + return ode_problem.dom_eig(t, yvec, fnvec, tempv1, tempv2, tempv3) + + lsrk = LSRKStepCreateSTS(rhs, 0, y, sunctx) + status = LSRKStepSetDomEigFn(lsrk.get(), dom_eig) + assert status == 0 + + status = ARKodeSStolerances(lsrk.get(), SUNREALTYPE_RTOL, SUNREALTYPE_ATOL) + assert status == ARK_SUCCESS + + status = ARKodeSetMaxNumSteps(lsrk.get(), 100000) + + tout = 10.0 + status, tret = ARKodeEvolve(lsrk.get(), tout, y, ARK_NORMAL) + assert status == ARK_SUCCESS + + sol = N_VClone(y) + ode_problem.solution(y, sol, tret) + assert np.allclose(N_VGetArrayPointer(sol), N_VGetArrayPointer(y), atol=1e-2) diff --git a/bindings/sundials4py/test/test_mristep.py b/bindings/sundials4py/test/test_mristep.py new file mode 100644 index 0000000000..041f0779a3 --- /dev/null +++ b/bindings/sundials4py/test/test_mristep.py @@ -0,0 +1,91 @@ +#!/bin/python +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + + +import pytest +import weakref +import numpy as np +from sundials4py.core import * +from sundials4py.arkode import * +from problems import AnalyticMultiscaleODE +from fixtures import sunctx + + +def test_multirate(sunctx): + ode_problem = AnalyticMultiscaleODE() + + t0, tf = AnalyticMultiscaleODE.T0, 0.01 + + def fslow(t, y, ydot, _): + return ode_problem.f_linear(t, y, ydot) + + def ffast(t, y, ydot, _): + # # TODO(CJB): fix MRIStepInnerStepper_GetForcingData + # inner_stepper = ode_problem.inner_stepper + # # test MRIStepInnerStepper_GetForcingData + # status, tshift, tscale, forcing, nforcing = MRIStepInnerStepper_GetForcingData(inner_stepper) + # assert status == ARK_SUCCESS + # assert len(forcing) == nforcing + + return ode_problem.f_nonlinear(t, y, ydot) + + y = N_VNew_Serial(1, sunctx) + y0 = N_VClone(y) + + ode_problem.set_init_cond(y) + ode_problem.set_init_cond(y0) + + # create fast integrator + inner_ark = ERKStepCreate(ffast, t0, y, sunctx) + status = ARKodeSetFixedStep(inner_ark.get(), 5e-3) + assert status == ARK_SUCCESS + + status, inner_stepper = ARKodeCreateMRIStepInnerStepper(inner_ark.get()) + assert status == ARK_SUCCESS + + # store inner_stepper in ode_problem so we can access it in ffast + ode_problem.inner_stepper = inner_stepper + + # create slow integrator + ark = MRIStepCreate(fslow, None, t0, y, inner_stepper, sunctx) + status = ARKodeSetFixedStep(ark.get(), 1e-3) + assert status == ARK_SUCCESS + + tout = tf + status, tret = ARKodeEvolve(ark.get(), tout, y, ARK_NORMAL) + assert status == ARK_SUCCESS + + sol = N_VClone(y) + ode_problem.solution(y, sol, tret) + assert np.allclose(N_VGetArrayPointer(sol), N_VGetArrayPointer(y), atol=1e-2) + + # We must set this to None to ensure inner_stepper can be garbage collected + # If we do not do this, then nanobind will warn that references are leaked. + # This seems to be unavoidable without setting this to None or using a weakref. + # Its possible newer versions of Python may not result in the warning. + ode_problem.inner_stepper = None + + +# Allow the test to be invoked without pytest +def main(): + status, sunctx = SUNContext_Create(SUN_COMM_NULL) + test_multirate(sunctx) + + +if __name__ == "__main__": + main() diff --git a/bindings/sundials4py/test/test_nvector.py b/bindings/sundials4py/test/test_nvector.py new file mode 100644 index 0000000000..316048506c --- /dev/null +++ b/bindings/sundials4py/test/test_nvector.py @@ -0,0 +1,168 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import pytest +import numpy as np +from fixtures import * +from sundials4py.core import * + + +def test_create_manyvector(sunctx): + y = N_VNew_Serial(5, sunctx) + z = N_VNew_Serial(5, sunctx) + + N_VConst(1.0, y) + N_VConst(2.0, z) + + yz = N_VNew_ManyVector(2, [y, z], sunctx) + + yarr = N_VGetArrayPointer(N_VGetSubvector_ManyVector(yz, 0)) + assert np.allclose(N_VGetArrayPointer(y), 1.0) + + zarr = N_VGetArrayPointer(N_VGetSubvector_ManyVector(yz, 1)) + assert np.allclose(N_VGetArrayPointer(z), 2.0) + + N_VConst(3.0, yz) + assert np.allclose(3.0, yarr) + assert np.allclose(3.0, zarr) + + +@pytest.mark.parametrize("vector_type", ["serial"]) +def test_create_nvector(vector_type, sunctx): + if vector_type == "serial": + nvec = N_VNew_Serial(5, sunctx) + else: + raise ValueError("Unknown vector type") + assert nvec is not None + + arr = N_VGetArrayPointer(nvec) + assert arr.shape[0] == 5 + + arr[:] = np.array([5.0, 4.0, 3.0, 2.0, 1.0], dtype=sunrealtype) + assert np.allclose(N_VGetArrayPointer(nvec), [5.0, 4.0, 3.0, 2.0, 1.0]) + + N_VConst(2.0, nvec) + assert np.allclose(arr, 2.0) + + +@pytest.mark.parametrize("vector_type", ["serial"]) +def test_make_nvector(vector_type, sunctx): + arr = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=sunrealtype) + if vector_type == "serial": + nvec = N_VMake_Serial(5, arr, sunctx) + else: + raise ValueError("Unknown vector type") + assert nvec is not None + + assert np.allclose(N_VGetArrayPointer(nvec), arr) + + N_VConst(2.0, nvec) + assert np.allclose(arr, 2.0) + + arr[:] = np.array([5.0, 4.0, 3.0, 2.0, 1.0], dtype=sunrealtype) + assert np.allclose(N_VGetArrayPointer(nvec), [5.0, 4.0, 3.0, 2.0, 1.0]) + + +# Test an operation that involves vector arrays +@pytest.mark.parametrize("vector_type", ["serial"]) +def test_nvlinearcombination(vector_type, sunctx): + if vector_type == "serial": + nvec1 = N_VNew_Serial(5, sunctx) + nvec2 = N_VNew_Serial(5, sunctx) + else: + raise ValueError("Unknown vector type") + + arr1 = N_VGetArrayPointer(nvec1) + arr1[:] = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=sunrealtype) + + arr2 = N_VGetArrayPointer(nvec2) + arr2[:] = np.array([10.0, 20.0, 30.0, 40.0, 50.0], dtype=sunrealtype) + + c = np.array([1.0, 0.1], dtype=sunrealtype) + X = [nvec1, nvec2] + + z = N_VNew_Serial(5, sunctx) + N_VConst(0.0, z) + + N_VLinearCombination(2, c, X, z) + + assert np.allclose(N_VGetArrayPointer(z), [2.0, 4.0, 6.0, 8.0, 10.0]) + + +def test_nvscaleaddmultivectorarray_serial(sunctx): + nvec = 2 + nsum = 2 + length = 3 + + # c_1d shape (nsum,) + c_1d = np.array([2.0, 3.0], dtype=sunrealtype) + + # X_1d shape (nvec,) + X_1d = [N_VNew_Serial(length, sunctx) for _ in range(nvec)] + + for i, x in enumerate(X_1d): + N_VConst(float(i + 1), x) + + # Y_2d shape (nsum, nvec) + Y_2d = [[N_VNew_Serial(length, sunctx) for _ in range(nvec)] for _ in range(nsum)] + for s in range(nsum): + for v in range(nvec): + N_VConst(float((s + 1) * 10 + v), Y_2d[s][v]) + + # Z_2d shape (nsum, nvec) + Z_2d = [[N_VNew_Serial(length, sunctx) for _ in range(nvec)] for _ in range(nsum)] + + err = N_VScaleAddMultiVectorArray(nvec, nsum, c_1d, X_1d, Y_2d, Z_2d) + assert err == 0 + + # Check Z_2d[s][v] = c_1d[s] * X_1d[v] + Y_2d[s][v] + for s in range(nsum): + for v in range(nvec): + expected = c_1d[s] * N_VGetArrayPointer(X_1d[v]) + N_VGetArrayPointer(Y_2d[s][v]) + actual = N_VGetArrayPointer(Z_2d[s][v]) + assert np.allclose(actual, expected) + + +def test_nvlinearcombinationvectorarray_serial(sunctx): + nvec = 2 + nsum = 2 + length = 3 + + # c_1d shape (nsum,) + c_1d = np.array([2.0, 3.0], dtype=sunrealtype) + + # X_2d shape (nsum, nvec) + X_2d = [] + for s in range(nsum): + row = [] + for v in range(nvec): + x = N_VNew_Serial(length, sunctx) + N_VConst(float((s + 1) * 10 + v), x) + row.append(x) + X_2d.append(row) + + # Z_1d shape (nvec,) + Z_1d = [N_VNew_Serial(length, sunctx) for _ in range(nvec)] + + err = N_VLinearCombinationVectorArray(nvec, nsum, c_1d, X_2d, Z_1d) + assert err == 0 + + # Check Z_1d[v] = sum_s c_1d[s] * X_2d[s][v] + for v in range(nvec): + expected = sum(c_1d[s] * N_VGetArrayPointer(X_2d[s][v]) for s in range(nsum)) + actual = N_VGetArrayPointer(Z_1d[v]) + assert np.allclose(actual, expected) diff --git a/bindings/sundials4py/test/test_splittingstep.py b/bindings/sundials4py/test/test_splittingstep.py new file mode 100644 index 0000000000..13bdc7a7d5 --- /dev/null +++ b/bindings/sundials4py/test/test_splittingstep.py @@ -0,0 +1,64 @@ +#!/bin/python +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import pytest +import numpy as np +from fixtures import * +from sundials4py.core import * +from sundials4py.arkode import * +from problems import AnalyticMultiscaleODE + + +def test_splittingstep(sunctx): + ode_problem = AnalyticMultiscaleODE() + t0, tf = AnalyticMultiscaleODE.T0, 0.01 + + def f_linear(t, y, ydot, _): + return ode_problem.f_linear(t, y, ydot) + + def f_nonlinear(t, y, ydot, _): + return ode_problem.f_nonlinear(t, y, ydot) + + y = N_VNew_Serial(1, sunctx) + y0 = N_VNew_Serial(1, sunctx) + ode_problem.set_init_cond(y) + ode_problem.set_init_cond(y0) + + linear_ark = ERKStepCreate(f_linear, t0, y, sunctx) + status = ARKodeSetFixedStep(linear_ark.get(), 5e-3) + assert status == 0 + + nonlinear_ark = ARKStepCreate(f_nonlinear, None, t0, y, sunctx) + status = ARKodeSetFixedStep(nonlinear_ark.get(), 1e-3) + assert status == 0 + + status, linear_stepper = ARKodeCreateSUNStepper(linear_ark.get()) + status, nonlinear_stepper = ARKodeCreateSUNStepper(nonlinear_ark.get()) + + steppers = [linear_stepper, nonlinear_stepper] + ark = SplittingStepCreate(steppers, len(steppers), t0, y, sunctx) + status = ARKodeSetFixedStep(ark.get(), 1e-2) + assert status == 0 + + tout = tf + status, tret = ARKodeEvolve(ark.get(), tout, y, ARK_NORMAL) + assert status == 0 + + sol = N_VClone(y) + ode_problem.solution(y0, sol, tf) + assert np.allclose(N_VGetArrayPointer(sol), N_VGetArrayPointer(y), atol=1e-2) diff --git a/bindings/sundials4py/test/test_sprkstep.py b/bindings/sundials4py/test/test_sprkstep.py new file mode 100644 index 0000000000..6afdcd9560 --- /dev/null +++ b/bindings/sundials4py/test/test_sprkstep.py @@ -0,0 +1,54 @@ +#!/bin/python +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import pytest +import numpy as np +from fixtures import * +from sundials4py.core import * +from sundials4py.arkode import * +from problems import HarmonicOscillatorODE + + +def test_sprkstep(sunctx): + tout = 2 * np.pi + dt = 0.01 + y = N_VNew_Serial(2, sunctx) + ode_problem = HarmonicOscillatorODE() + + def f1(t, y, ydot, _): + return ode_problem.xdot(t, y, ydot) + + def f2(t, y, ydot, _): + return ode_problem.vdot(t, y, ydot) + + ode_problem.set_init_cond(y) + + sprk = SPRKStepCreate(f1, f2, 0, y, sunctx) + + status = ARKodeSetFixedStep(sprk.get(), dt) + assert status == ARK_SUCCESS + + status = ARKodeSetMaxNumSteps(sprk.get(), int(np.ceil(tout / dt))) + assert status == ARK_SUCCESS + + status, tret = ARKodeEvolve(sprk.get(), tout, y, ARK_NORMAL) + assert status == ARK_SUCCESS + + sol = N_VClone(y) + ode_problem.solution(y, sol, tret) + assert np.allclose(N_VGetArrayPointer(sol), N_VGetArrayPointer(y), atol=1e-2) diff --git a/bindings/sundials4py/test/test_sunadaptcontroller.py b/bindings/sundials4py/test/test_sunadaptcontroller.py new file mode 100644 index 0000000000..649f66145d --- /dev/null +++ b/bindings/sundials4py/test/test_sunadaptcontroller.py @@ -0,0 +1,103 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# Unit/smoke tests for SUNAdaptController module +# ----------------------------------------------------------------- + +import pytest +from fixtures import * +from sundials4py.core import * + + +def make_controller(controller_type, sunctx): + if controller_type == "soderlind": + c = SUNAdaptController_Soderlind(sunctx) + return c, None, None + elif controller_type == "imexgus": + c = SUNAdaptController_ImExGus(sunctx) + return c, None, None + elif controller_type == "mrihtol": + c1 = SUNAdaptController_ImExGus(sunctx) + c2 = SUNAdaptController_Soderlind(sunctx) + c = SUNAdaptController_MRIHTol(c1, c2, sunctx) + return c, c1, c2 + else: + raise ValueError("Unknown controller type") + + +@pytest.mark.parametrize("controller_type", ["soderlind", "imexgus", "mrihtol"]) +def test_create_controller(controller_type, sunctx): + c, c1, c2 = make_controller(controller_type, sunctx) + assert c is not None + + +@pytest.mark.parametrize("controller_type", ["soderlind", "imexgus", "mrihtol"]) +def test_get_type(controller_type, sunctx): + c, c1, c2 = make_controller(controller_type, sunctx) + t = SUNAdaptController_GetType(c) + assert isinstance(t, int) + + +@pytest.mark.parametrize("controller_type", ["soderlind", "imexgus", "mrihtol"]) +def test_estimate_step(controller_type, sunctx): + c, c1, c2 = make_controller(controller_type, sunctx) + err, hnew = SUNAdaptController_EstimateStep(c, 1.0, 1, 0.1) + assert isinstance(err, int) + assert isinstance(hnew, float) + + +@pytest.mark.parametrize("controller_type", ["soderlind", "imexgus", "mrihtol"]) +def test_reset(controller_type, sunctx): + c, c1, c2 = make_controller(controller_type, sunctx) + status = SUNAdaptController_Reset(c) + assert status == 0 + + +@pytest.mark.parametrize("controller_type", ["soderlind", "imexgus", "mrihtol"]) +def test_set_defaults(controller_type, sunctx): + c, c1, c2 = make_controller(controller_type, sunctx) + status = SUNAdaptController_SetDefaults(c) + assert status == 0 + + +@pytest.mark.parametrize("controller_type", ["soderlind", "imexgus", "mrihtol"]) +def test_set_error_bias(controller_type, sunctx): + c, c1, c2 = make_controller(controller_type, sunctx) + status = SUNAdaptController_SetErrorBias(c, 1.0) + assert status == 0 + + +@pytest.mark.parametrize("controller_type", ["soderlind", "imexgus", "mrihtol"]) +def test_update_h(controller_type, sunctx): + c, c1, c2 = make_controller(controller_type, sunctx) + status = SUNAdaptController_UpdateH(c, 1.0, 0.1) + assert status == 0 + + +@pytest.mark.parametrize("controller_type", ["soderlind", "imexgus", "mrihtol"]) +def test_estimate_step_tol(controller_type, sunctx): + c, c1, c2 = make_controller(controller_type, sunctx) + err, hnew, tolfacnew = SUNAdaptController_EstimateStepTol(c, 1.0, 1.0, 1, 0.1, 0.1) + assert isinstance(err, int) + assert isinstance(hnew, float) + assert isinstance(tolfacnew, float) + + +@pytest.mark.parametrize("controller_type", ["mrihtol"]) +def test_update_mrihtol(controller_type, sunctx): + c, c1, c2 = make_controller(controller_type, sunctx) + status = SUNAdaptController_UpdateMRIHTol(c, 1.0, 1.0, 0.1, 0.1) + assert status == 0 diff --git a/bindings/sundials4py/test/test_sunadjointcheckpointscheme.py b/bindings/sundials4py/test/test_sunadjointcheckpointscheme.py new file mode 100644 index 0000000000..fa47adfa34 --- /dev/null +++ b/bindings/sundials4py/test/test_sunadjointcheckpointscheme.py @@ -0,0 +1,74 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# Unit/smoke tests for SUNAdjointCheckpointScheme module +# ----------------------------------------------------------------- + +import pytest +from fixtures import * +from sundials4py.core import * + + +def make_fixed_scheme(sunctx): + io_mode = SUNDATAIOMODE_INMEM + mem_helper = SUNMemoryHelper_Sys(sunctx) + interval = 1 + estimate = 1 + keep = 0 + status, scheme = SUNAdjointCheckpointScheme_Create_Fixed( + io_mode, mem_helper, interval, estimate, keep, sunctx + ) + # must return mem_helper or it will get cleaned up + return status, scheme, mem_helper + + +def test_needs_saving(sunctx): + scheme_status, scheme, mem_helper = make_fixed_scheme(sunctx) + step_num = 0 + stage_num = 0 + t = 0.0 + status, result = SUNAdjointCheckpointScheme_NeedsSaving(scheme, step_num, stage_num, t) + assert status == 0 + assert isinstance(result, int) + + +def test_insert_vector(sunctx, nvec): + scheme_status, scheme, mem_helper = make_fixed_scheme(sunctx) + step_num = 0 + stage_num = 0 + t = 0.0 + status = SUNAdjointCheckpointScheme_InsertVector(scheme, step_num, stage_num, t, nvec) + assert status == 0 + + +# def test_load_vector(sunctx, nvec): +# scheme_status, scheme, mem_helper = make_fixed_scheme(sunctx) + +# step_num = 0 +# stage_num = 0 +# t = 0.0 +# status = SUNAdjointCheckpointScheme_InsertVector(scheme, step_num, stage_num, t, nvec) +# assert status == 0 + +# peek = False +# status, vec, tout = SUNAdjointCheckpointScheme_LoadVector(scheme, step_num, stage_num, peek) +# assert status == 0 + + +def test_enable_dense(sunctx): + scheme_status, scheme, mem_helper = make_fixed_scheme(sunctx) + status = SUNAdjointCheckpointScheme_EnableDense(scheme, True) + assert status == 0 diff --git a/bindings/sundials4py/test/test_sunadjointstepper.py b/bindings/sundials4py/test/test_sunadjointstepper.py new file mode 100644 index 0000000000..dc4b60b37f --- /dev/null +++ b/bindings/sundials4py/test/test_sunadjointstepper.py @@ -0,0 +1,92 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# Unit/smoke tests for SUNAdjointStepper module +# ----------------------------------------------------------------- + +import pytest +from fixtures import * +from sundials4py.core import * + + +def make_adjoint_stepper(sunctx, sunstepper, nvec): + mem_helper = SUNMemoryHelper_Sys(sunctx) + status, scheme = SUNAdjointCheckpointScheme_Create_Fixed(0, mem_helper, 1, 1, 0, sunctx) + b1 = 0 + b2 = 0 + nsteps = 1 + t0 = 0.0 + y0 = nvec + status, adj_stepper = SUNAdjointStepper_Create( + sunstepper, b1, sunstepper, b2, nsteps, t0, y0, scheme, sunctx + ) + return adj_stepper, scheme, mem_helper + + +def test_create_adjoint_stepper(sunctx, nvec, sunstepper): + adj_stepper, scheme, mem_helper = make_adjoint_stepper(sunctx, sunstepper, nvec) + assert adj_stepper is not None + + +def test_adjointstepper_reinit(sunctx, nvec, sunstepper): + adj_stepper, scheme, mem_helper = make_adjoint_stepper(sunctx, sunstepper, nvec) + t0 = 0.0 + y0 = nvec + tf = 1.0 + err = SUNAdjointStepper_ReInit(adj_stepper, t0, y0, tf, nvec) + assert isinstance(err, int) + + +def test_adjointstepper_evolve(sunctx, nvec, sunstepper): + adj_stepper, scheme, mem_helper = make_adjoint_stepper(sunctx, sunstepper, nvec) + tout = 1.0 + sens = nvec + err, tret = SUNAdjointStepper_Evolve(adj_stepper, tout, sens) + assert isinstance(err, int) + assert isinstance(tret, float) + + +def test_adjointstepper_onestep(sunctx, nvec, sunstepper): + adj_stepper, scheme, mem_helper = make_adjoint_stepper(sunctx, sunstepper, nvec) + tout = 1.0 + sens = nvec + err, tret = SUNAdjointStepper_OneStep(adj_stepper, tout, sens) + assert isinstance(err, int) + assert isinstance(tret, float) + + +def test_adjointstepper_recomputefwd(sunctx, nvec, sunstepper): + adj_stepper, scheme, mem_helper = make_adjoint_stepper(sunctx, sunstepper, nvec) + start_idx = 0 + t0 = 0.0 + y0 = nvec + tf = 1.0 + err = SUNAdjointStepper_RecomputeFwd(adj_stepper, start_idx, t0, y0, tf) + assert isinstance(err, int) + + +def test_adjointstepper_getnumsteps(sunctx, nvec, sunstepper): + adj_stepper, scheme, mem_helper = make_adjoint_stepper(sunctx, sunstepper, nvec) + err, num_steps = SUNAdjointStepper_GetNumSteps(adj_stepper) + assert isinstance(err, int) + assert isinstance(num_steps, int) + + +def test_adjointstepper_getnumrecompute(sunctx, nvec, sunstepper): + adj_stepper, scheme, mem_helper = make_adjoint_stepper(sunctx, sunstepper, nvec) + err, num_recompute = SUNAdjointStepper_GetNumRecompute(adj_stepper) + assert isinstance(err, int) + assert isinstance(num_recompute, int) diff --git a/bindings/sundials4py/test/test_suncontext.py b/bindings/sundials4py/test/test_suncontext.py new file mode 100644 index 0000000000..92148c60c9 --- /dev/null +++ b/bindings/sundials4py/test/test_suncontext.py @@ -0,0 +1,56 @@ +#!/bin/python + +import pytest +import numpy as np +from sundials4py.core import * + + +def test_with_null_comm(): + # Create a new context with a null comm + err, sunctx = SUNContext_Create(SUN_COMM_NULL) + assert err == SUN_SUCCESS + + # Try calling a SUNContext_ function + last_err = SUNContext_GetLastError(sunctx) + + assert last_err == SUN_SUCCESS + + +def test_push_pop_err_handlers(): + # Create a new context with a null comm + err, sunctx = SUNContext_Create(SUN_COMM_NULL) + assert err == SUN_SUCCESS + + called = {"err_fn1": False, "err_fn2": False} + + def err_fn1(line, func_name, file_name, msg, err_code, _, sunctx): + # err_fn2 should already be called since it was pushed second + assert called["err_fn2"] + called["err_fn1"] = True + + def err_fn2(line, func_name, file_name, msg, err_code, _, sunctx): + called["err_fn2"] = True + + status = SUNContext_PushErrHandler(sunctx, err_fn1) + assert status == SUN_SUCCESS + + status = SUNContext_PushErrHandler(sunctx, err_fn2) + assert status == SUN_SUCCESS + + SUNContext_TestErrHandler(sunctx) + assert called["err_fn1"] + + called = {"err_fn1": False, "err_fn2": False} + + status = SUNContext_PopErrHandler(sunctx) + assert status == SUN_SUCCESS + + status = SUNContext_PopErrHandler(sunctx) + assert status == SUN_SUCCESS + + SUNContext_TestErrHandler(sunctx) + assert not called["err_fn1"] + + # Popping again should do nothing + status = SUNContext_PopErrHandler(sunctx) + assert status == SUN_SUCCESS diff --git a/bindings/sundials4py/test/test_sundomeigest.py b/bindings/sundials4py/test/test_sundomeigest.py new file mode 100644 index 0000000000..4d2fee58b6 --- /dev/null +++ b/bindings/sundials4py/test/test_sundomeigest.py @@ -0,0 +1,115 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# Unit/smoke tests for SUNDomEigEstimator module +# ----------------------------------------------------------------- + +import pytest +from fixtures import * +from sundials4py.core import * + +# Note: some of these tests will fail if SUNDIALS error checks are turned on because +# we dont properly mock some of the requirements + + +def make_estimator(estimator_type, sunctx): + if estimator_type == "power": + nvec = N_VNew_Serial(5, sunctx) + e = SUNDomEigEstimator_Power(nvec, 10, 1.0, sunctx) + + def atimes(_, v, z): + # dummy atimes for smoke testing + return 0 + + SUNDomEigEstimator_SetATimes(e, atimes) + + return e, nvec + else: + raise ValueError("Unknown estimator type") + + +@pytest.mark.parametrize("estimator_type", ["power"]) +def test_create_estimator(estimator_type, sunctx): + est, nvec = make_estimator(estimator_type, sunctx) + assert est is not None + + +@pytest.mark.parametrize("estimator_type", ["power"]) +def test_set_max_iters(estimator_type, sunctx): + est, nvec = make_estimator(estimator_type, sunctx) + status = SUNDomEigEstimator_SetMaxIters(est, 10) + assert status == 0 + + +@pytest.mark.parametrize("estimator_type", ["power"]) +def test_set_num_preprocess_iters(estimator_type, sunctx): + est, nvec = make_estimator(estimator_type, sunctx) + status = SUNDomEigEstimator_SetNumPreprocessIters(est, 2) + assert status == 0 + + +@pytest.mark.parametrize("estimator_type", ["power"]) +def test_set_reltol(estimator_type, sunctx): + est, nvec = make_estimator(estimator_type, sunctx) + status = SUNDomEigEstimator_SetRelTol(est, 1e-6) + assert status == 0 + + +@pytest.mark.parametrize("estimator_type", ["power"]) +def test_set_initial_guess(estimator_type, sunctx): + est, nvec = make_estimator(estimator_type, sunctx) + status = SUNDomEigEstimator_SetInitialGuess(est, nvec) + assert status == 0 + + +@pytest.mark.parametrize("estimator_type", ["power"]) +def test_initialize(estimator_type, sunctx): + est, nvec = make_estimator(estimator_type, sunctx) + status = SUNDomEigEstimator_Initialize(est) + assert status == 0 + + +@pytest.mark.parametrize("estimator_type", ["power"]) +def test_estimate(estimator_type, sunctx): + est, nvec = make_estimator(estimator_type, sunctx) + err, lambdaR, lambdaI = SUNDomEigEstimator_Estimate(est) + assert isinstance(err, int) + assert isinstance(lambdaR, float) + assert isinstance(lambdaI, float) + + +@pytest.mark.parametrize("estimator_type", ["power"]) +def test_get_res(estimator_type, sunctx): + est, nvec = make_estimator(estimator_type, sunctx) + err, res = SUNDomEigEstimator_GetRes(est) + assert isinstance(err, int) + assert isinstance(res, float) + + +@pytest.mark.parametrize("estimator_type", ["power"]) +def test_get_num_iters(estimator_type, sunctx): + est, nvec = make_estimator(estimator_type, sunctx) + err, num_iters = SUNDomEigEstimator_GetNumIters(est) + assert isinstance(err, int) + assert isinstance(num_iters, int) + + +@pytest.mark.parametrize("estimator_type", ["power"]) +def test_get_num_atimes_calls(estimator_type, sunctx): + est, nvec = make_estimator(estimator_type, sunctx) + err, num_atimes = SUNDomEigEstimator_GetNumATimesCalls(est) + assert isinstance(err, int) + assert isinstance(num_atimes, int) diff --git a/bindings/sundials4py/test/test_sunlinearsolver.py b/bindings/sundials4py/test/test_sunlinearsolver.py new file mode 100644 index 0000000000..898fe2d8e6 --- /dev/null +++ b/bindings/sundials4py/test/test_sunlinearsolver.py @@ -0,0 +1,144 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- + +import pytest +from fixtures import * +from sundials4py.core import * + +# Note: some of these tests will fail if SUNDIALS error checks are turned on because +# we dont properly mock some of the requirements + + +def test_create_dense(sunctx, nvec): + A = SUNDenseMatrix(2, 2, sunctx) + LS = SUNLinSol_Dense(nvec, A, sunctx) + assert LS is not None + + +def test_create_band(sunctx, nvec): + A = SUNBandMatrix(2, 1, 1, sunctx) + LS = SUNLinSol_Band(nvec, A, sunctx) + assert LS is not None + + +def test_create_spgmr(sunctx, nvec): + LS = SUNLinSol_SPGMR(nvec, SUN_PREC_NONE, 0, sunctx) + assert LS is not None + + +def test_create_pcg(sunctx, nvec): + LS = SUNLinSol_PCG(nvec, SUN_PREC_NONE, 0, sunctx) + assert LS is not None + + +def test_create_spbcgs(sunctx, nvec): + LS = SUNLinSol_SPBCGS(nvec, SUN_PREC_NONE, 0, sunctx) + assert LS is not None + + +def test_create_sptfqmr(sunctx, nvec): + LS = SUNLinSol_SPTFQMR(nvec, SUN_PREC_NONE, 0, sunctx) + assert LS is not None + + +def test_get_type_and_id(sunctx, nvec): + A = SUNDenseMatrix(2, 2, sunctx) + LS = SUNLinSol_Dense(nvec, A, sunctx) + typ = SUNLinSolGetType(LS) + id_ = SUNLinSolGetID(LS) + assert isinstance(typ, int) + assert isinstance(id_, int) + + +def test_initialize_setup(sunctx, nvec): + A = SUNDenseMatrix(2, 2, sunctx) + LS = SUNLinSol_Dense(nvec, A, sunctx) + ret_init = SUNLinSolInitialize(LS) + ret_setup = SUNLinSolSetup(LS, A) + assert isinstance(ret_init, int) + assert isinstance(ret_setup, int) + + +def test_num_iters_resnorm_lastflag(sunctx, nvec): + LS = SUNLinSol_SPGMR(nvec, 0, 0, sunctx) + niters = SUNLinSolNumIters(LS) + resnorm = SUNLinSolResNorm(LS) + lastflag = SUNLinSolLastFlag(LS) + assert isinstance(niters, int) + assert isinstance(resnorm, float) + assert isinstance(lastflag, int) + + +def test_sunlinsol_set_atimes(sunctx): + x = N_VNew_Serial(1, sunctx) + y = N_VNew_Serial(1, sunctx) + + # Create a simple dense matrix and linear solver + LS = SUNLinSol_SPGMR(x, 0, 0, sunctx) + assert LS is not None + + # Define a dummy ATimes function + called = {"flag": False} + + def atimes(LS, x, y): + called["flag"] = True + return 0 + + # Set the ATimes function + ret = SUNLinSolSetATimes(LS, atimes) + assert isinstance(ret, int) + + SUNLinSolInitialize(LS) + SUNLinSolSetup(LS, None) + SUNLinSolSolve(LS, None, x, y, 1e-2) + assert called["flag"] + + +def test_sunlinsol_set_preconditioner(sunctx): + x = N_VNew_Serial(1, sunctx) + y = N_VNew_Serial(1, sunctx) + + # Create a simple dense matrix and linear solver + LS = SUNLinSol_SPGMR(x, SUN_PREC_LEFT, 0, sunctx) + assert LS is not None + + def atimes(_, x, y): + return 0 + + # Define a dummy preconditioner functions + called = {"psetup": False, "psolve": False} + + def psetup(_): + called["psetup"] = True + return 0 + + def psolve(_, r, z, tol, lr): + called["psolve"] = True + return 0 + + # Set the ATimes function + ret = SUNLinSolSetATimes(LS, atimes) + assert ret == SUN_SUCCESS + + ret = SUNLinSolSetPreconditioner(LS, psetup, psolve) + assert ret == SUN_SUCCESS + + SUNLinSolInitialize(LS) + SUNLinSolSetup(LS, None) + SUNLinSolSolve(LS, None, x, y, 1e-2) + assert called["psetup"] + assert called["psolve"] diff --git a/bindings/sundials4py/test/test_sunmemoryhelper.py b/bindings/sundials4py/test/test_sunmemoryhelper.py new file mode 100644 index 0000000000..53003ff25b --- /dev/null +++ b/bindings/sundials4py/test/test_sunmemoryhelper.py @@ -0,0 +1,27 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# Unit/smoke tests for SUNMemoryHelper module +# ----------------------------------------------------------------- + +import pytest +from fixtures import * +from sundials4py.core import * + + +def test_create_memory_helper_sys(sunctx): + mem_helper = SUNMemoryHelper_Sys(sunctx) # noqa: F405 + assert mem_helper is not None diff --git a/bindings/sundials4py/test/test_sunnonlinearsolver.py b/bindings/sundials4py/test/test_sunnonlinearsolver.py new file mode 100644 index 0000000000..6dceef3019 --- /dev/null +++ b/bindings/sundials4py/test/test_sunnonlinearsolver.py @@ -0,0 +1,147 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# Unit/smoke tests for SUNNonlinearSolver module +# ----------------------------------------------------------------- + +import pytest +import numpy as np +from fixtures import * +from sundials4py.core import * + +# Note: some of these tests will fail if SUNDIALS error checks are turned on because +# we dont properly mock some of the requirements + + +@pytest.mark.parametrize("solver_type", ["fixedpoint", "newton"]) +def test_create_solver(solver_type, sunctx, nvec): + if solver_type == "fixedpoint": + nls = SUNNonlinSol_FixedPoint(nvec, 5, sunctx) + elif solver_type == "newton": + nls = SUNNonlinSol_Newton(nvec, sunctx) + else: + raise ValueError("Unknown solver type") + assert nls is not None + + +@pytest.mark.parametrize("solver_type", ["fixedpoint", "newton"]) +def make_solver(solver_type, sunctx, nvec): + if solver_type == "fixedpoint": + return SUNNonlinSol_FixedPoint(nvec, 5, sunctx) + elif solver_type == "newton": + return SUNNonlinSol_Newton(nvec, sunctx) + else: + raise ValueError("Unknown solver type") + + +@pytest.mark.parametrize( + "solver_type, expected_type", + [("newton", SUNNONLINEARSOLVER_ROOTFIND), ("fixedpoint", SUNNONLINEARSOLVER_FIXEDPOINT)], +) +def test_gettype(solver_type, expected_type, sunctx, nvec): + nls = make_solver(solver_type, sunctx, nvec) + typ = SUNNonlinSolGetType(nls) + assert typ is expected_type + + +@pytest.mark.parametrize("solver_type", ["fixedpoint", "newton"]) +def test_initialize(solver_type, sunctx, nvec): + nls = make_solver(solver_type, sunctx, nvec) + ret = SUNNonlinSolInitialize(nls) + assert ret == 0 + + +@pytest.mark.parametrize("solver_type,max_iters", [("newton", 5), ("fixedpoint", 10)]) +def test_set_max_iters_and_get_num_iters(solver_type, max_iters, sunctx, nvec): + nls = make_solver(solver_type, sunctx, nvec) + ret = SUNNonlinSolSetMaxIters(nls, max_iters) + assert ret == 0 + err, niters = SUNNonlinSolGetNumIters(nls) + assert err == 0 + assert isinstance(niters, int) + + +@pytest.mark.parametrize("solver_type", ["fixedpoint", "newton"]) +def test_get_cur_iter(solver_type, sunctx, nvec): + nls = make_solver(solver_type, sunctx, nvec) + err, cur_iter = SUNNonlinSolGetCurIter(nls) + assert err == 0 + assert isinstance(cur_iter, int) + + +@pytest.mark.parametrize("solver_type", ["fixedpoint", "newton"]) +def test_get_num_conv_fails(solver_type, sunctx, nvec): + nls = make_solver(solver_type, sunctx, nvec) + err, nconvfails = SUNNonlinSolGetNumConvFails(nls) + assert err == 0 + assert isinstance(nconvfails, int) + + +def test_fixedpoint_setup_and_solve(sunctx): + from problems import AnalyticNonlinearSys + + NEQ = AnalyticNonlinearSys.NEQ + ucor = N_VNew_Serial(NEQ, sunctx) + u0 = N_VNew_Serial(NEQ, sunctx) + w = N_VNew_Serial(NEQ, sunctx) + ucur = N_VNew_Serial(NEQ, sunctx) + + # Initial guess + udata = N_VGetArrayPointer(u0) + udata[:] = [0.1, 0.1, -0.1] + + # Initial correction + N_VConst(0.0, ucor) + + # Set the weights + N_VConst(1.0, w) + + # Create the problem + with AnalyticNonlinearSys(u0) as problem: + + # Create the solver + nls = SUNNonlinSol_FixedPoint(u0, 2, sunctx) + + # System function + def g_fn(u, g, _): + return problem.corrector_fp_fn(u, g) + + # Convergence test + def conv_test(nls, u, delta, tol, ewt, _): + return problem.conv_test(nls, u, delta, tol, ewt) + + ret = SUNNonlinSolSetSysFn(nls, g_fn) + assert ret == 0 + + ret = SUNNonlinSolSetConvTestFn(nls, conv_test) + assert ret == 0 + + ret = SUNNonlinSolSetMaxIters(nls, 50) + + ret = SUNNonlinSolSetup(nls, u0) + assert ret == 0 + + tol = 1e-10 + ret = SUNNonlinSolSolve(nls, u0, ucor, w, tol, 0) + assert ret == 0 + + # Update the initial guess with the correction + N_VLinearSum(1.0, u0, 1.0, ucor, ucur) + + # Compare to analytic solution + utrue = N_VNew_Serial(NEQ, sunctx) + problem.solution(utrue) + assert np.allclose(N_VGetArrayPointer(ucur), N_VGetArrayPointer(utrue), atol=1e-2) diff --git a/bindings/sundials4py/test/test_sunstepper.py b/bindings/sundials4py/test/test_sunstepper.py new file mode 100644 index 0000000000..72e62a08d6 --- /dev/null +++ b/bindings/sundials4py/test/test_sunstepper.py @@ -0,0 +1,203 @@ +# ----------------------------------------------------------------- +# Programmer(s): Cody J. Balos @ LLNL +# ----------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025-2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------- +# Unit/smoke tests for SUNStepper module +# ----------------------------------------------------------------- + +import pytest +from fixtures import * +from sundials4py.core import * + + +def make_stepper(sunctx): + # Create am empty stepper + status, s = SUNStepper_Create(sunctx) + assert status == SUN_SUCCESS + return s + + +def test_create_stepper(sunctx): + s = make_stepper(sunctx) + assert s is not None + + +def test_stepper_evolve(sunctx, nvec): + s = make_stepper(sunctx) + vret = nvec + err, tret = SUNStepper_Evolve(s, 1.0, vret) + assert isinstance(err, int) + assert isinstance(tret, float) + + +def test_stepper_one_step(sunctx, nvec): + s = make_stepper(sunctx) + vret = nvec + err, tret = SUNStepper_OneStep(s, 1.0, vret) + assert isinstance(err, int) + assert isinstance(tret, float) + + +def test_stepper_reset(sunctx, nvec): + s = make_stepper(sunctx) + err = SUNStepper_Reset(s, 0.0, nvec) + assert isinstance(err, int) + + +def test_stepper_set_evolve_fn(sunctx, nvec): + s = make_stepper(sunctx) + called = {"flag": False} + + def evolve_fn(stepper, tout, vret, tret): + called["flag"] = True + return 0 + + err = SUNStepper_SetEvolveFn(s, evolve_fn) + assert err == 0 + # Call evolve to trigger callback + SUNStepper_Evolve(s, 1.0, nvec) + assert called["flag"] + + +def test_stepper_set_one_step_fn(sunctx, nvec): + s = make_stepper(sunctx) + called = {"flag": False} + + def one_step_fn(stepper, tout, vret, tret): + called["flag"] = True + return 0 + + err = SUNStepper_SetOneStepFn(s, one_step_fn) + assert err == 0 + SUNStepper_OneStep(s, 1.0, nvec) + assert called["flag"] + + +def test_stepper_set_full_rhs_fn(sunctx, nvec): + s = make_stepper(sunctx) + called = {"flag": False} + + def full_rhs_fn(stepper, t, v, f, mode): + called["flag"] = True + return 0 + + err = SUNStepper_SetFullRhsFn(s, full_rhs_fn) + assert err == 0 + # Call with dummy args + SUNStepper_FullRhs(s, 0.0, nvec, nvec, 0) + assert called["flag"] + + +def test_stepper_set_reinit_fn(sunctx, nvec): + s = make_stepper(sunctx) + called = {"flag": False} + + def reinit_fn(stepper, t, y): + called["flag"] = True + return 0 + + err = SUNStepper_SetReInitFn(s, reinit_fn) + assert err == 0 + SUNStepper_ReInit(s, 0.0, nvec) + assert called["flag"] + + +def test_stepper_set_reset_fn(sunctx, nvec): + s = make_stepper(sunctx) + called = {"flag": False} + + def reset_fn(stepper, t, y): + called["flag"] = True + return 0 + + err = SUNStepper_SetResetFn(s, reset_fn) + assert err == 0 + SUNStepper_Reset(s, 0.0, nvec) + assert called["flag"] + + +def test_stepper_set_reset_ckpt_idx_fn(sunctx): + s = make_stepper(sunctx) + called = {"flag": False} + + def reset_ckpt_idx_fn(stepper, idx): + called["flag"] = True + return 0 + + err = SUNStepper_SetResetCheckpointIndexFn(s, reset_ckpt_idx_fn) + assert err == 0 + SUNStepper_ResetCheckpointIndex(s, 1) + assert called["flag"] + + +def test_stepper_set_stop_time_fn(sunctx): + s = make_stepper(sunctx) + called = {"flag": False} + + def stop_time_fn(stepper, tstop): + called["flag"] = True + return 0 + + err = SUNStepper_SetStopTimeFn(s, stop_time_fn) + assert err == 0 + SUNStepper_SetStopTime(s, 2.0) + assert called["flag"] + + +def test_stepper_set_step_direction_fn(sunctx): + s = make_stepper(sunctx) + called = {"flag": False} + + def step_direction_fn(stepper, direction): + called["flag"] = True + return 0 + + err = SUNStepper_SetStepDirectionFn(s, step_direction_fn) + assert err == 0 + SUNStepper_SetStepDirection(s, 1) + assert called["flag"] + + +def test_stepper_set_forcing_fn(sunctx, nvec): + s = make_stepper(sunctx) + called = {"flag": False} + + def forcing_fn(stepper, tshift, tscale, forcing, nforcing): + assert type(forcing) is list + assert len(forcing) == nforcing + called["flag"] = True + return 0 + + err = SUNStepper_SetForcingFn(s, forcing_fn) + assert err == 0 + SUNStepper_SetForcing(s, 0.0, 1.0, [nvec], 1) + assert called["flag"] + + +def test_stepper_set_get_num_steps_fn(sunctx): + s = make_stepper(sunctx) + called = {"flag": False} + + def get_num_steps_fn(stepper): + called["flag"] = True + nst = 1 + return 0, nst + + err = SUNStepper_SetGetNumStepsFn(s, get_num_steps_fn) + assert err == 0 + status, nst = SUNStepper_GetNumSteps(s) + assert called["flag"] + assert isinstance(nst, int) + assert nst == 1 diff --git a/cmake/SundialsBuildOptionsPre.cmake b/cmake/SundialsBuildOptionsPre.cmake index 6a7c3be6ec..7535662521 100644 --- a/cmake/SundialsBuildOptionsPre.cmake +++ b/cmake/SundialsBuildOptionsPre.cmake @@ -233,6 +233,13 @@ if(BUILD_FORTRAN_MODULE_INTERFACE) sundials_option(Fortran_INSTALL_MODDIR STRING "${DOCSTR}" "fortran") endif() +# --------------------------------------------------------------- +# Options to enable Python interfaces. +# --------------------------------------------------------------- + +set(DOCSTR "Enable Python interfaces") +sundials_option(SUNDIALS_ENABLE_PYTHON BOOL "${DOCSTR}" OFF) + # --------------------------------------------------------------- # Options for benchmark suite # --------------------------------------------------------------- diff --git a/cmake/SundialsSetupCXX.cmake b/cmake/SundialsSetupCXX.cmake index 1891398177..b9ed14b286 100644 --- a/cmake/SundialsSetupCXX.cmake +++ b/cmake/SundialsSetupCXX.cmake @@ -31,7 +31,9 @@ set(CXX_FOUND TRUE) sundials_option(CMAKE_CXX_STANDARD_REQUIRED BOOL "Require C++ standard version" ON) -if(ENABLE_SYCL OR ENABLE_GINKGO) +if(SUNDIALS_ENABLE_PYTHON + OR ENABLE_SYCL + OR ENABLE_GINKGO) set(DOCSTR "The C++ standard to use if C++ is enabled (17, 20, 23)") sundials_option(CMAKE_CXX_STANDARD STRING "${DOCSTR}" "17" OPTIONS "17;20;23") else() @@ -45,7 +47,20 @@ set(DOCSTR "Enable C++ compiler specific extensions") sundials_option(CMAKE_CXX_EXTENSIONS BOOL "${DOCSTR}" ON) message(STATUS "C++ extensions set to ${CMAKE_CXX_EXTENSIONS}") +# Python interface code requires C++17 +if(SUNDIALS_ENABLE_PYTHON AND (CMAKE_CXX_STANDARD LESS "17")) + message( + SEND_ERROR + "CMAKE_CXX_STANDARD must be >= 17 because SUNDIALS_ENABLE_PYTHON=ON") +endif() + # SYCL requires C++17 if(ENABLE_SYCL AND (CMAKE_CXX_STANDARD LESS "17")) message(FATAL_ERROR "CMAKE_CXX_STANDARD must be >= 17 because ENABLE_SYCL=ON") endif() + +# Ginkgo requires C++17 +if(ENABLE_GINKGO AND (CMAKE_CXX_STANDARD LESS "17")) + message( + FATAL_ERROR "CMAKE_CXX_STANDARD must be >= 17 because ENABLE_GINKGO=ON") +endif() diff --git a/cmake/SundialsSetupCompilers.cmake b/cmake/SundialsSetupCompilers.cmake index 232dc7ed8b..5aacd64968 100644 --- a/cmake/SundialsSetupCompilers.cmake +++ b/cmake/SundialsSetupCompilers.cmake @@ -426,6 +426,7 @@ endif() # =============================================================== if(BUILD_BENCHMARKS + OR SUNDIALS_ENABLE_PYTHON OR SUNDIALS_TEST_ENABLE_UNIT_TESTS OR EXAMPLES_ENABLE_CXX OR ENABLE_CUDA diff --git a/doc/arkode/guide/source/Python b/doc/arkode/guide/source/Python new file mode 120000 index 0000000000..7f5da0abae --- /dev/null +++ b/doc/arkode/guide/source/Python @@ -0,0 +1 @@ +../../../shared/Python \ No newline at end of file diff --git a/doc/arkode/guide/source/index.rst b/doc/arkode/guide/source/index.rst index f1af830835..05dc97907a 100644 --- a/doc/arkode/guide/source/index.rst +++ b/doc/arkode/guide/source/index.rst @@ -76,6 +76,7 @@ with support by the `US Department of Energy `_, Constants Butcher Fortran/index.rst + Python/index.rst History_link.rst Changelog_link.rst References diff --git a/doc/cvodes/guide/source/Python b/doc/cvodes/guide/source/Python new file mode 120000 index 0000000000..7f5da0abae --- /dev/null +++ b/doc/cvodes/guide/source/Python @@ -0,0 +1 @@ +../../../shared/Python \ No newline at end of file diff --git a/doc/cvodes/guide/source/index.rst b/doc/cvodes/guide/source/index.rst index 91f73583fc..74aec8a091 100644 --- a/doc/cvodes/guide/source/index.rst +++ b/doc/cvodes/guide/source/index.rst @@ -41,6 +41,7 @@ CVODES Documentation sundials/Install_link.rst Constants Fortran/index.rst + Python/index.rst History_link.rst Changelog_link.rst References diff --git a/doc/idas/guide/source/Python b/doc/idas/guide/source/Python new file mode 120000 index 0000000000..7f5da0abae --- /dev/null +++ b/doc/idas/guide/source/Python @@ -0,0 +1 @@ +../../../shared/Python \ No newline at end of file diff --git a/doc/idas/guide/source/index.rst b/doc/idas/guide/source/index.rst index 62c3c501a6..9604ca84c7 100644 --- a/doc/idas/guide/source/index.rst +++ b/doc/idas/guide/source/index.rst @@ -41,6 +41,7 @@ IDAS Documentation sundials/Install_link.rst Constants Fortran/index.rst + Python/index.rst History_link.rst Changelog_link.rst References diff --git a/doc/kinsol/guide/source/Python b/doc/kinsol/guide/source/Python new file mode 120000 index 0000000000..7f5da0abae --- /dev/null +++ b/doc/kinsol/guide/source/Python @@ -0,0 +1 @@ +../../../shared/Python \ No newline at end of file diff --git a/doc/kinsol/guide/source/index.rst b/doc/kinsol/guide/source/index.rst index 5f61577e58..855233b7b8 100644 --- a/doc/kinsol/guide/source/index.rst +++ b/doc/kinsol/guide/source/index.rst @@ -40,6 +40,7 @@ KINSOL Documentation sundials/Install_link.rst Constants Fortran/index.rst + Python/index.rst History_link.rst Changelog_link.rst References diff --git a/doc/requirements.txt b/doc/requirements.txt index 870c37a91d..4696854cf6 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,8 +1,10 @@ +numpydoc sphinx >=4.0.0 -sphinx-fortran sphinx_rtd_theme -sphinxcontrib.bibtex sphinx-copybutton +sphinx-fortran +sphinx-multitoc-numbering sphinx-toolbox sphinxcontrib-googleanalytics -sphinx-multitoc-numbering +sphinxcontrib.bibtex +git+https://github.com/LLNL/sundials.git@feature/python-nanobind diff --git a/doc/shared/Python/API.rst b/doc/shared/Python/API.rst new file mode 100644 index 0000000000..db7791c21c --- /dev/null +++ b/doc/shared/Python/API.rst @@ -0,0 +1,72 @@ +.. ---------------------------------------------------------------- + SUNDIALS Copyright Start + Copyright (c) 2025, Lawrence Livermore National Security, + University of Maryland Baltimore County, and the SUNDIALS contributors. + Copyright (c) 2013-2025, Lawrence Livermore National Security + and Southern Methodist University. + Copyright (c) 2002-2013, Lawrence Livermore National Security. + All rights reserved. + + See the top-level LICENSE and NOTICE files for details. + + SPDX-License-Identifier: BSD-3-Clause + SUNDIALS Copyright End + ---------------------------------------------------------------- + +.. _Python.API: + +sundials4py API +=============== + +This section lists the Python API for the ``sundials4py`` module and submodules. + +Core Module +----------- + +.. automodule:: sundials4py.core + :members: + :undoc-members: + :private-members: + +.. include:: ../../../shared/Python/sundials4py-core-functions.rst + + +ARKODE Module +-------------- + +.. automodule:: sundials4py.arkode + :members: + :undoc-members: + :private-members: + +.. include:: ../../../shared/Python/sundials4py-arkode-functions.rst + +CVODES Module +------------- + +.. automodule:: sundials4py.cvodes + :members: + :undoc-members: + :private-members: + +.. include:: ../../../shared/Python/sundials4py-cvodes-functions.rst + +IDAS Module +----------- + +.. automodule:: sundials4py.idas + :members: + :undoc-members: + :private-members: + +.. include:: ../../../shared/Python/sundials4py-idas-functions.rst + +KINSOL Module +------------- + +.. automodule:: sundials4py.kinsol + :members: + :undoc-members: + :private-members: + +.. include:: ../../../shared/Python/sundials4py-kinsol-functions.rst diff --git a/doc/shared/Python/Introduction.rst b/doc/shared/Python/Introduction.rst new file mode 100644 index 0000000000..70a93e15c1 --- /dev/null +++ b/doc/shared/Python/Introduction.rst @@ -0,0 +1,73 @@ +.. ---------------------------------------------------------------- + SUNDIALS Copyright Start + Copyright (c) 2025, Lawrence Livermore National Security, + University of Maryland Baltimore County, and the SUNDIALS contributors. + Copyright (c) 2013-2025, Lawrence Livermore National Security + and Southern Methodist University. + Copyright (c) 2002-2013, Lawrence Livermore National Security. + All rights reserved. + + See the top-level LICENSE and NOTICE files for details. + + SPDX-License-Identifier: BSD-3-Clause + SUNDIALS Copyright End + ---------------------------------------------------------------- + +.. _Python.Introduction: + +Introduction +============ + +sundials4py is designed to be easy to use from Python in conjunction with ubiquitous libraries in the Python scientific computing and machine learning ecosystems. +To that end, it supports: + +- Python's automatic memory management via "View" classes which wrap the plain SUNDIALS C objects +- Python definitions of user-supplied callback functions +- Zero-copy exchange of arrays (CPU and Device) through DLPack protocol and numpy's ndarray + +sundials4py is built using `nanobind `__ and `litgen `__. +**It requires Python 3.12+**. + + +Installation +------------ + +You can install sundials4py directly from PyPI using pip: + +.. code-block:: bash + + pip install sundials4py + +The default build of sundials4py that is distributed as a binary wheel uses double precision real types and 64-bit indices. +To install SUNDIALS with different precisions and index sizes, you can build from source wheels instead of using the pre-built +binary wheels. When building from source wheels instead of binary wheels, you can customize the SUNDIALS precision (real type) +and index type at build time by passing the CMake arguments in environment variables when running pip. For example: + +.. code-block:: bash + + export CMAKE_ARGS="-DSUNDIALS_PRECISION=SINGLE -DSUNDIALS_INDEX_SIZE=64" + pip install sundials4py --no-binary=sundials4py + +Other SUNDIALS options can also be accessed in this way. Review :numref:`Installation.Options` for more information on the available options. + +.. note:: + + Not all SUNDIALS options are supported by the Python interfaces. In particular, third-party libraries are not yet supported. + +After installation, you can import sundials4py in your Python scripts: + +.. code-block:: python + + import sundials4py + +The modules available are: + +- ``sundials4py.core``: contains all the shared SUNDIALS classes and functions +- ``sundials4py.arkode``: contains all of the ARKODE specific classes and functions +- ``sundials4py.cvodes``: contains all of the CVODES specific classes and functions +- ``sundials4py.idas``: contains all of the IDAS specific classes and functions +- ``sundials4py.kinsol``: contains all of the KINSOL specific classes and functions + +CVODE and IDA dot not have modules because CVODES and IDAS provide all of the same capabilities plus continuous forward and adjoint sensitivity analysis. + +For more information on usage, differences from the C/C++ API and examples, continue to the next sections of this documentation. diff --git a/doc/shared/Python/Usage.rst b/doc/shared/Python/Usage.rst new file mode 100644 index 0000000000..2ce27c59c8 --- /dev/null +++ b/doc/shared/Python/Usage.rst @@ -0,0 +1,174 @@ +.. ---------------------------------------------------------------- + SUNDIALS Copyright Start + Copyright (c) 2025, Lawrence Livermore National Security, + University of Maryland Baltimore County, and the SUNDIALS contributors. + Copyright (c) 2013-2025, Lawrence Livermore National Security + and Southern Methodist University. + Copyright (c) 2002-2013, Lawrence Livermore National Security. + All rights reserved. + + See the top-level LICENSE and NOTICE files for details. + + SPDX-License-Identifier: BSD-3-Clause + SUNDIALS Copyright End + ---------------------------------------------------------------- + +.. _Python.Usage: + +Using sundials4py +================= + +At a high level, using SUNDIALS from Python via sundials4py looks a lot like using SUNDIALS from C or C++. + +The few notable differences are: + +View Classes and Memory Management +---------------------------------- + +sundials4py provides natural usage of SUNDIALS objects with natural Python object lifetimes managed by the Python garbage collection as with any other Python object. +There is only one caveat that arises with ``void*`` objects/variables due to restrictions in nanobind: +the SUNDIALS integrator/solver memory ``void*`` objects are (behind the scenes) wrapped in "View" classes. +These view objects cannot be implicitly converted to the underlying ``void*``. As such, when calling a function which operates on these ``void*`` objects, one must +extract the ``void*`` "capsule" from the view object by calling the view's ``get`` method: + +.. code-block:: python + + from sundials4py.core import * + from sundials4py.cvode import * + + sunctx = SUNContext_Create(SUN_COMM_NULL) + cvode = CVodeCreate(CV_BDF, sunctx) + # notice we need to call cvode.get() + status = CVodeInit(cvode.get(), lambda t, y, ydot, _: ode_problem.f(t, y, ydot), T0, y) + + +Return-by-Pointer Parameters +---------------------------- + +Functions that return values via pointer arguments in the C API are mapped to Python functions that return a tuple: + +- **First element:** The function's return value (typically an error code). +- **Subsequent elements:** Values that would be returned via pointer arguments in C, in the same order as the C function signature. + +**Example 1: Single Return-by-Pointer Value** + +C: + .. code-block:: C + + int CVodeGetNumSteps(void *cvode_mem, long int *numsteps); + +Python: + .. code-block:: python + + retval, numsteps = CVodeGetNumSteps(cvode_mem.get()) + print(f"Number of steps: {numsteps}") + +**Example 2: Multiple Return-by-Pointer Values** + +C: + .. code-block:: C + + int CVodeGetIntegratorStats(void *cvode_mem, + long int *nsteps, + long int *nfevals, + long int *nlinsetups, + long int *netfails); + +Python: + .. code-block:: python + + retval, nsteps, nfevals, nlinsetups, netfails = CVodeGetIntegratorStats(cvode_mem.get()) + print(f"Steps: {nsteps}, Function evals: {nfevals}, Linear setups: {nlinsetups}, Error test fails: {netfails}") + + +Arrays +------ + +``N_Vector`` objects in sundials4py are compatible with numpy's `ndarray`. Each ``N_Vector`` can work on a numpy arrays without copies, and you can access +and modify the underlying data directly using :py:func:`N_VGetArrayPointer`, which returns a numpy `ndarray` view of the data. + +- SUNDIALS matrix types (dense, banded, sparse) are also exposed as Python objects that provide access to their underlying data as numpy arrays (e.g., via :py:func:`SUNDenseMatrix_Data`). +- Arrays of scalars (e.g., scaling factors passed to :py:func:`N_VLinearCombination`) are also represented as numpy arrays. + +**Example: Accessing and modifying an N_Vector** + +.. code-block:: python + + y_nvec = NVectorView.Create(N_VNew_Serial(10, sunctx.get())) + y = N_VGetArrayPointer(y_nvec.get()) + y[:] = np.linspace(0, 1, 10) # Set values using numpy + +**Example: Using a matrix as a numpy array** + +.. code-block:: python + + mat = SUNMatrixView.Create(SUNDenseMatrix(3, 3, sunctx.get())) + arr = SUNDenseMatrix_Data(mat.get()) + arr = np.eye(3) # Set to identity matrix + +This allows you to use numpy operations for vector and matrix data, and to pass numpy arrays to and from SUNDIALS routines efficiently and without unnecessary copies. + + +User-Supplied Callback Functions +-------------------------------- + +SUNDIALS packages and several modules/classes require user-supplied callback functions to define problem-specific behavior, +such as the right-hand side of an ODE or a nonlinear system function. In sundials4py, you can provide these as standard Python functions or lambdas. +Some things to note: + +- The callback signatures follow the C API. As such, ``N_Vector`` arguments are passed as ``N_Vector`` objects and the underlying ndarray must be extracted in the user code. The only caveat is that return-by-pointer parameters are removed from the signature, and instead become return values (mirroring how return-by-pointer parameters for other functions are handled) +- Most callback signatures include a ``void* user_data`` argument. In Python, this argument must be present in the signature, but it should be ignored. + +**Example: ODE right-hand side for ARKStep** + +.. code-block:: python + + def rhs(t, y_nvector, ydot_nvector, _): + # Compute ydot = f(t, y) + y = N_VGetArrayPointer(y_nvector) + ydot = N_VGetArrayPointer(ydot_nvector) + ydot[:] = -y + return 0 + + ark = ARKodeView.Create(ARKStepCreate(rhs, None, t0, y.get(), sunctx.get())) + +**Example: Nonlinear system for KINSOL** + +.. code-block:: python + + def fp_function(u_nvector, g_nvector, _): + # Compute g = F(u) + u = N_VGetArrayPointer(u_nvector) + g = N_VGetArrayPointer(g_nvector) + g[:] = u**2 - 1 + return 0 + + kin = KINView.Create(KINCreate(sunctx.get())) + KINInit(kin.get(), fp_function, u.get()) + +**Example: ARKODE LSRKStep dominant eigenvalue estimation function with return-by-pointer parameters** + +.. code-block:: python + + # The C signature is: + # int(sunrealtype t, N_Vector y, N_Vector fn, + # sunrealtype* lambdaR, sunrealtype* lambdaI, + # void* user_data, N_Vector temp1, + # N_Vector temp2, N_Vector temp3) + def dom_eig(t, yvec, fnvec, temp1, temp2, temp3, _): + lamdbaR = L + lamdbaI = 0.0 + # lambdaR and lambdaI should be returned in the order that they appear + # as parameters in the C API and follow the error code to return + return 0, lamdbaR, lamdbaI + + +.. warning:: + + The ``user_data`` argument should always be ``None`` or ``_`` on the Python side. If it is listed otherwise then it should be ignored to avoid causing catastrophic errors. + + +Examples +-------- + +Examples can be found in ``bindings/sundials4py/examples``. diff --git a/doc/shared/Python/index.rst b/doc/shared/Python/index.rst new file mode 100644 index 0000000000..29156ba8b5 --- /dev/null +++ b/doc/shared/Python/index.rst @@ -0,0 +1,38 @@ +.. + ----------------------------------------------------------------------------- + SUNDIALS Copyright Start + Copyright (c) 2025, Lawrence Livermore National Security, + University of Maryland Baltimore County, and the SUNDIALS contributors. + Copyright (c) 2013-2025, Lawrence Livermore National Security + and Southern Methodist University. + Copyright (c) 2002-2013, Lawrence Livermore National Security. + All rights reserved. + + See the top-level LICENSE and NOTICE files for details. + + SPDX-License-Identifier: BSD-3-Clause + SUNDIALS Copyright End + ----------------------------------------------------------------------------- + +.. _Python: + +****** +Python +****** + +sundials4py provides official (supported by the SUNDIALS team) Python bindings to much of the +SUNDIALS library, allowing you to use SUNDIALS directly from Python. + +.. note:: + + New SUNDIALS users should first read the :ref:`General User Guide `. + The Python User Guide focuses on specific aspects of using SUNDIALS from Python and assumes + the user is familiar with SUNDIALS. + + +.. toctree:: + :maxdepth: 1 + + Introduction.rst + Usage.rst + API.rst diff --git a/doc/shared/RecentChanges.rst b/doc/shared/RecentChanges.rst index b0a7913045..ad444a9a19 100644 --- a/doc/shared/RecentChanges.rst +++ b/doc/shared/RecentChanges.rst @@ -3,6 +3,10 @@ **Major Features** +SUNDIALS now has official Python interfaces! With this release, we are shipping a **beta version** of +the sundials4py Python module (created with nanobind and litgen). sundials4py provides explicit +interfaces to most features of SUNDIALS. + **New Features and Enhancements** The functions ``CVodeGetUserDataB`` and ``IDAGetUserDataB`` were added to CVODES diff --git a/doc/shared/generate_autofunctions.py b/doc/shared/generate_autofunctions.py new file mode 100644 index 0000000000..e646910957 --- /dev/null +++ b/doc/shared/generate_autofunctions.py @@ -0,0 +1,47 @@ +# ----------------------------------------------------------------------------- +# SUNDIALS Copyright Start +# Copyright (c) 2025, Lawrence Livermore National Security, +# University of Maryland Baltimore County, and the SUNDIALS contributors. +# Copyright (c) 2013-2025, Lawrence Livermore National Security +# and Southern Methodist University. +# Copyright (c) 2002-2013, Lawrence Livermore National Security. +# All rights reserved. +# +# See the top-level LICENSE and NOTICE files for details. +# +# SPDX-License-Identifier: BSD-3-Clause +# SUNDIALS Copyright End +# ----------------------------------------------------------------------------- +# Script that generates sphinx-autodoc autofunction directives for each +# function in the Python modules. This is necessary due to automodule not +# yet being able to handle nanobind generated functions. +# See https://github.com/sphinx-doc/sphinx/issues/13868. +# ----------------------------------------------------------------------------- + +import os +import importlib +import sundials4py + + +def generate_autofunctions_for_submodule(module_name: str): + module = importlib.import_module(f"sundials4py.{module_name}") + autogen_file = os.path.join( + os.path.dirname(__file__), f"./Python/sundials4py-{module_name}-functions.rst" + ) + with open(autogen_file, "w") as f: + f.write("Functions\n") + f.write("^^^^^^^^^\n\n") + for func_name in dir(module): + obj = getattr(module, func_name) + if type(obj).__name__ == "nb_func": + f.write(f".. autofunction:: sundials4py.{module_name}.{func_name}\n") + f.write(" :no-index:\n\n") + f.write(f" See :c:func:`{func_name}`.\n\n") + + +def generate_autofunctions_for_sundials4py(): + generate_autofunctions_for_submodule("core") + generate_autofunctions_for_submodule("arkode") + generate_autofunctions_for_submodule("cvodes") + generate_autofunctions_for_submodule("idas") + generate_autofunctions_for_submodule("kinsol") diff --git a/doc/shared/sundials_vars.py b/doc/shared/sundials_vars.py index 54c60ef0a1..5bb44a99c1 100644 --- a/doc/shared/sundials_vars.py +++ b/doc/shared/sundials_vars.py @@ -136,4 +136,9 @@ # documentation to use .. cpp:function rather than .. c:function ("c:identifier", "SUNCudaExecPolicy"), ("c:identifier", "SUNHipExecPolicy"), + # Python + ("py:class", "typing_extensions.CapsuleType"), + ("py:class", "types.CapsuleType"), + ("py:class", "collections.abc.Callable"), + ("py:class", "collections.abc.Sequence"), ] diff --git a/doc/superbuild/source/Python b/doc/superbuild/source/Python new file mode 120000 index 0000000000..9a26c4bc26 --- /dev/null +++ b/doc/superbuild/source/Python @@ -0,0 +1 @@ +../../shared/Python \ No newline at end of file diff --git a/doc/superbuild/source/conf.py b/doc/superbuild/source/conf.py index 70971fc743..23dfb5c125 100644 --- a/doc/superbuild/source/conf.py +++ b/doc/superbuild/source/conf.py @@ -18,9 +18,12 @@ sys.path.append(os.path.dirname(os.path.abspath("../../shared/sundials_vars.py"))) from sundials_vars import * +sys.path.append(os.path.dirname(os.path.abspath("../../shared/generate_autofunctions.py"))) +from generate_autofunctions import generate_autofunctions_for_sundials4py + sys.path.append(os.path.dirname(os.path.abspath("../../shared"))) -# Add suntools directory to import python function docstings with autodoc +# Add suntools directory to import python function docstrings with autodoc sys.path.append(os.path.abspath("../../../tools/suntools")) # -- General configuration ---------------------------------------------------- @@ -35,19 +38,20 @@ # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ + "numpydoc", + "sphinx_copybutton", + "sphinx_multitoc_numbering", "sphinx_rtd_theme", + "sphinx_sundials", + "sphinx_toolbox.collapse", + "sphinx.ext.autodoc", "sphinx.ext.extlinks", + "sphinx.ext.graphviz", "sphinx.ext.ifconfig", - "sphinx.ext.mathjax", "sphinx.ext.intersphinx", - "sphinxfortran.fortran_domain", + "sphinx.ext.mathjax", "sphinxcontrib.bibtex", - "sphinx_copybutton", - "sphinx.ext.graphviz", - "sphinx_sundials", - "sphinx_toolbox.collapse", - "sphinx.ext.autodoc", - "sphinx_multitoc_numbering", + "sphinxfortran.fortran_domain", ] extlinks = { @@ -59,7 +63,10 @@ # Where to find cross-references to the Sphinx documentation. intersphinx_mapping = { - "sphinx": ("https://www.sphinx-doc.org/en/master", ("../objects.inv", None)) + "sphinx": ("https://www.sphinx-doc.org/en/master", ("../objects.inv", None)), + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "scipy": ("https://docs.scipy.org/doc/scipy/", None), } # Only setup Google analytics for the readthedocs being deployed (not local). @@ -91,7 +98,7 @@ project = "Documentation for SUNDIALS" # RTD adds the first Copyright (c), so we leave it out. copyright = """\ - 2025-{year}, Lawrence Livermore National Security, University of Maryland Baltimore County, and the SUNDIALS contributors. + {year}, Lawrence Livermore National Security, University of Maryland Baltimore County, and the SUNDIALS contributors. Copyright (c) 2013-2025, Lawrence Livermore National Security and Southern Methodist University. Copyright (c) 2002-2013, Lawrence Livermore National Security""".format( year=year @@ -252,3 +259,9 @@ # Output file base name for HTML help builder. htmlhelp_basename = "SUNDIALSdoc" + +# This prevents numpydoc from showing too much detail of the Enum classes +numpydoc_show_class_members = False + +# Generate rst files with autofunction directives for sundials4py functions +generate_autofunctions_for_sundials4py() diff --git a/doc/superbuild/source/developers/commandline/index.rst b/doc/superbuild/source/developers/commandline/index.rst index 6fff46676e..051a5d2899 100644 --- a/doc/superbuild/source/developers/commandline/index.rst +++ b/doc/superbuild/source/developers/commandline/index.rst @@ -1,5 +1,5 @@ .. - Author(s): Daniel R. Reynolds @ UMBC + Author(s): Daniel R. Reynolds @ UMBC and David J. Gardner @ LLNL ----------------------------------------------------------------------------- SUNDIALS Copyright Start Copyright (c) 2025, Lawrence Livermore National Security, diff --git a/doc/superbuild/source/developers/getting_started/Checklist.rst b/doc/superbuild/source/developers/getting_started/Checklist.rst index 25e82a157c..a683200f86 100644 --- a/doc/superbuild/source/developers/getting_started/Checklist.rst +++ b/doc/superbuild/source/developers/getting_started/Checklist.rst @@ -98,3 +98,11 @@ system, etc. developers should adhere to the following checklist. from the SWIG GitHub action that we run on all pull requests. The patch can be found under the job artifacts (if there were in fact changes that required updates to the Fortran). + +#. Similarly, re-run the Python interface generator to generate updated Python + interfaces. This is done by navigating to the ``bindings/sundials4py/`` + directory and running ``python generate.py``. + + * If you added a new user-supplied function, or new module, then there will be manual + changes to make in the ``bindings/sundials4py/`` directory. See the + :ref:`Python` section for more details. diff --git a/doc/superbuild/source/developers/index.rst b/doc/superbuild/source/developers/index.rst index b9a6b8bbf0..b0ccd2ca15 100644 --- a/doc/superbuild/source/developers/index.rst +++ b/doc/superbuild/source/developers/index.rst @@ -43,6 +43,7 @@ meant for SUNDIALS developers. testing/index benchmarks/index pull_requests/index + python/index releases/index packages/index appendix/index diff --git a/doc/superbuild/source/developers/python/index.rst b/doc/superbuild/source/developers/python/index.rst new file mode 100644 index 0000000000..7137eec084 --- /dev/null +++ b/doc/superbuild/source/developers/python/index.rst @@ -0,0 +1,254 @@ +.. + Author(s): Cody J. Balos @ LLNL + ----------------------------------------------------------------------------- + SUNDIALS Copyright Start + Copyright (c) 2002-2025, Lawrence Livermore National Security + and Southern Methodist University. + All rights reserved. + + See the top-level LICENSE and NOTICE files for details. + + SPDX-License-Identifier: BSD-3-Clause + SUNDIALS Copyright End + ----------------------------------------------------------------------------- + +.. _Developer.Python: + +Python Interfaces +================= + +This chapter covers details developers need to know about the SUNDIALS Python interfaces, distributed as the Python package sundials4py. + +We use `nanobind `__ for the Python bindings. nanobind is a sleeker, faster ``pybind11``. +It is a C++ library, i.e. you write your binding code in C++. Nanobind does have some restrictions: + +- Cannot bind to functions which take double, or more pointer arguments. I.e., it cannot bind to `**` or `***` and so on. These have to be flattened somehow. +- Cannot implicitly convert between a "View" container class and the underlying C type. I.e., it cannot implicitly convert ``ARKodeView`` to ``void*``. + This means that users must explicitly convert from the "View" class by calling the ``get`` member function. + +We use `litgen `__ to generate a large portion of the nanobind code. + +- We have ``generate.yaml`` files designate headers to generate bindings from and functions to exclude. +- A ``generate.py`` script uses litgen to generate the bindings as a C++ header according to the ``generate.yaml``. +- For each generated file, there is at least one hand-coded file that includes the generated header. + +.. note:: + + Litgen itself is licensed under GPLv3. This means the ``generate.py`` script is effectively governed by ``GPLv3``, + **but the binding code generated by the script/litgen falls only under our SUNDIALS license**. + Because of this, the ``generate.py`` script and litgen extensions are kept in a separate git repository as package, + ``sundials4py-generator``. + + +Structure +--------- + +sundials4py code lives in ``bindings/sundials4py``. The main python module and all of its submodules are defined in ``sundials4py.cpp``. +sundials4py consists of 5 modules. Below we list how each one maps to the directory layout: + +- **sundials.arkode**: + - Implements bindings for all of ARKODE. + - Source directory: ``arkode/`` + +- **sundials.cvodes**: + - Provides bindings for all of CVODES. + - Source directory: ``cvodes/`` + +- **sundials.idas**: + - Contains bindings for all of IDAS. + - Source directory: ``idas/`` + +- **sundials.kinsol**: + - Facilitates bindings for all of KINSOL. + - Source directory: ``kinsol/``. + +- **sundials.core**: + - All SUNDIALS shared classes/modules and implementations. + - Source directories: + + - `nvector/` + - `sunadaptcontroller/` + - `sunadjointcheckpointscheme/` + - `sunlinsol/` + - `sunmatrix/` + - `sunmemory/` + - `sunnonlinsol/` + - `sundials/` + - `sundomeigest/` + + +Development +----------- + +sundials4py requires Python 3.12+ and the Interpreter/Development components. E.g., if you were installing Python on a RedHat Linux system, +you could install Python 3.12 with these modules like this: + +.. code-block:: shell + + yum install python3.12 python3.12-devel + +The recommended method for development is to use a typical Python development workflow with ``pip`` rather than invoking CMake directly. + +.. code-block:: shell + + cd sundials_root_directory + python -m venv .venv # create python virtual environment + . .venv/bin/activate # activate the python virtual environment + pip install scikit-build-core[pyproject] hatchling nanobind # this is a prerequisite for the next step + MAKEFLAGS="-j$(nproc)" pip install --no-build-isolation -Ceditable.rebuild=true -ve .[dev] # install sundials4py into the virtual environment + +The last ``pip install`` command will allow automatic incremental builds. It will invoke the SUNDIALS `CMake` build system with the +``-DSUNDIALS_ENABLE_PYTHON=ON`` option through `scikit-build-core `__. +After the initial build, if you make any changes within SUNDIALS a rebuild will be triggered when you import the ``sundials4py`` +module within a Python script. + +Different CMake options can be controlled by passing them through the ``--config-settings`` (or ``-C`` for short) option of ``pip install``. +E.g., + +.. code-block:: shell + + MAKEFLAGS="-j$(nproc)" pip install --no-build-isolation -Ceditable.rebuild=true -ve .[dev] \ + -C cmake.define.SUNDIALS_INDEX_SIZE=32 + +Alternatively, you can set the CMAKE_ARGS environment variable: + +.. code-block:: shell + + export CMAKE_ARGS="-DSUNDIALS_INDEX_SIZE=32" + MAKEFLAGS="-j$(nproc)" pip install --no-build-isolation -Ceditable.rebuild=true -ve .[dev] + + +Tests +----- + +We use pytest for setting up unit/smoke tests of the interfaces. All tests are in ``bindings/sundials4py/test``. The goal is to test the interfacing, +not the correctness of SUNDIALS itself. + + +-------------- + +All user-supplied Python functions have to be wrapped with functions that convert between a ``std::function`` and a raw C function pointer. +This is done by smuggling in a "function table" -- a struct of ``std::function`` members -- in a ``python`` member inside each integrator memory structure, and then storing the integrator memory structure in the ``user_data`` pointer. For the objects which are +not the integrator, we still stuff the function table in the ``python`` member of the struct so it will be available in all of the module/class methods. +The upshot is that every time we add a user-supplied function, we need to add a new member to the function table struct, +and add a wrapper for it. We also have to add a wrapper for the "Set" function that takes the user-supplied function. + +Here is an example for ARKODE: + +In ``bindings/sundials4py/arkode/arkode_usersupplied.hpp``, the function table struct is defined: + +.. code-block:: cpp + + struct arkode_user_supplied_fn_table + { + // common user-supplied function pointers + nb::object rootfn; + nb::object ewtn; + nb::object rwtn; + nb::object adaptfn; + nb::object expstabfn; + nb::object vecresizefn; + nb::object postprocessstepfn; + nb::object postprocessstagefn; + nb::object stagepredictfn; + nb::object relaxfn; + nb::object relaxjacfn; + nb::object nlsfi; + + // truncated ... + }; + + +Then each one of the functions in the table has a wrapper function defined below this struct definition, e.g., + +.. code-block:: cpp + + template + inline int arkode_postprocessstepfn_wrapper(Args... args) + { + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, arkode_user_supplied_fn_table, + ARKodeMem, 1>(&arkode_user_supplied_fn_table::postprocessstepfn, + std::forward(args)...); + } + +Finally, in ``bindings/sundials4py/arkode/arkode.cpp``, the Set function is registered with nanobind: + +.. code-block:: cpp + + BIND_ARKODE_CALLBACK(ARKodeSetPostprocessStepFn, ARKPostProcessFn, + postprocessstepfn, arkode_postprocessstepfn_wrapper, + nb::arg("arkode_mem"), nb::arg("postprocessstep").none()); + +``BIND_ARKODE_CALLBACK`` is a macro which expands to + +.. code-block:: cpp + + m.def( + ARKodeSetPostprocessStepFn, + [](void* ark_mem, std::function> fn) + { + auto fn_table = get_arkode_fn_table(ark_mem); + fn_table->MEMBER = nb::cast(fn); + fntable->postprocessstepfn = nb::cast(fn); + if (fn) { return NAME(ark_mem, &arkode_postprocessstepfn_wrapper); } + else { return NAME(ark_mem, nullptr); } + }, + nb::arg("arkode_mem"), nb::arg("postprocessstep").none()) + +What we are doing is creating a custom nanobind wrapper of :c:func:`ARKodeSetPostprocessStepFn` which takes the user-supplied +Python side function as a ``std::function`` and stores it in the function table (which is stored in user data). +The ``nb::arg`` arguments are needed so that we can make ``postprocessstep`` nullable (or ``None`` from Python). + +Here is another example, but this time for the ``SUNStepper`` and with the ``python`` member instead of ``user_data``. +From ``sundials_stepper_usersupplied.hpp``: + +.. code-block:: cpp + + struct SUNStepperFunctionTable + { + nb::object evolve; + nb::object one_step; + nb::object full_rhs; + nb::object reinit; + nb::object reset; + nb::object reset_ckpt_idx; + nb::object set_stop_time; + nb::object set_step_direction; + nb::object set_forcing; + nb::object get_num_steps; + }; + + template + inline SUNErrCode sunstepper_evolve_wrapper(Args... args) + { + return sundials4py::user_supplied_fn_caller< + std::remove_pointer_t, SUNStepperFunctionTable, + SUNStepper>(&SUNStepperFunctionTable::evolve, std::forward(args)...); + } + +From ``sundials_stepper.cpp``, + +.. code-block:: cpp + + m.def( + "SUNStepper_SetEvolveFn", + [](SUNStepper stepper, + std::function> fn) -> SUNErrCode + { + if (!stepper->python) + { + stepper->python = SUNStepperFunctionTable_Alloc(); + } + auto fntable = static_cast(stepper->python); + fntable->evolve = nb::cast(fn); + if (fn) + { + return SUNStepper_SetEvolveFn(stepper, sunstepper_evolve_wrapper); + } + else { return SUNStepper_SetEvolveFn(stepper, nullptr); } + }, + nb::arg("stepper"), nb::arg("fn").none()); + +We are again creating a nanobind wrapper for :c:func:`SUNStepper_SetEvolveFn`, but this time, +the function table is smuggled inside of the SUNStepper structure's ``python`` member. diff --git a/doc/superbuild/source/developers/source_code/Naming.rst b/doc/superbuild/source/developers/source_code/Naming.rst index b4293ece4d..93b529d16c 100644 --- a/doc/superbuild/source/developers/source_code/Naming.rst +++ b/doc/superbuild/source/developers/source_code/Naming.rst @@ -50,6 +50,48 @@ Variable names Snake case is preferred for local variable names e.g. ``foo_bar``. +Variables which are pointers to an array, and are effectively treated/indexed +as a contiguous array, should use the suffix `_<1|2|3>d`, e.g. + +.. code-block:: c + + sunrealtype my_array[3] = {1.0, 2.0, 3.0}; + sunrealtype* sequence_1d = my_array; + + sunrealtype my_matrix[2][3] = {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}; + sunrealtype* my_matrix_rows[2] = { my_matrix[0], my_matrix[1] }; + sunrealtype** sequence_2d = my_matrix_rows; + + +Variables which are purely pointers should use the suffix, ``_ptr``, e.g. + +.. code-block:: c + + N_Vector y = N_VNew_Serial(2, sunctx); + N_Vector y_ptr = &y; + +When combining the two rules, the ``_ptr`` suffix should come last, e.g. + +.. code-block:: c + + sunrealtype my_array[3] = {1.0, 2.0, 3.0}; + sunrealtype* sequence_1d = my_array; + sunrealtype** sequence_1d_ptr = &sequence_1d; + + sunrealtype my_matrix[2][3] = {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}; + sunrealtype* my_matrix_rows[2] = { my_matrix[0], my_matrix[1] }; + sunrealtype** sequence_2d = my_matrix_rows; + sunrealtype*** sequence_2d_ptr = &my_matrix_rows; + + +.. warning:: + + The suffixes are **required** for parameters of functions within public header + files because the Python interface generator relies on the suffixes to determine + the proper way to expose the parameter to Python users. It is preferable to follow + this convention within other code, but not required. + + C function names ---------------- @@ -72,15 +114,15 @@ Names for Vectors, Matrices, and Solvers The SUNDIALS vector, matrix, linear solver, and nonlinear solver classes use the naming convention ```` for base class methods where each component of the name uses Pascal case. See -:numref:`Style.Table.OldBaseClassMethodNaming` for examples. +:numref:`SourceCode.Naming.Table.OldBaseClassMethodNaming` for examples. .. note:: This naming convention *only* applies to the vector, matrix, and solver classes. All other classes should follow the naming convention described in - :ref:`Style.Naming.NewClasses`. + :ref:`SourceCode.Naming.NewClasses`. -.. _Style.Table.OldBaseClassMethodNaming: +.. _SourceCode.Naming.Table.OldBaseClassMethodNaming: .. Table:: SUNDIALS base class naming convention examples for vectors, matrices, linear solvers and nonlinear solvers. @@ -99,9 +141,9 @@ each component of the name uses Pascal case. See Derived class implementations of the base class methods should follow the naming convention ``_``. See -:numref:`Style.Table.OldDerivedClassMethodNaming` for examples. +:numref:`SourceCode.Naming.Table.OldDerivedClassMethodNaming` for examples. -.. _Style.Table.OldDerivedClassMethodNaming: +.. _SourceCode.Naming.Table.OldDerivedClassMethodNaming: .. Table:: SUNDIALS derived class naming convention examples for vectors, matrices, linear solvers and nonlinear solvers. @@ -124,16 +166,16 @@ existing class, follow the naming style used within that class. When adding a new derived class, use the same style as above for implementations of the base class method i.e., ``_``. -.. _Style.Naming.NewClasses: +.. _SourceCode.Naming.NewClasses: Names for New Classes --------------------- All new base classes should use the naming convention ``_`` for the base class methods. See -:numref:`Style.Table.NewBaseClassMethodNaming` for examples. +:numref:`SourceCode.Naming.Table.NewBaseClassMethodNaming` for examples. -.. _Style.Table.NewBaseClassMethodNaming: +.. _SourceCode.Naming.Table.NewBaseClassMethodNaming: .. Table:: SUNDIALS naming conventions for methods in new base classes. @@ -145,9 +187,9 @@ for the base class methods. See Derived class implementations of the base class methods should follow the naming convention ``__``. See -:numref:`Style.Table.NewDerivedClassMethodNaming` for examples. +:numref:`SourceCode.Naming.Table.NewDerivedClassMethodNaming` for examples. -.. _Style.Table.NewDerivedClassMethodNaming: +.. _SourceCode.Naming.Table.NewDerivedClassMethodNaming: .. Table:: SUNDIALS naming conventions for derived class implementations of methods in new base classes. @@ -161,7 +203,7 @@ convention ``__``. See For destructor functions, use ``Destroy`` rather than ``Free`` or some other alternative. -.. _Style.Classes.Cpp: +.. _SourceCode.Naming.CppClasses: Naming Convention for C++ Classes --------------------------------- @@ -173,3 +215,12 @@ Private C++ class functions should use camelcase (e.g. ``doSomething``). C++ private class members should use snake case with a trailing underscore (e.g. ``some_var_``). + + +.. _SourceCode.Naming.Enums: + +Enums +----- + +Enum tags/identifiers should follow class naming rules and use Pascal case. +Enum values should follow the rules for macros and constants. diff --git a/doc/superbuild/source/developers/source_code/Rules.rst b/doc/superbuild/source/developers/source_code/Rules.rst index 9ef33192f7..d10165db9d 100644 --- a/doc/superbuild/source/developers/source_code/Rules.rst +++ b/doc/superbuild/source/developers/source_code/Rules.rst @@ -23,6 +23,9 @@ Coding Conventions and Rules These rules should be followed for all new code. Unfortunately, old code might not adhere to all of these rules. + +#. Identifiers should follow our :ref:`Naming Conventions `. + #. Do not use language features that are not compatible with C99, C++14, and MSVC v1900+ (Visual Studio 2015). Examples of such features include variable-length arrays. Exceptions are allowed when interfacing with a @@ -205,12 +208,7 @@ not adhere to all of these rules. SUNDIALS API that will be interfaced to Fortran since the Fortran standard does not include unsigned integers. -#. Use the print functions, format macros, and output guidelines detailed in - :ref:`Style.Output`. - -#. Follow the logging style detailed in :ref:`Style.Logging`. - -#. Use `sizeof(variable)` rather than `sizeof(type)`. E.g., +#. Use ``sizeof(variable)`` rather than ``sizeof(type)``. E.g., .. code-block:: c @@ -218,4 +216,29 @@ not adhere to all of these rules. int array_length = 10; int* array1 = malloc(array_length * sizeof(a)); // Do this int* array2 = malloc(array_length * sizeof(int)); // Don't do this - \ No newline at end of file + +#. Do not use anonymous ``enum`` s in public header files (the Python interface + generator doesn't like it). Wrap typedef statements in SWIG guards, e.g. + + .. code-block:: c + + // Don't do this + typedef enum { + ARK_RELAX_BRENT, + ARK_RELAX_NEWTON + } ARKRelaxSolver; + + // Do this + enum ARKRelaxSolver { + ARK_RELAX_BRENT, + ARK_RELAX_NEWTON + }; + + #ifndef SWIG + typedef enum ARKRelaxSolver ARKRelaxSolver; + #endif + +#. Use the print functions, format macros, and output guidelines detailed in + :ref:`Style.Output`. + +#. Follow the logging style detailed in :ref:`Style.Logging`. diff --git a/doc/superbuild/source/index.rst b/doc/superbuild/source/index.rst index d838bc15a6..078a74c31f 100644 --- a/doc/superbuild/source/index.rst +++ b/doc/superbuild/source/index.rst @@ -176,7 +176,7 @@ SUNDIALS License and Notices .. toctree:: - :caption: USAGE + :caption: GENERAL USER GUIDE :maxdepth: 1 :numbered: :hidden: @@ -208,6 +208,7 @@ SUNDIALS License and Notices :hidden: Fortran/index.rst + Python/index.rst .. toctree:: :caption: EXAMPLES diff --git a/pyproject.toml b/pyproject.toml index 36100c833e..b4583ef4c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,85 @@ +[build-system] +requires = ["scikit-build-core[pyproject] >=0.4.3", "nanobind"] +build-backend = "scikit_build_core.build" + +[project] +name = "sundials4py" +version = "7.6.0" +description = "Official Python bindings for the SUNDIALS suite of nonlinear and differential/algebraic equation solvers." +authors = [ + { name = "SUNDIALS Developers", email = "sundials-users@llnl.gov" } +] +license = { file = "LICENSE" } +readme = "README.md" +requires-python = ">=3.12" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Mathematics" +] +dependencies = [ + "nanobind >= 2.9.2", + "numpy >= 2.0.0", +] + +[project.optional-dependencies] +dev = [ + "black", + "isort", + "litgen@git+https://github.com/sundials-codes/litgen.git", + "pytest", + "pytest-random-order", + "pyyaml", +] + +[project.urls] +Homepage = "https://computing.llnl.gov/projects/sundials" +Documentation = "https://sundials.readthedocs.io/" +Source = "https://github.com/LLNL/sundials" + +[tool.scikit-build] +# Protect the configuration against future changes in scikit-build-core +minimum-version = "0.4" + +# Setuptools-style build caching in a local directory +build-dir = "build/{wheel_tag}" + +# Build stable ABI wheels for CPython 3.12+ +wheel.py-api = "cp312" + +cmake.args = [ + "-DBUILD_SHARED_LIBS:BOOL=OFF", + "-DBUILD_CVODE:BOOL=OFF", + "-DBUILD_IDA:BOOL=OFF", + "-DEXAMPLES_ENABLE_C:BOOL=OFF", + "-DEXAMPLES_ENABLE_CXX:BOOL=OFF", + "-DSUNDIALS_ENABLE_PYTHON:BOOL=ON" +] + +cmake.build-type = "RelWithDebInfo" + +[tool.pytest.ini_options] +addopts = [ + "--import-mode=importlib" +] +pythonpath = [ + ".", "bindings/sundials4py/test" +] + +[tool.cibuildwheel] +# Necessary to see build output from the actual compilation +build-verbosity = 1 + +# Optional: run pytest to ensure that the package was correctly built +test-command = "pytest --random-order bindings/sundials4py/test" +test-requires = "pytest" + +# Needed for full C++17 support on macOS +[tool.cibuildwheel.macos.environment] +MACOSX_DEPLOYMENT_TARGET = "10.15" + [tool.black] line-length = 99 skip_magic_trailing_comma = true