Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions beets/autotag/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def add_string(self, key: str, str1: str | None, str2: str | None):
dist = string_dist(str1, str2)
self.add(key, dist)

def add_data_source(self, before: str | None, after: str | None) -> None:
def add_data_source(self, before: object, after: str | None) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

if before != after and (
before or len(metadata_plugins.find_metadata_source_plugins()) > 1
):
Expand Down Expand Up @@ -384,11 +384,19 @@ def track_distance(
cached because this function is called many times during the matching
process and their access comes with a performance overhead.
"""
dist = Distance()
dist: Distance = Distance()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant


# Length.
info_length: float | None
if info_length := track_info.length:
diff = abs(item.length - info_length) - get_track_length_grace()
diff: float = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, redundant

abs(
(item.length - info_length)
if isinstance(item.length, (int, float))
else 0
)
- get_track_length_grace()
)
dist.add_ratio("track_length", diff, get_track_length_max())

# Title.
Expand Down
24 changes: 8 additions & 16 deletions beets/dbcore/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,7 @@

from ..util import cached_classproperty, functemplate
from . import types
from .query import (
FieldQueryType,
FieldSort,
MatchQuery,
NullSort,
Query,
Sort,
TrueQuery,
)
from .query import FieldSort, MatchQuery, NullSort, Query, Sort, TrueQuery

if TYPE_CHECKING:
from types import TracebackType
Expand Down Expand Up @@ -310,7 +302,7 @@ def _types(cls) -> dict[str, types.Type]:
"""

@cached_classproperty
def _queries(cls) -> dict[str, FieldQueryType]:
def _queries(cls) -> dict[str, type[Query]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these changes from FieldQuery -> Query are incorrect. Why?

"""Named queries that use a field-like `name:value` syntax but which
do not relate to any specific field.
"""
Expand All @@ -328,9 +320,9 @@ def _queries(cls) -> dict[str, FieldQueryType]:
"""

@cached_classproperty
def _relation(cls):
def _relation(cls) -> type[Model[D]]:
"""The model that this model is closely related to."""
return cls
return cls # type: ignore[return-value]

@cached_classproperty
def relation_join(cls) -> str:
Expand Down Expand Up @@ -373,7 +365,7 @@ def __init__(self, db: D | None = None, **values):
"""Create a new object with an optional Database association and
initial field values.
"""
self._db = db
self._db: D | None = db
self._dirty: set[str] = set()
self._values_fixed = LazyConvertDict(self)
self._values_flex = LazyConvertDict(self)
Expand Down Expand Up @@ -744,7 +736,7 @@ def __getstate__(self):
AnyModel = TypeVar("AnyModel", bound=Model)


class Results(Generic[AnyModel]):
class Results(Generic[AnyModel, D]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

"""An item query result set. Iterating over the collection lazily
constructs Model objects that reflect database rows.
"""
Expand Down Expand Up @@ -1238,7 +1230,7 @@ def _make_attribute_table(self, flex_table: str):

# Querying.

def _fetch(
def _fetch_model(
self,
model_cls: type[AnyModel],
query: Query | None = None,
Expand Down Expand Up @@ -1304,4 +1296,4 @@ def _get(
"""Get a Model object by its id or None if the id does not
exist.
"""
return self._fetch(model_cls, MatchQuery("id", id)).get()
return self._fetch_model(model_cls, MatchQuery("id", id)).get()
1 change: 0 additions & 1 deletion beets/dbcore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def __hash__(self) -> int:

SQLiteType = Union[str, bytes, float, int, memoryview, None]
AnySQLiteType = TypeVar("AnySQLiteType", bound=SQLiteType)
FieldQueryType = type["FieldQuery"]


class FieldQuery(Query, Generic[P]):
Expand Down
72 changes: 44 additions & 28 deletions beets/dbcore/queryparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,44 @@

import itertools
import re
from typing import TYPE_CHECKING
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any improvements in this module except for redundant types and adjustments in function input types which then force you to use cast and type: ignore.

from typing import TYPE_CHECKING, cast

from . import query
from beets.library import Album, Item, LibModel

from . import FieldQuery, Model, Query, query

if TYPE_CHECKING:
import sys
from collections.abc import Collection, Sequence

from ..library import LibModel
from .query import FieldQueryType, Sort
from .query import Sort

if not sys.version_info < (3, 10):
from typing import TypeAlias # pyright: ignore[reportUnreachable]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use mypy, please remove this and setup your editor to use mypy in your environment for this project.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And see #5879 #5869

else:
from typing_extensions import TypeAlias

Prefixes = dict[str, FieldQueryType]
Prefixes: TypeAlias = dict[str, type[FieldQuery]]


PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
r"(-|\^)?" # Negation prefixes.
r"(?:"
r"(\S+?)" # The field key.
r"(?<!\\):" # Unescaped :
r")?"
r"(.*)", # The term itself.
r"(-|\^)?" # Negation prefixes. # noqa: ISC003
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? Please undo

+ r"(?:"
+ r"(\S+?)" # The field key.
+ r"(?<!\\):" # Unescaped :
+ r")?"
+ r"(.*)", # The term itself.
re.I, # Case-insensitive.
)


def parse_query_part(
part: str,
query_classes: dict[str, FieldQueryType] = {},
query_classes: dict[str, type[Query]] = {},
prefixes: Prefixes = {},
default_class: type[query.SubstringQuery] = query.SubstringQuery,
) -> tuple[str | None, str, FieldQueryType, bool]:
) -> tuple[str | None, str, type[Query], bool]:
"""Parse a single *query part*, which is a chunk of a complete query
string representing a single criterion.

Expand Down Expand Up @@ -94,15 +101,17 @@ def parse_query_part(
"""
# Apply the regular expression and extract the components.
part = part.strip()
match = PARSE_QUERY_PART_REGEX.match(part)
match: re.Match[str] | None = PARSE_QUERY_PART_REGEX.match(part)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant


assert match # Regex should always match
negate = bool(match.group(1))
key = match.group(2)
term = match.group(3).replace("\\:", ":")
negate: bool = bool(match.group(1))
key: str = match.group(2)
term: str = match.group(3).replace("\\:", ":")

# Check whether there's a prefix in the query and use the
# corresponding query type.
pre: str
query_class: type[Query]
for pre, query_class in prefixes.items():
if term.startswith(pre):
return key, term[len(pre) :], query_class, negate
Expand Down Expand Up @@ -137,26 +146,30 @@ def construct_query_part(

# Use `model_cls` to build up a map from field (or query) names to
# `Query` classes.
query_classes: dict[str, FieldQueryType] = {}
query_classes: dict[str, type[Query]] = {}
for k, t in itertools.chain(
model_cls._fields.items(), model_cls._types.items()
):
query_classes[k] = t.query
query_classes.update(model_cls._queries) # Non-field queries.
query_classes.update(
model_cls._queries.items() # Non-field queries.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

)

# Parse the string.
key, pattern, query_class, negate = parse_query_part(
query_part, query_classes, prefixes
)
if key is not None:
# Field queries get constructed according to the name of the field
# they are querying.
out_query = model_cls.field_query(
key.lower(), pattern, cast(type[FieldQuery], query_class)
)

if key is None:
else:
# If there's no key (field name) specified, this is a "match anything"
# query.
out_query = model_cls.any_field_query(pattern, query_class)
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why switching the order of the conditional?

Why did you replace FieldQuery in the function input and then introduce cast here?

# Field queries get constructed according to the name of the field
# they are querying.
out_query = model_cls.field_query(key.lower(), pattern, query_class)

# Apply negation.
if negate:
Expand All @@ -176,7 +189,7 @@ def query_from_strings(
strings in the format used by parse_query_part. `model_cls`
determines how queries are constructed from strings.
"""
subqueries = []
subqueries: list[Query] = []
for part in query_parts:
subqueries.append(construct_query_part(model_cls, prefixes, part))
if not subqueries: # No terms in query.
Expand All @@ -185,7 +198,7 @@ def query_from_strings(


def construct_sort_part(
model_cls: type[LibModel],
model_cls: Model,
part: str,
case_insensitive: bool = True,
) -> Sort:
Expand All @@ -196,6 +209,9 @@ def construct_sort_part(
indicates whether or not the sort should be performed in a case
sensitive manner.
"""
assert isinstance(model_cls, type(Album)) or isinstance(
model_cls, type(Item)
)
assert part, "part must be a field name and + or -"
field = part[:-1]
assert field, "field is missing"
Expand Down Expand Up @@ -224,12 +240,12 @@ def sort_from_strings(
if not sort_parts:
return query.NullSort()
elif len(sort_parts) == 1:
return construct_sort_part(model_cls, sort_parts[0], case_insensitive)
return construct_sort_part(model_cls, sort_parts[0], case_insensitive) # type: ignore[arg-type]
else:
sort = query.MultipleSort()
for part in sort_parts:
sort.add_sort(
construct_sort_part(model_cls, part, case_insensitive)
construct_sort_part(model_cls, part, case_insensitive) # type: ignore[arg-type]
)
return sort

Expand Down
4 changes: 2 additions & 2 deletions beets/dbcore/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Type(ABC, Generic[T, N]):
"""The SQLite column type for the value.
"""

query: query.FieldQueryType = query.SubstringQuery
query: type[query.FieldQuery] = query.SubstringQuery
"""The `Query` subclass to be used when querying the field.
"""

Expand Down Expand Up @@ -242,7 +242,7 @@ class BaseFloat(Type[float, N]):
"""

sql = "REAL"
query: query.FieldQueryType = query.NumericQuery
query: type[query.FieldQuery] = query.NumericQuery
model_type = float

def __init__(self, digits: int = 1):
Expand Down
30 changes: 25 additions & 5 deletions beets/library/library.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING

import platformdirs
Expand All @@ -8,11 +9,17 @@
from beets import dbcore
from beets.util import normpath

from .models import Album, Item
from .models import Album, AnyLibModel, Item
from .queries import PF_KEY_DEFAULT, parse_query_parts, parse_query_string

if TYPE_CHECKING:
from beets.dbcore import Results
from beets.dbcore.query import Query, Sort

if not sys.version_info < (3, 12):
pass # pyright: ignore[reportUnreachable]
else:
pass


class Library(dbcore.Database):
Expand Down Expand Up @@ -79,7 +86,12 @@ def add_album(self, items):

# Querying.

def _fetch(self, model_cls, query, sort=None):
def _fetch(
self,
model_cls: type[AnyLibModel],
query: list[str] | Query | str | tuple[str] | None = None,
sort: Sort | None = None,
) -> Results[AnyLibModel]:
"""Parse a query and fetch.

If an order specification is present in the query string
Expand All @@ -100,7 +112,7 @@ def _fetch(self, model_cls, query, sort=None):
if parsed_sort and not isinstance(parsed_sort, dbcore.query.NullSort):
sort = parsed_sort

return super()._fetch(model_cls, query, sort)
return super()._fetch_model(model_cls, query, sort)

@staticmethod
def get_default_album_sort():
Expand All @@ -116,11 +128,19 @@ def get_default_item_sort():
Item, beets.config["sort_item"].as_str_seq()
)

def albums(self, query=None, sort=None) -> Results[Album]:
def albums(
self,
query: list[str] | Query | str | tuple[str] | None = None,
sort: Sort | None = None,
) -> Results[Album]:
"""Get :class:`Album` objects matching the query."""
return self._fetch(Album, query, sort or self.get_default_album_sort())

def items(self, query=None, sort=None) -> Results[Item]:
def items(
self,
query: list[str] | Query | str | tuple[str] | None = None,
sort: Sort | None = None,
) -> Results[Item]:
"""Get :class:`Item` objects matching the query."""
return self._fetch(Item, query, sort or self.get_default_item_sort())

Expand Down
Loading