Skip to content

Commit 7234d74

Browse files
committed
Upgrade musicbrainz and discogs to SearchApiMetadataSourcePlugin
And centralise common search functionality inside the parent class
1 parent 17bc110 commit 7234d74

File tree

6 files changed

+143
-244
lines changed

6 files changed

+143
-244
lines changed

beets/metadata_plugins.py

Lines changed: 62 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,32 @@
1111
import inspect
1212
import re
1313
import warnings
14-
from typing import TYPE_CHECKING, Generic, Literal, Sequence, TypedDict, TypeVar
14+
from typing import (
15+
TYPE_CHECKING,
16+
Generic,
17+
Literal,
18+
NamedTuple,
19+
TypedDict,
20+
TypeVar,
21+
)
1522

1623
import unidecode
17-
from typing_extensions import NotRequired
1824

1925
from beets.util import cached_classproperty
2026
from beets.util.id_extractors import extract_release_id
2127

2228
from .plugins import BeetsPlugin, find_plugins, notify_info_yielded, send
2329

2430
if TYPE_CHECKING:
25-
from collections.abc import Iterable
31+
from collections.abc import Iterable, Sequence
2632

2733
from confuse import ConfigView
2834

2935
from .autotag import Distance
3036
from .autotag.hooks import AlbumInfo, Item, TrackInfo
3137

38+
QueryType = Literal["album", "track"]
39+
3240

3341
def find_metadata_source_plugins() -> list[MetadataSourcePlugin]:
3442
"""Returns a list of MetadataSourcePlugin subclass instances
@@ -203,7 +211,7 @@ def item_candidates(
203211
"""
204212
raise NotImplementedError
205213

206-
def albums_for_ids(self, ids: Sequence[str]) -> Iterable[AlbumInfo | None]:
214+
def albums_for_ids(self, ids: Iterable[str]) -> Iterable[AlbumInfo | None]:
207215
"""Batch lookup of album metadata for a list of album IDs.
208216
209217
Given a list of album identifiers, yields corresponding AlbumInfo objects.
@@ -214,7 +222,7 @@ def albums_for_ids(self, ids: Sequence[str]) -> Iterable[AlbumInfo | None]:
214222

215223
return (self.album_for_id(id) for id in ids)
216224

217-
def tracks_for_ids(self, ids: Sequence[str]) -> Iterable[TrackInfo | None]:
225+
def tracks_for_ids(self, ids: Iterable[str]) -> Iterable[TrackInfo | None]:
218226
"""Batch lookup of track metadata for a list of track IDs.
219227
220228
Given a list of track identifiers, yields corresponding TrackInfo objects.
@@ -320,12 +328,13 @@ class IDResponse(TypedDict):
320328
id: str
321329

322330

323-
class SearchFilter(TypedDict):
324-
artist: NotRequired[str]
325-
album: NotRequired[str]
331+
R = TypeVar("R", bound=IDResponse)
326332

327333

328-
R = TypeVar("R", bound=IDResponse)
334+
class SearchParams(NamedTuple):
335+
query_type: QueryType
336+
query: str
337+
filters: dict[str, str]
329338

330339

331340
class SearchApiMetadataSourcePlugin(
@@ -348,12 +357,26 @@ def __init__(self, *args, **kwargs) -> None:
348357
}
349358
)
350359

360+
def get_search_filters(
361+
self,
362+
query_type: QueryType,
363+
items: Sequence[Item],
364+
artist: str,
365+
name: str,
366+
va_likely: bool,
367+
) -> tuple[str, dict[str, str]]:
368+
query = f'album:"{name}"' if query_type == "album" else name
369+
if query_type == "track" or not va_likely:
370+
query += f' artist:"{artist}"'
371+
372+
return query, {}
373+
351374
@abc.abstractmethod
375+
def get_search_response(self, params: SearchParams) -> Sequence[R]:
376+
raise NotImplementedError
377+
352378
def _search_api(
353-
self,
354-
query_type: Literal["album", "track"],
355-
filters: SearchFilter,
356-
query_string: str = "",
379+
self, query_type: QueryType, query: str, filters: dict[str, str]
357380
) -> Sequence[R]:
358381
"""Perform a search on the API.
359382
@@ -363,7 +386,28 @@ def _search_api(
363386
364387
Should return a list of identifiers for the requested type (album or track).
365388
"""
366-
raise NotImplementedError
389+
if self.config["search_query_ascii"].get():
390+
query = unidecode.unidecode(query)
391+
392+
filters["limit"] = str(self.config["search_limit"].get())
393+
params = SearchParams(query_type, query, filters)
394+
395+
self._log.debug("Searching for '{}' with {}", query, filters)
396+
try:
397+
response_data = self.get_search_response(params)
398+
except Exception:
399+
self._log.error("Error fetching data", exc_info=True)
400+
return ()
401+
402+
self._log.debug("Found {} result(s)", len(response_data))
403+
return response_data
404+
405+
def _get_candidates(
406+
self, query_type: QueryType, *args, **kwargs
407+
) -> Sequence[R]:
408+
return self._search_api(
409+
query_type, *self.get_search_filters(query_type, *args, **kwargs)
410+
)
367411

368412
def candidates(
369413
self,
@@ -372,55 +416,14 @@ def candidates(
372416
album: str,
373417
va_likely: bool,
374418
) -> Iterable[AlbumInfo]:
375-
query_filters: SearchFilter = {"album": album}
376-
if not va_likely:
377-
query_filters["artist"] = artist
378-
379-
results = self._search_api("album", query_filters)
380-
if not results:
381-
return []
382-
383-
return filter(
384-
None, self.albums_for_ids([result["id"] for result in results])
385-
)
419+
results = self._get_candidates("album", items, artist, album, va_likely)
420+
return filter(None, self.albums_for_ids(r["id"] for r in results))
386421

387422
def item_candidates(
388423
self, item: Item, artist: str, title: str
389424
) -> Iterable[TrackInfo]:
390-
results = self._search_api(
391-
"track", {"artist": artist}, query_string=title
392-
)
393-
if not results:
394-
return []
395-
396-
return filter(
397-
None,
398-
self.tracks_for_ids([result["id"] for result in results if result]),
399-
)
400-
401-
def _construct_search_query(
402-
self, filters: SearchFilter, query_string: str
403-
) -> str:
404-
"""Construct a query string with the specified filters and keywords to
405-
be provided to the spotify (or similar) search API.
406-
407-
The returned format was initially designed for spotify's search API but
408-
we found is also useful with other APIs that support similar query structures.
409-
see `spotify <https://developer.spotify.com/documentation/web-api/reference/search>`_
410-
and `deezer <https://developers.deezer.com/api/search>`_.
411-
412-
:param filters: Field filters to apply.
413-
:param query_string: Query keywords to use.
414-
:return: Query string to be provided to the search API.
415-
"""
416-
417-
components = [query_string, *(f'{k}:"{v}"' for k, v in filters.items())]
418-
query = " ".join(filter(None, components))
419-
420-
if self.config["search_query_ascii"].get():
421-
query = unidecode.unidecode(query)
422-
423-
return query
425+
results = self._get_candidates("track", [item], artist, title, False)
426+
return filter(None, self.tracks_for_ids(r["id"] for r in results))
424427

425428

426429
# Dynamically copy methods to BeetsPlugin for legacy support

beetsplug/deezer.py

Lines changed: 9 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,18 @@
1818

1919
import collections
2020
import time
21-
from typing import TYPE_CHECKING, Literal, Sequence
21+
from typing import TYPE_CHECKING, Sequence
2222

2323
import requests
2424

2525
from beets import ui
2626
from beets.autotag import AlbumInfo, TrackInfo
2727
from beets.dbcore import types
28-
from beets.metadata_plugins import (
29-
IDResponse,
30-
SearchApiMetadataSourcePlugin,
31-
SearchFilter,
32-
)
28+
from beets.metadata_plugins import IDResponse, SearchApiMetadataSourcePlugin
3329

3430
if TYPE_CHECKING:
3531
from beets.library import Item, Library
32+
from beets.metadata_plugins import SearchParams
3633

3734
from ._typing import JSONDict
3835

@@ -218,58 +215,12 @@ def _get_track(self, track_data: JSONDict) -> TrackInfo:
218215
deezer_updated=time.time(),
219216
)
220217

221-
def _search_api(
222-
self,
223-
query_type: Literal[
224-
"album",
225-
"track",
226-
"artist",
227-
"history",
228-
"playlist",
229-
"podcast",
230-
"radio",
231-
"user",
232-
],
233-
filters: SearchFilter,
234-
query_string: str = "",
235-
) -> Sequence[IDResponse]:
236-
"""Query the Deezer Search API for the specified ``query_string``, applying
237-
the provided ``filters``.
238-
239-
:param filters: Field filters to apply.
240-
:param query_string: Additional query to include in the search.
241-
:return: JSON data for the class:`Response <Response>` object or None
242-
if no search results are returned.
243-
"""
244-
query = self._construct_search_query(
245-
query_string=query_string, filters=filters
246-
)
247-
self._log.debug("Searching {.data_source} for '{}'", self, query)
248-
try:
249-
response = requests.get(
250-
f"{self.search_url}{query_type}",
251-
params={
252-
"q": query,
253-
"limit": self.config["search_limit"].get(),
254-
},
255-
timeout=10,
256-
)
257-
response.raise_for_status()
258-
except requests.exceptions.RequestException as e:
259-
self._log.error(
260-
"Error fetching data from {.data_source} API\n Error: {}",
261-
self,
262-
e,
263-
)
264-
return ()
265-
response_data: Sequence[IDResponse] = response.json().get("data", [])
266-
self._log.debug(
267-
"Found {} result(s) from {.data_source} for '{}'",
268-
len(response_data),
269-
self,
270-
query,
271-
)
272-
return response_data
218+
def get_search_response(self, params: SearchParams) -> list[IDResponse]:
219+
return requests.get(
220+
f"{self.search_url}{params.query_type}",
221+
params={**params.filters, "q": params.query},
222+
timeout=10,
223+
).json()["data"]
273224

274225
def deezerupdate(self, items: Sequence[Item], write: bool):
275226
"""Obtain rank information from Deezer."""

beetsplug/discogs.py

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,13 @@
4040
from beets import config
4141
from beets.autotag.distance import string_dist
4242
from beets.autotag.hooks import AlbumInfo, TrackInfo
43-
from beets.metadata_plugins import MetadataSourcePlugin
43+
from beets.metadata_plugins import IDResponse, SearchApiMetadataSourcePlugin
4444

4545
if TYPE_CHECKING:
46-
from collections.abc import Callable, Iterable
46+
from collections.abc import Callable, Iterator
4747

4848
from beets.library import Item
49+
from beets.metadata_plugins import QueryType, SearchParams
4950

5051
USER_AGENT = f"beets/{beets.__version__} +https://beets.io/"
5152
API_KEY = "rAzVUQYRaoFjeBjyWuWZ"
@@ -83,7 +84,7 @@ class ReleaseFormat(TypedDict):
8384
descriptions: list[str] | None
8485

8586

86-
class DiscogsPlugin(MetadataSourcePlugin):
87+
class DiscogsPlugin(SearchApiMetadataSourcePlugin[IDResponse]):
8788
def __init__(self):
8889
super().__init__()
8990
self.config.add(
@@ -167,11 +168,6 @@ def authenticate(self, c_key, c_secret):
167168

168169
return token, secret
169170

170-
def candidates(
171-
self, items: Sequence[Item], artist: str, album: str, va_likely: bool
172-
) -> Iterable[AlbumInfo]:
173-
return self.get_albums(f"{artist} {album}" if va_likely else album)
174-
175171
def get_track_from_album(
176172
self, album_info: AlbumInfo, compare: Callable[[TrackInfo], float]
177173
) -> TrackInfo | None:
@@ -188,21 +184,19 @@ def get_track_from_album(
188184

189185
def item_candidates(
190186
self, item: Item, artist: str, title: str
191-
) -> Iterable[TrackInfo]:
187+
) -> Iterator[TrackInfo]:
192188
albums = self.candidates([item], artist, title, False)
193189

194190
def compare_func(track_info: TrackInfo) -> float:
195191
return string_dist(track_info.title, title)
196192

197193
tracks = (self.get_track_from_album(a, compare_func) for a in albums)
198-
return list(filter(None, tracks))
194+
return filter(None, tracks)
199195

200196
def album_for_id(self, album_id: str) -> AlbumInfo | None:
201197
"""Fetches an album by its Discogs ID and returns an AlbumInfo object
202198
or None if the album is not found.
203199
"""
204-
self._log.debug("Searching for release {}", album_id)
205-
206200
discogs_id = self._extract_id(album_id)
207201

208202
if not discogs_id:
@@ -236,29 +230,25 @@ def track_for_id(self, track_id: str) -> TrackInfo | None:
236230

237231
return None
238232

239-
def get_albums(self, query: str) -> Iterable[AlbumInfo]:
240-
"""Returns a list of AlbumInfo objects for a discogs search query."""
241-
# Strip non-word characters from query. Things like "!" and "-" can
242-
# cause a query to return no results, even if they match the artist or
243-
# album title. Use `re.UNICODE` flag to avoid stripping non-english
244-
# word characters.
245-
query = re.sub(r"(?u)\W+", " ", query)
246-
# Strip medium information from query, Things like "CD1" and "disk 1"
247-
# can also negate an otherwise positive result.
248-
query = re.sub(r"(?i)\b(CD|disc|vinyl)\s*\d+", "", query)
233+
def _get_search_filters(
234+
self,
235+
query_type: QueryType,
236+
items: Sequence[Item],
237+
artist: str,
238+
name: str,
239+
va_likely: bool,
240+
) -> tuple[str, dict[str, str]]:
241+
if va_likely:
242+
artist = items[0].artist
249243

250-
try:
251-
results = self.discogs_client.search(query, type="release")
252-
results.per_page = self.config["search_limit"].get()
253-
releases = results.page(1)
254-
except CONNECTION_ERRORS:
255-
self._log.debug(
256-
"Communication error while searching for {0!r}",
257-
query,
258-
exc_info=True,
259-
)
260-
return []
261-
return filter(None, map(self.get_album_info, releases))
244+
return f"{artist} - {name}", {"type": "release"}
245+
246+
def get_search_response(self, params: SearchParams) -> Sequence[IDResponse]:
247+
"""Returns a list of AlbumInfo objects for a discogs search query."""
248+
limit = params.filters.pop("limit")
249+
results = self.discogs_client.search(params.query, **params.filters)
250+
results.per_page = limit
251+
return [r.data for r in results.page(1)]
262252

263253
@cache
264254
def get_master_year(self, master_id: str) -> int | None:

0 commit comments

Comments
 (0)