@@ -153,7 +153,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any], call_id: str) -> BaseMess
153153 return HumanMessage (content = _dict .get ("content" , "" ), id = id_ , name = name )
154154 if role == "assistant" :
155155 content = _dict .get ("content" , "" ) or ""
156- additional_kwargs : dict = {}
156+ additional_kwargs : dict [ str , Any ] = {}
157157 if function_call := _dict .get ("function_call" ):
158158 additional_kwargs ["function_call" ] = dict (function_call )
159159 tool_calls = []
@@ -242,7 +242,7 @@ def _format_message_content(content: Any) -> Any:
242242 return formatted_content
243243
244244
245- def _convert_message_to_dict (message : BaseMessage ) -> dict :
245+ def _convert_message_to_dict (message : BaseMessage ) -> dict [ str , Any ] :
246246 """Convert a LangChain message to a dictionary.
247247
248248 Args:
@@ -327,7 +327,7 @@ def _convert_delta_to_message_chunk(
327327 id_ = call_id
328328 role = cast ("str" , _dict .get ("role" ))
329329 content = cast ("str" , _dict .get ("content" ) or "" )
330- additional_kwargs : dict = {}
330+ additional_kwargs : dict [ str , Any ] = {}
331331 if _dict .get ("function_call" ):
332332 function_call = dict (_dict ["function_call" ])
333333 if "name" in function_call and function_call ["name" ] is None :
@@ -379,7 +379,7 @@ def _convert_delta_to_message_chunk(
379379
380380
381381def _convert_chunk_to_generation_chunk (
382- chunk : dict ,
382+ chunk : dict [ str , Any ] ,
383383 default_chunk_class : type ,
384384 * ,
385385 is_first_tool_chunk : bool ,
@@ -939,7 +939,7 @@ class Joke(BaseModel):
939939 version : SecretStr | None = None
940940 """Version of the CPD instance."""
941941
942- params : dict | TextChatParameters | None = None
942+ params : dict [ str , Any ] | TextChatParameters | None = None
943943 """Model parameters to use during request generation.
944944
945945 !!! note
@@ -988,7 +988,7 @@ class Joke(BaseModel):
988988
989989 We generally recommend altering this or top_p but not both."""
990990
991- response_format : dict | None = None
991+ response_format : dict [ str , Any ] | None = None
992992 """The chat response format parameters."""
993993
994994 top_p : float | None = None
@@ -1003,7 +1003,7 @@ class Joke(BaseModel):
10031003 """Time limit in milliseconds - if not completed within this time,
10041004 generation will stop."""
10051005
1006- logit_bias : dict | None = None
1006+ logit_bias : dict [ str , int ] | None = None
10071007 """Increasing or decreasing probability of tokens being selected
10081008 during generation."""
10091009
@@ -1015,7 +1015,7 @@ class Joke(BaseModel):
10151015 """Stop sequences are one or more strings which will cause the text generation
10161016 to stop if/when they are produced as part of the output."""
10171017
1018- chat_template_kwargs : dict | None = None
1018+ chat_template_kwargs : dict [ str , Any ] | None = None
10191019 """Additional chat template parameters."""
10201020
10211021 verify : str | bool | None = None
@@ -1187,7 +1187,9 @@ def validate_environment(self) -> Self:
11871187 return self
11881188
11891189 @gateway_error_handler
1190- def _call_model_gateway (self , * , model : str , messages : list , ** params : Any ) -> Any :
1190+ def _call_model_gateway (
1191+ self , * , model : str , messages : list [dict [str , Any ]], ** params : Any
1192+ ) -> Any :
11911193 return self .watsonx_model_gateway .chat .completions .create (
11921194 model = model ,
11931195 messages = messages ,
@@ -1199,7 +1201,7 @@ async def _acall_model_gateway(
11991201 self ,
12001202 * ,
12011203 model : str ,
1202- messages : list ,
1204+ messages : list [ dict [ str , Any ]] ,
12031205 ** params : Any ,
12041206 ) -> Any :
12051207 return await self .watsonx_model_gateway .chat .completions .acreate (
@@ -1405,7 +1407,7 @@ async def _astream(
14051407 yield generation_chunk
14061408
14071409 @staticmethod
1408- def _merge_params (params : dict , kwargs : dict ) -> dict :
1410+ def _merge_params (params : dict [ str , Any ], kwargs : dict [ str , Any ] ) -> dict [ str , Any ] :
14091411 param_updates = {}
14101412 for k in ChatWatsonx ._get_supported_chat_params ():
14111413 if kwargs .get (k ) is not None :
@@ -1438,8 +1440,8 @@ def _create_message_dicts(
14381440
14391441 def _create_chat_result (
14401442 self ,
1441- response : dict ,
1442- generation_info : dict | None = None ,
1443+ response : dict [ str , Any ] ,
1444+ generation_info : dict [ str , Any ] | None = None ,
14431445 ) -> ChatResult :
14441446 generations = []
14451447
@@ -1496,9 +1498,9 @@ def _get_supported_chat_params() -> list[str]:
14961498
14971499 def bind_tools (
14981500 self ,
1499- tools : Sequence [dict [str , Any ] | type | Callable | BaseTool ],
1501+ tools : Sequence [dict [str , Any ] | type | Callable [..., Any ] | BaseTool ],
15001502 * ,
1501- tool_choice : dict | str | bool | None = None ,
1503+ tool_choice : dict [ str , Any ] | str | bool | None = None ,
15021504 strict : bool | None = None ,
15031505 ** kwargs : Any ,
15041506 ) -> Runnable [LanguageModelInput , AIMessage ]:
@@ -1580,7 +1582,7 @@ def bind_tools(
15801582 @override
15811583 def with_structured_output (
15821584 self ,
1583- schema : dict | type | None = None ,
1585+ schema : dict [ str , Any ] | type | None = None ,
15841586 * ,
15851587 method : Literal [
15861588 "function_calling" ,
@@ -1590,7 +1592,7 @@ def with_structured_output(
15901592 include_raw : bool = False ,
15911593 strict : bool | None = None ,
15921594 ** kwargs : Any ,
1593- ) -> Runnable [LanguageModelInput , dict | BaseModel ]:
1595+ ) -> Runnable [LanguageModelInput , dict [ str , Any ] | BaseModel ]:
15941596 r"""Model wrapper that returns outputs formatted to match the given schema.
15951597
15961598 Args:
@@ -1947,7 +1949,7 @@ class AnswerWithJustification(BaseModel):
19471949 },
19481950 )
19491951 if is_pydantic_schema :
1950- output_parser : Runnable = PydanticToolsParser (
1952+ output_parser : Runnable [ Any , Any ] = PydanticToolsParser (
19511953 tools = [schema ],
19521954 first_tool_only = True ,
19531955 )
@@ -2016,7 +2018,7 @@ def _is_pydantic_class(obj: Any) -> bool:
20162018 return isinstance (obj , type ) and is_basemodel_subclass (obj )
20172019
20182020
2019- def _lc_tool_call_to_watsonx_tool_call (tool_call : ToolCall ) -> dict :
2021+ def _lc_tool_call_to_watsonx_tool_call (tool_call : ToolCall ) -> dict [ str , Any ] :
20202022 return {
20212023 "type" : "function" ,
20222024 "id" : tool_call ["id" ],
@@ -2029,7 +2031,7 @@ def _lc_tool_call_to_watsonx_tool_call(tool_call: ToolCall) -> dict:
20292031
20302032def _lc_invalid_tool_call_to_watsonx_tool_call (
20312033 invalid_tool_call : InvalidToolCall ,
2032- ) -> dict :
2034+ ) -> dict [ str , Any ] :
20332035 return {
20342036 "type" : "function" ,
20352037 "id" : invalid_tool_call ["id" ],
@@ -2041,7 +2043,7 @@ def _lc_invalid_tool_call_to_watsonx_tool_call(
20412043
20422044
20432045def _create_usage_metadata (
2044- oai_token_usage : dict ,
2046+ oai_token_usage : dict [ str , Any ] ,
20452047 * ,
20462048 _prompt_tokens_included : bool ,
20472049) -> UsageMetadata :
@@ -2059,7 +2061,7 @@ def _create_usage_metadata(
20592061
20602062def _convert_to_openai_response_format (
20612063 schema : dict [str , Any ] | type , * , strict : bool | None = None
2062- ) -> dict | TypeBaseModel :
2064+ ) -> dict [ str , Any ] | TypeBaseModel :
20632065 if isinstance (schema , type ) and is_basemodel_subclass (schema ):
20642066 return schema
20652067
0 commit comments