diff --git a/beets/metadata_plugins.py b/beets/metadata_plugins.py index f42e8f690a..039691518a 100644 --- a/beets/metadata_plugins.py +++ b/beets/metadata_plugins.py @@ -10,11 +10,17 @@ import abc import re from functools import cache, cached_property -from typing import TYPE_CHECKING, Generic, Literal, TypedDict, TypeVar +from typing import ( + TYPE_CHECKING, + Generic, + Literal, + NamedTuple, + TypedDict, + TypeVar, +) import unidecode from confuse import NotFoundError -from typing_extensions import NotRequired from beets.util import cached_classproperty from beets.util.id_extractors import extract_release_id @@ -26,6 +32,8 @@ from .autotag.hooks import AlbumInfo, Item, TrackInfo + QueryType = Literal["album", "track"] + @cache def find_metadata_source_plugins() -> list[MetadataSourcePlugin]: @@ -169,7 +177,7 @@ def item_candidates( """ raise NotImplementedError - def albums_for_ids(self, ids: Sequence[str]) -> Iterable[AlbumInfo | None]: + def albums_for_ids(self, ids: Iterable[str]) -> Iterable[AlbumInfo | None]: """Batch lookup of album metadata for a list of album IDs. Given a list of album identifiers, yields corresponding AlbumInfo objects. @@ -180,7 +188,7 @@ def albums_for_ids(self, ids: Sequence[str]) -> Iterable[AlbumInfo | None]: return (self.album_for_id(id) for id in ids) - def tracks_for_ids(self, ids: Sequence[str]) -> Iterable[TrackInfo | None]: + def tracks_for_ids(self, ids: Iterable[str]) -> Iterable[TrackInfo | None]: """Batch lookup of track metadata for a list of track IDs. Given a list of track identifiers, yields corresponding TrackInfo objects. @@ -254,12 +262,13 @@ class IDResponse(TypedDict): id: str -class SearchFilter(TypedDict): - artist: NotRequired[str] - album: NotRequired[str] +R = TypeVar("R", bound=IDResponse) -R = TypeVar("R", bound=IDResponse) +class SearchParams(NamedTuple): + query_type: QueryType + query: str + filters: dict[str, str] class SearchApiMetadataSourcePlugin( @@ -282,12 +291,26 @@ def __init__(self, *args, **kwargs) -> None: } ) + def get_search_filters( + self, + query_type: QueryType, + items: Sequence[Item], + artist: str, + name: str, + va_likely: bool, + ) -> tuple[str, dict[str, str]]: + query = f'album:"{name}"' if query_type == "album" else name + if query_type == "track" or not va_likely: + query += f' artist:"{artist}"' + + return query, {} + @abc.abstractmethod + def get_search_response(self, params: SearchParams) -> Sequence[R]: + raise NotImplementedError + def _search_api( - self, - query_type: Literal["album", "track"], - filters: SearchFilter, - query_string: str = "", + self, query_type: QueryType, query: str, filters: dict[str, str] ) -> Sequence[R]: """Perform a search on the API. @@ -297,7 +320,28 @@ def _search_api( Should return a list of identifiers for the requested type (album or track). """ - raise NotImplementedError + if self.config["search_query_ascii"].get(): + query = unidecode.unidecode(query) + + filters["limit"] = str(self.config["search_limit"].get()) + params = SearchParams(query_type, query, filters) + + self._log.debug("Searching for '{}' with {}", query, filters) + try: + response_data = self.get_search_response(params) + except Exception: + self._log.error("Error fetching data", exc_info=True) + return () + + self._log.debug("Found {} result(s)", len(response_data)) + return response_data + + def _get_candidates( + self, query_type: QueryType, *args, **kwargs + ) -> Sequence[R]: + return self._search_api( + query_type, *self.get_search_filters(query_type, *args, **kwargs) + ) def candidates( self, @@ -306,54 +350,11 @@ def candidates( album: str, va_likely: bool, ) -> Iterable[AlbumInfo]: - query_filters: SearchFilter = {} - if album: - query_filters["album"] = album - if not va_likely: - query_filters["artist"] = artist - - results = self._search_api("album", query_filters) - if not results: - return [] - - return filter( - None, self.albums_for_ids([result["id"] for result in results]) - ) + results = self._get_candidates("album", items, artist, album, va_likely) + return filter(None, self.albums_for_ids(r["id"] for r in results)) def item_candidates( self, item: Item, artist: str, title: str ) -> Iterable[TrackInfo]: - results = self._search_api( - "track", {"artist": artist}, query_string=title - ) - if not results: - return [] - - return filter( - None, - self.tracks_for_ids([result["id"] for result in results if result]), - ) - - def _construct_search_query( - self, filters: SearchFilter, query_string: str - ) -> str: - """Construct a query string with the specified filters and keywords to - be provided to the spotify (or similar) search API. - - The returned format was initially designed for spotify's search API but - we found is also useful with other APIs that support similar query structures. - see `spotify `_ - and `deezer `_. - - :param filters: Field filters to apply. - :param query_string: Query keywords to use. - :return: Query string to be provided to the search API. - """ - - components = [query_string, *(f"{k}:'{v}'" for k, v in filters.items())] - query = " ".join(filter(None, components)) - - if self.config["search_query_ascii"].get(): - query = unidecode.unidecode(query) - - return query + results = self._get_candidates("track", [item], artist, title, False) + return filter(None, self.tracks_for_ids(r["id"] for r in results)) diff --git a/beetsplug/deezer.py b/beetsplug/deezer.py index ef27dddc7a..62dd1038b5 100644 --- a/beetsplug/deezer.py +++ b/beetsplug/deezer.py @@ -18,23 +18,20 @@ import collections import time -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING import requests from beets import ui from beets.autotag import AlbumInfo, TrackInfo from beets.dbcore import types -from beets.metadata_plugins import ( - IDResponse, - SearchApiMetadataSourcePlugin, - SearchFilter, -) +from beets.metadata_plugins import IDResponse, SearchApiMetadataSourcePlugin if TYPE_CHECKING: from collections.abc import Sequence from beets.library import Item, Library + from beets.metadata_plugins import SearchParams from ._typing import JSONDict @@ -220,58 +217,12 @@ def _get_track(self, track_data: JSONDict) -> TrackInfo: deezer_updated=time.time(), ) - def _search_api( - self, - query_type: Literal[ - "album", - "track", - "artist", - "history", - "playlist", - "podcast", - "radio", - "user", - ], - filters: SearchFilter, - query_string: str = "", - ) -> Sequence[IDResponse]: - """Query the Deezer Search API for the specified ``query_string``, applying - the provided ``filters``. - - :param filters: Field filters to apply. - :param query_string: Additional query to include in the search. - :return: JSON data for the class:`Response ` object or None - if no search results are returned. - """ - query = self._construct_search_query( - query_string=query_string, filters=filters - ) - self._log.debug("Searching {.data_source} for '{}'", self, query) - try: - response = requests.get( - f"{self.search_url}{query_type}", - params={ - "q": query, - "limit": self.config["search_limit"].get(), - }, - timeout=10, - ) - response.raise_for_status() - except requests.exceptions.RequestException as e: - self._log.error( - "Error fetching data from {.data_source} API\n Error: {}", - self, - e, - ) - return () - response_data: Sequence[IDResponse] = response.json().get("data", []) - self._log.debug( - "Found {} result(s) from {.data_source} for '{}'", - len(response_data), - self, - query, - ) - return response_data + def get_search_response(self, params: SearchParams) -> list[IDResponse]: + return requests.get( + f"{self.search_url}{params.query_type}", + params={**params.filters, "q": params.query}, + timeout=10, + ).json()["data"] def deezerupdate(self, items: Sequence[Item], write: bool): """Obtain rank information from Deezer.""" diff --git a/beetsplug/discogs.py b/beetsplug/discogs.py index 29600a6760..5361dad9f6 100644 --- a/beetsplug/discogs.py +++ b/beetsplug/discogs.py @@ -40,12 +40,13 @@ from beets import config from beets.autotag.distance import string_dist from beets.autotag.hooks import AlbumInfo, TrackInfo -from beets.metadata_plugins import MetadataSourcePlugin +from beets.metadata_plugins import IDResponse, SearchApiMetadataSourcePlugin if TYPE_CHECKING: - from collections.abc import Callable, Iterable, Sequence + from collections.abc import Callable, Iterable, Iterator, Sequence from beets.library import Item + from beets.metadata_plugins import QueryType, SearchParams USER_AGENT = f"beets/{beets.__version__} +https://beets.io/" API_KEY = "rAzVUQYRaoFjeBjyWuWZ" @@ -121,7 +122,7 @@ def __init__( super().__init__(**kwargs) -class DiscogsPlugin(MetadataSourcePlugin): +class DiscogsPlugin(SearchApiMetadataSourcePlugin[IDResponse]): def __init__(self): super().__init__() self.config.add( @@ -211,11 +212,6 @@ def authenticate(self, c_key: str, c_secret: str) -> tuple[str, str]: return token, secret - def candidates( - self, items: Sequence[Item], artist: str, album: str, va_likely: bool - ) -> Iterable[AlbumInfo]: - return self.get_albums(f"{artist} {album}" if va_likely else album) - def get_track_from_album( self, album_info: AlbumInfo, compare: Callable[[TrackInfo], float] ) -> TrackInfo | None: @@ -232,21 +228,19 @@ def get_track_from_album( def item_candidates( self, item: Item, artist: str, title: str - ) -> Iterable[TrackInfo]: + ) -> Iterator[TrackInfo]: albums = self.candidates([item], artist, title, False) def compare_func(track_info: TrackInfo) -> float: return string_dist(track_info.title, title) tracks = (self.get_track_from_album(a, compare_func) for a in albums) - return list(filter(None, tracks)) + return filter(None, tracks) def album_for_id(self, album_id: str) -> AlbumInfo | None: """Fetches an album by its Discogs ID and returns an AlbumInfo object or None if the album is not found. """ - self._log.debug("Searching for release {}", album_id) - discogs_id = self._extract_id(album_id) if not discogs_id: @@ -280,29 +274,25 @@ def track_for_id(self, track_id: str) -> TrackInfo | None: return None - def get_albums(self, query: str) -> Iterable[AlbumInfo]: + def get_search_filters( + self, + query_type: QueryType, + items: Sequence[Item], + artist: str, + name: str, + va_likely: bool, + ) -> tuple[str, dict[str, str]]: + if va_likely: + artist = items[0].artist + + return f"{artist} - {name}", {"type": "release"} + + def get_search_response(self, params: SearchParams) -> Sequence[IDResponse]: """Returns a list of AlbumInfo objects for a discogs search query.""" - # Strip non-word characters from query. Things like "!" and "-" can - # cause a query to return no results, even if they match the artist or - # album title. Use `re.UNICODE` flag to avoid stripping non-english - # word characters. - query = re.sub(r"(?u)\W+", " ", query) - # Strip medium information from query, Things like "CD1" and "disk 1" - # can also negate an otherwise positive result. - query = re.sub(r"(?i)\b(CD|disc|vinyl)\s*\d+", "", query) - - try: - results = self.discogs_client.search(query, type="release") - results.per_page = self.config["search_limit"].get() - releases = results.page(1) - except CONNECTION_ERRORS: - self._log.debug( - "Communication error while searching for {0!r}", - query, - exc_info=True, - ) - return [] - return filter(None, map(self.get_album_info, releases)) + limit = params.filters.pop("limit") + results = self.discogs_client.search(params.query, **params.filters) + results.per_page = limit + return [r.data for r in results.page(1)] @cache def get_master_year(self, master_id: str) -> int | None: diff --git a/beetsplug/musicbrainz.py b/beetsplug/musicbrainz.py index 3b49107ad6..afafda38cd 100644 --- a/beetsplug/musicbrainz.py +++ b/beetsplug/musicbrainz.py @@ -30,14 +30,14 @@ import beets import beets.autotag.hooks from beets import config, plugins, util -from beets.metadata_plugins import MetadataSourcePlugin +from beets.metadata_plugins import IDResponse, SearchApiMetadataSourcePlugin from beets.util.id_extractors import extract_release_id if TYPE_CHECKING: - from collections.abc import Iterable, Sequence - from typing import Literal + from collections.abc import Sequence from beets.library import Item + from beets.metadata_plugins import QueryType, SearchParams from ._typing import JSONDict @@ -369,7 +369,7 @@ def _merge_pseudo_and_actual_album( return merged -class MusicBrainzPlugin(MetadataSourcePlugin): +class MusicBrainzPlugin(SearchApiMetadataSourcePlugin[IDResponse]): def __init__(self): """Set up the python-musicbrainz-ngs module according to settings from the beets configuration. This should be called at startup. @@ -798,52 +798,27 @@ def get_album_criteria( return criteria - def _search_api( - self, - query_type: Literal["recording", "release"], - filters: dict[str, str], - ) -> list[JSONDict]: - """Perform MusicBrainz API search and return results. - - Execute a search against the MusicBrainz API for recordings or releases - using the provided criteria. Handles API errors by converting them into - MusicBrainzAPIError exceptions with contextual information. - """ - filters = { - k: _v for k, v in filters.items() if (_v := v.lower().strip()) - } - self._log.debug( - "Searching for MusicBrainz {}s with: {!r}", query_type, filters - ) - try: - method = getattr(musicbrainzngs, f"search_{query_type}s") - res = method(limit=self.config["search_limit"].get(), **filters) - except musicbrainzngs.MusicBrainzError as exc: - raise MusicBrainzAPIError( - exc, f"{query_type} search", filters, traceback.format_exc() - ) - return res[f"{query_type}-list"] - - def candidates( + def get_search_filters( self, + query_type: QueryType, items: Sequence[Item], artist: str, - album: str, + name: str, va_likely: bool, - ) -> Iterable[beets.autotag.hooks.AlbumInfo]: - criteria = self.get_album_criteria(items, artist, album, va_likely) - release_ids = (r["id"] for r in self._search_api("release", criteria)) - - yield from filter(None, map(self.album_for_id, release_ids)) + ) -> tuple[str, dict[str, str]]: + if query_type == "album": + criteria = self.get_album_criteria(items, artist, name, va_likely) + else: + criteria = {"artist": artist, "recording": name, "alias": name} - def item_candidates( - self, item: Item, artist: str, title: str - ) -> Iterable[beets.autotag.hooks.TrackInfo]: - criteria = {"artist": artist, "recording": title, "alias": title} + return "", { + k: _v for k, v in criteria.items() if (_v := v.lower().strip()) + } - yield from filter( - None, map(self.track_info, self._search_api("recording", criteria)) - ) + def get_search_response(self, params: SearchParams) -> Sequence[IDResponse]: + mb_entity = "release" if params.query_type == "album" else "recording" + method = getattr(musicbrainzngs, f"search_{mb_entity}s") + return method(**params.filters)[f"{mb_entity}-list"] def album_for_id( self, album_id: str diff --git a/beetsplug/spotify.py b/beetsplug/spotify.py index b3c653682d..74252719f0 100644 --- a/beetsplug/spotify.py +++ b/beetsplug/spotify.py @@ -39,7 +39,7 @@ from beets.metadata_plugins import ( IDResponse, SearchApiMetadataSourcePlugin, - SearchFilter, + SearchParams, ) if TYPE_CHECKING: @@ -447,11 +447,8 @@ def track_for_id(self, track_id: str) -> None | TrackInfo: track.medium_total = medium_total return track - def _search_api( - self, - query_type: Literal["album", "track"], - filters: SearchFilter, - query_string: str = "", + def get_search_response( + self, params: SearchParams ) -> Sequence[SearchResponseAlbums | SearchResponseTracks]: """Query the Spotify Search API for the specified ``query_string``, applying the provided ``filters``. @@ -460,34 +457,27 @@ def _search_api( 'artist', 'playlist', and 'track'. :param filters: Field filters to apply. :param query_string: Additional query to include in the search. - """ - query = self._construct_search_query( - filters=filters, query_string=query_string + response = requests.get( + self.search_url, + headers={"Authorization": f"Bearer {self.access_token}"}, + params={ + **params.filters, + "q": params.query, + "type": params.query_type, + }, + timeout=10, ) - - self._log.debug("Searching {.data_source} for '{}'", self, query) try: - response = self._handle_response( - "get", - self.search_url, - params={ - "q": query, - "type": query_type, - "limit": self.config["search_limit"].get(), - }, - ) - except APIError as e: - self._log.debug("Spotify API error: {}", e) - return () - response_data = response.get(f"{query_type}s", {}).get("items", []) - self._log.debug( - "Found {} result(s) from {.data_source} for '{}'", - len(response_data), - self, - query, - ) - return response_data + response.raise_for_status() + except requests.exceptions.HTTPError: + if response.status_code == 401: + self._authenticate() + return self.get_search_response(params) + + raise + + return response.json().get(f"{params.query_type}s", {}).get("items", []) def commands(self) -> list[ui.Subcommand]: # autotagger import command @@ -600,22 +590,14 @@ def _match_library_tracks(self, library: Library, keywords: str): query_string = item["title"] # Query the Web API for each track, look for the items' JSON data - query_filters: SearchFilter = {} + query = query_string if artist: - query_filters["artist"] = artist + query += f" artist:'{artist}'" if album: - query_filters["album"] = album + query += f" album:'{album}'" - response_data_tracks = self._search_api( - query_type="track", - query_string=query_string, - filters=query_filters, - ) + response_data_tracks = self._search_api("track", query, {}) if not response_data_tracks: - query = self._construct_search_query( - query_string=query_string, filters=query_filters - ) - failures.append(query) continue diff --git a/test/plugins/test_musicbrainz.py b/test/plugins/test_musicbrainz.py index 844b2ad4ef..15ba1c90ac 100644 --- a/test/plugins/test_musicbrainz.py +++ b/test/plugins/test_musicbrainz.py @@ -990,7 +990,7 @@ class TestMusicBrainzPlugin(PluginMixin): plugin = "musicbrainz" mbid = "d2a6f856-b553-40a0-ac54-a321e8e2da99" - RECORDING = {"title": "foo", "id": "bar", "length": 42} + RECORDING = {"title": "foo", "id": mbid, "length": 42} @pytest.fixture def plugin_config(self): @@ -1035,6 +1035,10 @@ def test_item_candidates(self, monkeypatch, mb): "musicbrainzngs.search_recordings", lambda *_, **__: {"recording-list": [self.RECORDING]}, ) + monkeypatch.setattr( + "musicbrainzngs.get_recording_by_id", + lambda *_, **__: {"recording": self.RECORDING}, + ) candidates = list(mb.item_candidates(Item(), "hello", "there")) diff --git a/test/plugins/test_spotify.py b/test/plugins/test_spotify.py index bc55485c67..9960ae5d88 100644 --- a/test/plugins/test_spotify.py +++ b/test/plugins/test_spotify.py @@ -81,6 +81,7 @@ def test_missing_request(self): params = _params(responses.calls[0].request.url) query = params["q"][0] + print(query) assert "duifhjslkef" in query assert "artist:'ujydfsuihse'" in query assert "album:'lkajsdflakjsd'" in query