Skip to content

Commit 920f973

Browse files
version
1 parent bd6544d commit 920f973

File tree

4 files changed

+107
-30
lines changed

4 files changed

+107
-30
lines changed

ai_transform/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.31.3"
1+
__version__ = "0.32.0"
22

33
from ai_transform.timer import Timer
44

ai_transform/engine/abstract_engine.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,10 @@ def __init__(
106106
filters = []
107107
assert isinstance(filters, list), "Filters must be applied as a list of Dictionaries"
108108

109-
if not refresh:
110-
filters += self._get_refresh_filter(select_fields, dataset)
109+
self._refresh = refresh
110+
self._after_id = after_id
111+
112+
filters += self._get_refresh_filter()
111113
filters += self._get_workflow_filter()
112114

113115
self._filters = filters
@@ -117,9 +119,6 @@ def __init__(
117119
else:
118120
self._size = dataset.len(filters=filters) if self._limit_documents is None else self._limit_documents
119121

120-
self._refresh = refresh
121-
self._after_id = after_id
122-
123122
self._successful_documents = 0
124123
self._success_ratio = None
125124

@@ -205,36 +204,36 @@ def _operate(self, mini_batch):
205204
self._successful_documents += len(mini_batch)
206205
return transformed_batch
207206

208-
def _get_refresh_filter(self, select_fields: List[str], dataset: Dataset):
207+
def _get_refresh_filter(self):
209208
# initialize the refresh filter container
210-
refresh_filters = {"filter_type": "or", "condition_value": []}
209+
input_field_filters = {"filter_type": "or", "condition_value": []}
211210

212211
# initialize where the filters are going
213-
input_field_filters = []
214212
output_field_filters = {"filter_type": "or", "condition_value": []}
215213

216-
# We want documents where all select_fields exists
214+
# We want documents where any of the select_fields exists
217215
# as these are needed for operator ...
218-
for field in select_fields:
219-
input_field_filters += dataset[field].exists()
220-
221-
# ... and where any of its output_fields dont exist
222-
for operator in self.operators:
223-
if operator.output_fields is not None:
224-
for output_field in operator.output_fields:
225-
output_field_filters["condition_value"] += dataset[output_field].not_exists()
226-
227216
# We construct this as:
228217
#
229-
# input_field1 and input_field2 and (not output_field1 or not output_field2)
218+
# (input_field1 or input_field2) and (not output_field1 or not output_field2)
230219
#
231220
# This use case here is for two input fields and two output fields
232221
# tho this extends to arbitrarily many.
233-
refresh_filters["condition_value"] = input_field_filters
234-
refresh_filters["condition_value"] += [output_field_filters]
222+
for field in self._select_fields:
223+
input_field_filters["condition_value"] += self.dataset[field].exists()
224+
225+
# ... and where any of its output_fields dont exist
226+
if not self._refresh:
227+
for operator in self.operators:
228+
if operator.output_fields is not None:
229+
for output_field in operator.output_fields:
230+
output_field_filters["condition_value"] += self.dataset[output_field].not_exists()
231+
232+
return [input_field_filters, output_field_filters]
235233

236-
# Wrap in list at end
237-
return [refresh_filters]
234+
else:
235+
# Wrap in list at end
236+
return [input_field_filters]
238237

239238
def _get_workflow_filter(self, field: str = "_id"):
240239
# Get the required workflow filter as an environment variable

tests/conftest.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,34 @@ def partial_dataset(test_client: Client) -> Dataset:
6969
dataset_id = f"_sample_dataset_{salt}"
7070
dataset = test_client.Dataset(dataset_id, expire=True)
7171
documents = mock_documents(1000)
72+
fields = ["sample_1_label", "sample_2_label", "sample_3_label"]
7273
for document in documents:
73-
for field in random.choices(document.keys(), k=min(len(document), 5)):
74+
for field in random.sample(fields, k=random.randint(1, 3)):
7475
document.pop(field)
7576
dataset.insert_documents(documents)
7677
yield dataset
7778
test_client.delete_dataset(dataset_id)
7879

7980

81+
@pytest.fixture(scope="class")
82+
def partial_dataset_with_outputs(test_client: Client) -> Dataset:
83+
salt = "".join(random.choices(string.ascii_lowercase, k=10))
84+
dataset_id = f"_sample_dataset_{salt}"
85+
dataset = test_client.Dataset(dataset_id, expire=True)
86+
documents = mock_documents(1000)
87+
fields = ["sample_1_label", "sample_2_label", "sample_3_label"]
88+
for document in documents:
89+
for field in random.sample(fields, k=random.randint(1, 3)):
90+
document.pop(field)
91+
for document in documents:
92+
for field in fields:
93+
if document.get(field) and random.random() < 0.5:
94+
document[field + "_output"] = document[field] + "_output"
95+
dataset.insert_documents(documents)
96+
yield dataset
97+
test_client.delete_dataset(dataset_id)
98+
99+
80100
@pytest.fixture(scope="class")
81101
def mixed_dataset(test_client: Client) -> Dataset:
82102
salt = "".join(random.choices(string.ascii_lowercase, k=10))
@@ -164,6 +184,26 @@ def transform(self, documents: DocumentList) -> DocumentList:
164184
return ExampleOperator()
165185

166186

187+
@pytest.fixture(scope="function")
188+
def test_partial_operator() -> AbstractOperator:
189+
class PartialOperator(AbstractOperator):
190+
def __init__(self, fields):
191+
super().__init__(input_fields=fields, output_fields=[field + "_output" for field in fields])
192+
193+
def transform(self, documents: DocumentList) -> DocumentList:
194+
"""
195+
Main transform function
196+
"""
197+
for input_field, output_field in zip(self.input_fields, self.output_fields):
198+
for document in documents:
199+
if document.get(input_field):
200+
document[output_field] = document[input_field] + "_output"
201+
202+
return documents
203+
204+
return PartialOperator
205+
206+
167207
@pytest.fixture(scope="function")
168208
def test_paid_operator() -> AbstractOperator:
169209
class ExampleOperator(AbstractOperator):
Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import uuid
2+
3+
from typing import Type
4+
15
from ai_transform.dataset.dataset import Dataset
26
from ai_transform.engine.stable_engine import StableEngine
37
from ai_transform.engine.small_batch_stable_engine import SmallBatchStableEngine
@@ -6,21 +10,55 @@
610
from ai_transform.workflow.abstract_workflow import Workflow
711

812

13+
def _random_id():
14+
return str(uuid.uuid4())
15+
16+
917
class TestStableEngine:
1018
def test_stable_engine(self, full_dataset: Dataset, test_operator: AbstractOperator):
1119
engine = StableEngine(full_dataset, test_operator, worker_number=0)
12-
workflow = Workflow(name="workflow_test123", engine=engine, job_id="test_job123")
20+
workflow = Workflow(name=_random_id(), engine=engine, job_id=_random_id())
1321
workflow.run()
1422
assert engine.success_ratio == 1
1523

1624
def test_small_batch_stable_engine(self, full_dataset: Dataset, test_operator: AbstractOperator):
1725
engine = SmallBatchStableEngine(full_dataset, test_operator)
18-
workflow = Workflow(name="workflow_test123", engine=engine, job_id="test_job123")
26+
workflow = Workflow(name=_random_id(), engine=engine, job_id=_random_id())
1927
workflow.run()
2028
assert engine.success_ratio == 1
2129

22-
def test_stable_engine_filters(self, partial_dataset: Dataset, test_operator: AbstractOperator):
23-
engine = StableEngine(partial_dataset, test_operator, select_fields=["sample_1_label"])
24-
workflow = Workflow(name="workflow_test123", engine=engine, job_id="test_job123")
30+
31+
class TestStableEngineFilters:
32+
_SELECTED_FIELDS = ["sample_1_label", "sample_2_label", "sample_3_label"]
33+
34+
def test_stable_engine_filters1(self, partial_dataset: Dataset, test_partial_operator: Type[AbstractOperator]):
35+
prev_health = partial_dataset.health()
36+
operator = test_partial_operator(self._SELECTED_FIELDS)
37+
38+
engine = StableEngine(partial_dataset, operator, select_fields=self._SELECTED_FIELDS)
39+
workflow = Workflow(name=_random_id(), engine=engine, job_id=_random_id())
2540
workflow.run()
41+
42+
post_health = partial_dataset.health()
43+
for input_field, output_field in zip(operator.input_fields, operator.output_fields):
44+
assert prev_health[input_field]["exists"] == post_health[output_field]["exists"]
45+
46+
assert engine.success_ratio == 1
47+
48+
def test_stable_engine_filters2(
49+
self, partial_dataset_with_outputs: Dataset, test_partial_operator: Type[AbstractOperator]
50+
):
51+
prev_health = partial_dataset_with_outputs.health()
52+
operator = test_partial_operator(self._SELECTED_FIELDS)
53+
54+
engine = StableEngine(
55+
partial_dataset_with_outputs, operator, select_fields=self._SELECTED_FIELDS, refresh=False
56+
)
57+
workflow = Workflow(name=_random_id(), engine=engine, job_id=_random_id())
58+
workflow.run()
59+
60+
post_health = partial_dataset_with_outputs.health()
61+
for input_field, output_field in zip(operator.input_fields, operator.output_fields):
62+
assert prev_health[input_field]["exists"] == post_health[output_field]["exists"]
63+
2664
assert engine.success_ratio == 1

0 commit comments

Comments
 (0)