Skip to content

Commit 53950f8

Browse files
Fix/model provider key injection check (#1799)
* Check available models for type validation * Semver * Fix ruff and pyright * Apply feedback
1 parent e39d869 commit 53950f8

File tree

8 files changed

+89
-30
lines changed

8 files changed

+89
-30
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Add check for custom model tyoes while config loading"
4+
}

graphrag/config/models/language_model_config.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
AzureDeploymentNameMissingError,
1616
ConflictingSettingsError,
1717
)
18+
from graphrag.language_model.factory import ModelFactory
1819

1920

2021
class LanguageModelConfig(BaseModel):
@@ -44,7 +45,7 @@ def _validate_api_key(self) -> None:
4445
self.api_key is None or self.api_key.strip() == ""
4546
):
4647
raise ApiKeyMissingError(
47-
self.type.value,
48+
self.type,
4849
self.auth_type.value,
4950
)
5051

@@ -73,10 +74,24 @@ def _validate_auth_type(self) -> None:
7374
if self.auth_type == AuthType.AzureManagedIdentity and (
7475
self.type == ModelType.OpenAIChat or self.type == ModelType.OpenAIEmbedding
7576
):
76-
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type.value}. Please rerun `graphrag init` and set the auth_type to api_key."
77+
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type}. Please rerun `graphrag init` and set the auth_type to api_key."
7778
raise ConflictingSettingsError(msg)
7879

79-
type: ModelType = Field(description="The type of LLM model to use.")
80+
type: ModelType | str = Field(description="The type of LLM model to use.")
81+
82+
def _validate_type(self) -> None:
83+
"""Validate the model type.
84+
85+
Raises
86+
------
87+
KeyError
88+
If the model name is not recognized.
89+
"""
90+
# Type should be contained by the registered models
91+
if not ModelFactory.is_supported_model(self.type):
92+
msg = f"Model type {self.type} is not recognized, must be one of {ModelFactory.get_chat_models() + ModelFactory.get_embedding_models()}."
93+
raise KeyError(msg)
94+
8095
model: str = Field(description="The LLM model to use.")
8196
encoding_model: str = Field(
8297
description="The encoding model to use",
@@ -141,7 +156,7 @@ def _validate_api_base(self) -> None:
141156
self.type == ModelType.AzureOpenAIChat
142157
or self.type == ModelType.AzureOpenAIEmbedding
143158
) and (self.api_base is None or self.api_base.strip() == ""):
144-
raise AzureApiBaseMissingError(self.type.value)
159+
raise AzureApiBaseMissingError(self.type)
145160

146161
api_version: str | None = Field(
147162
description="The version of the LLM API to use.",
@@ -162,7 +177,7 @@ def _validate_api_version(self) -> None:
162177
self.type == ModelType.AzureOpenAIChat
163178
or self.type == ModelType.AzureOpenAIEmbedding
164179
) and (self.api_version is None or self.api_version.strip() == ""):
165-
raise AzureApiVersionMissingError(self.type.value)
180+
raise AzureApiVersionMissingError(self.type)
166181

167182
deployment_name: str | None = Field(
168183
description="The deployment name to use for the LLM service.",
@@ -183,7 +198,7 @@ def _validate_deployment_name(self) -> None:
183198
self.type == ModelType.AzureOpenAIChat
184199
or self.type == ModelType.AzureOpenAIEmbedding
185200
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
186-
raise AzureDeploymentNameMissingError(self.type.value)
201+
raise AzureDeploymentNameMissingError(self.type)
187202

188203
organization: str | None = Field(
189204
description="The organization to use for the LLM service.",
@@ -251,6 +266,7 @@ def _validate_azure_settings(self) -> None:
251266

252267
@model_validator(mode="after")
253268
def _validate_model(self):
269+
self._validate_type()
254270
self._validate_auth_type()
255271
self._validate_api_key()
256272
self._validate_azure_settings()

graphrag/language_model/factory.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,33 @@ def create_embedding_model(cls, model_type: str, **kwargs: Any) -> EmbeddingMode
7070
raise ValueError(msg)
7171
return cls._embedding_registry[model_type](**kwargs)
7272

73+
@classmethod
74+
def get_chat_models(cls) -> list[str]:
75+
"""Get the registered ChatModel implementations."""
76+
return list(cls._chat_registry.keys())
77+
78+
@classmethod
79+
def get_embedding_models(cls) -> list[str]:
80+
"""Get the registered EmbeddingModel implementations."""
81+
return list(cls._embedding_registry.keys())
82+
83+
@classmethod
84+
def is_supported_chat_model(cls, model_type: str) -> bool:
85+
"""Check if the given model type is supported."""
86+
return model_type in cls._chat_registry
87+
88+
@classmethod
89+
def is_supported_embedding_model(cls, model_type: str) -> bool:
90+
"""Check if the given model type is supported."""
91+
return model_type in cls._embedding_registry
92+
93+
@classmethod
94+
def is_supported_model(cls, model_type: str) -> bool:
95+
"""Check if the given model type is supported."""
96+
return cls.is_supported_chat_model(
97+
model_type
98+
) or cls.is_supported_embedding_model(model_type)
99+
73100

74101
# --- Register default implementations ---
75102
ModelFactory.register_chat(

graphrag/language_model/protocol/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import TYPE_CHECKING, Any, Protocol
99

1010
if TYPE_CHECKING:
11-
from collections.abc import AsyncGenerator
11+
from collections.abc import AsyncGenerator, Generator
1212

1313
from graphrag.language_model.response.base import ModelResponse
1414

@@ -143,7 +143,7 @@ def chat(
143143

144144
def chat_stream(
145145
self, prompt: str, history: list | None = None, **kwargs: Any
146-
) -> AsyncGenerator[str, None]:
146+
) -> Generator[str, None]:
147147
"""
148148
Generate a response for the given text using a streaming interface.
149149

graphrag/language_model/providers/fnllm/models.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,16 @@
33

44
"""A module containing fnllm model provider definitions."""
55

6-
from collections.abc import AsyncGenerator
6+
from __future__ import annotations
7+
8+
from typing import TYPE_CHECKING
79

810
from fnllm.openai import (
911
create_openai_chat_llm,
1012
create_openai_client,
1113
create_openai_embeddings_llm,
1214
)
13-
from fnllm.openai.types.client import OpenAIChatLLM as FNLLMChatLLM
14-
from fnllm.openai.types.client import OpenAIEmbeddingsLLM as FNLLMEmbeddingLLM
1515

16-
from graphrag.cache.pipeline_cache import PipelineCache
17-
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
18-
from graphrag.config.models.language_model_config import (
19-
LanguageModelConfig,
20-
)
2116
from graphrag.language_model.providers.fnllm.events import FNLLMEvents
2217
from graphrag.language_model.providers.fnllm.utils import (
2318
_create_cache,
@@ -31,6 +26,18 @@
3126
ModelResponse,
3227
)
3328

29+
if TYPE_CHECKING:
30+
from collections.abc import AsyncGenerator, Generator
31+
32+
from fnllm.openai.types.client import OpenAIChatLLM as FNLLMChatLLM
33+
from fnllm.openai.types.client import OpenAIEmbeddingsLLM as FNLLMEmbeddingLLM
34+
35+
from graphrag.cache.pipeline_cache import PipelineCache
36+
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
37+
from graphrag.config.models.language_model_config import (
38+
LanguageModelConfig,
39+
)
40+
3441

3542
class OpenAIChatFNLLM:
3643
"""An OpenAI Chat Model provider using the fnllm library."""
@@ -121,7 +128,7 @@ def chat(self, prompt: str, history: list | None = None, **kwargs) -> ModelRespo
121128

122129
def chat_stream(
123130
self, prompt: str, history: list | None = None, **kwargs
124-
) -> AsyncGenerator[str, None]:
131+
) -> Generator[str, None]:
125132
"""
126133
Stream Chat with the Model using the given prompt.
127134
@@ -319,7 +326,7 @@ def chat(self, prompt: str, history: list | None = None, **kwargs) -> ModelRespo
319326

320327
def chat_stream(
321328
self, prompt: str, history: list | None = None, **kwargs
322-
) -> AsyncGenerator[str, None]:
329+
) -> Generator[str, None]:
323330
"""
324331
Stream Chat with the Model using the given prompt.
325332

graphrag/language_model/providers/fnllm/utils.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,29 @@
33

44
"""A module containing utils for fnllm."""
55

6+
from __future__ import annotations
7+
68
import asyncio
79
import threading
8-
from collections.abc import Coroutine
9-
from typing import Any, TypeVar
10+
from typing import TYPE_CHECKING, Any, TypeVar
1011

1112
from fnllm.base.config import JsonStrategy, RetryStrategy
1213
from fnllm.openai import AzureOpenAIConfig, OpenAIConfig, PublicOpenAIConfig
1314
from fnllm.openai.types.chat.parameters import OpenAIChatParameters
1415

1516
import graphrag.config.defaults as defs
16-
from graphrag.cache.pipeline_cache import PipelineCache
17-
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
18-
from graphrag.config.models.language_model_config import (
19-
LanguageModelConfig,
20-
)
21-
from graphrag.index.typing.error_handler import ErrorHandlerFn
2217
from graphrag.language_model.providers.fnllm.cache import FNLLMCacheProvider
2318

19+
if TYPE_CHECKING:
20+
from collections.abc import Coroutine
21+
22+
from graphrag.cache.pipeline_cache import PipelineCache
23+
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
24+
from graphrag.config.models.language_model_config import (
25+
LanguageModelConfig,
26+
)
27+
from graphrag.index.typing.error_handler import ErrorHandlerFn
28+
2429

2530
def _create_cache(cache: PipelineCache | None, name: str) -> FNLLMCacheProvider | None:
2631
"""Create an FNLLM cache from a pipeline cache."""

tests/integration/language_model/test_factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
These tests will test the LLMFactory class and the creation of custom and provided LLMs.
77
"""
88

9-
from collections.abc import AsyncGenerator
9+
from collections.abc import AsyncGenerator, Generator
1010
from typing import Any
1111

1212
from graphrag.language_model.factory import ModelFactory
@@ -40,7 +40,7 @@ async def achat_stream(
4040

4141
def chat_stream(
4242
self, prompt: str, history: list | None = None, **kwargs: Any
43-
) -> AsyncGenerator[str, None]: ...
43+
) -> Generator[str, None]: ...
4444

4545
ModelFactory.register_chat("custom_chat", CustomChatModel)
4646
model = ModelManager().get_or_create_chat_model("custom", "custom_chat")

tests/mock_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
"""A module containing mock model provider definitions."""
55

6-
from collections.abc import AsyncGenerator
6+
from collections.abc import AsyncGenerator, Generator
77
from typing import Any
88

99
from pydantic import BaseModel
@@ -85,7 +85,7 @@ def chat_stream(
8585
prompt: str,
8686
history: list | None = None,
8787
**kwargs,
88-
) -> AsyncGenerator[str, None]:
88+
) -> Generator[str, None]:
8989
"""Return the next response in the list."""
9090
raise NotImplementedError
9191

0 commit comments

Comments
 (0)