Skip to content

Commit 9f7d773

Browse files
rishipalGoogle-ML-Automation
authored andcommitted
Delete the floating point optimization level and memory fitting level settings in favor of the EffortLevel enum.
##### Why? Simplifies flag input handling by removing multiple ways to specify the same thing. PiperOrigin-RevId: 821862460
1 parent 680045b commit 9f7d773

File tree

5 files changed

+1
-52
lines changed

5 files changed

+1
-52
lines changed

jax/_src/compiler.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,6 @@ def get_compile_options(
188188
assert device_assignment.computation_count() == num_partitions
189189
compile_options.device_assignment = device_assignment
190190

191-
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
192-
build_options.memory_fitting_effort = config.memory_fitting_effort.value
193191
build_options.optimization_level = config.EffortLevel(
194192
config.optimization_level.value
195193
).value
@@ -199,11 +197,7 @@ def get_compile_options(
199197

200198
if env_options_overrides is not None:
201199
# Some overrides are passed directly on build_options.
202-
overrides_on_build_options = [
203-
"exec_time_optimization_effort", "memory_fitting_effort"]
204-
overrides_on_build_options.extend(
205-
["optimization_level", "memory_fitting_level"]
206-
)
200+
overrides_on_build_options = ["optimization_level", "memory_fitting_level"]
207201

208202
env_options_overrides = dict(env_options_overrides)
209203
for name in overrides_on_build_options:

jax/_src/config.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2136,18 +2136,6 @@ def _default_pmap_no_rank_reduction(new_val):
21362136
),
21372137
)
21382138

2139-
exec_time_optimization_effort = float_state(
2140-
name='jax_exec_time_optimization_effort',
2141-
default=0.0,
2142-
help='Effort for minimizing execution time (higher means more effort), valid range [-1.0, 1.0].'
2143-
)
2144-
2145-
memory_fitting_effort = float_state(
2146-
name='jax_memory_fitting_effort',
2147-
default=0.0,
2148-
help='Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].'
2149-
)
2150-
21512139
optimization_level = enum_state(
21522140
name='jax_optimization_level',
21532141
enum_values=[

jaxlib/_jax/__init__.pyi

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -561,16 +561,6 @@ class ExecutableBuildOptions:
561561
self, arg: bytes, /
562562
) -> None: ...
563563
@property
564-
def exec_time_optimization_effort(self) -> float: ...
565-
@exec_time_optimization_effort.setter
566-
def exec_time_optimization_effort(
567-
self, arg: float, /
568-
) -> ExecutableBuildOptions: ...
569-
@property
570-
def memory_fitting_effort(self) -> float: ...
571-
@memory_fitting_effort.setter
572-
def memory_fitting_effort(self, arg: float, /) -> ExecutableBuildOptions: ...
573-
@property
574564
def optimization_level(self) -> int: ...
575565
@optimization_level.setter
576566
def optimization_level(self, arg: int, /) -> None: ...

jaxlib/xla_compiler.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,12 +1139,6 @@ void BuildXlaCompilerSubmodule(nb::module_& m) {
11391139
xla::CompilationEnvironments::CreateFromProto(env_proto));
11401140
*options.mutable_comp_envs() = std::move(*comp_envs);
11411141
})
1142-
.def_prop_rw("exec_time_optimization_effort",
1143-
&ExecutableBuildOptions::exec_time_optimization_effort,
1144-
&ExecutableBuildOptions::set_exec_time_optimization_effort)
1145-
.def_prop_rw("memory_fitting_effort",
1146-
&ExecutableBuildOptions::memory_fitting_effort,
1147-
&ExecutableBuildOptions::set_memory_fitting_effort)
11481142
.def_prop_rw(
11491143
"optimization_level",
11501144
[](ExecutableBuildOptions& options) {

tests/api_test.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,23 +1383,6 @@ def f(x):
13831383
"xla_gpu_auto_spmd_partitioning_memory_budget_ratio": 0.5,
13841384
})(1.0) # doesn't crash.
13851385

1386-
def test_exec_time_optimization_effort_compiler_option(self):
1387-
def f(x):
1388-
return jnp.sqrt(x ** 2) + 1.
1389-
1390-
f_jit = jit(
1391-
f,
1392-
compiler_options={
1393-
"exec_time_optimization_effort": 0.0,
1394-
})(1.0) # doesn't crash.
1395-
1396-
with self.assertRaisesRegex(jax.errors.JaxRuntimeError, "No such"):
1397-
f_jit = jit(
1398-
f,
1399-
compiler_options={
1400-
"exec_time_compilation_effort": 0.0,
1401-
})(1.0)
1402-
14031386
def test_optimization_level_compiler_option(self):
14041387
def f(x):
14051388
return jnp.sqrt(x**2) + 1.0

0 commit comments

Comments
 (0)