Skip to content

Commit 40b9ad5

Browse files
committed
add lots of type hints
1 parent 043581e commit 40b9ad5

File tree

15 files changed

+346
-193
lines changed

15 files changed

+346
-193
lines changed

beets/autotag/distance.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def add_string(self, key: str, str1: str | None, str2: str | None):
345345
dist = string_dist(str1, str2)
346346
self.add(key, dist)
347347

348-
def add_data_source(self, before: str | None, after: str | None) -> None:
348+
def add_data_source(self, before: object, after: str | None) -> None:
349349
if before != after and (
350350
before or len(metadata_plugins.find_metadata_source_plugins()) > 1
351351
):
@@ -384,11 +384,19 @@ def track_distance(
384384
cached because this function is called many times during the matching
385385
process and their access comes with a performance overhead.
386386
"""
387-
dist = Distance()
387+
dist: Distance = Distance()
388388

389389
# Length.
390+
info_length: float | None
390391
if info_length := track_info.length:
391-
diff = abs(item.length - info_length) - get_track_length_grace()
392+
diff: float = (
393+
abs(
394+
(item.length - info_length)
395+
if isinstance(item.length, (int, float))
396+
else 0
397+
)
398+
- get_track_length_grace()
399+
)
392400
dist.add_ratio("track_length", diff, get_track_length_max())
393401

394402
# Title.

beets/dbcore/db.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,7 @@
3737

3838
from ..util import cached_classproperty, functemplate
3939
from . import types
40-
from .query import (
41-
FieldQueryType,
42-
FieldSort,
43-
MatchQuery,
44-
NullSort,
45-
Query,
46-
Sort,
47-
TrueQuery,
48-
)
40+
from .query import FieldSort, MatchQuery, NullSort, Query, Sort, TrueQuery
4941

5042
if TYPE_CHECKING:
5143
from types import TracebackType
@@ -310,7 +302,7 @@ def _types(cls) -> dict[str, types.Type]:
310302
"""
311303

312304
@cached_classproperty
313-
def _queries(cls) -> dict[str, FieldQueryType]:
305+
def _queries(cls) -> dict[str, type[Query]]:
314306
"""Named queries that use a field-like `name:value` syntax but which
315307
do not relate to any specific field.
316308
"""
@@ -328,7 +320,7 @@ def _queries(cls) -> dict[str, FieldQueryType]:
328320
"""
329321

330322
@cached_classproperty
331-
def _relation(cls):
323+
def _relation(cls) -> type[Model[D]]:
332324
"""The model that this model is closely related to."""
333325
return cls
334326

@@ -373,7 +365,7 @@ def __init__(self, db: D | None = None, **values):
373365
"""Create a new object with an optional Database association and
374366
initial field values.
375367
"""
376-
self._db = db
368+
self._db: D | None = db
377369
self._dirty: set[str] = set()
378370
self._values_fixed = LazyConvertDict(self)
379371
self._values_flex = LazyConvertDict(self)
@@ -744,7 +736,7 @@ def __getstate__(self):
744736
AnyModel = TypeVar("AnyModel", bound=Model)
745737

746738

747-
class Results(Generic[AnyModel]):
739+
class Results(Generic[AnyModel, D]):
748740
"""An item query result set. Iterating over the collection lazily
749741
constructs Model objects that reflect database rows.
750742
"""
@@ -1238,7 +1230,7 @@ def _make_attribute_table(self, flex_table: str):
12381230

12391231
# Querying.
12401232

1241-
def _fetch(
1233+
def _fetch_model(
12421234
self,
12431235
model_cls: type[AnyModel],
12441236
query: Query | None = None,
@@ -1304,4 +1296,4 @@ def _get(
13041296
"""Get a Model object by its id or None if the id does not
13051297
exist.
13061298
"""
1307-
return self._fetch(model_cls, MatchQuery("id", id)).get()
1299+
return self._fetch_model(model_cls, MatchQuery("id", id)).get()

beets/dbcore/query.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def __hash__(self) -> int:
124124

125125
SQLiteType = Union[str, bytes, float, int, memoryview, None]
126126
AnySQLiteType = TypeVar("AnySQLiteType", bound=SQLiteType)
127-
FieldQueryType = type["FieldQuery"]
128127

129128

130129
class FieldQuery(Query, Generic[P]):

beets/dbcore/queryparse.py

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,37 +18,44 @@
1818

1919
import itertools
2020
import re
21-
from typing import TYPE_CHECKING
21+
from typing import TYPE_CHECKING, cast
2222

23-
from . import query
23+
from beets.library import Album, Item, LibModel
24+
25+
from . import FieldQuery, Model, Query, query
2426

2527
if TYPE_CHECKING:
28+
import sys
2629
from collections.abc import Collection, Sequence
2730

28-
from ..library import LibModel
29-
from .query import FieldQueryType, Sort
31+
from .query import Sort
32+
33+
if not sys.version_info < (3, 10):
34+
from typing import TypeAlias # pyright: ignore[reportUnreachable]
35+
else:
36+
from typing_extensions import TypeAlias
3037

31-
Prefixes = dict[str, FieldQueryType]
38+
Prefixes: TypeAlias = dict[str, type[FieldQuery]]
3239

3340

3441
PARSE_QUERY_PART_REGEX = re.compile(
3542
# Non-capturing optional segment for the keyword.
36-
r"(-|\^)?" # Negation prefixes.
37-
r"(?:"
38-
r"(\S+?)" # The field key.
39-
r"(?<!\\):" # Unescaped :
40-
r")?"
41-
r"(.*)", # The term itself.
43+
r"(-|\^)?" # Negation prefixes. # noqa: ISC003
44+
+ r"(?:"
45+
+ r"(\S+?)" # The field key.
46+
+ r"(?<!\\):" # Unescaped :
47+
+ r")?"
48+
+ r"(.*)", # The term itself.
4249
re.I, # Case-insensitive.
4350
)
4451

4552

4653
def parse_query_part(
4754
part: str,
48-
query_classes: dict[str, FieldQueryType] = {},
55+
query_classes: dict[str, type[Query]] = {},
4956
prefixes: Prefixes = {},
5057
default_class: type[query.SubstringQuery] = query.SubstringQuery,
51-
) -> tuple[str | None, str, FieldQueryType, bool]:
58+
) -> tuple[str | None, str, type[Query], bool]:
5259
"""Parse a single *query part*, which is a chunk of a complete query
5360
string representing a single criterion.
5461
@@ -94,15 +101,17 @@ def parse_query_part(
94101
"""
95102
# Apply the regular expression and extract the components.
96103
part = part.strip()
97-
match = PARSE_QUERY_PART_REGEX.match(part)
104+
match: re.Match[str] | None = PARSE_QUERY_PART_REGEX.match(part)
98105

99106
assert match # Regex should always match
100-
negate = bool(match.group(1))
101-
key = match.group(2)
102-
term = match.group(3).replace("\\:", ":")
107+
negate: bool = bool(match.group(1))
108+
key: str = match.group(2)
109+
term: str = match.group(3).replace("\\:", ":")
103110

104111
# Check whether there's a prefix in the query and use the
105112
# corresponding query type.
113+
pre: str
114+
query_class: type[Query]
106115
for pre, query_class in prefixes.items():
107116
if term.startswith(pre):
108117
return key, term[len(pre) :], query_class, negate
@@ -137,26 +146,30 @@ def construct_query_part(
137146

138147
# Use `model_cls` to build up a map from field (or query) names to
139148
# `Query` classes.
140-
query_classes: dict[str, FieldQueryType] = {}
149+
query_classes: dict[str, type[Query]] = {}
141150
for k, t in itertools.chain(
142151
model_cls._fields.items(), model_cls._types.items()
143152
):
144153
query_classes[k] = t.query
145-
query_classes.update(model_cls._queries) # Non-field queries.
154+
query_classes.update(
155+
model_cls._queries.items() # Non-field queries.
156+
)
146157

147158
# Parse the string.
148159
key, pattern, query_class, negate = parse_query_part(
149160
query_part, query_classes, prefixes
150161
)
162+
if key is not None:
163+
# Field queries get constructed according to the name of the field
164+
# they are querying.
165+
out_query = model_cls.field_query(
166+
key.lower(), pattern, cast(type[FieldQuery], query_class)
167+
)
151168

152-
if key is None:
169+
else:
153170
# If there's no key (field name) specified, this is a "match anything"
154171
# query.
155172
out_query = model_cls.any_field_query(pattern, query_class)
156-
else:
157-
# Field queries get constructed according to the name of the field
158-
# they are querying.
159-
out_query = model_cls.field_query(key.lower(), pattern, query_class)
160173

161174
# Apply negation.
162175
if negate:
@@ -176,7 +189,7 @@ def query_from_strings(
176189
strings in the format used by parse_query_part. `model_cls`
177190
determines how queries are constructed from strings.
178191
"""
179-
subqueries = []
192+
subqueries: list[Query] = []
180193
for part in query_parts:
181194
subqueries.append(construct_query_part(model_cls, prefixes, part))
182195
if not subqueries: # No terms in query.
@@ -185,7 +198,7 @@ def query_from_strings(
185198

186199

187200
def construct_sort_part(
188-
model_cls: type[LibModel],
201+
model_cls: Model,
189202
part: str,
190203
case_insensitive: bool = True,
191204
) -> Sort:
@@ -196,6 +209,9 @@ def construct_sort_part(
196209
indicates whether or not the sort should be performed in a case
197210
sensitive manner.
198211
"""
212+
assert isinstance(model_cls, type(Album)) or isinstance(
213+
model_cls, type(Item)
214+
)
199215
assert part, "part must be a field name and + or -"
200216
field = part[:-1]
201217
assert field, "field is missing"
@@ -224,12 +240,12 @@ def sort_from_strings(
224240
if not sort_parts:
225241
return query.NullSort()
226242
elif len(sort_parts) == 1:
227-
return construct_sort_part(model_cls, sort_parts[0], case_insensitive)
243+
return construct_sort_part(model_cls, sort_parts[0], case_insensitive) # type: ignore[arg-type]
228244
else:
229245
sort = query.MultipleSort()
230246
for part in sort_parts:
231247
sort.add_sort(
232-
construct_sort_part(model_cls, part, case_insensitive)
248+
construct_sort_part(model_cls, part, case_insensitive) # type: ignore[arg-type]
233249
)
234250
return sort
235251

beets/dbcore/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class Type(ABC, Generic[T, N]):
6262
"""The SQLite column type for the value.
6363
"""
6464

65-
query: query.FieldQueryType = query.SubstringQuery
65+
query: type[query.FieldQuery] = query.SubstringQuery
6666
"""The `Query` subclass to be used when querying the field.
6767
"""
6868

@@ -242,7 +242,7 @@ class BaseFloat(Type[float, N]):
242242
"""
243243

244244
sql = "REAL"
245-
query: query.FieldQueryType = query.NumericQuery
245+
query: type[query.FieldQuery] = query.NumericQuery
246246
model_type = float
247247

248248
def __init__(self, digits: int = 1):

beets/library/library.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import sys
34
from typing import TYPE_CHECKING
45

56
import platformdirs
@@ -8,11 +9,17 @@
89
from beets import dbcore
910
from beets.util import normpath
1011

11-
from .models import Album, Item
12+
from .models import Album, AnyLibModel, Item
1213
from .queries import PF_KEY_DEFAULT, parse_query_parts, parse_query_string
1314

1415
if TYPE_CHECKING:
1516
from beets.dbcore import Results
17+
from beets.dbcore.query import Query, Sort
18+
19+
if not sys.version_info < (3, 12):
20+
pass # pyright: ignore[reportUnreachable]
21+
else:
22+
pass
1623

1724

1825
class Library(dbcore.Database):
@@ -79,7 +86,12 @@ def add_album(self, items):
7986

8087
# Querying.
8188

82-
def _fetch(self, model_cls, query, sort=None):
89+
def _fetch(
90+
self,
91+
model_cls: type[AnyLibModel],
92+
query: list[str] | Query | str | tuple[str] | None = None,
93+
sort: Sort | None = None,
94+
) -> Results[AnyLibModel]:
8395
"""Parse a query and fetch.
8496
8597
If an order specification is present in the query string
@@ -100,7 +112,7 @@ def _fetch(self, model_cls, query, sort=None):
100112
if parsed_sort and not isinstance(parsed_sort, dbcore.query.NullSort):
101113
sort = parsed_sort
102114

103-
return super()._fetch(model_cls, query, sort)
115+
return super()._fetch_model(model_cls, query, sort)
104116

105117
@staticmethod
106118
def get_default_album_sort():
@@ -116,11 +128,19 @@ def get_default_item_sort():
116128
Item, beets.config["sort_item"].as_str_seq()
117129
)
118130

119-
def albums(self, query=None, sort=None) -> Results[Album]:
131+
def albums(
132+
self,
133+
query: list[str] | Query | str | tuple[str] | None = None,
134+
sort: Sort | None = None,
135+
) -> Results[Album]:
120136
"""Get :class:`Album` objects matching the query."""
121137
return self._fetch(Album, query, sort or self.get_default_album_sort())
122138

123-
def items(self, query=None, sort=None) -> Results[Item]:
139+
def items(
140+
self,
141+
query: list[str] | Query | str | tuple[str] | None = None,
142+
sort: Sort | None = None,
143+
) -> Results[Item]:
124144
"""Get :class:`Item` objects matching the query."""
125145
return self._fetch(Item, query, sort or self.get_default_item_sort())
126146

0 commit comments

Comments
 (0)