Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion draccus/choice_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,33 @@ def __init_subclass__(cls, discover_packages_path: Optional[str] = None, **kwarg
@classmethod
def get_choice_class(cls, name: str) -> Any:
cls._discover_packages()
return cls._choice_registry[name]
try:
return cls._choice_registry[name]
except KeyError:
if "." not in name:
raise

module_path, _, class_name = name.rpartition(".")
if not module_path:
# rpartition returns an empty string when no separator is found, so
# we shouldn't get here, but this keeps mypy happy and guards
# against malformed inputs.
raise KeyError(name)

try:
module = importlib.import_module(module_path)
except ModuleNotFoundError as e:
raise KeyError(name) from e

try:
choice_cls = getattr(module, class_name)
except AttributeError as e:
raise KeyError(name) from e

if not isinstance(choice_cls, type) or not issubclass(choice_cls, cls):
raise KeyError(name)

return choice_cls

@classmethod
def get_known_choices(cls) -> Dict[str, Any]:
Expand Down
22 changes: 22 additions & 0 deletions tests/test_plugin_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# SPDX-License-Identifier: MIT
# Copyright 2025 The Board of Trustees of the Leland Stanford Junior University

# SPDX-License-Identifier: MIT

import draccus


def test_plugin_registry_decode_fully_qualified_name():
from tests.draccus_choice_plugins.gpt import GptConfig
from tests.draccus_choice_plugins.model_config import ModelConfig

config = draccus.decode(
ModelConfig,
{
"type": "tests.draccus_choice_plugins.gpt.GptConfig",
"layers": 12,
"attn_pdrop": 0.2,
},
)

assert config == GptConfig(12, 0.2)