Skip to content

Commit 4ece0b5

Browse files
authored
chore(langchain-ibm): disallow Any generics (#147)
1 parent d90697f commit 4ece0b5

24 files changed

+95
-77
lines changed

libs/ibm/langchain_ibm/agent_toolkits/sql/tool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
136136
"""Use an LLM to check if a query is correct."""
137137

138138
template: str = QUERY_CHECKER
139-
llm: BaseLanguageModel
139+
llm: BaseLanguageModel[Any]
140140
llm_chain: Any = Field(init=False)
141141
name: str = "sql_db_query_checker"
142142
description: str = """
@@ -154,7 +154,7 @@ def initialize_llm_chain(cls, values: dict[str, Any]) -> Any:
154154
template=QUERY_CHECKER,
155155
input_variables=["query", "schema"],
156156
)
157-
llm = cast("BaseLanguageModel", values.get("llm"))
157+
llm = cast("BaseLanguageModel[Any]", values.get("llm"))
158158

159159
values["llm_chain"] = prompt | llm
160160

libs/ibm/langchain_ibm/agent_toolkits/sql/toolkit.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""IBM watsonx.ai SQL Toolkit wrapper."""
22

3+
from typing import Any
4+
35
from langchain_core.language_models import BaseLanguageModel
46
from langchain_core.tools import BaseTool
57
from langchain_core.tools.base import BaseToolkit
@@ -20,7 +22,7 @@ class WatsonxSQLDatabaseToolkit(BaseToolkit):
2022
db: WatsonxSQLDatabase = Field(exclude=True)
2123
"""Instance of the watsonx SQL database."""
2224

23-
llm: BaseLanguageModel = Field(exclude=True)
25+
llm: BaseLanguageModel[Any] = Field(exclude=True)
2426
"""Instance of the LLM."""
2527

2628
model_config = ConfigDict(
@@ -70,7 +72,7 @@ def get_tools(self) -> list[BaseTool]:
7072
query_sql_checker_tool,
7173
]
7274

73-
def get_context(self) -> dict:
75+
def get_context(self) -> dict[str, Any]:
7476
"""Return db context that you may want in agent prompt."""
7577
return self.db.get_context()
7678

libs/ibm/langchain_ibm/agent_toolkits/utility/toolkit.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,18 @@ class WatsonxTool(BaseTool):
4343
"""The precise instruction to agent LLMs
4444
and should be treated as part of the system prompt."""
4545

46-
tool_input_schema: dict | None = None
46+
tool_input_schema: dict[str, Any] | None = None
4747
"""Schema of the input that is provided when running the tool if applicable."""
4848

49-
tool_config_schema: dict | None = None
49+
tool_config_schema: dict[str, Any] | None = None
5050
"""Schema of the config that can be provided when running the tool if applicable."""
5151

52-
tool_config: dict | None = None
52+
tool_config: dict[str, Any] | None = None
5353
"""Config properties to be used when running a tool if applicable."""
5454

5555
args_schema: type[BaseModel] = BaseModel
5656

57-
_watsonx_tool: Tool | None = PrivateAttr(default=None) #: :meta private:
57+
_watsonx_tool: Tool #: :meta private:
5858

5959
watsonx_client: APIClient = Field(exclude=True)
6060

@@ -84,7 +84,7 @@ def _run(
8484
*args: Any,
8585
run_manager: CallbackManagerForToolRun | None = None,
8686
**kwargs: Any,
87-
) -> dict:
87+
) -> Any:
8888
"""Run the tool."""
8989
if self.tool_input_schema is None:
9090
input_data = kwargs.get("input") or args[0]
@@ -95,9 +95,9 @@ def _run(
9595
if k in self.tool_input_schema["properties"]
9696
}
9797

98-
return cast("dict", self._watsonx_tool.run(input_data, self.tool_config)) # type: ignore[union-attr]
98+
return self._watsonx_tool.run(input_data, self.tool_config)
9999

100-
def set_tool_config(self, tool_config: dict) -> None:
100+
def set_tool_config(self, tool_config: dict[str, Any]) -> None:
101101
"""Set tool config properties.
102102
103103
???+ example "Example"

libs/ibm/langchain_ibm/agent_toolkits/utility/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""Utility helpers."""
22

33
from copy import deepcopy
4-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, Any
55

66
if TYPE_CHECKING:
77
from langchain_ibm.agent_toolkits.utility.toolkit import WatsonxTool
88

99

10-
def convert_to_watsonx_tool(tool: "WatsonxTool") -> dict:
10+
def convert_to_watsonx_tool(tool: "WatsonxTool") -> dict[str, Any]:
1111
"""Convert `WatsonxTool` to watsonx tool structure.
1212
1313
Args:
@@ -53,7 +53,7 @@ def convert_to_watsonx_tool(tool: "WatsonxTool") -> dict:
5353
```
5454
"""
5555

56-
def parse_parameters(input_schema: dict | None) -> dict:
56+
def parse_parameters(input_schema: dict[str, Any] | None) -> dict[str, Any]:
5757
if input_schema:
5858
parameters = deepcopy(input_schema)
5959
else:

libs/ibm/langchain_ibm/chat_models.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any], call_id: str) -> BaseMess
153153
return HumanMessage(content=_dict.get("content", ""), id=id_, name=name)
154154
if role == "assistant":
155155
content = _dict.get("content", "") or ""
156-
additional_kwargs: dict = {}
156+
additional_kwargs: dict[str, Any] = {}
157157
if function_call := _dict.get("function_call"):
158158
additional_kwargs["function_call"] = dict(function_call)
159159
tool_calls = []
@@ -242,7 +242,7 @@ def _format_message_content(content: Any) -> Any:
242242
return formatted_content
243243

244244

245-
def _convert_message_to_dict(message: BaseMessage) -> dict:
245+
def _convert_message_to_dict(message: BaseMessage) -> dict[str, Any]:
246246
"""Convert a LangChain message to a dictionary.
247247
248248
Args:
@@ -327,7 +327,7 @@ def _convert_delta_to_message_chunk(
327327
id_ = call_id
328328
role = cast("str", _dict.get("role"))
329329
content = cast("str", _dict.get("content") or "")
330-
additional_kwargs: dict = {}
330+
additional_kwargs: dict[str, Any] = {}
331331
if _dict.get("function_call"):
332332
function_call = dict(_dict["function_call"])
333333
if "name" in function_call and function_call["name"] is None:
@@ -379,7 +379,7 @@ def _convert_delta_to_message_chunk(
379379

380380

381381
def _convert_chunk_to_generation_chunk(
382-
chunk: dict,
382+
chunk: dict[str, Any],
383383
default_chunk_class: type,
384384
*,
385385
is_first_tool_chunk: bool,
@@ -939,7 +939,7 @@ class Joke(BaseModel):
939939
version: SecretStr | None = None
940940
"""Version of the CPD instance."""
941941

942-
params: dict | TextChatParameters | None = None
942+
params: dict[str, Any] | TextChatParameters | None = None
943943
"""Model parameters to use during request generation.
944944
945945
!!! note
@@ -988,7 +988,7 @@ class Joke(BaseModel):
988988
989989
We generally recommend altering this or top_p but not both."""
990990

991-
response_format: dict | None = None
991+
response_format: dict[str, Any] | None = None
992992
"""The chat response format parameters."""
993993

994994
top_p: float | None = None
@@ -1003,7 +1003,7 @@ class Joke(BaseModel):
10031003
"""Time limit in milliseconds - if not completed within this time,
10041004
generation will stop."""
10051005

1006-
logit_bias: dict | None = None
1006+
logit_bias: dict[str, int] | None = None
10071007
"""Increasing or decreasing probability of tokens being selected
10081008
during generation."""
10091009

@@ -1015,7 +1015,7 @@ class Joke(BaseModel):
10151015
"""Stop sequences are one or more strings which will cause the text generation
10161016
to stop if/when they are produced as part of the output."""
10171017

1018-
chat_template_kwargs: dict | None = None
1018+
chat_template_kwargs: dict[str, Any] | None = None
10191019
"""Additional chat template parameters."""
10201020

10211021
verify: str | bool | None = None
@@ -1187,7 +1187,9 @@ def validate_environment(self) -> Self:
11871187
return self
11881188

11891189
@gateway_error_handler
1190-
def _call_model_gateway(self, *, model: str, messages: list, **params: Any) -> Any:
1190+
def _call_model_gateway(
1191+
self, *, model: str, messages: list[dict[str, Any]], **params: Any
1192+
) -> Any:
11911193
return self.watsonx_model_gateway.chat.completions.create(
11921194
model=model,
11931195
messages=messages,
@@ -1199,7 +1201,7 @@ async def _acall_model_gateway(
11991201
self,
12001202
*,
12011203
model: str,
1202-
messages: list,
1204+
messages: list[dict[str, Any]],
12031205
**params: Any,
12041206
) -> Any:
12051207
return await self.watsonx_model_gateway.chat.completions.acreate(
@@ -1405,7 +1407,7 @@ async def _astream(
14051407
yield generation_chunk
14061408

14071409
@staticmethod
1408-
def _merge_params(params: dict, kwargs: dict) -> dict:
1410+
def _merge_params(params: dict[str, Any], kwargs: dict[str, Any]) -> dict[str, Any]:
14091411
param_updates = {}
14101412
for k in ChatWatsonx._get_supported_chat_params():
14111413
if kwargs.get(k) is not None:
@@ -1438,8 +1440,8 @@ def _create_message_dicts(
14381440

14391441
def _create_chat_result(
14401442
self,
1441-
response: dict,
1442-
generation_info: dict | None = None,
1443+
response: dict[str, Any],
1444+
generation_info: dict[str, Any] | None = None,
14431445
) -> ChatResult:
14441446
generations = []
14451447

@@ -1496,9 +1498,9 @@ def _get_supported_chat_params() -> list[str]:
14961498

14971499
def bind_tools(
14981500
self,
1499-
tools: Sequence[dict[str, Any] | type | Callable | BaseTool],
1501+
tools: Sequence[dict[str, Any] | type | Callable[..., Any] | BaseTool],
15001502
*,
1501-
tool_choice: dict | str | bool | None = None,
1503+
tool_choice: dict[str, Any] | str | bool | None = None,
15021504
strict: bool | None = None,
15031505
**kwargs: Any,
15041506
) -> Runnable[LanguageModelInput, AIMessage]:
@@ -1580,7 +1582,7 @@ def bind_tools(
15801582
@override
15811583
def with_structured_output(
15821584
self,
1583-
schema: dict | type | None = None,
1585+
schema: dict[str, Any] | type | None = None,
15841586
*,
15851587
method: Literal[
15861588
"function_calling",
@@ -1590,7 +1592,7 @@ def with_structured_output(
15901592
include_raw: bool = False,
15911593
strict: bool | None = None,
15921594
**kwargs: Any,
1593-
) -> Runnable[LanguageModelInput, dict | BaseModel]:
1595+
) -> Runnable[LanguageModelInput, dict[str, Any] | BaseModel]:
15941596
r"""Model wrapper that returns outputs formatted to match the given schema.
15951597
15961598
Args:
@@ -1947,7 +1949,7 @@ class AnswerWithJustification(BaseModel):
19471949
},
19481950
)
19491951
if is_pydantic_schema:
1950-
output_parser: Runnable = PydanticToolsParser(
1952+
output_parser: Runnable[Any, Any] = PydanticToolsParser(
19511953
tools=[schema],
19521954
first_tool_only=True,
19531955
)
@@ -2016,7 +2018,7 @@ def _is_pydantic_class(obj: Any) -> bool:
20162018
return isinstance(obj, type) and is_basemodel_subclass(obj)
20172019

20182020

2019-
def _lc_tool_call_to_watsonx_tool_call(tool_call: ToolCall) -> dict:
2021+
def _lc_tool_call_to_watsonx_tool_call(tool_call: ToolCall) -> dict[str, Any]:
20202022
return {
20212023
"type": "function",
20222024
"id": tool_call["id"],
@@ -2029,7 +2031,7 @@ def _lc_tool_call_to_watsonx_tool_call(tool_call: ToolCall) -> dict:
20292031

20302032
def _lc_invalid_tool_call_to_watsonx_tool_call(
20312033
invalid_tool_call: InvalidToolCall,
2032-
) -> dict:
2034+
) -> dict[str, Any]:
20332035
return {
20342036
"type": "function",
20352037
"id": invalid_tool_call["id"],
@@ -2041,7 +2043,7 @@ def _lc_invalid_tool_call_to_watsonx_tool_call(
20412043

20422044

20432045
def _create_usage_metadata(
2044-
oai_token_usage: dict,
2046+
oai_token_usage: dict[str, Any],
20452047
*,
20462048
_prompt_tokens_included: bool,
20472049
) -> UsageMetadata:
@@ -2059,7 +2061,7 @@ def _create_usage_metadata(
20592061

20602062
def _convert_to_openai_response_format(
20612063
schema: dict[str, Any] | type, *, strict: bool | None = None
2062-
) -> dict | TypeBaseModel:
2064+
) -> dict[str, Any] | TypeBaseModel:
20632065
if isinstance(schema, type) and is_basemodel_subclass(schema):
20642066
return schema
20652067

libs/ibm/langchain_ibm/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ class WatsonxEmbeddings(BaseModel, LangChainEmbeddings):
177177
version: SecretStr | None = None
178178
"""Version of the CPD instance."""
179179

180-
params: dict | None = None
180+
params: dict[str, Any] | None = None
181181
"""Model parameters to use during request generation."""
182182

183183
verify: str | bool | None = None

libs/ibm/langchain_ibm/llms.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ class WatsonxLLM(BaseLLM):
209209
version: SecretStr | None = None
210210
"""Version of the CPD instance."""
211211

212-
params: dict | None = None
212+
params: dict[str, Any] | None = None
213213
"""Model parameters to use during request generation."""
214214

215215
verify: str | bool | None = None
@@ -343,7 +343,9 @@ def validate_environment(self) -> Self:
343343
return self
344344

345345
@gateway_error_handler
346-
def _call_model_gateway(self, *, model: str, prompt: list, **params: Any) -> Any:
346+
def _call_model_gateway(
347+
self, *, model: str, prompt: str | list[str] | list[int], **params: Any
348+
) -> Any:
347349
return self.watsonx_model_gateway.completions.create(
348350
model=model,
349351
prompt=prompt,
@@ -355,7 +357,7 @@ async def _acall_model_gateway(
355357
self,
356358
*,
357359
model: str,
358-
prompt: list,
360+
prompt: str | list[str] | list[int],
359361
**params: Any,
360362
) -> Any:
361363
return await self.watsonx_model_gateway.completions.acreate(
@@ -455,7 +457,7 @@ def _get_chat_params(
455457
params = (params or {}) | {"stop_sequences": stop}
456458
return params, kwargs
457459

458-
def _create_llm_result(self, response: list[dict]) -> LLMResult:
460+
def _create_llm_result(self, response: list[dict[str, Any]]) -> LLMResult:
459461
"""Create the LLMResult from the choices and prompts."""
460462
generations = [
461463
[
@@ -480,7 +482,7 @@ def _create_llm_result(self, response: list[dict]) -> LLMResult:
480482
}
481483
return LLMResult(generations=generations, llm_output=llm_output)
482484

483-
def _create_llm_gateway_result(self, response: dict) -> LLMResult:
485+
def _create_llm_gateway_result(self, response: dict[str, Any]) -> LLMResult:
484486
"""Create the LLMResult from the choices and prompts."""
485487
choices = response["choices"]
486488

libs/ibm/langchain_ibm/rerank.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ class WatsonxRerank(BaseDocumentCompressor):
148148
version: SecretStr | None = None
149149
"""Version of the CPD instance."""
150150

151-
params: dict | RerankParameters | None = None
151+
params: dict[str, Any] | RerankParameters | None = None
152152
"""Model parameters to use during request generation."""
153153

154154
verify: str | bool | None = None
@@ -235,7 +235,7 @@ def validate_environment(self) -> Self:
235235

236236
def rerank(
237237
self,
238-
documents: Sequence[str | Document | dict],
238+
documents: Sequence[str | Document | dict[str, Any]],
239239
query: str,
240240
**kwargs: Any,
241241
) -> list[dict[str, Any]]:

0 commit comments

Comments
 (0)