Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 62 additions & 59 deletions beets/metadata_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,32 @@
import inspect
import re
import warnings
from typing import TYPE_CHECKING, Generic, Literal, Sequence, TypedDict, TypeVar
from typing import (
TYPE_CHECKING,
Generic,
Literal,
NamedTuple,
TypedDict,
TypeVar,
)

import unidecode
from typing_extensions import NotRequired

from beets.util import cached_classproperty
from beets.util.id_extractors import extract_release_id

from .plugins import BeetsPlugin, find_plugins, notify_info_yielded, send

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Iterable, Sequence

from confuse import ConfigView

from .autotag import Distance
from .autotag.hooks import AlbumInfo, Item, TrackInfo

QueryType = Literal["album", "track"]


def find_metadata_source_plugins() -> list[MetadataSourcePlugin]:
"""Returns a list of MetadataSourcePlugin subclass instances
Expand Down Expand Up @@ -203,7 +211,7 @@ def item_candidates(
"""
raise NotImplementedError

def albums_for_ids(self, ids: Sequence[str]) -> Iterable[AlbumInfo | None]:
def albums_for_ids(self, ids: Iterable[str]) -> Iterable[AlbumInfo | None]:
"""Batch lookup of album metadata for a list of album IDs.
Given a list of album identifiers, yields corresponding AlbumInfo objects.
Expand All @@ -214,7 +222,7 @@ def albums_for_ids(self, ids: Sequence[str]) -> Iterable[AlbumInfo | None]:

return (self.album_for_id(id) for id in ids)

def tracks_for_ids(self, ids: Sequence[str]) -> Iterable[TrackInfo | None]:
def tracks_for_ids(self, ids: Iterable[str]) -> Iterable[TrackInfo | None]:
"""Batch lookup of track metadata for a list of track IDs.
Given a list of track identifiers, yields corresponding TrackInfo objects.
Expand Down Expand Up @@ -320,12 +328,13 @@ class IDResponse(TypedDict):
id: str


class SearchFilter(TypedDict):
artist: NotRequired[str]
album: NotRequired[str]
R = TypeVar("R", bound=IDResponse)


R = TypeVar("R", bound=IDResponse)
class SearchParams(NamedTuple):
query_type: QueryType
query: str
filters: dict[str, str]


class SearchApiMetadataSourcePlugin(
Expand All @@ -348,12 +357,26 @@ def __init__(self, *args, **kwargs) -> None:
}
)

def get_search_filters(
self,
query_type: QueryType,
items: Sequence[Item],
artist: str,
name: str,
va_likely: bool,
) -> tuple[str, dict[str, str]]:
query = f'album:"{name}"' if query_type == "album" else name
if query_type == "track" or not va_likely:
query += f' artist:"{artist}"'

return query, {}

@abc.abstractmethod
def get_search_response(self, params: SearchParams) -> Sequence[R]:
raise NotImplementedError

def _search_api(
self,
query_type: Literal["album", "track"],
filters: SearchFilter,
query_string: str = "",
self, query_type: QueryType, query: str, filters: dict[str, str]
) -> Sequence[R]:
"""Perform a search on the API.
Expand All @@ -363,7 +386,28 @@ def _search_api(
Should return a list of identifiers for the requested type (album or track).
"""
raise NotImplementedError
if self.config["search_query_ascii"].get():
query = unidecode.unidecode(query)

filters["limit"] = str(self.config["search_limit"].get())
params = SearchParams(query_type, query, filters)

self._log.debug("Searching for '{}' with {}", query, filters)
try:
response_data = self.get_search_response(params)
except Exception:
self._log.error("Error fetching data", exc_info=True)
return ()

self._log.debug("Found {} result(s)", len(response_data))
return response_data

def _get_candidates(
self, query_type: QueryType, *args, **kwargs
) -> Sequence[R]:
return self._search_api(
query_type, *self.get_search_filters(query_type, *args, **kwargs)
)

def candidates(
self,
Expand All @@ -372,55 +416,14 @@ def candidates(
album: str,
va_likely: bool,
) -> Iterable[AlbumInfo]:
query_filters: SearchFilter = {"album": album}
if not va_likely:
query_filters["artist"] = artist

results = self._search_api("album", query_filters)
if not results:
return []

return filter(
None, self.albums_for_ids([result["id"] for result in results])
)
results = self._get_candidates("album", items, artist, album, va_likely)
return filter(None, self.albums_for_ids(r["id"] for r in results))

def item_candidates(
self, item: Item, artist: str, title: str
) -> Iterable[TrackInfo]:
results = self._search_api(
"track", {"artist": artist}, query_string=title
)
if not results:
return []

return filter(
None,
self.tracks_for_ids([result["id"] for result in results if result]),
)

def _construct_search_query(
self, filters: SearchFilter, query_string: str
) -> str:
"""Construct a query string with the specified filters and keywords to
be provided to the spotify (or similar) search API.
The returned format was initially designed for spotify's search API but
we found is also useful with other APIs that support similar query structures.
see `spotify <https://developer.spotify.com/documentation/web-api/reference/search>`_
and `deezer <https://developers.deezer.com/api/search>`_.
:param filters: Field filters to apply.
:param query_string: Query keywords to use.
:return: Query string to be provided to the search API.
"""

components = [query_string, *(f'{k}:"{v}"' for k, v in filters.items())]
query = " ".join(filter(None, components))

if self.config["search_query_ascii"].get():
query = unidecode.unidecode(query)

return query
results = self._get_candidates("track", [item], artist, title, False)
return filter(None, self.tracks_for_ids(r["id"] for r in results))


# Dynamically copy methods to BeetsPlugin for legacy support
Expand Down
67 changes: 9 additions & 58 deletions beetsplug/deezer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,18 @@

import collections
import time
from typing import TYPE_CHECKING, Literal, Sequence
from typing import TYPE_CHECKING, Sequence

import requests

from beets import ui
from beets.autotag import AlbumInfo, TrackInfo
from beets.dbcore import types
from beets.metadata_plugins import (
IDResponse,
SearchApiMetadataSourcePlugin,
SearchFilter,
)
from beets.metadata_plugins import IDResponse, SearchApiMetadataSourcePlugin

if TYPE_CHECKING:
from beets.library import Item, Library
from beets.metadata_plugins import SearchParams

from ._typing import JSONDict

Expand Down Expand Up @@ -218,58 +215,12 @@ def _get_track(self, track_data: JSONDict) -> TrackInfo:
deezer_updated=time.time(),
)

def _search_api(
self,
query_type: Literal[
"album",
"track",
"artist",
"history",
"playlist",
"podcast",
"radio",
"user",
],
filters: SearchFilter,
query_string: str = "",
) -> Sequence[IDResponse]:
"""Query the Deezer Search API for the specified ``query_string``, applying
the provided ``filters``.
:param filters: Field filters to apply.
:param query_string: Additional query to include in the search.
:return: JSON data for the class:`Response <Response>` object or None
if no search results are returned.
"""
query = self._construct_search_query(
query_string=query_string, filters=filters
)
self._log.debug("Searching {.data_source} for '{}'", self, query)
try:
response = requests.get(
f"{self.search_url}{query_type}",
params={
"q": query,
"limit": self.config["search_limit"].get(),
},
timeout=10,
)
response.raise_for_status()
except requests.exceptions.RequestException as e:
self._log.error(
"Error fetching data from {.data_source} API\n Error: {}",
self,
e,
)
return ()
response_data: Sequence[IDResponse] = response.json().get("data", [])
self._log.debug(
"Found {} result(s) from {.data_source} for '{}'",
len(response_data),
self,
query,
)
return response_data
def get_search_response(self, params: SearchParams) -> list[IDResponse]:
return requests.get(
f"{self.search_url}{params.query_type}",
params={**params.filters, "q": params.query},
timeout=10,
).json()["data"]

def deezerupdate(self, items: Sequence[Item], write: bool):
"""Obtain rank information from Deezer."""
Expand Down
58 changes: 24 additions & 34 deletions beetsplug/discogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@
from beets import config
from beets.autotag.distance import string_dist
from beets.autotag.hooks import AlbumInfo, TrackInfo
from beets.metadata_plugins import MetadataSourcePlugin
from beets.metadata_plugins import IDResponse, SearchApiMetadataSourcePlugin

if TYPE_CHECKING:
from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterator

from beets.library import Item
from beets.metadata_plugins import QueryType, SearchParams

USER_AGENT = f"beets/{beets.__version__} +https://beets.io/"
API_KEY = "rAzVUQYRaoFjeBjyWuWZ"
Expand Down Expand Up @@ -83,7 +84,7 @@ class ReleaseFormat(TypedDict):
descriptions: list[str] | None


class DiscogsPlugin(MetadataSourcePlugin):
class DiscogsPlugin(SearchApiMetadataSourcePlugin[IDResponse]):
def __init__(self):
super().__init__()
self.config.add(
Expand Down Expand Up @@ -167,11 +168,6 @@ def authenticate(self, c_key, c_secret):

return token, secret

def candidates(
self, items: Sequence[Item], artist: str, album: str, va_likely: bool
) -> Iterable[AlbumInfo]:
return self.get_albums(f"{artist} {album}" if va_likely else album)

def get_track_from_album(
self, album_info: AlbumInfo, compare: Callable[[TrackInfo], float]
) -> TrackInfo | None:
Expand All @@ -188,21 +184,19 @@ def get_track_from_album(

def item_candidates(
self, item: Item, artist: str, title: str
) -> Iterable[TrackInfo]:
) -> Iterator[TrackInfo]:
albums = self.candidates([item], artist, title, False)

def compare_func(track_info: TrackInfo) -> float:
return string_dist(track_info.title, title)

tracks = (self.get_track_from_album(a, compare_func) for a in albums)
return list(filter(None, tracks))
return filter(None, tracks)

def album_for_id(self, album_id: str) -> AlbumInfo | None:
"""Fetches an album by its Discogs ID and returns an AlbumInfo object
or None if the album is not found.
"""
self._log.debug("Searching for release {}", album_id)

discogs_id = self._extract_id(album_id)

if not discogs_id:
Expand Down Expand Up @@ -236,29 +230,25 @@ def track_for_id(self, track_id: str) -> TrackInfo | None:

return None

def get_albums(self, query: str) -> Iterable[AlbumInfo]:
"""Returns a list of AlbumInfo objects for a discogs search query."""
# Strip non-word characters from query. Things like "!" and "-" can
# cause a query to return no results, even if they match the artist or
# album title. Use `re.UNICODE` flag to avoid stripping non-english
# word characters.
query = re.sub(r"(?u)\W+", " ", query)
# Strip medium information from query, Things like "CD1" and "disk 1"
# can also negate an otherwise positive result.
query = re.sub(r"(?i)\b(CD|disc|vinyl)\s*\d+", "", query)
def get_search_filters(
self,
query_type: QueryType,
items: Sequence[Item],
artist: str,
name: str,
va_likely: bool,
) -> tuple[str, dict[str, str]]:
if va_likely:
artist = items[0].artist
Comment on lines +241 to +242
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question (bug_risk): Possible unintended override of artist for VA releases.

Using items[0].artist for VA releases could be inaccurate if items include multiple artists. Please verify that this logic aligns with the expected handling of VA releases.


try:
results = self.discogs_client.search(query, type="release")
results.per_page = self.config["search_limit"].get()
releases = results.page(1)
except CONNECTION_ERRORS:
self._log.debug(
"Communication error while searching for {0!r}",
query,
exc_info=True,
)
return []
return filter(None, map(self.get_album_info, releases))
return f"{artist} - {name}", {"type": "release"}

def get_search_response(self, params: SearchParams) -> Sequence[IDResponse]:
"""Returns a list of AlbumInfo objects for a discogs search query."""
limit = params.filters.pop("limit")
results = self.discogs_client.search(params.query, **params.filters)
results.per_page = limit
return [r.data for r in results.page(1)]

@cache
def get_master_year(self, master_id: str) -> int | None:
Expand Down
Loading
Loading