2929 from collections .abc import Callable
3030 from typing import Any , Literal
3131
32+ from numpy .typing import NDArray
33+
3234 CSMatrix = sp .csc_matrix | sp .csr_matrix
3335
3436
@@ -144,31 +146,55 @@ def test_normalize_per_cell():
144146 assert adata .X .sum (axis = 1 ).tolist () == adata_sparse .X .sum (axis = 1 ).A1 .tolist ()
145147
146148
149+ def _random_probs (n : int , frac_zero : float ) -> NDArray [np .float64 ]:
150+ """
151+ Generate a random probability distribution of `n` values between 0 and 1.
152+ """
153+ probs = np .random .randint (0 , 10000 , n ).astype (np .float64 )
154+ probs [probs < np .quantile (probs , frac_zero )] = 0
155+ probs /= probs .sum ()
156+ np .testing .assert_almost_equal (probs .sum (), 1 )
157+ return probs
158+
159+
147160@pytest .mark .parametrize ("array_type" , ARRAY_TYPES )
148161@pytest .mark .parametrize ("which" , ["copy" , "inplace" , "array" ])
149162@pytest .mark .parametrize (
150- ("axis" , "fraction" , "n" , "replace" , "expected" ),
163+ ("axis" , "f_or_n" , "replace" ),
164+ [
165+ pytest .param (0 , 40 , False , id = "obs-40-no_replace" ),
166+ pytest .param (0 , 0.1 , False , id = "obs-0.1-no_replace" ),
167+ pytest .param (0 , 201 , True , id = "obs-201-replace" ),
168+ pytest .param (0 , 1 , True , id = "obs-1-replace" ),
169+ pytest .param (1 , 10 , False , id = "var-10-no_replace" ),
170+ pytest .param (1 , 11 , True , id = "var-11-replace" ),
171+ pytest .param (1 , 2.0 , True , id = "var-2.0-replace" ),
172+ ],
173+ )
174+ @pytest .mark .parametrize (
175+ "ps" ,
151176 [
152- pytest .param (0 , None , 40 , False , 40 , id = "obs-40-no_replace" ),
153- pytest .param (0 , 0.1 , None , False , 20 , id = "obs-0.1-no_replace" ),
154- pytest .param (0 , None , 201 , True , 201 , id = "obs-201-replace" ),
155- pytest .param (0 , None , 1 , True , 1 , id = "obs-1-replace" ),
156- pytest .param (1 , None , 10 , False , 10 , id = "var-10-no_replace" ),
157- pytest .param (1 , None , 11 , True , 11 , id = "var-11-replace" ),
158- pytest .param (1 , 2.0 , None , True , 20 , id = "var-2.0-replace" ),
177+ dict (obs = None , var = None ),
178+ dict (obs = np .tile ([True , False ], 100 ), var = np .tile ([True , False ], 5 )),
179+ dict (obs = _random_probs (200 , 0.3 ), var = _random_probs (10 , 0.7 )),
159180 ],
181+ ids = ["all" , "mask" , "p" ],
160182)
161183def test_sample (
162184 * ,
185+ request : pytest .FixtureRequest ,
163186 array_type : Callable [[np .ndarray ], np .ndarray | CSMatrix ],
164187 which : Literal ["copy" , "inplace" , "array" ],
165188 axis : Literal [0 , 1 ],
166- fraction : float | None ,
167- n : int | None ,
189+ f_or_n : float | int , # noqa: PYI041
168190 replace : bool ,
169- expected : int ,
191+ ps : dict [ Literal [ "obs" , "var" ], NDArray [ np . bool_ ] | None ] ,
170192):
171193 adata = AnnData (array_type (np .ones ((200 , 10 ))))
194+ p = ps ["obs" if axis == 0 else "var" ]
195+ expected = int (adata .shape [axis ] * f_or_n ) if isinstance (f_or_n , float ) else f_or_n
196+ if p is not None and not replace and expected > (n_possible := (p != 0 ).sum ()):
197+ request .applymarker (pytest .xfail (f"Can’t draw { expected } out of { n_possible } " ))
172198
173199 # ignoring this warning declaratively is a pain so do it here
174200 if find_spec ("dask" ):
@@ -182,12 +208,13 @@ def test_sample(
182208 )
183209 rv = sc .pp .sample (
184210 adata .X if which == "array" else adata ,
185- fraction ,
186- n = n ,
211+ f_or_n if isinstance ( f_or_n , float ) else None ,
212+ n = f_or_n if isinstance ( f_or_n , int ) else None ,
187213 replace = replace ,
188214 axis = axis ,
189215 # `copy` only effects AnnData inputs
190216 copy = dict (copy = True , inplace = False , array = False )[which ],
217+ p = p ,
191218 )
192219
193220 match which :
@@ -232,6 +259,12 @@ def test_sample(
232259 r"`fraction=-0\.3` needs to be nonnegative" ,
233260 id = "frac<0" ,
234261 ),
262+ pytest .param (
263+ dict (n = 3 , p = np .ones (200 , dtype = np .int32 )),
264+ ValueError ,
265+ r"mask/probabilities array must be boolean or floating point" ,
266+ id = "type(p)" ,
267+ ),
235268 ],
236269)
237270def test_sample_error (args : dict [str , Any ], exc : type [Exception ], pattern : str ):
0 commit comments