Skip to content

Commit 62ef560

Browse files
committed
Derive agent parameters from types
1 parent a2a5703 commit 62ef560

File tree

6 files changed

+231
-55
lines changed

6 files changed

+231
-55
lines changed

docs/modules/core.md

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,19 @@ It uses the [`any-llm`](https://mozilla-ai.github.io/any-llm/) package to provid
1818

1919
To use the `LLMService`:
2020

21-
````python
21+
```python
2222

2323
from django_ai_core.llm import LLMService
2424

2525
service = LLMService.create(
2626
provider="openai",
2727
model="gpt-4o"
28-
)```
28+
)
29+
```
2930

3031
You can also alternatively instantiate `LLMService` with your own client instance:
3132

32-
````
33-
33+
```python
3434
from any_llm import AnyLLM
3535

3636
client = AnyLLM.create(
@@ -39,14 +39,22 @@ model="gpt-4o"
3939
)
4040

4141
service = LLMService(client=client)
42-
4342
```
4443

4544
# Completions
46-
response = service.completion("What is the airspeed velocity of an unladen swallow?")
45+
46+
```python
47+
response = service.completion(
48+
"What is the airspeed velocity of an unladen swallow?"
49+
)
50+
```
4751

4852
# Embeddings
49-
response = service.embedding("What's the speed on that bird when it's not hauling stuff?")
53+
54+
```python
55+
response = service.embedding(
56+
"What's the speed on that bird when it's not hauling stuff?"
57+
)
5058
```
5159

5260
All keyword arguments are passed to the underlying `any-llm` [`completion`](https://mozilla-ai.github.io/any-llm/api/completion/) and [`embedding`](https://mozilla-ai.github.io/any-llm/api/embedding/) APIs.

src/django_ai_core/contrib/agents/base.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import inspect
12
from abc import ABC, abstractmethod
23
from dataclasses import dataclass
4+
from typing import get_args, get_origin, Annotated
35

46
from django.core.validators import validate_slug
57
from django.core.exceptions import ValidationError
@@ -11,13 +13,39 @@ class AgentParameter:
1113
type: type
1214
description: str
1315

16+
def as_dict(self):
17+
return {
18+
"name": self.name,
19+
"type": self.type.__name__,
20+
"description": self.description,
21+
}
22+
1423

1524
class Agent(ABC):
1625
"""Base class for agents."""
1726

1827
slug: str
1928
description: str
20-
parameters: list[AgentParameter]
29+
parameters: list[AgentParameter] | None
30+
31+
@classmethod
32+
def _derive_parameters_from_signature(cls) -> list[AgentParameter]:
33+
"""Derive parameters from `execute` type signature"""
34+
parameters = []
35+
annotations = inspect.get_annotations(cls.execute)
36+
for name, annotation in annotations.items():
37+
if name == "return":
38+
continue
39+
description: str = ""
40+
base_type = annotation
41+
if get_origin(annotation) is Annotated:
42+
base_type, *metadata = get_args(annotation)
43+
if metadata and isinstance(metadata[0], str):
44+
description = metadata[0]
45+
parameters.append(
46+
AgentParameter(name=name, type=base_type, description=description)
47+
)
48+
return parameters
2149

2250
@abstractmethod
2351
def execute(self, *args, **kwargs) -> str:
@@ -26,6 +54,9 @@ def execute(self, *args, **kwargs) -> str:
2654
def __init_subclass__(cls, **kwargs):
2755
super().__init_subclass__(**kwargs)
2856

57+
if "parameters" not in cls.__dict__:
58+
cls.parameters = cls._derive_parameters_from_signature()
59+
2960
if hasattr(cls, "slug"):
3061
try:
3162
validate_slug(cls.slug)
@@ -39,11 +70,11 @@ class AgentRegistry:
3970
def __init__(self):
4071
self._agents: dict[str, type[Agent]] = {}
4172

42-
def register(self, slug: str | None = None):
73+
def register(self):
4374
"""Decorator to register an agent."""
4475

4576
def decorator(cls: type[Agent]) -> type[Agent]:
46-
agent_slug = slug or getattr(cls, "slug", cls.__name__.lower())
77+
agent_slug = getattr(cls, "slug")
4778
self._agents[agent_slug] = cls
4879
return cls
4980

src/django_ai_core/contrib/agents/views.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Type
1+
from typing import Any
22
import json
33

44
from django.http import JsonResponse
@@ -9,10 +9,38 @@
99
from . import registry, Agent
1010

1111

12+
class AgentExecutionException(Exception):
13+
pass
14+
15+
16+
class AgentNotFound(AgentExecutionException):
17+
code = "agent_not_found"
18+
19+
1220
@method_decorator(csrf_exempt, name="dispatch")
1321
class AgentExecutionView(View):
1422
agent_slug: str = ""
1523

24+
def get(self, request):
25+
try:
26+
agent = self._get_agent()
27+
except AgentNotFound as e:
28+
return JsonResponse(
29+
{
30+
"error": f"Agent not found: {self.agent_slug}",
31+
"code": e.code,
32+
},
33+
status=404,
34+
)
35+
36+
return JsonResponse(
37+
{
38+
"slug": agent.slug,
39+
"description": agent.description,
40+
"parameters": [param.as_dict() for param in agent.parameters or []],
41+
}
42+
)
43+
1644
def post(self, request):
1745
"""
1846
Execute the agent with provided input data.
@@ -35,12 +63,12 @@ def post(self, request):
3563
)
3664

3765
try:
38-
agent = registry.get(self.agent_slug)
39-
except KeyError:
66+
agent = self._get_agent()
67+
except AgentNotFound as e:
4068
return JsonResponse(
4169
{
4270
"error": f"Agent not found: {self.agent_slug}",
43-
"code": "agent_not_found",
71+
"code": e.code,
4472
},
4573
status=404,
4674
)
@@ -49,7 +77,11 @@ def post(self, request):
4977

5078
return JsonResponse({"status": "completed", "data": result})
5179

52-
def _execute_agent(self, agent_cls: Type[Agent], arguments: dict[str, Any]) -> Any:
53-
agent = agent_cls()
80+
def _get_agent(self) -> Agent:
81+
try:
82+
return registry.get(self.agent_slug)()
83+
except KeyError:
84+
raise AgentNotFound
5485

86+
def _execute_agent(self, agent: Agent, arguments: dict[str, Any]) -> Any:
5587
return agent.execute(**arguments)

tests/testapp/agents.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1-
from django_ai_core.contrib.agents import Agent, AgentParameter, registry
1+
from django_ai_core.contrib.agents import Agent, registry
2+
from typing import Annotated
23

34

45
@registry.register()
56
class BasicAgent(Agent):
6-
name = "basic"
7+
slug = "basic"
78
description = "Basic agent that just takes a prompt and returns a response."
8-
parameters = [
9-
AgentParameter(
10-
name="prompt",
11-
type=str,
12-
description="The prompt to use for the agent",
13-
),
14-
]
159

16-
def execute(self, *, prompt: str):
10+
def execute(self, *, prompt: Annotated[str, "The prompt to use for the agent"]):
1711
return prompt
12+
13+
14+
@registry.register()
15+
class StubAgent(Agent):
16+
slug = "stub"
17+
description = "Basic agent that just takes a prompt and returns a response."
18+
19+
def execute(self):
20+
return ""

tests/unit/contrib/agents/test_agent_base.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from typing import Annotated
23

34
from django_ai_core.contrib.agents.base import Agent, AgentParameter
45

@@ -29,6 +30,7 @@ def test_agent_initialization():
2930

3031
assert agent.slug == "test-agent"
3132
assert agent.description == "Test agent for testing"
33+
assert agent.parameters
3234
assert len(agent.parameters) == 2
3335

3436
# Check parameter definitions
@@ -86,3 +88,115 @@ def test_agent_parameter_dataclass():
8688
assert param.name == "test"
8789
assert param.type is bool
8890
assert param.description == "A test parameter"
91+
92+
93+
def test_agent_parameter_as_dict():
94+
"""Test AgentParameter.as_dict() method."""
95+
param = AgentParameter(
96+
name="test_param",
97+
type=str,
98+
description="A test parameter",
99+
)
100+
101+
result = param.as_dict()
102+
103+
assert result == {
104+
"name": "test_param",
105+
"type": "str",
106+
"description": "A test parameter",
107+
}
108+
109+
110+
def test_derive_parameters_from_signature():
111+
"""Test _derive_parameters_from_signature method."""
112+
113+
class TestAgentWithAnnotations(Agent):
114+
slug = "test-annotated"
115+
description = "Test agent with type annotations"
116+
117+
def execute(
118+
self,
119+
*,
120+
name: Annotated[str, "The name parameter"],
121+
count: Annotated[int, "The count parameter"],
122+
):
123+
return f"Hello {name}, count is {count}"
124+
125+
agent = TestAgentWithAnnotations()
126+
127+
assert agent.parameters
128+
assert len(agent.parameters) == 2
129+
130+
name_param = agent.parameters[0]
131+
assert name_param.name == "name"
132+
assert name_param.type is str
133+
assert name_param.description == "The name parameter"
134+
135+
count_param = agent.parameters[1]
136+
assert count_param.name == "count"
137+
assert count_param.type is int
138+
assert count_param.description == "The count parameter"
139+
140+
141+
def test_derive_parameters_without_descriptions():
142+
"""Test parameter derivation without Annotated descriptions."""
143+
144+
class TestAgentAnnotatedWithoutDescriptions(Agent):
145+
slug = "test-no-desc"
146+
description = "Test agent without descriptions"
147+
148+
def execute(self, *, value: str, flag: bool):
149+
return "result"
150+
151+
agent = TestAgentAnnotatedWithoutDescriptions()
152+
153+
assert agent.parameters
154+
assert len(agent.parameters) == 2
155+
156+
value_param = agent.parameters[0]
157+
assert value_param.name == "value"
158+
assert value_param.type is str
159+
assert value_param.description == ""
160+
161+
flag_param = agent.parameters[1]
162+
assert flag_param.name == "flag"
163+
assert flag_param.type is bool
164+
assert flag_param.description == ""
165+
166+
167+
def test_explicit_parameters_override_derived():
168+
"""Test that explicit parameters take precedence over derived ones."""
169+
170+
class AgentWithExplicitParams(Agent):
171+
slug = "test-explicit"
172+
description = "Test agent with explicit parameters"
173+
parameters = [
174+
AgentParameter(
175+
name="custom",
176+
type=str,
177+
description="Custom parameter",
178+
),
179+
]
180+
181+
def execute(self, *, name: Annotated[str, "Should be ignored"]):
182+
return "result"
183+
184+
agent = AgentWithExplicitParams()
185+
186+
assert agent.parameters
187+
assert len(agent.parameters) == 1
188+
assert agent.parameters[0].name == "custom"
189+
assert agent.parameters[0].description == "Custom parameter"
190+
191+
192+
def test_agent_without_parameter_schema():
193+
class AgentWithNoParams(Agent):
194+
slug = "test-explicit"
195+
description = "Test agent with explicit parameters"
196+
197+
def execute(self, *, name):
198+
return "result"
199+
200+
agent = AgentWithNoParams()
201+
202+
assert not agent.parameters

0 commit comments

Comments
 (0)