Skip to content

Commit 204cda5

Browse files
author
mengqian
committed
feat: (ocr service)add confidence score and update package path
1 parent abb8f35 commit 204cda5

File tree

10 files changed

+47
-31
lines changed

10 files changed

+47
-31
lines changed

examples/handwriting_ocr_example.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
api_key=""
66
)
77

8+
89
def handwriting_ocr_example():
910
"""
1011
Full Example: Submit image for recognition and wait for the result to be returned.
@@ -14,15 +15,17 @@ def handwriting_ocr_example():
1415
file_path = 'Your image path'
1516
with open(file_path, 'rb') as f:
1617
print("Submitting a handwriting recognition task ...")
17-
response = client.file_parser.handwriting_ocr(
18+
response = client.ocr.handwriting_ocr(
1819
file=f,
1920
tool_type="hand_write",
21+
probability=True
2022
)
2123
print("Task created successfully. Response:")
2224
print(response)
2325

2426
print("Handwriting OCR demo completed.")
2527

28+
2629
if __name__ == "__main__":
2730
print("=== Handwriting recognition quick demo ===\n")
28-
handwriting_ocr_example()
31+
handwriting_ocr_example()

src/__init__.py

Whitespace-only changes.

src/zai/_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from zai.api_resource.web_search import WebSearchApi
2525
from zai.api_resource.web_reader import WebReaderApi
2626
from zai.api_resource.file_parser import FileParser
27-
from zai.api_resource.file_parser import HandwritingOCR
27+
from zai.api_resource.ocr import HandwritingOCR
2828

2929
from .core import (
3030
NOT_GIVEN,
@@ -202,8 +202,8 @@ def file_parser(self) -> FileParser:
202202
return FileParser(self)
203203

204204
@cached_property
205-
def file_parser(self) -> HandwritingOCR:
206-
from zai.api_resource.file_parser import HandwritingOCR
205+
def ocr(self) -> HandwritingOCR:
206+
from zai.api_resource.ocr import HandwritingOCR
207207
return HandwritingOCR(self)
208208

209209
@property
@@ -214,7 +214,7 @@ def auth_headers(self) -> dict[str, str]:
214214
if self.disable_token_cache:
215215
return {
216216
'Authorization': f'Bearer {api_key}',
217-
'x-source-channel': source_channel,
217+
'x-source-channel': source_channel,
218218
}
219219
else:
220220
return {

src/zai/api_resource/__init__.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,35 @@
1010
Completions,
1111
)
1212
from .embeddings import Embeddings
13+
from .file_parser import FileParser
1314
from .files import Files, FilesWithRawResponse
1415
from .images import Images
1516
from .moderations import Moderations
17+
from .ocr import HandwritingOCR
1618
from .tools import Tools
1719
from .videos import (
1820
Videos,
1921
)
20-
from .web_search import WebSearchApi
2122
from .web_reader import WebReaderApi
22-
from .file_parser import FileParser, HandwritingOCR
23+
from .web_search import WebSearchApi
2324

2425
__all__ = [
25-
'Videos',
26-
'AsyncCompletions',
27-
'Chat',
28-
'Completions',
29-
'Images',
30-
'Embeddings',
31-
'Files',
32-
'FilesWithRawResponse',
33-
'Batches',
34-
'Tools',
35-
'Assistant',
36-
'Audio',
37-
'Moderations',
38-
'WebSearchApi',
39-
'WebReaderApi',
40-
'Agents',
41-
'FileParser',
42-
'HandwritingOCR'
26+
'Videos',
27+
'AsyncCompletions',
28+
'Chat',
29+
'Completions',
30+
'Images',
31+
'Embeddings',
32+
'Files',
33+
'FilesWithRawResponse',
34+
'Batches',
35+
'Tools',
36+
'Assistant',
37+
'Audio',
38+
'Moderations',
39+
'WebSearchApi',
40+
'WebReaderApi',
41+
'Agents',
42+
'FileParser',
43+
'HandwritingOCR'
4344
]
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from .file_parser import FileParser
2-
from .handwriting_ocr import HandwritingOCR
32

4-
__all__ = ['FileParser', "HandwritingOCR"]
3+
__all__ = ['FileParser']
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from src.zai.api_resource.ocr.handwriting_ocr import HandwritingOCR
2+
3+
__all__ = ["HandwritingOCR"]

src/zai/api_resource/file_parser/handwriting_ocr.py renamed to src/zai/api_resource/ocr/handwriting_ocr.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
extract_files,
1717
make_request_options
1818
)
19-
from zai.types.file_parser.handwriting_ocr_params import HandwritingOCRParams
20-
from zai.types.file_parser.handwriting_ocr_resp import HandwritingOCRResp
19+
from zai.types.ocr.handwriting_ocr_params import HandwritingOCRParams
20+
from zai.types.ocr.handwriting_ocr_resp import HandwritingOCRResp
2121

2222
if TYPE_CHECKING:
2323
from zai._client import ZaiClient
@@ -35,7 +35,8 @@ def handwriting_ocr(
3535
*,
3636
file: FileTypes,
3737
tool_type: Literal["hand_write"],
38-
language_type: str = None, # optional
38+
language_type: str = None, # optional,
39+
probability: bool = None,
3940
extra_headers: Headers | None = None,
4041
extra_body: Body | None = None,
4142
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
@@ -49,6 +50,7 @@ def handwriting_ocr(
4950
"file": file,
5051
"tool_type": tool_type,
5152
"language_type": language_type,
53+
"probability": probability
5254
}
5355
)
5456
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])

src/zai/types/ocr/__init__.py

Whitespace-only changes.

src/zai/types/file_parser/handwriting_ocr_params.py renamed to src/zai/types/ocr/handwriting_ocr_params.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ class HandwritingOCRParams(TypedDict, total=False):
1010
file: FileTypes # Required
1111
tool_type: Literal["hand_write"] # Required
1212
language_type: str # Optional
13+
probability: bool

src/zai/types/file_parser/handwriting_ocr_resp.py renamed to src/zai/types/ocr/handwriting_ocr_resp.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,16 @@ class Location(BaseModel):
1313
height: int
1414

1515

16+
class Probability(BaseModel):
17+
average: float
18+
variance: float
19+
min: float
20+
21+
1622
class WordsResult(BaseModel):
1723
location: Location
1824
words: str
25+
probability: Probability
1926

2027

2128
class HandwritingOCRResp(BaseModel):

0 commit comments

Comments
 (0)