File tree Expand file tree Collapse file tree 5 files changed +1
-52
lines changed
Expand file tree Collapse file tree 5 files changed +1
-52
lines changed Original file line number Diff line number Diff line change @@ -188,8 +188,6 @@ def get_compile_options(
188188 assert device_assignment .computation_count () == num_partitions
189189 compile_options .device_assignment = device_assignment
190190
191- build_options .exec_time_optimization_effort = config .exec_time_optimization_effort .value
192- build_options .memory_fitting_effort = config .memory_fitting_effort .value
193191 build_options .optimization_level = config .EffortLevel (
194192 config .optimization_level .value
195193 ).value
@@ -199,11 +197,7 @@ def get_compile_options(
199197
200198 if env_options_overrides is not None :
201199 # Some overrides are passed directly on build_options.
202- overrides_on_build_options = [
203- "exec_time_optimization_effort" , "memory_fitting_effort" ]
204- overrides_on_build_options .extend (
205- ["optimization_level" , "memory_fitting_level" ]
206- )
200+ overrides_on_build_options = ["optimization_level" , "memory_fitting_level" ]
207201
208202 env_options_overrides = dict (env_options_overrides )
209203 for name in overrides_on_build_options :
Original file line number Diff line number Diff line change @@ -2136,18 +2136,6 @@ def _default_pmap_no_rank_reduction(new_val):
21362136 ),
21372137)
21382138
2139- exec_time_optimization_effort = float_state (
2140- name = 'jax_exec_time_optimization_effort' ,
2141- default = 0.0 ,
2142- help = 'Effort for minimizing execution time (higher means more effort), valid range [-1.0, 1.0].'
2143- )
2144-
2145- memory_fitting_effort = float_state (
2146- name = 'jax_memory_fitting_effort' ,
2147- default = 0.0 ,
2148- help = 'Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].'
2149- )
2150-
21512139optimization_level = enum_state (
21522140 name = 'jax_optimization_level' ,
21532141 enum_values = [
Original file line number Diff line number Diff line change @@ -561,16 +561,6 @@ class ExecutableBuildOptions:
561561 self , arg : bytes , /
562562 ) -> None : ...
563563 @property
564- def exec_time_optimization_effort (self ) -> float : ...
565- @exec_time_optimization_effort .setter
566- def exec_time_optimization_effort (
567- self , arg : float , /
568- ) -> ExecutableBuildOptions : ...
569- @property
570- def memory_fitting_effort (self ) -> float : ...
571- @memory_fitting_effort .setter
572- def memory_fitting_effort (self , arg : float , / ) -> ExecutableBuildOptions : ...
573- @property
574564 def optimization_level (self ) -> int : ...
575565 @optimization_level .setter
576566 def optimization_level (self , arg : int , / ) -> None : ...
Original file line number Diff line number Diff line change @@ -1139,12 +1139,6 @@ void BuildXlaCompilerSubmodule(nb::module_& m) {
11391139 xla::CompilationEnvironments::CreateFromProto (env_proto));
11401140 *options.mutable_comp_envs () = std::move (*comp_envs);
11411141 })
1142- .def_prop_rw (" exec_time_optimization_effort" ,
1143- &ExecutableBuildOptions::exec_time_optimization_effort,
1144- &ExecutableBuildOptions::set_exec_time_optimization_effort)
1145- .def_prop_rw (" memory_fitting_effort" ,
1146- &ExecutableBuildOptions::memory_fitting_effort,
1147- &ExecutableBuildOptions::set_memory_fitting_effort)
11481142 .def_prop_rw (
11491143 " optimization_level" ,
11501144 [](ExecutableBuildOptions& options) {
Original file line number Diff line number Diff line change @@ -1383,23 +1383,6 @@ def f(x):
13831383 "xla_gpu_auto_spmd_partitioning_memory_budget_ratio" : 0.5 ,
13841384 })(1.0 ) # doesn't crash.
13851385
1386- def test_exec_time_optimization_effort_compiler_option (self ):
1387- def f (x ):
1388- return jnp .sqrt (x ** 2 ) + 1.
1389-
1390- f_jit = jit (
1391- f ,
1392- compiler_options = {
1393- "exec_time_optimization_effort" : 0.0 ,
1394- })(1.0 ) # doesn't crash.
1395-
1396- with self .assertRaisesRegex (jax .errors .JaxRuntimeError , "No such" ):
1397- f_jit = jit (
1398- f ,
1399- compiler_options = {
1400- "exec_time_compilation_effort" : 0.0 ,
1401- })(1.0 )
1402-
14031386 def test_optimization_level_compiler_option (self ):
14041387 def f (x ):
14051388 return jnp .sqrt (x ** 2 ) + 1.0
You can’t perform that action at this time.
0 commit comments