@@ -86,7 +86,6 @@ def test_getitem(shape, dtype, data):
8686 key = data .draw (xps .indices (shape = shape , allow_newaxis = True ), label = "key" )
8787
8888 repro_snippet = ph .format_snippet (f"{ x !r} [{ key !r} ]" )
89-
9089 try :
9190 out = x [key ]
9291
@@ -109,6 +108,7 @@ def test_getitem(shape, dtype, data):
109108 exc .add_note (repro_snippet )
110109 raise
111110
111+
112112@pytest .mark .unvectorized
113113@given (
114114 shape = hh .shapes (),
@@ -133,28 +133,34 @@ def test_setitem(shape, dtypes, data):
133133 value = data .draw (value_strat , label = "value" )
134134
135135 res = xp .asarray (x , copy = True )
136- res [key ] = value
137-
138- ph .assert_dtype ("__setitem__" , in_dtype = x .dtype , out_dtype = res .dtype , repr_name = "x.dtype" )
139- ph .assert_shape ("__setitem__" , out_shape = res .shape , expected = x .shape , repr_name = "x.shape" )
140- f_res = sh .fmt_idx ("x" , key )
141- if isinstance (value , get_args (Scalar )):
142- msg = f"{ f_res } ={ res [key ]!r} , but should be { value = } [__setitem__()]"
143- if cmath .isnan (value ):
144- assert xp .isnan (res [key ]), msg
136+
137+ repro_snippet = ph .format_snippet (f"{ res !r} [{ key !r} ] = { value !r} " )
138+ try :
139+ res [key ] = value
140+
141+ ph .assert_dtype ("__setitem__" , in_dtype = x .dtype , out_dtype = res .dtype , repr_name = "x.dtype" )
142+ ph .assert_shape ("__setitem__" , out_shape = res .shape , expected = x .shape , repr_name = "x.shape" )
143+ f_res = sh .fmt_idx ("x" , key )
144+ if isinstance (value , get_args (Scalar )):
145+ msg = f"{ f_res } ={ res [key ]!r} , but should be { value = } [__setitem__()]"
146+ if cmath .isnan (value ):
147+ assert xp .isnan (res [key ]), msg
148+ else :
149+ assert res [key ] == value , msg
145150 else :
146- assert res [key ] == value , msg
147- else :
148- ph .assert_array_elements ("__setitem__" , out = res [key ], expected = value , out_repr = f_res )
149- unaffected_indices = set (sh .ndindex (res .shape )) - set (product (* axes_indices ))
150- for idx in unaffected_indices :
151- ph .assert_0d_equals (
152- "__setitem__" ,
153- x_repr = f"old { f_res } " ,
154- x_val = x [idx ],
155- out_repr = f"modified { f_res } " ,
156- out_val = res [idx ],
157- )
151+ ph .assert_array_elements ("__setitem__" , out = res [key ], expected = value , out_repr = f_res )
152+ unaffected_indices = set (sh .ndindex (res .shape )) - set (product (* axes_indices ))
153+ for idx in unaffected_indices :
154+ ph .assert_0d_equals (
155+ "__setitem__" ,
156+ x_repr = f"old { f_res } " ,
157+ x_val = x [idx ],
158+ out_repr = f"modified { f_res } " ,
159+ out_val = res [idx ],
160+ )
161+ except Exception as exc :
162+ exc .add_note (repro_snippet )
163+ raise
158164
159165
160166@pytest .mark .unvectorized
@@ -178,29 +184,34 @@ def test_getitem_masking(shape, data):
178184 x [key ]
179185 return
180186
181- out = x [key ]
187+ repro_snippet = ph .format_snippet (f"out = { x !r} [{ key !r} ]" )
188+ try :
189+ out = x [key ]
182190
183- ph .assert_dtype ("__getitem__" , in_dtype = x .dtype , out_dtype = out .dtype )
184- if key .ndim == 0 :
185- expected_shape = (1 ,) if key else (0 ,)
186- expected_shape += x .shape
187- else :
188- size = int (xp .sum (xp .astype (key , xp .uint8 )))
189- expected_shape = (size ,) + x .shape [key .ndim :]
190- ph .assert_shape ("__getitem__" , out_shape = out .shape , expected = expected_shape )
191- if not any (s == 0 for s in key .shape ):
192- assume (key .ndim == x .ndim ) # TODO: test key.ndim < x.ndim scenarios
193- out_indices = sh .ndindex (out .shape )
194- for x_idx in sh .ndindex (x .shape ):
195- if key [x_idx ]:
196- out_idx = next (out_indices )
197- ph .assert_0d_equals (
198- "__getitem__" ,
199- x_repr = f"x[{ x_idx } ]" ,
200- x_val = x [x_idx ],
201- out_repr = f"out[{ out_idx } ]" ,
202- out_val = out [out_idx ],
203- )
191+ ph .assert_dtype ("__getitem__" , in_dtype = x .dtype , out_dtype = out .dtype )
192+ if key .ndim == 0 :
193+ expected_shape = (1 ,) if key else (0 ,)
194+ expected_shape += x .shape
195+ else :
196+ size = int (xp .sum (xp .astype (key , xp .uint8 )))
197+ expected_shape = (size ,) + x .shape [key .ndim :]
198+ ph .assert_shape ("__getitem__" , out_shape = out .shape , expected = expected_shape )
199+ if not any (s == 0 for s in key .shape ):
200+ assume (key .ndim == x .ndim ) # TODO: test key.ndim < x.ndim scenarios
201+ out_indices = sh .ndindex (out .shape )
202+ for x_idx in sh .ndindex (x .shape ):
203+ if key [x_idx ]:
204+ out_idx = next (out_indices )
205+ ph .assert_0d_equals (
206+ "__getitem__" ,
207+ x_repr = f"x[{ x_idx } ]" ,
208+ x_val = x [x_idx ],
209+ out_repr = f"out[{ out_idx } ]" ,
210+ out_val = out [out_idx ],
211+ )
212+ except Exception as exc :
213+ exc .add_note (repro_snippet )
214+ raise
204215
205216
206217@pytest .mark .unvectorized
@@ -213,38 +224,44 @@ def test_setitem_masking(shape, data):
213224 )
214225
215226 res = xp .asarray (x , copy = True )
216- res [key ] = value
217-
218- ph .assert_dtype ("__setitem__" , in_dtype = x .dtype , out_dtype = res .dtype , repr_name = "x.dtype" )
219- ph .assert_shape ("__setitem__" , out_shape = res .shape , expected = x .shape , repr_name = "x.dtype" )
220- scalar_type = dh .get_scalar_type (x .dtype )
221- for idx in sh .ndindex (x .shape ):
222- if key [idx ]:
223- if isinstance (value , get_args (Scalar )):
224- ph .assert_scalar_equals (
225- "__setitem__" ,
226- type_ = scalar_type ,
227- idx = idx ,
228- out = scalar_type (res [idx ]),
229- expected = value ,
230- repr_name = "modified x" ,
231- )
227+
228+ repro_snippet = ph .format_snippet (f"{ res } [{ key !r} ] = { value !r} " )
229+ try :
230+ res [key ] = value
231+
232+ ph .assert_dtype ("__setitem__" , in_dtype = x .dtype , out_dtype = res .dtype , repr_name = "x.dtype" )
233+ ph .assert_shape ("__setitem__" , out_shape = res .shape , expected = x .shape , repr_name = "x.dtype" )
234+ scalar_type = dh .get_scalar_type (x .dtype )
235+ for idx in sh .ndindex (x .shape ):
236+ if key [idx ]:
237+ if isinstance (value , get_args (Scalar )):
238+ ph .assert_scalar_equals (
239+ "__setitem__" ,
240+ type_ = scalar_type ,
241+ idx = idx ,
242+ out = scalar_type (res [idx ]),
243+ expected = value ,
244+ repr_name = "modified x" ,
245+ )
246+ else :
247+ ph .assert_0d_equals (
248+ "__setitem__" ,
249+ x_repr = "value" ,
250+ x_val = value ,
251+ out_repr = f"modified x[{ idx } ]" ,
252+ out_val = res [idx ]
253+ )
232254 else :
233255 ph .assert_0d_equals (
234256 "__setitem__" ,
235- x_repr = "value " ,
236- x_val = value ,
257+ x_repr = f"old x[ { idx } ] " ,
258+ x_val = x [ idx ] ,
237259 out_repr = f"modified x[{ idx } ]" ,
238260 out_val = res [idx ]
239261 )
240- else :
241- ph .assert_0d_equals (
242- "__setitem__" ,
243- x_repr = f"old x[{ idx } ]" ,
244- x_val = x [idx ],
245- out_repr = f"modified x[{ idx } ]" ,
246- out_val = res [idx ]
247- )
262+ except Exception as exc :
263+ exc .add_note (repro_snippet )
264+ raise
248265
249266
250267# ### Fancy indexing ###
@@ -309,15 +326,20 @@ def _test_getitem_arrays_and_ints(shape, data, idx_max_dims):
309326 key .append (data .draw (st .integers (- shape [i ], shape [i ]- 1 )))
310327
311328 key = tuple (key )
312- out = x [key ]
329+ repro_snippet = ph .format_snippet (f"out = { x !r} [{ key !r} ]" )
330+ try :
331+ out = x [key ]
313332
314- arrays = [xp .asarray (k ) for k in key ]
315- bcast_shape = sh .broadcast_shapes (* [arr .shape for arr in arrays ])
316- bcast_key = [xp .broadcast_to (arr , bcast_shape ) for arr in arrays ]
333+ arrays = [xp .asarray (k ) for k in key ]
334+ bcast_shape = sh .broadcast_shapes (* [arr .shape for arr in arrays ])
335+ bcast_key = [xp .broadcast_to (arr , bcast_shape ) for arr in arrays ]
317336
318- for idx in sh .ndindex (bcast_shape ):
319- tpl = tuple (k [idx ] for k in bcast_key )
320- assert out [idx ] == x [tpl ], f"failing at { idx = } w/ { key = } "
337+ for idx in sh .ndindex (bcast_shape ):
338+ tpl = tuple (k [idx ] for k in bcast_key )
339+ assert out [idx ] == x [tpl ], f"failing at { idx = } w/ { key = } "
340+ except Exception as exc :
341+ exc .add_note (repro_snippet )
342+ raise
321343
322344
323345def make_scalar_casting_param (
0 commit comments