diff --git a/draccus/choice_types.py b/draccus/choice_types.py index 091031d..e1f89e5 100644 --- a/draccus/choice_types.py +++ b/draccus/choice_types.py @@ -182,7 +182,33 @@ def __init_subclass__(cls, discover_packages_path: Optional[str] = None, **kwarg @classmethod def get_choice_class(cls, name: str) -> Any: cls._discover_packages() - return cls._choice_registry[name] + try: + return cls._choice_registry[name] + except KeyError: + if "." not in name: + raise + + module_path, _, class_name = name.rpartition(".") + if not module_path: + # rpartition returns an empty string when no separator is found, so + # we shouldn't get here, but this keeps mypy happy and guards + # against malformed inputs. + raise KeyError(name) + + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError as e: + raise KeyError(name) from e + + try: + choice_cls = getattr(module, class_name) + except AttributeError as e: + raise KeyError(name) from e + + if not isinstance(choice_cls, type) or not issubclass(choice_cls, cls): + raise KeyError(name) + + return choice_cls @classmethod def get_known_choices(cls) -> Dict[str, Any]: diff --git a/tests/test_plugin_registry.py b/tests/test_plugin_registry.py new file mode 100644 index 0000000..997c763 --- /dev/null +++ b/tests/test_plugin_registry.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: MIT +# Copyright 2025 The Board of Trustees of the Leland Stanford Junior University + +# SPDX-License-Identifier: MIT + +import draccus + + +def test_plugin_registry_decode_fully_qualified_name(): + from tests.draccus_choice_plugins.gpt import GptConfig + from tests.draccus_choice_plugins.model_config import ModelConfig + + config = draccus.decode( + ModelConfig, + { + "type": "tests.draccus_choice_plugins.gpt.GptConfig", + "layers": 12, + "attn_pdrop": 0.2, + }, + ) + + assert config == GptConfig(12, 0.2)