Skip to content

Commit a2a5703

Browse files
authored
Merge pull request #1 from wagtail/fix/llmservice
Pass model instead of model_id to AnyLLM
2 parents 55042d5 + 738d413 commit a2a5703

File tree

3 files changed

+9
-14
lines changed

3 files changed

+9
-14
lines changed

src/django_ai_core/llm/base.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,10 @@ def service_id(self) -> str:
2121
return f"{self.__class__.__name__}:{self.client.PROVIDER_NAME}:{self.model}"
2222

2323
def completion(self, messages, **kwargs):
24-
return self.client.completion(model_id=self.model, messages=messages, **kwargs)
24+
return self.client.completion(model=self.model, messages=messages, **kwargs)
2525

2626
def responses(self, input_data, **kwargs):
27-
return self.client.responses(
28-
model_id=self.model, input_data=input_data, **kwargs
29-
)
27+
return self.client.responses(model=self.model, input_data=input_data, **kwargs)
3028

3129
def embedding(self, inputs, **kwargs):
32-
return self.client._embedding(model_id=self.model, inputs=inputs, **kwargs)
30+
return self.client._embedding(model=self.model, inputs=inputs, **kwargs)

tests/testapp/indexes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515

1616

1717
class MockAnyLLM(AnyLLM):
18-
def completion(self, *, model_id, messages):
18+
def completion(self, *, model, messages):
1919
return "completion"
2020

21-
def responses(self, *, model_id, input_data):
21+
def responses(self, *, model, input_data):
2222
return "responses"
2323

24-
def _embedding(self, *, model_id, inputs):
24+
def _embedding(self, *, model, inputs):
2525
return [0, 1, 2]
2626

2727

tests/unit/llm/test_llm_service.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@ def test_llm_service_completion_wraps_anyllm(mock_any_llm):
2525
]
2626
service = LLMService(client=mock_any_llm, model="mock-model")
2727
service.completion(messages)
28-
print(mock_any_llm.completion.call_args_list)
2928
mock_any_llm.completion.assert_called_once_with(
30-
model_id="mock-model", messages=messages
29+
model="mock-model", messages=messages
3130
)
3231

3332

@@ -36,14 +35,12 @@ def test_llm_service_responses_wraps_anyllm(mock_any_llm):
3635
service = LLMService(client=mock_any_llm, model="mock-model")
3736
service.responses(prompt)
3837
mock_any_llm.responses.assert_called_once_with(
39-
model_id="mock-model", input_data=prompt
38+
model="mock-model", input_data=prompt
4039
)
4140

4241

4342
def test_llm_service_embedding_wraps_anyllm(mock_any_llm):
4443
prompt = "What is the airspeed velocity of an unladen swallow?"
4544
service = LLMService(client=mock_any_llm, model="mock-model")
4645
service.embedding(prompt)
47-
mock_any_llm._embedding.assert_called_once_with(
48-
model_id="mock-model", inputs=prompt
49-
)
46+
mock_any_llm._embedding.assert_called_once_with(model="mock-model", inputs=prompt)

0 commit comments

Comments
 (0)