Skip to content

Commit bd6544d

Browse files
add new test
1 parent 3f92c1a commit bd6544d

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

ai_transform/engine/abstract_engine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import time
2-
import logging
32
import warnings
43

54
from json import JSONDecodeError
@@ -8,7 +7,7 @@
87

98
from tqdm.auto import tqdm
109

11-
from ai_transform.logger import format_logging_info, ic
10+
from ai_transform.logger import ic
1211
from ai_transform.types import Filter
1312
from ai_transform.dataset.dataset import Dataset
1413
from ai_transform.operator.abstract_operator import AbstractOperator

tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,20 @@ def full_dataset(test_client: Client) -> Dataset:
6363
test_client.delete_dataset(dataset_id)
6464

6565

66+
@pytest.fixture(scope="class")
67+
def partial_dataset(test_client: Client) -> Dataset:
68+
salt = "".join(random.choices(string.ascii_lowercase, k=10))
69+
dataset_id = f"_sample_dataset_{salt}"
70+
dataset = test_client.Dataset(dataset_id, expire=True)
71+
documents = mock_documents(1000)
72+
for document in documents:
73+
for field in random.choices(document.keys(), k=min(len(document), 5)):
74+
document.pop(field)
75+
dataset.insert_documents(documents)
76+
yield dataset
77+
test_client.delete_dataset(dataset_id)
78+
79+
6680
@pytest.fixture(scope="class")
6781
def mixed_dataset(test_client: Client) -> Dataset:
6882
salt = "".join(random.choices(string.ascii_lowercase, k=10))

tests/core/test_engine/test_stable_engine.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,24 @@
33
from ai_transform.engine.small_batch_stable_engine import SmallBatchStableEngine
44

55
from ai_transform.operator.abstract_operator import AbstractOperator
6-
from ai_transform.workflow.abstract_workflow import AbstractWorkflow
6+
from ai_transform.workflow.abstract_workflow import Workflow
77

88

99
class TestStableEngine:
1010
def test_stable_engine(self, full_dataset: Dataset, test_operator: AbstractOperator):
1111
engine = StableEngine(full_dataset, test_operator, worker_number=0)
12-
workflow = AbstractWorkflow(name="workflow_test123", engine=engine, job_id="test_job123")
12+
workflow = Workflow(name="workflow_test123", engine=engine, job_id="test_job123")
1313
workflow.run()
1414
assert engine.success_ratio == 1
1515

1616
def test_small_batch_stable_engine(self, full_dataset: Dataset, test_operator: AbstractOperator):
1717
engine = SmallBatchStableEngine(full_dataset, test_operator)
18-
workflow = AbstractWorkflow(name="workflow_test123", engine=engine, job_id="test_job123")
18+
workflow = Workflow(name="workflow_test123", engine=engine, job_id="test_job123")
19+
workflow.run()
20+
assert engine.success_ratio == 1
21+
22+
def test_stable_engine_filters(self, partial_dataset: Dataset, test_operator: AbstractOperator):
23+
engine = StableEngine(partial_dataset, test_operator, select_fields=["sample_1_label"])
24+
workflow = Workflow(name="workflow_test123", engine=engine, job_id="test_job123")
1925
workflow.run()
2026
assert engine.success_ratio == 1

0 commit comments

Comments
 (0)