|
1 | 1 | import inspect |
2 | 2 | from abc import ABC, abstractmethod |
3 | 3 | from dataclasses import dataclass |
4 | | -from typing import Annotated, get_args, get_origin |
| 4 | +from typing import Annotated, Callable, TypeVar, get_args, get_origin |
5 | 5 |
|
6 | 6 | from django.core.exceptions import ValidationError |
7 | 7 | from django.core.validators import validate_slug |
8 | 8 |
|
9 | 9 | from .permissions import BasePermission |
10 | 10 | from .views import AgentExecutionView |
11 | 11 |
|
| 12 | +AgentT = TypeVar("AgentT", bound="Agent") |
| 13 | + |
12 | 14 |
|
13 | 15 | @dataclass |
14 | 16 | class AgentParameter: |
@@ -79,15 +81,20 @@ class AgentRegistry: |
79 | 81 | def __init__(self): |
80 | 82 | self._agents: dict[str, type[Agent]] = {} |
81 | 83 |
|
82 | | - def register(self): |
83 | | - """Decorator to register an agent.""" |
84 | | - |
85 | | - def decorator(cls: type[Agent]) -> type[Agent]: |
86 | | - agent_slug = cls.slug |
87 | | - self._agents[agent_slug] = cls |
88 | | - return cls |
89 | | - |
90 | | - return decorator |
| 84 | + def register( |
| 85 | + self, cls: type[AgentT] | None = None |
| 86 | + ) -> type[AgentT] | Callable[[type[AgentT]], type[AgentT]]: |
| 87 | + def decorator(agent_cls: type[AgentT]) -> type[AgentT]: |
| 88 | + agent_slug = agent_cls.slug |
| 89 | + self._agents[agent_slug] = agent_cls |
| 90 | + return agent_cls |
| 91 | + |
| 92 | + if cls is None: |
| 93 | + # Called with parentheses: @registry.register() |
| 94 | + return decorator |
| 95 | + else: |
| 96 | + # Called without parentheses: @registry.register |
| 97 | + return decorator(cls) |
91 | 98 |
|
92 | 99 | def get(self, slug: str) -> type[Agent]: |
93 | 100 | if slug not in self._agents: |
|
0 commit comments