Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions examples/handwriting_ocr_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from zai import ZaiClient

client = ZaiClient(
base_url="",
api_key=""
)


def handwriting_ocr_example():
"""
Full Example: Submit image for recognition and wait for the result to be returned.
"""
# Create recognition task
# Please modify the local file path
file_path = 'Your image path'
with open(file_path, 'rb') as f:
print("Submitting a handwriting recognition task ...")
response = client.ocr.handwriting_ocr(
file=f,
tool_type="hand_write",
probability=True
)
print("Task created successfully. Response:")
print(response)

print("Handwriting OCR demo completed.")


if __name__ == "__main__":
print("=== Handwriting recognition quick demo ===\n")
handwriting_ocr_example()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zai-sdk"
version = "0.0.4.2"
version = "0.0.4.3"
description = "A SDK library for accessing big model apis from Z.ai"
authors = ["Z.ai"]
readme = "README.md"
Expand Down
Empty file added src/__init__.py
Empty file.
8 changes: 7 additions & 1 deletion src/zai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from zai.api_resource.web_search import WebSearchApi
from zai.api_resource.web_reader import WebReaderApi
from zai.api_resource.file_parser import FileParser
from zai.api_resource.ocr import HandwritingOCR

from .core import (
NOT_GIVEN,
Expand Down Expand Up @@ -200,6 +201,11 @@ def file_parser(self) -> FileParser:
from zai.api_resource.file_parser import FileParser
return FileParser(self)

@cached_property
def ocr(self) -> HandwritingOCR:
from zai.api_resource.ocr import HandwritingOCR
return HandwritingOCR(self)

@property
@override
def auth_headers(self) -> dict[str, str]:
Expand All @@ -208,7 +214,7 @@ def auth_headers(self) -> dict[str, str]:
if self.disable_token_cache:
return {
'Authorization': f'Bearer {api_key}',
'x-source-channel': source_channel,
'x-source-channel': source_channel,
}
else:
return {
Expand Down
2 changes: 1 addition & 1 deletion src/zai/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__title__ = 'Z.ai'
__version__ = '0.0.4.2'
__version__ = '0.0.4.3'
41 changes: 21 additions & 20 deletions src/zai/api_resource/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,35 @@
Completions,
)
from .embeddings import Embeddings
from .file_parser import FileParser
from .files import Files, FilesWithRawResponse
from .images import Images
from .moderations import Moderations
from .ocr import HandwritingOCR
from .tools import Tools
from .videos import (
Videos,
)
from .web_search import WebSearchApi
from .web_reader import WebReaderApi
from .file_parser import FileParser

from .web_search import WebSearchApi

__all__ = [
'Videos',
'AsyncCompletions',
'Chat',
'Completions',
'Images',
'Embeddings',
'Files',
'FilesWithRawResponse',
'Batches',
'Tools',
'Assistant',
'Audio',
'Moderations',
'WebSearchApi',
'WebReaderApi',
'Agents',
'FileParser',
'Videos',
'AsyncCompletions',
'Chat',
'Completions',
'Images',
'Embeddings',
'Files',
'FilesWithRawResponse',
'Batches',
'Tools',
'Assistant',
'Audio',
'Moderations',
'WebSearchApi',
'WebReaderApi',
'Agents',
'FileParser',
'HandwritingOCR'
]
2 changes: 1 addition & 1 deletion src/zai/api_resource/file_parser/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .file_parser import FileParser

__all__ = ['FileParser']
__all__ = ['FileParser']
3 changes: 3 additions & 0 deletions src/zai/api_resource/ocr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .handwriting_ocr import HandwritingOCR

__all__ = ["HandwritingOCR"]
67 changes: 67 additions & 0 deletions src/zai/api_resource/ocr/handwriting_ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Mapping, cast

import httpx
from typing_extensions import Literal
from zai.core import (
BaseAPI,
maybe_transform,
NOT_GIVEN,
Body,
Headers,
NotGiven,
FileTypes,
deepcopy_minimal,
extract_files,
make_request_options
)
from zai.types.ocr.handwriting_ocr_params import HandwritingOCRParams
from zai.types.ocr.handwriting_ocr_resp import HandwritingOCRResp

if TYPE_CHECKING:
from zai._client import ZaiClient

__all__ = ["HandwritingOCR"]


class HandwritingOCR(BaseAPI):

def __init__(self, client: "ZaiClient") -> None:
super().__init__(client)

def handwriting_ocr(
self,
*,
file: FileTypes,
tool_type: Literal["hand_write"],
language_type: str = None, # optional,
probability: bool = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> HandwritingOCRResp:
if not file:
raise ValueError("`file` must be provided.")
if not tool_type:
raise ValueError("`tool_type` must be provided.")
body = deepcopy_minimal(
{
"file": file,
"tool_type": tool_type,
"language_type": language_type,
"probability": probability
}
)
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
if files:
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return self._post(
"/files/ocr",
body=maybe_transform(body, HandwritingOCRParams),
files=files,
options=make_request_options(
extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
),
cast_type=HandwritingOCRResp,
)
Empty file added src/zai/types/ocr/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions src/zai/types/ocr/handwriting_ocr_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

from typing_extensions import Literal, TypedDict
from zai.core import FileTypes

__all__ = ["HandwritingOCRParams"]


class HandwritingOCRParams(TypedDict, total=False):
file: FileTypes # Required
tool_type: Literal["hand_write"] # Required
language_type: str # Optional
probability: bool
33 changes: 33 additions & 0 deletions src/zai/types/ocr/handwriting_ocr_resp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import List
from typing import Optional

from zai.core import BaseModel

__all__ = ["HandwritingOCRResp"]


class Location(BaseModel):
left: int
top: int
width: int
height: int


class Probability(BaseModel):
average: float
variance: float
min: float


class WordsResult(BaseModel):
location: Location
words: str
probability: Probability


class HandwritingOCRResp(BaseModel):
task_id: str # Task ID or Result ID
message: str # Status message
status: str # OCR task status
words_result_num: int # Number of recognition results
words_result: Optional[List[WordsResult]] = None # List of recognition resultst details (if any)