Skip to content

Commit 397d703

Browse files
authored
Add sample probabilities (#3410)
1 parent ac4c629 commit 397d703

File tree

8 files changed

+102
-34
lines changed

8 files changed

+102
-34
lines changed

docs/release-notes/3410.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add sampling probabilities/mask parameter `p` to {func}`~scanpy.pp.sample` {smaller}`P Angerer`

src/scanpy/get/_aggregated.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,7 @@ def aggregate(
263263
if axis is None:
264264
axis = 1 if varm else 0
265265
axis, axis_name = _resolve_axis(axis)
266-
if mask is not None:
267-
mask = _check_mask(adata, mask, axis_name)
266+
mask = _check_mask(adata, mask, axis_name)
268267
data = adata.X
269268
if sum(p is not None for p in [varm, obsm, layer]) > 1:
270269
raise TypeError("Please only provide one (or none) of varm, obsm, or layer")

src/scanpy/get/get.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, TypeVar
66

77
import numpy as np
88
import pandas as pd
99
from anndata import AnnData
10+
from numpy.typing import NDArray
1011
from packaging.version import Version
1112
from scipy.sparse import spmatrix
1213

@@ -16,7 +17,11 @@
1617

1718
from anndata._core.sparse_dataset import BaseCompressedSparseDataset
1819
from anndata._core.views import ArrayView
19-
from numpy.typing import NDArray
20+
from scipy.sparse import csc_matrix, csr_matrix
21+
22+
from .._compat import DaskArray
23+
24+
CSMatrix = csr_matrix | csc_matrix
2025

2126
# --------------------------------------------------------------------------------
2227
# Plotting data helpers
@@ -485,42 +490,62 @@ def _set_obs_rep(
485490
raise AssertionError(msg)
486491

487492

493+
M = TypeVar("M", bound=NDArray[np.bool_] | NDArray[np.floating] | pd.Series | None)
494+
495+
488496
def _check_mask(
489-
data: AnnData | np.ndarray,
490-
mask: NDArray[np.bool_] | str,
497+
data: AnnData | np.ndarray | CSMatrix | DaskArray,
498+
mask: str | M,
491499
dim: Literal["obs", "var"],
492-
) -> NDArray[np.bool_]: # Could also be a series, but should be one or the other
500+
*,
501+
allow_probabilities: bool = False,
502+
) -> M: # Could also be a series, but should be one or the other
493503
"""
494504
Validate mask argument
495505
Params
496506
------
497507
data
498508
Annotated data matrix or numpy array.
499509
mask
500-
The mask. Either an appropriatley sized boolean array, or name of a column which will be used to mask.
510+
Mask (or probabilities if `allow_probabilities=True`).
511+
Either an appropriatley sized array, or name of a column.
501512
dim
502513
The dimension being masked.
514+
allow_probabilities
515+
Whether to allow probabilities as `mask`
503516
"""
517+
if mask is None:
518+
return mask
519+
desc = "mask/probabilities" if allow_probabilities else "mask"
520+
504521
if isinstance(mask, str):
505522
if not isinstance(data, AnnData):
506-
msg = "Cannot refer to mask with string without providing anndata object as argument"
523+
msg = f"Cannot refer to {desc} with string without providing anndata object as argument"
507524
raise ValueError(msg)
508525

509526
annot: pd.DataFrame = getattr(data, dim)
510527
if mask not in annot.columns:
511528
msg = (
512529
f"Did not find `adata.{dim}[{mask!r}]`. "
513-
f"Either add the mask first to `adata.{dim}`"
514-
"or consider using the mask argument with a boolean array."
530+
f"Either add the {desc} first to `adata.{dim}`"
531+
f"or consider using the {desc} argument with an array."
515532
)
516533
raise ValueError(msg)
517534
mask_array = annot[mask].to_numpy()
518535
else:
519536
if len(mask) != data.shape[0 if dim == "obs" else 1]:
520-
raise ValueError("The shape of the mask do not match the data.")
537+
msg = f"The shape of the {desc} do not match the data."
538+
raise ValueError(msg)
521539
mask_array = mask
522540

523-
if not pd.api.types.is_bool_dtype(mask_array.dtype):
524-
raise ValueError("Mask array must be boolean.")
541+
is_bool = pd.api.types.is_bool_dtype(mask_array.dtype)
542+
if not allow_probabilities and not is_bool:
543+
msg = "Mask array must be boolean."
544+
raise ValueError(msg)
545+
elif allow_probabilities and not (
546+
is_bool or pd.api.types.is_float_dtype(mask_array.dtype)
547+
):
548+
msg = f"{desc} array must be boolean or floating point."
549+
raise ValueError(msg)
525550

526551
return mask_array

src/scanpy/plotting/_tools/scatterplots.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,7 @@ def embedding(
150150
# Checking the mask format and if used together with groups
151151
if groups is not None and mask_obs is not None:
152152
raise ValueError("Groups and mask arguments are incompatible.")
153-
if mask_obs is not None:
154-
mask_obs = _check_mask(adata, mask_obs, "obs")
153+
mask_obs = _check_mask(adata, mask_obs, "obs")
155154

156155
# Figure out if we're using raw
157156
if use_raw is None:

src/scanpy/preprocessing/_scale.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ def scale_array(
164164
):
165165
if copy:
166166
X = X.copy()
167+
mask_obs = _check_mask(X, mask_obs, "obs")
167168
if mask_obs is not None:
168-
mask_obs = _check_mask(X, mask_obs, "obs")
169169
scale_rv = scale_array(
170170
X[mask_obs, :],
171171
zero_center=zero_center,

src/scanpy/preprocessing/_simple.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
sanitize_anndata,
3131
view_to_actual,
3232
)
33-
from ..get import _get_obs_rep, _set_obs_rep
33+
from ..get import _check_mask, _get_obs_rep, _set_obs_rep
3434
from ._distributed import materialize_as_ndarray
3535
from ._utils import _to_dense
3636

@@ -838,6 +838,7 @@ def sample(
838838
copy: Literal[False] = False,
839839
replace: bool = False,
840840
axis: Literal["obs", 0, "var", 1] = "obs",
841+
p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None,
841842
) -> None: ...
842843
@overload
843844
def sample(
@@ -849,6 +850,7 @@ def sample(
849850
copy: Literal[True],
850851
replace: bool = False,
851852
axis: Literal["obs", 0, "var", 1] = "obs",
853+
p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None,
852854
) -> AnnData: ...
853855
@overload
854856
def sample(
@@ -860,6 +862,7 @@ def sample(
860862
copy: bool = False,
861863
replace: bool = False,
862864
axis: Literal["obs", 0, "var", 1] = "obs",
865+
p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None,
863866
) -> tuple[A, NDArray[np.int64]]: ...
864867
def sample(
865868
data: AnnData | np.ndarray | CSMatrix | DaskArray,
@@ -870,6 +873,7 @@ def sample(
870873
copy: bool = False,
871874
replace: bool = False,
872875
axis: Literal["obs", 0, "var", 1] = "obs",
876+
p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None,
873877
) -> AnnData | None | tuple[np.ndarray | CSMatrix | DaskArray, NDArray[np.int64]]:
874878
"""\
875879
Sample observations or variables with or without replacement.
@@ -881,6 +885,7 @@ def sample(
881885
Rows correspond to cells and columns to genes.
882886
fraction
883887
Sample to this `fraction` of the number of observations or variables.
888+
(All of them, even if there are `0`s/`False`s in `p`.)
884889
This can be larger than 1.0, if `replace=True`.
885890
See `axis` and `replace`.
886891
n
@@ -894,6 +899,10 @@ def sample(
894899
If True, samples are drawn with replacement.
895900
axis
896901
Sample `obs`\\ ervations (axis 0) or `var`\\ iables (axis 1).
902+
p
903+
Drawing probabilities (floats) or mask (bools).
904+
Either an `axis`-sized array, or the name of a column.
905+
If `p` is an array of probabilities, it must sum to 1.
897906
898907
Returns
899908
-------
@@ -910,6 +919,9 @@ def sample(
910919
msg = "Inplace sampling (`copy=False`) is not implemented for backed objects."
911920
raise NotImplementedError(msg)
912921
axis, axis_name = _resolve_axis(axis)
922+
p = _check_mask(data, p, dim=axis_name, allow_probabilities=True)
923+
if p is not None and p.dtype == bool:
924+
p = p.astype(np.float64) / p.sum()
913925
old_n = data.shape[axis]
914926
match (fraction, n):
915927
case (None, None):
@@ -933,7 +945,7 @@ def sample(
933945

934946
# actually do subsampling
935947
rng = np.random.default_rng(rng)
936-
indices = rng.choice(old_n, size=n, replace=replace)
948+
indices = rng.choice(old_n, size=n, replace=replace, p=p)
937949

938950
# overload 1: inplace AnnData subset
939951
if not copy and isinstance(data, AnnData):

src/scanpy/tools/_rank_genes_groups.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,8 +594,7 @@ def rank_genes_groups(
594594
>>> # to visualize the results
595595
>>> sc.pl.rank_genes_groups(adata)
596596
"""
597-
if mask_var is not None:
598-
mask_var = _check_mask(adata, mask_var, "var")
597+
mask_var = _check_mask(adata, mask_var, "var")
599598

600599
if use_raw is None:
601600
use_raw = adata.raw is not None

tests/test_preprocessing.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from collections.abc import Callable
3030
from typing import Any, Literal
3131

32+
from numpy.typing import NDArray
33+
3234
CSMatrix = sp.csc_matrix | sp.csr_matrix
3335

3436

@@ -144,31 +146,55 @@ def test_normalize_per_cell():
144146
assert adata.X.sum(axis=1).tolist() == adata_sparse.X.sum(axis=1).A1.tolist()
145147

146148

149+
def _random_probs(n: int, frac_zero: float) -> NDArray[np.float64]:
150+
"""
151+
Generate a random probability distribution of `n` values between 0 and 1.
152+
"""
153+
probs = np.random.randint(0, 10000, n).astype(np.float64)
154+
probs[probs < np.quantile(probs, frac_zero)] = 0
155+
probs /= probs.sum()
156+
np.testing.assert_almost_equal(probs.sum(), 1)
157+
return probs
158+
159+
147160
@pytest.mark.parametrize("array_type", ARRAY_TYPES)
148161
@pytest.mark.parametrize("which", ["copy", "inplace", "array"])
149162
@pytest.mark.parametrize(
150-
("axis", "fraction", "n", "replace", "expected"),
163+
("axis", "f_or_n", "replace"),
164+
[
165+
pytest.param(0, 40, False, id="obs-40-no_replace"),
166+
pytest.param(0, 0.1, False, id="obs-0.1-no_replace"),
167+
pytest.param(0, 201, True, id="obs-201-replace"),
168+
pytest.param(0, 1, True, id="obs-1-replace"),
169+
pytest.param(1, 10, False, id="var-10-no_replace"),
170+
pytest.param(1, 11, True, id="var-11-replace"),
171+
pytest.param(1, 2.0, True, id="var-2.0-replace"),
172+
],
173+
)
174+
@pytest.mark.parametrize(
175+
"ps",
151176
[
152-
pytest.param(0, None, 40, False, 40, id="obs-40-no_replace"),
153-
pytest.param(0, 0.1, None, False, 20, id="obs-0.1-no_replace"),
154-
pytest.param(0, None, 201, True, 201, id="obs-201-replace"),
155-
pytest.param(0, None, 1, True, 1, id="obs-1-replace"),
156-
pytest.param(1, None, 10, False, 10, id="var-10-no_replace"),
157-
pytest.param(1, None, 11, True, 11, id="var-11-replace"),
158-
pytest.param(1, 2.0, None, True, 20, id="var-2.0-replace"),
177+
dict(obs=None, var=None),
178+
dict(obs=np.tile([True, False], 100), var=np.tile([True, False], 5)),
179+
dict(obs=_random_probs(200, 0.3), var=_random_probs(10, 0.7)),
159180
],
181+
ids=["all", "mask", "p"],
160182
)
161183
def test_sample(
162184
*,
185+
request: pytest.FixtureRequest,
163186
array_type: Callable[[np.ndarray], np.ndarray | CSMatrix],
164187
which: Literal["copy", "inplace", "array"],
165188
axis: Literal[0, 1],
166-
fraction: float | None,
167-
n: int | None,
189+
f_or_n: float | int, # noqa: PYI041
168190
replace: bool,
169-
expected: int,
191+
ps: dict[Literal["obs", "var"], NDArray[np.bool_] | None],
170192
):
171193
adata = AnnData(array_type(np.ones((200, 10))))
194+
p = ps["obs" if axis == 0 else "var"]
195+
expected = int(adata.shape[axis] * f_or_n) if isinstance(f_or_n, float) else f_or_n
196+
if p is not None and not replace and expected > (n_possible := (p != 0).sum()):
197+
request.applymarker(pytest.xfail(f"Can’t draw {expected} out of {n_possible}"))
172198

173199
# ignoring this warning declaratively is a pain so do it here
174200
if find_spec("dask"):
@@ -182,12 +208,13 @@ def test_sample(
182208
)
183209
rv = sc.pp.sample(
184210
adata.X if which == "array" else adata,
185-
fraction,
186-
n=n,
211+
f_or_n if isinstance(f_or_n, float) else None,
212+
n=f_or_n if isinstance(f_or_n, int) else None,
187213
replace=replace,
188214
axis=axis,
189215
# `copy` only effects AnnData inputs
190216
copy=dict(copy=True, inplace=False, array=False)[which],
217+
p=p,
191218
)
192219

193220
match which:
@@ -232,6 +259,12 @@ def test_sample(
232259
r"`fraction=-0\.3` needs to be nonnegative",
233260
id="frac<0",
234261
),
262+
pytest.param(
263+
dict(n=3, p=np.ones(200, dtype=np.int32)),
264+
ValueError,
265+
r"mask/probabilities array must be boolean or floating point",
266+
id="type(p)",
267+
),
235268
],
236269
)
237270
def test_sample_error(args: dict[str, Any], exc: type[Exception], pattern: str):

0 commit comments

Comments
 (0)