|
3 | 3 | from ai_transform.engine.small_batch_stable_engine import SmallBatchStableEngine |
4 | 4 |
|
5 | 5 | from ai_transform.operator.abstract_operator import AbstractOperator |
6 | | -from ai_transform.workflow.abstract_workflow import AbstractWorkflow |
| 6 | +from ai_transform.workflow.abstract_workflow import Workflow |
7 | 7 |
|
8 | 8 |
|
9 | 9 | class TestStableEngine: |
10 | 10 | def test_stable_engine(self, full_dataset: Dataset, test_operator: AbstractOperator): |
11 | 11 | engine = StableEngine(full_dataset, test_operator, worker_number=0) |
12 | | - workflow = AbstractWorkflow(name="workflow_test123", engine=engine, job_id="test_job123") |
| 12 | + workflow = Workflow(name="workflow_test123", engine=engine, job_id="test_job123") |
13 | 13 | workflow.run() |
14 | 14 | assert engine.success_ratio == 1 |
15 | 15 |
|
16 | 16 | def test_small_batch_stable_engine(self, full_dataset: Dataset, test_operator: AbstractOperator): |
17 | 17 | engine = SmallBatchStableEngine(full_dataset, test_operator) |
18 | | - workflow = AbstractWorkflow(name="workflow_test123", engine=engine, job_id="test_job123") |
| 18 | + workflow = Workflow(name="workflow_test123", engine=engine, job_id="test_job123") |
| 19 | + workflow.run() |
| 20 | + assert engine.success_ratio == 1 |
| 21 | + |
| 22 | + def test_stable_engine_filters(self, partial_dataset: Dataset, test_operator: AbstractOperator): |
| 23 | + engine = StableEngine(partial_dataset, test_operator, select_fields=["sample_1_label"]) |
| 24 | + workflow = Workflow(name="workflow_test123", engine=engine, job_id="test_job123") |
19 | 25 | workflow.run() |
20 | 26 | assert engine.success_ratio == 1 |
0 commit comments