Skip to content

Commit d38d55d

Browse files
authored
Merge pull request #2 from wagtail/fix/registry-register-type
Improve DX for `AgentRegistry.register()`
2 parents f53fe1d + 6fee359 commit d38d55d

File tree

3 files changed

+683
-666
lines changed

3 files changed

+683
-666
lines changed

src/django_ai_core/contrib/agents/base.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import inspect
22
from abc import ABC, abstractmethod
33
from dataclasses import dataclass
4-
from typing import Annotated, get_args, get_origin
4+
from typing import Annotated, Callable, TypeVar, get_args, get_origin
55

66
from django.core.exceptions import ValidationError
77
from django.core.validators import validate_slug
88

99
from .permissions import BasePermission
1010
from .views import AgentExecutionView
1111

12+
AgentT = TypeVar("AgentT", bound="Agent")
13+
1214

1315
@dataclass
1416
class AgentParameter:
@@ -79,15 +81,20 @@ class AgentRegistry:
7981
def __init__(self):
8082
self._agents: dict[str, type[Agent]] = {}
8183

82-
def register(self):
83-
"""Decorator to register an agent."""
84-
85-
def decorator(cls: type[Agent]) -> type[Agent]:
86-
agent_slug = cls.slug
87-
self._agents[agent_slug] = cls
88-
return cls
89-
90-
return decorator
84+
def register(
85+
self, cls: type[AgentT] | None = None
86+
) -> type[AgentT] | Callable[[type[AgentT]], type[AgentT]]:
87+
def decorator(agent_cls: type[AgentT]) -> type[AgentT]:
88+
agent_slug = agent_cls.slug
89+
self._agents[agent_slug] = agent_cls
90+
return agent_cls
91+
92+
if cls is None:
93+
# Called with parentheses: @registry.register()
94+
return decorator
95+
else:
96+
# Called without parentheses: @registry.register
97+
return decorator(cls)
9198

9299
def get(self, slug: str) -> type[Agent]:
93100
if slug not in self._agents:

tests/unit/contrib/agents/test_agent_registry.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,16 @@ def test_agent_registry_register_uses_agent_slug():
3131
assert registry._agents["test-one"] is TestAgentOne
3232

3333

34+
def test_agent_registry_register_without_parentheses():
35+
"""Test that register can be called without parentheses."""
36+
registry = AgentRegistry()
37+
38+
decorated = registry.register(TestAgentOne)
39+
assert decorated is TestAgentOne
40+
assert "test-one" in registry._agents
41+
assert registry._agents["test-one"] is TestAgentOne
42+
43+
3444
def test_agent_registry_register_multiple_agents():
3545
"""Test registering multiple agents with the registry."""
3646
registry = AgentRegistry()

0 commit comments

Comments
 (0)