1818
1919import itertools
2020import re
21- from typing import TYPE_CHECKING
21+ from typing import TYPE_CHECKING , cast
2222
23- from . import query
23+ from beets .library import Album , Item , LibModel
24+
25+ from . import FieldQuery , Model , Query , query
2426
2527if TYPE_CHECKING :
28+ import sys
2629 from collections .abc import Collection , Sequence
2730
28- from ..library import LibModel
29- from .query import FieldQueryType , Sort
31+ from .query import Sort
32+
33+ if not sys .version_info < (3 , 10 ):
34+ from typing import TypeAlias # pyright: ignore[reportUnreachable]
35+ else :
36+ from typing_extensions import TypeAlias
3037
31- Prefixes = dict [str , FieldQueryType ]
38+ Prefixes : TypeAlias = dict [str , type [ FieldQuery ] ]
3239
3340
3441PARSE_QUERY_PART_REGEX = re .compile (
3542 # Non-capturing optional segment for the keyword.
36- r"(-|\^)?" # Negation prefixes.
37- r"(?:"
38- r"(\S+?)" # The field key.
39- r"(?<!\\):" # Unescaped :
40- r")?"
41- r"(.*)" , # The term itself.
43+ r"(-|\^)?" # Negation prefixes. # noqa: ISC003
44+ + r"(?:"
45+ + r"(\S+?)" # The field key.
46+ + r"(?<!\\):" # Unescaped :
47+ + r")?"
48+ + r"(.*)" , # The term itself.
4249 re .I , # Case-insensitive.
4350)
4451
4552
4653def parse_query_part (
4754 part : str ,
48- query_classes : dict [str , FieldQueryType ] = {},
55+ query_classes : dict [str , type [ Query ] ] = {},
4956 prefixes : Prefixes = {},
5057 default_class : type [query .SubstringQuery ] = query .SubstringQuery ,
51- ) -> tuple [str | None , str , FieldQueryType , bool ]:
58+ ) -> tuple [str | None , str , type [ Query ] , bool ]:
5259 """Parse a single *query part*, which is a chunk of a complete query
5360 string representing a single criterion.
5461
@@ -94,15 +101,17 @@ def parse_query_part(
94101 """
95102 # Apply the regular expression and extract the components.
96103 part = part .strip ()
97- match = PARSE_QUERY_PART_REGEX .match (part )
104+ match : re . Match [ str ] | None = PARSE_QUERY_PART_REGEX .match (part )
98105
99106 assert match # Regex should always match
100- negate = bool (match .group (1 ))
101- key = match .group (2 )
102- term = match .group (3 ).replace ("\\ :" , ":" )
107+ negate : bool = bool (match .group (1 ))
108+ key : str = match .group (2 )
109+ term : str = match .group (3 ).replace ("\\ :" , ":" )
103110
104111 # Check whether there's a prefix in the query and use the
105112 # corresponding query type.
113+ pre : str
114+ query_class : type [Query ]
106115 for pre , query_class in prefixes .items ():
107116 if term .startswith (pre ):
108117 return key , term [len (pre ) :], query_class , negate
@@ -137,26 +146,30 @@ def construct_query_part(
137146
138147 # Use `model_cls` to build up a map from field (or query) names to
139148 # `Query` classes.
140- query_classes : dict [str , FieldQueryType ] = {}
149+ query_classes : dict [str , type [ Query ] ] = {}
141150 for k , t in itertools .chain (
142151 model_cls ._fields .items (), model_cls ._types .items ()
143152 ):
144153 query_classes [k ] = t .query
145- query_classes .update (model_cls ._queries ) # Non-field queries.
154+ query_classes .update (
155+ model_cls ._queries .items () # Non-field queries.
156+ )
146157
147158 # Parse the string.
148159 key , pattern , query_class , negate = parse_query_part (
149160 query_part , query_classes , prefixes
150161 )
162+ if key is not None :
163+ # Field queries get constructed according to the name of the field
164+ # they are querying.
165+ out_query = model_cls .field_query (
166+ key .lower (), pattern , cast (type [FieldQuery ], query_class )
167+ )
151168
152- if key is None :
169+ else :
153170 # If there's no key (field name) specified, this is a "match anything"
154171 # query.
155172 out_query = model_cls .any_field_query (pattern , query_class )
156- else :
157- # Field queries get constructed according to the name of the field
158- # they are querying.
159- out_query = model_cls .field_query (key .lower (), pattern , query_class )
160173
161174 # Apply negation.
162175 if negate :
@@ -176,7 +189,7 @@ def query_from_strings(
176189 strings in the format used by parse_query_part. `model_cls`
177190 determines how queries are constructed from strings.
178191 """
179- subqueries = []
192+ subqueries : list [ Query ] = []
180193 for part in query_parts :
181194 subqueries .append (construct_query_part (model_cls , prefixes , part ))
182195 if not subqueries : # No terms in query.
@@ -185,7 +198,7 @@ def query_from_strings(
185198
186199
187200def construct_sort_part (
188- model_cls : type [ LibModel ] ,
201+ model_cls : Model ,
189202 part : str ,
190203 case_insensitive : bool = True ,
191204) -> Sort :
@@ -196,6 +209,9 @@ def construct_sort_part(
196209 indicates whether or not the sort should be performed in a case
197210 sensitive manner.
198211 """
212+ assert isinstance (model_cls , type (Album )) or isinstance (
213+ model_cls , type (Item )
214+ )
199215 assert part , "part must be a field name and + or -"
200216 field = part [:- 1 ]
201217 assert field , "field is missing"
@@ -224,12 +240,12 @@ def sort_from_strings(
224240 if not sort_parts :
225241 return query .NullSort ()
226242 elif len (sort_parts ) == 1 :
227- return construct_sort_part (model_cls , sort_parts [0 ], case_insensitive )
243+ return construct_sort_part (model_cls , sort_parts [0 ], case_insensitive ) # type: ignore[arg-type]
228244 else :
229245 sort = query .MultipleSort ()
230246 for part in sort_parts :
231247 sort .add_sort (
232- construct_sort_part (model_cls , part , case_insensitive )
248+ construct_sort_part (model_cls , part , case_insensitive ) # type: ignore[arg-type]
233249 )
234250 return sort
235251
0 commit comments