diff --git a/examples/handwriting_ocr_example.py b/examples/handwriting_ocr_example.py new file mode 100644 index 0000000..d1c57f8 --- /dev/null +++ b/examples/handwriting_ocr_example.py @@ -0,0 +1,31 @@ +from zai import ZaiClient + +client = ZaiClient( + base_url="", + api_key="" +) + + +def handwriting_ocr_example(): + """ + Full Example: Submit image for recognition and wait for the result to be returned. + """ + # Create recognition task + # Please modify the local file path + file_path = 'Your image path' + with open(file_path, 'rb') as f: + print("Submitting a handwriting recognition task ...") + response = client.ocr.handwriting_ocr( + file=f, + tool_type="hand_write", + probability=True + ) + print("Task created successfully. Response:") + print(response) + + print("Handwriting OCR demo completed.") + + +if __name__ == "__main__": + print("=== Handwriting recognition quick demo ===\n") + handwriting_ocr_example() diff --git a/pyproject.toml b/pyproject.toml index ee2c54b..0283f77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zai-sdk" -version = "0.0.4.2" +version = "0.0.4.3" description = "A SDK library for accessing big model apis from Z.ai" authors = ["Z.ai"] readme = "README.md" diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/zai/_client.py b/src/zai/_client.py index 517c029..faadce9 100644 --- a/src/zai/_client.py +++ b/src/zai/_client.py @@ -24,6 +24,7 @@ from zai.api_resource.web_search import WebSearchApi from zai.api_resource.web_reader import WebReaderApi from zai.api_resource.file_parser import FileParser + from zai.api_resource.ocr import HandwritingOCR from .core import ( NOT_GIVEN, @@ -200,6 +201,11 @@ def file_parser(self) -> FileParser: from zai.api_resource.file_parser import FileParser return FileParser(self) + @cached_property + def ocr(self) -> HandwritingOCR: + from zai.api_resource.ocr import HandwritingOCR + return HandwritingOCR(self) + @property @override def auth_headers(self) -> dict[str, str]: @@ -208,7 +214,7 @@ def auth_headers(self) -> dict[str, str]: if self.disable_token_cache: return { 'Authorization': f'Bearer {api_key}', - 'x-source-channel': source_channel, + 'x-source-channel': source_channel, } else: return { diff --git a/src/zai/_version.py b/src/zai/_version.py index 823a735..1718685 100644 --- a/src/zai/_version.py +++ b/src/zai/_version.py @@ -1,2 +1,2 @@ __title__ = 'Z.ai' -__version__ = '0.0.4.2' +__version__ = '0.0.4.3' diff --git a/src/zai/api_resource/__init__.py b/src/zai/api_resource/__init__.py index beeb7bd..f2bd160 100644 --- a/src/zai/api_resource/__init__.py +++ b/src/zai/api_resource/__init__.py @@ -10,34 +10,35 @@ Completions, ) from .embeddings import Embeddings +from .file_parser import FileParser from .files import Files, FilesWithRawResponse from .images import Images from .moderations import Moderations +from .ocr import HandwritingOCR from .tools import Tools from .videos import ( Videos, ) -from .web_search import WebSearchApi from .web_reader import WebReaderApi -from .file_parser import FileParser - +from .web_search import WebSearchApi __all__ = [ - 'Videos', - 'AsyncCompletions', - 'Chat', - 'Completions', - 'Images', - 'Embeddings', - 'Files', - 'FilesWithRawResponse', - 'Batches', - 'Tools', - 'Assistant', - 'Audio', - 'Moderations', - 'WebSearchApi', - 'WebReaderApi', - 'Agents', - 'FileParser', + 'Videos', + 'AsyncCompletions', + 'Chat', + 'Completions', + 'Images', + 'Embeddings', + 'Files', + 'FilesWithRawResponse', + 'Batches', + 'Tools', + 'Assistant', + 'Audio', + 'Moderations', + 'WebSearchApi', + 'WebReaderApi', + 'Agents', + 'FileParser', + 'HandwritingOCR' ] diff --git a/src/zai/api_resource/file_parser/__init__.py b/src/zai/api_resource/file_parser/__init__.py index b263267..71dd743 100644 --- a/src/zai/api_resource/file_parser/__init__.py +++ b/src/zai/api_resource/file_parser/__init__.py @@ -1,3 +1,3 @@ from .file_parser import FileParser -__all__ = ['FileParser'] \ No newline at end of file +__all__ = ['FileParser'] diff --git a/src/zai/api_resource/ocr/__init__.py b/src/zai/api_resource/ocr/__init__.py new file mode 100644 index 0000000..0a9a47a --- /dev/null +++ b/src/zai/api_resource/ocr/__init__.py @@ -0,0 +1,3 @@ +from .handwriting_ocr import HandwritingOCR + +__all__ = ["HandwritingOCR"] diff --git a/src/zai/api_resource/ocr/handwriting_ocr.py b/src/zai/api_resource/ocr/handwriting_ocr.py new file mode 100644 index 0000000..cb6ba87 --- /dev/null +++ b/src/zai/api_resource/ocr/handwriting_ocr.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping, cast + +import httpx +from typing_extensions import Literal +from zai.core import ( + BaseAPI, + maybe_transform, + NOT_GIVEN, + Body, + Headers, + NotGiven, + FileTypes, + deepcopy_minimal, + extract_files, + make_request_options +) +from zai.types.ocr.handwriting_ocr_params import HandwritingOCRParams +from zai.types.ocr.handwriting_ocr_resp import HandwritingOCRResp + +if TYPE_CHECKING: + from zai._client import ZaiClient + +__all__ = ["HandwritingOCR"] + + +class HandwritingOCR(BaseAPI): + + def __init__(self, client: "ZaiClient") -> None: + super().__init__(client) + + def handwriting_ocr( + self, + *, + file: FileTypes, + tool_type: Literal["hand_write"], + language_type: str = None, # optional, + probability: bool = None, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> HandwritingOCRResp: + if not file: + raise ValueError("`file` must be provided.") + if not tool_type: + raise ValueError("`tool_type` must be provided.") + body = deepcopy_minimal( + { + "file": file, + "tool_type": tool_type, + "language_type": language_type, + "probability": probability + } + ) + files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) + if files: + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + return self._post( + "/files/ocr", + body=maybe_transform(body, HandwritingOCRParams), + files=files, + options=make_request_options( + extra_headers=extra_headers, extra_body=extra_body, timeout=timeout + ), + cast_type=HandwritingOCRResp, + ) diff --git a/src/zai/types/ocr/__init__.py b/src/zai/types/ocr/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/zai/types/ocr/handwriting_ocr_params.py b/src/zai/types/ocr/handwriting_ocr_params.py new file mode 100644 index 0000000..fdb3c40 --- /dev/null +++ b/src/zai/types/ocr/handwriting_ocr_params.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from typing_extensions import Literal, TypedDict +from zai.core import FileTypes + +__all__ = ["HandwritingOCRParams"] + + +class HandwritingOCRParams(TypedDict, total=False): + file: FileTypes # Required + tool_type: Literal["hand_write"] # Required + language_type: str # Optional + probability: bool diff --git a/src/zai/types/ocr/handwriting_ocr_resp.py b/src/zai/types/ocr/handwriting_ocr_resp.py new file mode 100644 index 0000000..55a05eb --- /dev/null +++ b/src/zai/types/ocr/handwriting_ocr_resp.py @@ -0,0 +1,33 @@ +from typing import List +from typing import Optional + +from zai.core import BaseModel + +__all__ = ["HandwritingOCRResp"] + + +class Location(BaseModel): + left: int + top: int + width: int + height: int + + +class Probability(BaseModel): + average: float + variance: float + min: float + + +class WordsResult(BaseModel): + location: Location + words: str + probability: Probability + + +class HandwritingOCRResp(BaseModel): + task_id: str # Task ID or Result ID + message: str # Status message + status: str # OCR task status + words_result_num: int # Number of recognition results + words_result: Optional[List[WordsResult]] = None # List of recognition resultst details (if any)