Skip to content

Commit fc600e6

Browse files
committed
Added a proxy to catch and handle exceptions in metadataplugins during
the autotag process.
1 parent beda6fc commit fc600e6

File tree

2 files changed

+240
-9
lines changed

2 files changed

+240
-9
lines changed

beets/metadata_plugins.py

Lines changed: 143 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,26 @@
88
from __future__ import annotations
99

1010
import abc
11+
import inspect
1112
import re
12-
from functools import cache, cached_property
13-
from typing import TYPE_CHECKING, Generic, Literal, Sequence, TypedDict, TypeVar
13+
from functools import cache, cached_property, wraps
14+
from typing import (
15+
TYPE_CHECKING,
16+
Callable,
17+
ClassVar,
18+
Generic,
19+
Literal,
20+
Sequence,
21+
TypedDict,
22+
TypeVar,
23+
overload,
24+
)
1425

1526
import unidecode
1627
from confuse import NotFoundError
17-
from typing_extensions import NotRequired
28+
from typing_extensions import NotRequired, ParamSpec
1829

30+
from beets import config, logging
1931
from beets.util import cached_classproperty
2032
from beets.util.id_extractors import extract_release_id
2133

@@ -26,12 +38,18 @@
2638

2739
from .autotag.hooks import AlbumInfo, Item, TrackInfo
2840

41+
P = ParamSpec("P")
42+
R = TypeVar("R")
43+
44+
# Global logger.
45+
log = logging.getLogger("beets")
46+
2947

3048
@cache
3149
def find_metadata_source_plugins() -> list[MetadataSourcePlugin]:
3250
"""Return a list of all loaded metadata source plugins."""
3351
# TODO: Make this an isinstance(MetadataSourcePlugin, ...) check in v3.0.0
34-
return [p for p in find_plugins() if hasattr(p, "data_source")] # type: ignore[misc]
52+
return [SafeProxy(p) for p in find_plugins() if hasattr(p, "data_source")] # type: ignore[misc,arg-type]
3553

3654

3755
@notify_info_yielded("albuminfo_received")
@@ -43,7 +61,7 @@ def candidates(*args, **kwargs) -> Iterable[AlbumInfo]:
4361

4462
@notify_info_yielded("trackinfo_received")
4563
def item_candidates(*args, **kwargs) -> Iterable[TrackInfo]:
46-
"""Return matching track candidates fromm all metadata source plugins."""
64+
"""Return matching track candidates from all metadata source plugins."""
4765
for plugin in find_metadata_source_plugins():
4866
yield from plugin.item_candidates(*args, **kwargs)
4967

@@ -54,7 +72,7 @@ def album_for_id(_id: str) -> AlbumInfo | None:
5472
A single ID can yield just a single album, so we return the first match.
5573
"""
5674
for plugin in find_metadata_source_plugins():
57-
if info := plugin.album_for_id(album_id=_id):
75+
if info := plugin.album_for_id(_id):
5876
send("albuminfo_received", info=info)
5977
return info
6078

@@ -259,11 +277,11 @@ class SearchFilter(TypedDict):
259277
album: NotRequired[str]
260278

261279

262-
R = TypeVar("R", bound=IDResponse)
280+
Res = TypeVar("Res", bound=IDResponse)
263281

264282

265283
class SearchApiMetadataSourcePlugin(
266-
Generic[R], MetadataSourcePlugin, metaclass=abc.ABCMeta
284+
Generic[Res], MetadataSourcePlugin, metaclass=abc.ABCMeta
267285
):
268286
"""Helper class to implement a metadata source plugin with an API.
269287
@@ -288,7 +306,7 @@ def _search_api(
288306
query_type: Literal["album", "track"],
289307
filters: SearchFilter,
290308
query_string: str = "",
291-
) -> Sequence[R]:
309+
) -> Sequence[Res]:
292310
"""Perform a search on the API.
293311
294312
:param query_type: The type of query to perform.
@@ -357,3 +375,119 @@ def _construct_search_query(
357375
query = unidecode.unidecode(query)
358376

359377
return query
378+
379+
380+
# To have proper typing for the proxy class below, we need to
381+
# trick mypy into thinking that SafeProxy is a subclass of
382+
# MetadataSourcePlugin.
383+
# https://stackoverflow.com/questions/71365594/how-to-make-a-proxy-object-with-typing-as-underlying-object-in-python
384+
Proxied = TypeVar("Proxied", bound=MetadataSourcePlugin)
385+
if TYPE_CHECKING:
386+
base = MetadataSourcePlugin
387+
else:
388+
base = object
389+
390+
391+
class SafeProxy(base):
392+
"""A proxy class that forwards all attribute access to the wrapped
393+
MetadataSourcePlugin instance.
394+
395+
We use this to catch and log exceptions from metadata source plugins
396+
without crashing beets. E.g. on long running autotag operations.
397+
"""
398+
399+
_plugin: MetadataSourcePlugin
400+
_SAFE_METHODS: ClassVar[set[str]] = {
401+
"candidates",
402+
"item_candidates",
403+
"album_for_id",
404+
"track_for_id",
405+
}
406+
407+
def __init__(self, plugin: MetadataSourcePlugin):
408+
self._plugin = plugin
409+
410+
def __getattribute__(self, name):
411+
if (
412+
name == "_plugin"
413+
or name == "_handle_exception"
414+
or name == "_SAFE_METHODS"
415+
or name == "_safe_execute"
416+
):
417+
return super().__getattribute__(name)
418+
419+
attr = getattr(self._plugin, name)
420+
421+
if callable(attr) and name in SafeProxy._SAFE_METHODS:
422+
return self._safe_execute(attr)
423+
return attr
424+
425+
def __setattr__(self, name, value):
426+
if name == "_plugin":
427+
super().__setattr__(name, value)
428+
else:
429+
self._plugin.__setattr__(name, value)
430+
431+
@overload
432+
def _safe_execute(
433+
self,
434+
func: Callable[P, Iterable[R]],
435+
) -> Callable[P, Iterable[R]]: ...
436+
@overload
437+
def _safe_execute(self, func: Callable[P, R]) -> Callable[P, R | None]: ...
438+
def _safe_execute(
439+
self, func: Callable[P, R]
440+
) -> Callable[P, R | Iterable[R] | None]:
441+
"""Wrap any function (generator or regular) and safely execute it.
442+
443+
Limitation: This does not work on properties!
444+
"""
445+
446+
@wraps(func)
447+
def wrapper(
448+
*args: P.args, **kwargs: P.kwargs
449+
) -> R | Iterable[R] | None:
450+
try:
451+
result = func(*args, **kwargs)
452+
except Exception as e:
453+
self._handle_exception(func, e)
454+
455+
return None
456+
457+
if inspect.isgenerator(result):
458+
try:
459+
yield from result
460+
except Exception as e:
461+
self._handle_exception(func, e)
462+
return None
463+
else:
464+
return result
465+
466+
return wrapper
467+
468+
def _handle_exception(self, func: Callable[P, R], e: Exception) -> None:
469+
"""Helper function to log exceptions from metadata source plugins."""
470+
if config["raise_on_error"].get(bool):
471+
raise e
472+
log.error(
473+
"Error in '{}.{}': {}",
474+
self._plugin.data_source,
475+
func.__name__,
476+
e,
477+
)
478+
log.debug("Exception details:", exc_info=True)
479+
480+
# Implement abstract methods to satisfy the ABC
481+
# this is only needed because of the typing hack above.
482+
483+
def album_for_id(self, album_id: str):
484+
raise NotImplementedError
485+
486+
def track_for_id(self, track_id: str):
487+
raise NotImplementedError
488+
489+
def candidates(self, *args, **kwargs):
490+
raise NotImplementedError
491+
492+
def item_candidates(self, *args, **kwargs):
493+
raise NotImplementedError

test/test_metadata_plugins.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from typing import Iterable
2+
3+
import pytest
4+
5+
from beets import metadata_plugins
6+
from beets.test.helper import PluginMixin
7+
8+
9+
class ErrorMetadataMockPlugin(metadata_plugins.MetadataSourcePlugin):
10+
"""A metadata source plugin that raises errors in all its methods."""
11+
12+
data_source = "ErrorMetadataMockPlugin"
13+
14+
def candidates(self, *args, **kwargs):
15+
raise ValueError("Mocked error")
16+
17+
def item_candidates(self, *args, **kwargs):
18+
for i in range(3):
19+
raise ValueError("Mocked error")
20+
yield # This is just to make this a generator
21+
22+
def album_for_id(self, *args, **kwargs):
23+
raise ValueError("Mocked error")
24+
25+
def track_for_id(self, *args, **kwargs):
26+
raise ValueError("Mocked error")
27+
28+
def track_distance(self, *args, **kwargs):
29+
raise ValueError("Mocked error")
30+
31+
def album_distance(self, *args, **kwargs):
32+
raise ValueError("Mocked error")
33+
34+
35+
class TestMetadataPluginsException(PluginMixin):
36+
"""Check that errors during the metadata plugins do not crash beets.
37+
They should be logged as errors instead.
38+
"""
39+
40+
@pytest.fixture(autouse=True)
41+
def setup(self):
42+
self.register_plugin(ErrorMetadataMockPlugin)
43+
yield
44+
self.unload_plugins()
45+
46+
@pytest.mark.parametrize(
47+
"method_name,args",
48+
[
49+
("candidates", ()),
50+
("item_candidates", ()),
51+
("album_for_id", ("some_id",)),
52+
("track_for_id", ("some_id",)),
53+
],
54+
)
55+
def test_logging(
56+
self,
57+
caplog,
58+
method_name,
59+
args,
60+
):
61+
self.config["raise_on_error"] = False
62+
with caplog.at_level("ERROR"):
63+
# Call the method to trigger the error
64+
ret = getattr(metadata_plugins, method_name)(*args)
65+
if isinstance(ret, Iterable):
66+
list(ret)
67+
68+
# Check that an error was logged
69+
assert len(caplog.records) >= 1
70+
logs = [record.getMessage() for record in caplog.records]
71+
for msg in logs:
72+
assert (
73+
msg
74+
== f"Error in 'ErrorMetadataMockPlugin.{method_name}': Mocked error"
75+
)
76+
77+
caplog.clear()
78+
79+
@pytest.mark.parametrize(
80+
"method_name,args",
81+
[
82+
("candidates", ()),
83+
("item_candidates", ()),
84+
("album_for_id", ("some_id",)),
85+
("track_for_id", ("some_id",)),
86+
],
87+
)
88+
def test_raising(
89+
self,
90+
method_name,
91+
args,
92+
):
93+
self.config["raise_on_error"] = True
94+
with pytest.raises(ValueError, match="Mocked error"):
95+
getattr(metadata_plugins, method_name)(*args) if not isinstance(
96+
args, Iterable
97+
) else list(getattr(metadata_plugins, method_name)(*args))

0 commit comments

Comments
 (0)