From f2641b3016a5b47354f61821728736488009ed66 Mon Sep 17 00:00:00 2001 From: Rishipal Singh Bhatia Date: Mon, 20 Oct 2025 17:01:03 -0700 Subject: [PATCH] Delete the floating point optimization level and memory fitting level settings in favor of the EffortLevel enum. ##### Why? Simplifies flag input handling by removing multiple ways to specify the same thing. PiperOrigin-RevId: 821862460 --- jax/_src/compiler.py | 8 +------- jax/_src/config.py | 12 ------------ jaxlib/_jax/__init__.pyi | 10 ---------- jaxlib/xla_compiler.cc | 6 ------ tests/api_test.py | 17 ----------------- 5 files changed, 1 insertion(+), 52 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 1b604d2abb52..6f1b9a0ed9c5 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -188,8 +188,6 @@ def get_compile_options( assert device_assignment.computation_count() == num_partitions compile_options.device_assignment = device_assignment - build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value - build_options.memory_fitting_effort = config.memory_fitting_effort.value build_options.optimization_level = config.EffortLevel( config.optimization_level.value ).value @@ -199,11 +197,7 @@ def get_compile_options( if env_options_overrides is not None: # Some overrides are passed directly on build_options. - overrides_on_build_options = [ - "exec_time_optimization_effort", "memory_fitting_effort"] - overrides_on_build_options.extend( - ["optimization_level", "memory_fitting_level"] - ) + overrides_on_build_options = ["optimization_level", "memory_fitting_level"] env_options_overrides = dict(env_options_overrides) for name in overrides_on_build_options: diff --git a/jax/_src/config.py b/jax/_src/config.py index 212902268d78..045a7481e97e 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -2136,18 +2136,6 @@ def _default_pmap_no_rank_reduction(new_val): ), ) -exec_time_optimization_effort = float_state( - name='jax_exec_time_optimization_effort', - default=0.0, - help='Effort for minimizing execution time (higher means more effort), valid range [-1.0, 1.0].' -) - -memory_fitting_effort = float_state( - name='jax_memory_fitting_effort', - default=0.0, - help='Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].' -) - optimization_level = enum_state( name='jax_optimization_level', enum_values=[ diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 831f44ea2e8e..72a653060d47 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -561,16 +561,6 @@ class ExecutableBuildOptions: self, arg: bytes, / ) -> None: ... @property - def exec_time_optimization_effort(self) -> float: ... - @exec_time_optimization_effort.setter - def exec_time_optimization_effort( - self, arg: float, / - ) -> ExecutableBuildOptions: ... - @property - def memory_fitting_effort(self) -> float: ... - @memory_fitting_effort.setter - def memory_fitting_effort(self, arg: float, /) -> ExecutableBuildOptions: ... - @property def optimization_level(self) -> int: ... @optimization_level.setter def optimization_level(self, arg: int, /) -> None: ... diff --git a/jaxlib/xla_compiler.cc b/jaxlib/xla_compiler.cc index eaed15c07487..4a9b88827033 100644 --- a/jaxlib/xla_compiler.cc +++ b/jaxlib/xla_compiler.cc @@ -1139,12 +1139,6 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { xla::CompilationEnvironments::CreateFromProto(env_proto)); *options.mutable_comp_envs() = std::move(*comp_envs); }) - .def_prop_rw("exec_time_optimization_effort", - &ExecutableBuildOptions::exec_time_optimization_effort, - &ExecutableBuildOptions::set_exec_time_optimization_effort) - .def_prop_rw("memory_fitting_effort", - &ExecutableBuildOptions::memory_fitting_effort, - &ExecutableBuildOptions::set_memory_fitting_effort) .def_prop_rw( "optimization_level", [](ExecutableBuildOptions& options) { diff --git a/tests/api_test.py b/tests/api_test.py index 8bfb5bc3d146..7bc49df84170 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1383,23 +1383,6 @@ def f(x): "xla_gpu_auto_spmd_partitioning_memory_budget_ratio": 0.5, })(1.0) # doesn't crash. - def test_exec_time_optimization_effort_compiler_option(self): - def f(x): - return jnp.sqrt(x ** 2) + 1. - - f_jit = jit( - f, - compiler_options={ - "exec_time_optimization_effort": 0.0, - })(1.0) # doesn't crash. - - with self.assertRaisesRegex(jax.errors.JaxRuntimeError, "No such"): - f_jit = jit( - f, - compiler_options={ - "exec_time_compilation_effort": 0.0, - })(1.0) - def test_optimization_level_compiler_option(self): def f(x): return jnp.sqrt(x**2) + 1.0