Skip to content

Commit f59cce0

Browse files
Fix the jax line search, the zoom may only store two points in lo, hi and rec.
PiperOrigin-RevId: 833328457
1 parent e4cadda commit f59cce0

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

jax/_src/scipy/optimize/line_search.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,16 @@ def body(state):
191191
),
192192
),
193193
)
194+
state = state._replace(
195+
**_binary_replace(
196+
lo_to_j & ~hi_to_lo,
197+
state._asdict(),
198+
dict(
199+
a_rec=state.a_lo,
200+
phi_rec=state.phi_lo,
201+
),
202+
),
203+
)
194204
state = state._replace(
195205
**_binary_replace(
196206
lo_to_j,
@@ -199,8 +209,6 @@ def body(state):
199209
a_lo=a_j,
200210
phi_lo=phi_j,
201211
dphi_lo=dphi_j,
202-
a_rec=state.a_lo,
203-
phi_rec=state.phi_lo,
204212
),
205213
),
206214
)

0 commit comments

Comments
 (0)