diff --git a/beets/autotag/distance.py b/beets/autotag/distance.py index 37c6f84f4a..e1fd6d16d0 100644 --- a/beets/autotag/distance.py +++ b/beets/autotag/distance.py @@ -345,7 +345,7 @@ def add_string(self, key: str, str1: str | None, str2: str | None): dist = string_dist(str1, str2) self.add(key, dist) - def add_data_source(self, before: str | None, after: str | None) -> None: + def add_data_source(self, before: object, after: str | None) -> None: if before != after and ( before or len(metadata_plugins.find_metadata_source_plugins()) > 1 ): @@ -384,11 +384,19 @@ def track_distance( cached because this function is called many times during the matching process and their access comes with a performance overhead. """ - dist = Distance() + dist: Distance = Distance() # Length. + info_length: float | None if info_length := track_info.length: - diff = abs(item.length - info_length) - get_track_length_grace() + diff: float = ( + abs( + (item.length - info_length) + if isinstance(item.length, (int, float)) + else 0 + ) + - get_track_length_grace() + ) dist.add_ratio("track_length", diff, get_track_length_max()) # Title. diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index afae6e9067..556423f97d 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -37,15 +37,7 @@ from ..util import cached_classproperty, functemplate from . import types -from .query import ( - FieldQueryType, - FieldSort, - MatchQuery, - NullSort, - Query, - Sort, - TrueQuery, -) +from .query import FieldSort, MatchQuery, NullSort, Query, Sort, TrueQuery if TYPE_CHECKING: from types import TracebackType @@ -310,7 +302,7 @@ def _types(cls) -> dict[str, types.Type]: """ @cached_classproperty - def _queries(cls) -> dict[str, FieldQueryType]: + def _queries(cls) -> dict[str, type[Query]]: """Named queries that use a field-like `name:value` syntax but which do not relate to any specific field. """ @@ -328,9 +320,9 @@ def _queries(cls) -> dict[str, FieldQueryType]: """ @cached_classproperty - def _relation(cls): + def _relation(cls) -> type[Model[D]]: """The model that this model is closely related to.""" - return cls + return cls # type: ignore[return-value] @cached_classproperty def relation_join(cls) -> str: @@ -373,7 +365,7 @@ def __init__(self, db: D | None = None, **values): """Create a new object with an optional Database association and initial field values. """ - self._db = db + self._db: D | None = db self._dirty: set[str] = set() self._values_fixed = LazyConvertDict(self) self._values_flex = LazyConvertDict(self) @@ -744,7 +736,7 @@ def __getstate__(self): AnyModel = TypeVar("AnyModel", bound=Model) -class Results(Generic[AnyModel]): +class Results(Generic[AnyModel, D]): """An item query result set. Iterating over the collection lazily constructs Model objects that reflect database rows. """ @@ -1238,7 +1230,7 @@ def _make_attribute_table(self, flex_table: str): # Querying. - def _fetch( + def _fetch_model( self, model_cls: type[AnyModel], query: Query | None = None, @@ -1304,4 +1296,4 @@ def _get( """Get a Model object by its id or None if the id does not exist. """ - return self._fetch(model_cls, MatchQuery("id", id)).get() + return self._fetch_model(model_cls, MatchQuery("id", id)).get() diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index dfeb427078..8ea661e26f 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -124,7 +124,6 @@ def __hash__(self) -> int: SQLiteType = Union[str, bytes, float, int, memoryview, None] AnySQLiteType = TypeVar("AnySQLiteType", bound=SQLiteType) -FieldQueryType = type["FieldQuery"] class FieldQuery(Query, Generic[P]): diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index f84ed74365..bd8a6def6d 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -18,37 +18,44 @@ import itertools import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast -from . import query +from beets.library import Album, Item, LibModel + +from . import FieldQuery, Model, Query, query if TYPE_CHECKING: + import sys from collections.abc import Collection, Sequence - from ..library import LibModel - from .query import FieldQueryType, Sort + from .query import Sort + + if not sys.version_info < (3, 10): + from typing import TypeAlias # pyright: ignore[reportUnreachable] + else: + from typing_extensions import TypeAlias - Prefixes = dict[str, FieldQueryType] + Prefixes: TypeAlias = dict[str, type[FieldQuery]] PARSE_QUERY_PART_REGEX = re.compile( # Non-capturing optional segment for the keyword. - r"(-|\^)?" # Negation prefixes. - r"(?:" - r"(\S+?)" # The field key. - r"(? tuple[str | None, str, FieldQueryType, bool]: +) -> tuple[str | None, str, type[Query], bool]: """Parse a single *query part*, which is a chunk of a complete query string representing a single criterion. @@ -94,15 +101,17 @@ def parse_query_part( """ # Apply the regular expression and extract the components. part = part.strip() - match = PARSE_QUERY_PART_REGEX.match(part) + match: re.Match[str] | None = PARSE_QUERY_PART_REGEX.match(part) assert match # Regex should always match - negate = bool(match.group(1)) - key = match.group(2) - term = match.group(3).replace("\\:", ":") + negate: bool = bool(match.group(1)) + key: str = match.group(2) + term: str = match.group(3).replace("\\:", ":") # Check whether there's a prefix in the query and use the # corresponding query type. + pre: str + query_class: type[Query] for pre, query_class in prefixes.items(): if term.startswith(pre): return key, term[len(pre) :], query_class, negate @@ -137,26 +146,30 @@ def construct_query_part( # Use `model_cls` to build up a map from field (or query) names to # `Query` classes. - query_classes: dict[str, FieldQueryType] = {} + query_classes: dict[str, type[Query]] = {} for k, t in itertools.chain( model_cls._fields.items(), model_cls._types.items() ): query_classes[k] = t.query - query_classes.update(model_cls._queries) # Non-field queries. + query_classes.update( + model_cls._queries.items() # Non-field queries. + ) # Parse the string. key, pattern, query_class, negate = parse_query_part( query_part, query_classes, prefixes ) + if key is not None: + # Field queries get constructed according to the name of the field + # they are querying. + out_query = model_cls.field_query( + key.lower(), pattern, cast(type[FieldQuery], query_class) + ) - if key is None: + else: # If there's no key (field name) specified, this is a "match anything" # query. out_query = model_cls.any_field_query(pattern, query_class) - else: - # Field queries get constructed according to the name of the field - # they are querying. - out_query = model_cls.field_query(key.lower(), pattern, query_class) # Apply negation. if negate: @@ -176,7 +189,7 @@ def query_from_strings( strings in the format used by parse_query_part. `model_cls` determines how queries are constructed from strings. """ - subqueries = [] + subqueries: list[Query] = [] for part in query_parts: subqueries.append(construct_query_part(model_cls, prefixes, part)) if not subqueries: # No terms in query. @@ -185,7 +198,7 @@ def query_from_strings( def construct_sort_part( - model_cls: type[LibModel], + model_cls: Model, part: str, case_insensitive: bool = True, ) -> Sort: @@ -196,6 +209,9 @@ def construct_sort_part( indicates whether or not the sort should be performed in a case sensitive manner. """ + assert isinstance(model_cls, type(Album)) or isinstance( + model_cls, type(Item) + ) assert part, "part must be a field name and + or -" field = part[:-1] assert field, "field is missing" @@ -224,12 +240,12 @@ def sort_from_strings( if not sort_parts: return query.NullSort() elif len(sort_parts) == 1: - return construct_sort_part(model_cls, sort_parts[0], case_insensitive) + return construct_sort_part(model_cls, sort_parts[0], case_insensitive) # type: ignore[arg-type] else: sort = query.MultipleSort() for part in sort_parts: sort.add_sort( - construct_sort_part(model_cls, part, case_insensitive) + construct_sort_part(model_cls, part, case_insensitive) # type: ignore[arg-type] ) return sort diff --git a/beets/dbcore/types.py b/beets/dbcore/types.py index 3b4badd33c..849b22fc6c 100644 --- a/beets/dbcore/types.py +++ b/beets/dbcore/types.py @@ -62,7 +62,7 @@ class Type(ABC, Generic[T, N]): """The SQLite column type for the value. """ - query: query.FieldQueryType = query.SubstringQuery + query: type[query.FieldQuery] = query.SubstringQuery """The `Query` subclass to be used when querying the field. """ @@ -242,7 +242,7 @@ class BaseFloat(Type[float, N]): """ sql = "REAL" - query: query.FieldQueryType = query.NumericQuery + query: type[query.FieldQuery] = query.NumericQuery model_type = float def __init__(self, digits: int = 1): diff --git a/beets/library/library.py b/beets/library/library.py index 7370f7ecd4..d6e7a21f51 100644 --- a/beets/library/library.py +++ b/beets/library/library.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from typing import TYPE_CHECKING import platformdirs @@ -8,11 +9,17 @@ from beets import dbcore from beets.util import normpath -from .models import Album, Item +from .models import Album, AnyLibModel, Item from .queries import PF_KEY_DEFAULT, parse_query_parts, parse_query_string if TYPE_CHECKING: from beets.dbcore import Results + from beets.dbcore.query import Query, Sort + +if not sys.version_info < (3, 12): + pass # pyright: ignore[reportUnreachable] +else: + pass class Library(dbcore.Database): @@ -79,7 +86,12 @@ def add_album(self, items): # Querying. - def _fetch(self, model_cls, query, sort=None): + def _fetch( + self, + model_cls: type[AnyLibModel], + query: list[str] | Query | str | tuple[str] | None = None, + sort: Sort | None = None, + ) -> Results[AnyLibModel]: """Parse a query and fetch. If an order specification is present in the query string @@ -100,7 +112,7 @@ def _fetch(self, model_cls, query, sort=None): if parsed_sort and not isinstance(parsed_sort, dbcore.query.NullSort): sort = parsed_sort - return super()._fetch(model_cls, query, sort) + return super()._fetch_model(model_cls, query, sort) @staticmethod def get_default_album_sort(): @@ -116,11 +128,19 @@ def get_default_item_sort(): Item, beets.config["sort_item"].as_str_seq() ) - def albums(self, query=None, sort=None) -> Results[Album]: + def albums( + self, + query: list[str] | Query | str | tuple[str] | None = None, + sort: Sort | None = None, + ) -> Results[Album]: """Get :class:`Album` objects matching the query.""" return self._fetch(Album, query, sort or self.get_default_album_sort()) - def items(self, query=None, sort=None) -> Results[Item]: + def items( + self, + query: list[str] | Query | str | tuple[str] | None = None, + sort: Sort | None = None, + ) -> Results[Item]: """Get :class:`Item` objects matching the query.""" return self._fetch(Item, query, sort or self.get_default_item_sort()) diff --git a/beets/library/models.py b/beets/library/models.py index cbee2a411e..36d6789f34 100644 --- a/beets/library/models.py +++ b/beets/library/models.py @@ -7,13 +7,14 @@ import unicodedata from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar, cast from mediafile import MediaFile, UnreadableFileError +from typing_extensions import override import beets from beets import dbcore, logging, plugins, util -from beets.dbcore import types +from beets.dbcore import Model, Query, types from beets.util import ( MoveOperation, bytestring_path, @@ -24,14 +25,16 @@ ) from beets.util.functemplate import Template, template +from ..dbcore.query import FieldQuery from .exceptions import FileOperationError, ReadError, WriteError from .queries import PF_KEY_DEFAULT, parse_query_string if TYPE_CHECKING: - from ..dbcore.query import FieldQuery, FieldQueryType + from collections.abc import Iterable + from .library import Library # noqa: F401 -log = logging.getLogger("beets") +log: logging.BeetsLogger = logging.getLogger("beets") class LibModel(dbcore.Model["Library"]): @@ -45,13 +48,21 @@ class LibModel(dbcore.Model["Library"]): def _types(cls) -> dict[str, types.Type]: """Return the types of the fields in this model.""" return { - **plugins.types(cls), # type: ignore[arg-type] + **plugins.types(cls), # type: ignore[type-var] "data_source": types.STRING, } @cached_classproperty - def _queries(cls) -> dict[str, FieldQueryType]: - return plugins.named_queries(cls) # type: ignore[arg-type] + def _queries(cls) -> dict[str, type[Query]]: + return plugins.named_queries(cls) # type: ignore[type-var] + + @cached_classproperty + def _relation(cls) -> type[Model]: + return cls # type: ignore[return-value] + + @cached_classproperty + def all_db_fields(cls) -> set[str]: + return cls._fields.keys() | cls._relation._fields.keys() @cached_classproperty def writable_media_fields(cls) -> set[str]: @@ -75,10 +86,10 @@ def remove(self): super().remove() plugins.send("database_change", lib=self._db, model=self) - def add(self, lib=None): + def add(self, db=None): # super().add() calls self.store(), which sends `database_change`, # so don't do it here - super().add(lib) + super().add(db) def __format__(self, spec): if not spec: @@ -96,7 +107,7 @@ def __bytes__(self): @classmethod def field_query( - cls, field: str, pattern: str, query_cls: FieldQueryType + cls, field: str, pattern: str, query_cls: type[FieldQuery] ) -> FieldQuery: """Get a `FieldQuery` for the given field on this model.""" fast = field in cls.all_db_fields @@ -109,9 +120,18 @@ def field_query( return query_cls(field, pattern, fast) @classmethod - def any_field_query(cls, *args, **kwargs) -> dbcore.OrQuery: + def any_field_query( + cls, pattern: str, query_cls: type[Query] + ) -> dbcore.OrQuery: return dbcore.OrQuery( - [cls.field_query(f, *args, **kwargs) for f in cls._search_fields] + [ + cls.field_query( + f, + pattern=pattern, + query_cls=cast(type[FieldQuery], query_cls), + ) + for f in cls._search_fields + ] ) @classmethod @@ -131,6 +151,9 @@ def duplicates_query(self, fields: list[str]) -> dbcore.AndQuery: ) +AnyLibModel = TypeVar("AnyLibModel", bound=LibModel) + + class FormattedItemMapping(dbcore.db.FormattedMapping): """Add lookup for album-level fields. @@ -224,7 +247,7 @@ class Album(LibModel): Reflects the library's "albums" table, including album art. """ - artpath: bytes + artpath: bytes | None _table = "albums" _flex_table = "album_attributes" @@ -334,7 +357,7 @@ def _types(cls) -> dict[str, types.Type]: _format_config_key = "format_album" @cached_classproperty - def _relation(cls) -> type[Item]: + def _relation(cls) -> type[Model]: return Item @cached_classproperty @@ -400,14 +423,14 @@ def remove(self, delete=False, with_items=True): for item in self.items(): item.remove(delete, False) - def move_art(self, operation=MoveOperation.MOVE): + def move_art(self, operation: MoveOperation = MoveOperation.MOVE) -> None: """Move, copy, link or hardlink (depending on `operation`) any existing album art so that it remains in the same directory as the items. `operation` should be an instance of `util.MoveOperation`. """ - old_art = self.artpath + old_art: bytes | None = self.artpath if not old_art: return @@ -431,7 +454,8 @@ def move_art(self, operation=MoveOperation.MOVE): ) if operation == MoveOperation.MOVE: util.move(old_art, new_art) - util.prune_dirs(os.path.dirname(old_art), self._db.directory) + if self._db: + util.prune_dirs(os.path.dirname(old_art), self._db.directory) elif operation == MoveOperation.COPY: util.copy(old_art, new_art) elif operation == MoveOperation.LINK: @@ -446,7 +470,12 @@ def move_art(self, operation=MoveOperation.MOVE): assert False, "unknown MoveOperation" self.artpath = new_art - def move(self, operation=MoveOperation.MOVE, basedir=None, store=True): + def move( + self, + operation: MoveOperation = MoveOperation.MOVE, + basedir: bytes | None = None, + store: bool = True, + ) -> None: """Move, copy, link or hardlink (depending on `operation`) all items to their destination. Any album art moves along with them. @@ -459,7 +488,7 @@ def move(self, operation=MoveOperation.MOVE, basedir=None, store=True): the album is not stored automatically, and it will have to be manually stored after invoking this method. """ - basedir = basedir or self._db.directory + basedir = basedir or self._db.directory if self._db else None # Ensure new metadata is available to items for destination # computation. @@ -467,7 +496,8 @@ def move(self, operation=MoveOperation.MOVE, basedir=None, store=True): self.store() # Move items. - items = list(self.items()) + items: list[Item] = list(self.items()) + item: Item for item in items: item.move(operation, basedir=basedir, with_album=False, store=store) @@ -567,7 +597,10 @@ def set_art(self, path, copy=True): plugins.send("art_set", album=self) - def store(self, fields=None, inherit=True): + @override + def store( + self, fields: Iterable[str] | None = None, inherit: bool = True + ) -> None: """Update the database with the album information. `fields` represents the fields to be stored. If not specified, @@ -589,6 +622,8 @@ def store(self, fields=None, inherit=True): elif key != "id": # is a flexible attribute track_updates[key] = self[key] + if not self._db: + return with self._db.transaction(): super().store(fields) if track_updates: @@ -745,7 +780,7 @@ class Item(LibModel): _sorts = {"artist": dbcore.query.SmartArtistSort} @cached_classproperty - def _queries(cls) -> dict[str, FieldQueryType]: + def _queries(cls) -> dict[str, type[Query]]: return {**super()._queries, "singleton": dbcore.query.SingletonQuery} _format_config_key = "format_item" @@ -754,7 +789,7 @@ def _queries(cls) -> dict[str, FieldQueryType]: __album: Album | None = None @cached_classproperty - def _relation(cls) -> type[Album]: + def _relation(cls) -> type[Model]: return Album @cached_classproperty @@ -1117,11 +1152,11 @@ def remove(self, delete=False, with_album=True): def move( self, - operation=MoveOperation.MOVE, - basedir=None, - with_album=True, - store=True, - ): + operation: MoveOperation = MoveOperation.MOVE, + basedir: bytes | None = None, + with_album: bool = True, + store: bool = True, + ) -> None: """Move the item to its designated location within the library directory (provided by destination()). @@ -1164,7 +1199,7 @@ def move( album.store() # Prune vacated directory. - if operation == MoveOperation.MOVE: + if operation == MoveOperation.MOVE and self._db: util.prune_dirs(os.path.dirname(old_path), self._db.directory) # Templating. diff --git a/beets/library/queries.py b/beets/library/queries.py index 7c9d688cd5..dc17d786ee 100644 --- a/beets/library/queries.py +++ b/beets/library/queries.py @@ -1,20 +1,30 @@ from __future__ import annotations import shlex +from typing import TYPE_CHECKING import beets from beets import dbcore, logging, plugins -log = logging.getLogger("beets") +if TYPE_CHECKING: + from typing_extensions import LiteralString + + from beets.dbcore.query import Sort + from beets.dbcore.queryparse import Prefixes + from beets.library.models import LibModel + +log: logging.BeetsLogger = logging.getLogger("beets") # Special path format key. -PF_KEY_DEFAULT = "default" +PF_KEY_DEFAULT: LiteralString = "default" # Query construction helpers. -def parse_query_parts(parts, model_cls): +def parse_query_parts( + parts: list[str] | tuple[str, ...], model_cls: type[LibModel] +) -> tuple[dbcore.Query, Sort]: """Given a beets query string as a list of components, return the `Query` and `Sort` they represent. @@ -22,7 +32,7 @@ def parse_query_parts(parts, model_cls): ensuring that implicit path queries are made explicit with 'path::' """ # Get query types and their prefix characters. - prefixes = { + prefixes: Prefixes = { ":": dbcore.query.RegexpQuery, "=~": dbcore.query.StringQuery, "=": dbcore.query.MatchQuery, @@ -36,8 +46,10 @@ def parse_query_parts(parts, model_cls): for s in parts ] - case_insensitive = beets.config["sort_case_insensitive"].get(bool) + case_insensitive: bool = beets.config["sort_case_insensitive"].get(bool) + query: dbcore.Query + sort: Sort query, sort = dbcore.parse_sorted_query( model_cls, parts, prefixes, case_insensitive ) @@ -46,7 +58,9 @@ def parse_query_parts(parts, model_cls): return query, sort -def parse_query_string(s, model_cls): +def parse_query_string( + s: str, model_cls: type[LibModel] +) -> tuple[dbcore.Query, Sort]: """Given a beets query string, return the `Query` and `Sort` they represent. diff --git a/beets/logging.py b/beets/logging.py index 3ed5e5a843..4de416f44d 100644 --- a/beets/logging.py +++ b/beets/logging.py @@ -37,7 +37,7 @@ RootLogger, StreamHandler, ) -from typing import TYPE_CHECKING, Any, Mapping, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, TypeVar, Union, overload __all__ = [ "DEBUG", @@ -54,9 +54,11 @@ ] if TYPE_CHECKING: - T = TypeVar("T") + from collections.abc import Mapping from types import TracebackType + T = TypeVar("T") + # see https://github.com/python/typeshed/blob/main/stdlib/logging/__init__.pyi _SysExcInfoType = Union[ tuple[type[BaseException], BaseException, Union[TracebackType, None]], @@ -144,13 +146,13 @@ def _log( class ThreadLocalLevelLogger(Logger): """A version of `Logger` whose level is thread-local instead of shared.""" - def __init__(self, name, level=NOTSET): - self._thread_level = threading.local() - self.default_level = NOTSET + def __init__(self, name: str, level: int = NOTSET) -> None: + self._thread_level: threading.local = threading.local() + self.default_level: int = NOTSET super().__init__(name, level) @property - def level(self): + def level(self) -> int: try: return self._thread_level.level except AttributeError: @@ -158,10 +160,10 @@ def level(self): return self.level @level.setter - def level(self, value): + def level(self, value: int) -> None: self._thread_level.level = value - def set_global_level(self, level): + def set_global_level(self, level: int) -> None: """Set the level on the current thread + the default value for all threads. """ diff --git a/beets/plugins.py b/beets/plugins.py index e10dcf80ca..7a262de92a 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -40,9 +40,8 @@ from confuse import ConfigView - from beets.dbcore import Query - from beets.dbcore.db import FieldQueryType - from beets.dbcore.types import Type + from beets.dbcore import Query, Type + from beets.dbcore.queryparse import Prefixes from beets.importer import ImportSession, ImportTask from beets.library import Album, Item, Library from beets.ui import Subcommand @@ -332,7 +331,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Ret: return wrapper - def queries(self) -> dict[str, type[Query]]: + def queries(self) -> Prefixes: """Return a dict mapping prefixes to Query subclasses.""" return {} @@ -503,22 +502,22 @@ def commands() -> list[Subcommand]: return out -def queries() -> dict[str, type[Query]]: +def queries() -> Prefixes: """Returns a dict mapping prefix strings to Query subclasses all loaded plugins. """ - out: dict[str, type[Query]] = {} + out: Prefixes = {} for plugin in find_plugins(): out.update(plugin.queries()) return out -def types(model_cls: type[AnyModel]) -> dict[str, Type]: +def types(model_cls: AnyModel) -> dict[str, Type]: """Return mapping between flex field names and types for the given model.""" - attr_name = f"{model_cls.__name__.lower()}_types" + attr_name: str = f"{model_cls.__name__.lower()}_types" types: dict[str, Type] = {} for plugin in find_plugins(): - plugin_types = getattr(plugin, attr_name, {}) + plugin_types: dict[str, Type] = getattr(plugin, attr_name, {}) for field in plugin_types: if field in types and plugin_types[field] != types[field]: raise PluginConflictError( @@ -530,9 +529,10 @@ def types(model_cls: type[AnyModel]) -> dict[str, Type]: return types -def named_queries(model_cls: type[AnyModel]) -> dict[str, FieldQueryType]: +def named_queries(model_cls: AnyModel) -> dict[str, type[Query]]: """Return mapping between field names and queries for the given model.""" - attr_name = f"{model_cls.__name__.lower()}_queries" + attr_name: str = f"{model_cls.__name__.lower()}_queries" + return { field: query for plugin in find_plugins() diff --git a/beets/test/helper.py b/beets/test/helper.py index ea08ec840b..6a177b95d3 100644 --- a/beets/test/helper.py +++ b/beets/test/helper.py @@ -274,8 +274,7 @@ def create_item(self, **values): } values_.update(values) values_["title"] = values_["title"].format(1) - values_["db"] = self.lib - item = Item(**values_) + item = Item(self.lib, **values_) if "path" not in values: item["path"] = f"audio.{item['format'].lower()}" # mtime needs to be set last since other assignments reset it. diff --git a/beets/ui/__init__.py b/beets/ui/__init__.py index 60e2014485..a079624aee 100644 --- a/beets/ui/__init__.py +++ b/beets/ui/__init__.py @@ -147,7 +147,7 @@ def print_(*strings: str, end: str = "\n") -> None: # Configuration wrappers. -def _bool_fallback(a, b): +def _bool_fallback(a: bool | None, b: bool | None) -> bool: """Given a boolean or None, return the original value or a fallback.""" if a is None: assert isinstance(b, bool) @@ -157,14 +157,14 @@ def _bool_fallback(a, b): return a -def should_write(write_opt=None): +def should_write(write_opt: bool | None = None) -> bool: """Decide whether a command that updates metadata should also write tags, using the importer configuration as the default. """ return _bool_fallback(write_opt, config["import"]["write"].get(bool)) -def should_move(move_opt=None): +def should_move(move_opt: bool | None = None) -> bool: """Decide whether a command that updates metadata should also move files when they're inside the library, using the importer configuration as the default. @@ -1045,7 +1045,13 @@ def print_newline_layout( FLOAT_EPSILON = 0.01 -def _field_diff(field, old, old_fmt, new, new_fmt): +def _field_diff( + field: str, + old: library.LibModel, + old_fmt: db.FormattedMapping | None, + new: library.LibModel, + new_fmt: db.FormattedMapping, +) -> str | None: """Given two Model objects and their formatted views, format their values for `field` and highlight changes among them. Return a human-readable string. If the value has not changed, return None instead. @@ -1064,8 +1070,8 @@ def _field_diff(field, old, old_fmt, new, new_fmt): return None # Get formatted values for output. - oldstr = old_fmt.get(field, "") - newstr = new_fmt.get(field, "") + oldstr: str = old_fmt.get(field, "") if old_fmt else "" + newstr: str = new_fmt.get(field, "") # For strings, highlight changes. For others, colorize the whole # thing. @@ -1079,8 +1085,12 @@ def _field_diff(field, old, old_fmt, new, new_fmt): def show_model_changes( - new, old=None, fields=None, always=False, print_obj: bool = True -): + new: library.LibModel, + old: library.LibModel | None = None, + fields: list[str] | None = None, + always: bool = False, + print_obj: bool = True, +) -> bool: """Given a Model object, print a list of changes from its pristine version stored in the database. Return a boolean indicating whether any changes were found. @@ -1090,27 +1100,30 @@ def show_model_changes( restrict the detection to. `always` indicates whether the object is always identified, regardless of whether any changes are present. """ - old = old or new._db._get(type(new), new.id) + if not old and new._db: + old = new._db._get(type(new), new.id) # Keep the formatted views around instead of re-creating them in each # iteration step - old_fmt = old.formatted() - new_fmt = new.formatted() - - # Build up lines showing changed fields. - changes = [] - for field in old: - # Subset of the fields. Never show mtime. - if field == "mtime" or (fields and field not in fields): - continue - - # Detect and show difference for this field. - line = _field_diff(field, old, old_fmt, new, new_fmt) - if line: - changes.append(f" {field}: {line}") + old_fmt: db.FormattedMapping | None = old.formatted() if old else None + new_fmt: db.FormattedMapping = new.formatted() + + # Build up lines showing changed fields + field: str + changes: list[str] = [] + if old: + for field in old: + # Subset of the fields. Never show mtime. + if field == "mtime" or (fields and field not in fields): + continue + + # Detect and show difference for this field. + line: str | None = _field_diff(field, old, old_fmt, new, new_fmt) + if line: + changes.append(f" {field}: {line}") # New fields. - for field in set(new) - set(old): + for field in set(new) - set(old or ()): if fields and field not in fields: continue @@ -1213,21 +1226,23 @@ class CommonOptionsParser(optparse.OptionParser): Each method is fully documented in the related method. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self._album_flags = False + self._album_flags: set[str] | None = None # this serves both as an indicator that we offer the feature AND allows # us to check whether it has been specified on the CLI - bypassing the # fact that arguments may be in any order - def add_album_option(self, flags=("-a", "--album")): + def add_album_option( + self, flags: tuple[str, str] = ("-a", "--album") + ) -> None: """Add a -a/--album option to match albums instead of tracks. If used then the format option can auto-detect whether we're setting the format for items or albums. Sets the album property on the options extracted from the CLI. """ - album = optparse.Option( + album: optparse.Option = optparse.Option( *flags, action="store_true", help="match albums instead of tracks" ) self.add_option(album) @@ -1235,42 +1250,47 @@ def add_album_option(self, flags=("-a", "--album")): def _set_format( self, - option, - opt_str, - value, - parser, - target=None, - fmt=None, - store_true=False, - ): + option: optparse.Option, + opt_str: str, + value: str, + parser: CommonOptionsParser, + target: type[library.Album | library.Item] | None = None, + fmt: str | None = None, + store_true: bool = False, + ) -> None: """Internal callback that sets the correct format while parsing CLI arguments. """ - if store_true: + if store_true and option.dest: setattr(parser.values, option.dest, True) # Use the explicitly specified format, or the string from the option. value = fmt or value or "" + if parser.values is None: + parser.values = optparse.Values() parser.values.format = value if target: config[target._format_config_key].set(value) - else: - if self._album_flags: - if parser.values.album: + return + + if self._album_flags: + if parser.values.album: + target = library.Album + else: + # the option is either missing either not parsed yet + if self._album_flags & set(parser.rargs or ()): target = library.Album else: - # the option is either missing either not parsed yet - if self._album_flags & set(parser.rargs): - target = library.Album - else: - target = library.Item - config[target._format_config_key].set(value) - else: - config[library.Item._format_config_key].set(value) - config[library.Album._format_config_key].set(value) + target = library.Item + config[target._format_config_key].set(value) + else: + config[library.Item._format_config_key].set(value) + config[library.Album._format_config_key].set(value) - def add_path_option(self, flags=("-p", "--path")): + def add_path_option( + self, flags: tuple[str, str] = ("-p", "--path") + ) -> None: """Add a -p/--path option to display the path instead of the default format. @@ -1290,7 +1310,11 @@ def add_path_option(self, flags=("-p", "--path")): ) self.add_option(path) - def add_format_option(self, flags=("-f", "--format"), target=None): + def add_format_option( + self, + flags: tuple[str, ...] = ("-f", "--format"), + target: str | type[library.LibModel] | None = None, + ) -> None: """Add -f/--format option to print some LibModel instances with a custom format. @@ -1305,7 +1329,7 @@ def add_format_option(self, flags=("-f", "--format"), target=None): Sets the format property on the options extracted from the CLI. """ - kwargs = {} + kwargs: dict[str, type[library.LibModel]] = {} if target: if isinstance(target, str): target = {"item": library.Item, "album": library.Album}[target] @@ -1319,8 +1343,9 @@ def add_format_option(self, flags=("-f", "--format"), target=None): help="print with custom format", ) self.add_option(opt) + return None - def add_all_common_options(self): + def add_all_common_options(self) -> None: """Add album, path and format options.""" self.add_album_option() self.add_path_option() diff --git a/beetsplug/mbsync.py b/beetsplug/mbsync.py index 3f7daec6c9..0971e99301 100644 --- a/beetsplug/mbsync.py +++ b/beetsplug/mbsync.py @@ -14,18 +14,35 @@ """Synchronise library metadata with metadata source backends.""" +from __future__ import annotations + +import sys from collections import defaultdict +from typing import TYPE_CHECKING + +from typing_extensions import override from beets import autotag, library, metadata_plugins, ui, util from beets.plugins import BeetsPlugin, apply_item_changes +if TYPE_CHECKING: + from optparse import Values + +if not sys.version_info < (3, 12): + from typing import override # pyright: ignore[reportUnreachable] +else: + from typing_extensions import override + class MBSyncPlugin(BeetsPlugin): - def __init__(self): + def __init__(self) -> None: super().__init__() - def commands(self): - cmd = ui.Subcommand("mbsync", help="update metadata from musicbrainz") + @override + def commands(self) -> list[ui.Subcommand]: + cmd: ui.Subcommand = ui.Subcommand( + "mbsync", help="update metadata from musicbrainz" + ) cmd.parser.add_option( "-p", "--pretend", @@ -58,19 +75,27 @@ def commands(self): cmd.func = self.func return [cmd] - def func(self, lib, opts, args): + def func(self, lib: library.Library, opts: Values, args: list[str]) -> None: """Command handler for the mbsync function.""" - move = ui.should_move(opts.move) - pretend = opts.pretend - write = ui.should_write(opts.write) + move: bool = ui.should_move(opts.move) + pretend: bool = opts.pretend + write: bool = ui.should_write(opts.write) self.singletons(lib, args, move, pretend, write) self.albums(lib, args, move, pretend, write) - def singletons(self, lib, query, move, pretend, write): + def singletons( + self, + lib: library.Library, + query: list[str], + move: bool, + pretend: bool, + write: bool, + ) -> None: """Retrieve and apply info from the autotagger for items matched by query. """ + item: library.Item for item in lib.items(query + ["singleton:true"]): if not item.mb_trackid: self._log.info( @@ -78,6 +103,7 @@ def singletons(self, lib, query, move, pretend, write): ) continue + track_info: autotag.TrackInfo | None if not ( track_info := metadata_plugins.track_for_id(item.mb_trackid) ): @@ -91,16 +117,25 @@ def singletons(self, lib, query, move, pretend, write): autotag.apply_item_metadata(item, track_info) apply_item_changes(lib, item, move, pretend, write) - def albums(self, lib, query, move, pretend, write): + def albums( + self, + lib: library.Library, + query: list[str], + move: bool, + pretend: bool, + write: bool, + ): """Retrieve and apply info from the autotagger for albums matched by query and their items. """ # Process matching albums. + album: library.Album for album in lib.albums(query): if not album.mb_albumid: self._log.info("Skipping album with no mb_albumid: {}", album) continue + album_info: autotag.AlbumInfo | None if not ( album_info := metadata_plugins.album_for_id(album.mb_albumid) ): @@ -112,17 +147,22 @@ def albums(self, lib, query, move, pretend, write): # Map release track and recording MBIDs to their information. # Recordings can appear multiple times on a release, so each MBID # maps to a list of TrackInfo objects. - releasetrack_index = {} - track_index = defaultdict(list) + releasetrack_index: dict[str, autotag.TrackInfo] = {} + track_index: defaultdict[str, list[autotag.TrackInfo]] = ( + defaultdict(list) + ) for track_info in album_info.tracks: - releasetrack_index[track_info.release_track_id] = track_info - track_index[track_info.track_id].append(track_info) + releasetrack_index[track_info.release_track_id or ""] = ( + track_info + ) + + track_index[track_info.track_id or ""].append(track_info) # Construct a track mapping according to MBIDs (release track MBIDs # first, if available, and recording MBIDs otherwise). This should # work for albums that have missing or extra tracks. - mapping = {} - items = list(album.items()) + mapping: dict[library.Item, autotag.TrackInfo] = {} + items: list[library.Item] = list(album.items()) for item in items: if ( item.mb_releasetrackid @@ -130,7 +170,9 @@ def albums(self, lib, query, move, pretend, write): ): mapping[item] = releasetrack_index[item.mb_releasetrackid] else: - candidates = track_index[item.mb_trackid] + candidates: list[autotag.TrackInfo] = track_index[ + item.mb_trackid + ] if len(candidates) == 1: mapping[item] = candidates[0] else: @@ -148,11 +190,11 @@ def albums(self, lib, query, move, pretend, write): self._log.debug("applying changes to {}", album) with lib.transaction(): autotag.apply_metadata(album_info, mapping) - changed = False + changed: bool = False # Find any changed item to apply changes to album. - any_changed_item = items[0] + any_changed_item: library.Item = items[0] for item in items: - item_changed = ui.show_model_changes(item) + item_changed: bool = ui.show_model_changes(item) changed |= item_changed if item_changed: any_changed_item = item @@ -164,6 +206,7 @@ def albums(self, lib, query, move, pretend, write): if not pretend: # Update album structure to reflect an item in it. + key: str for key in library.Album.item_keys: album[key] = any_changed_item[key] album.store() diff --git a/extra/release.py b/extra/release.py old mode 100755 new mode 100644 diff --git a/test/test_dbcore.py b/test/test_dbcore.py index 653adf2987..b4e9f1bdc6 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -728,16 +728,16 @@ def tearDown(self): self.db._connection().close() def test_iterate_once(self): - objs = self.db._fetch(ModelFixture1) + objs = self.db._fetch_model(ModelFixture1) assert len(list(objs)) == 2 def test_iterate_twice(self): - objs = self.db._fetch(ModelFixture1) + objs = self.db._fetch_model(ModelFixture1) list(objs) assert len(list(objs)) == 2 def test_concurrent_iterators(self): - results = self.db._fetch(ModelFixture1) + results = self.db._fetch_model(ModelFixture1) it1 = iter(results) it2 = iter(results) next(it1) @@ -746,44 +746,44 @@ def test_concurrent_iterators(self): def test_slow_query(self): q = dbcore.query.SubstringQuery("foo", "ba", False) - objs = self.db._fetch(ModelFixture1, q) + objs = self.db._fetch_model(ModelFixture1, q) assert len(list(objs)) == 2 def test_slow_query_negative(self): q = dbcore.query.SubstringQuery("foo", "qux", False) - objs = self.db._fetch(ModelFixture1, q) + objs = self.db._fetch_model(ModelFixture1, q) assert len(list(objs)) == 0 def test_iterate_slow_sort(self): s = dbcore.query.SlowFieldSort("foo") - res = self.db._fetch(ModelFixture1, sort=s) + res = self.db._fetch_model(ModelFixture1, sort=s) objs = list(res) assert objs[0].foo == "bar" assert objs[1].foo == "baz" def test_unsorted_subscript(self): - objs = self.db._fetch(ModelFixture1) + objs = self.db._fetch_model(ModelFixture1) assert objs[0].foo == "baz" assert objs[1].foo == "bar" def test_slow_sort_subscript(self): s = dbcore.query.SlowFieldSort("foo") - objs = self.db._fetch(ModelFixture1, sort=s) + objs = self.db._fetch_model(ModelFixture1, sort=s) assert objs[0].foo == "bar" assert objs[1].foo == "baz" def test_length(self): - objs = self.db._fetch(ModelFixture1) + objs = self.db._fetch_model(ModelFixture1) assert len(objs) == 2 def test_out_of_range(self): - objs = self.db._fetch(ModelFixture1) + objs = self.db._fetch_model(ModelFixture1) with pytest.raises(IndexError): objs[100] def test_no_results(self): assert ( - self.db._fetch(ModelFixture1, dbcore.query.FalseQuery()).get() + self.db._fetch_model(ModelFixture1, dbcore.query.FalseQuery()).get() is None )