Skip to content

Commit d7db0f2

Browse files
Merge pull request #441 from RelevanceAI/development
v0.32.1
2 parents f9a32cd + 0992528 commit d7db0f2

File tree

15 files changed

+131
-64
lines changed

15 files changed

+131
-64
lines changed

ai_transform/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.32.0"
1+
__version__ = "0.32.1"
22

33
from ai_transform.timer import Timer
44

ai_transform/api/api.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,10 @@ def _list_datasets(self):
162162
def _create_dataset(
163163
self, dataset_id: str, schema: Optional[Schema] = None, upsert: bool = True, expire: bool = False
164164
) -> Any:
165-
response = self.post(
166-
suffix=f"/datasets/create", json=dict(id=dataset_id, schema=schema, upsert=upsert, expire=expire)
167-
)
165+
obj = dict(id=dataset_id, upsert=upsert, expire=expire)
166+
if schema:
167+
obj["schema"] = schema
168+
response = self.post(suffix=f"/datasets/create", json=obj)
168169
return get_response(response)
169170

170171
def _delete_dataset(self, dataset_id: str) -> Any:

ai_transform/dataset/field.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,10 @@ def contains(self, other: str) -> Filter:
107107

108108
def exists(self) -> Filter:
109109
if "_chunk_" in self._field:
110-
count = self._field.count(".")
111-
if count:
112-
parent_field = self._field.split(".")[0]
113-
else:
110+
if self._field.endswith("_chunk_"):
114111
parent_field = self._field
112+
else:
113+
parent_field = self._field.split(".")[0]
115114

116115
return [{"chunk": {"path": parent_field, "filters": [{"fieldExists": {"field": self._field}}]}}]
117116
return [{"field": self._field, "filter_type": "exists", "condition": "==", "condition_value": " "}]

ai_transform/engine/abstract_engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363
chunk_index = field.index("_chunk_") + len("_chunk_")
6464
chunk_field = field[:chunk_index]
6565
fields_to_add += [chunk_field]
66-
select_fields += fields_to_add
66+
select_fields = select_fields + fields_to_add
6767
select_fields = list(set(select_fields))
6868
else:
6969
select_fields = []
@@ -109,8 +109,8 @@ def __init__(
109109
self._refresh = refresh
110110
self._after_id = after_id
111111

112-
filters += self._get_refresh_filter()
113-
filters += self._get_workflow_filter()
112+
filters = filters + self._get_refresh_filter()
113+
filters = filters + self._get_workflow_filter()
114114

115115
self._filters = filters
116116

ai_transform/engine/dense_output_engine.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222

2323
class DenseOutputEngine(AbstractEngine):
24+
operator: DenseOperator
25+
2426
def __init__(
2527
self,
2628
dataset: Dataset = None,
@@ -72,6 +74,7 @@ def apply(self) -> None:
7274
for mega_batch in self.api_progress(iterator):
7375
for mini_batch in AbstractEngine.chunk_documents(self._transform_chunksize, mega_batch):
7476
document_mapping = self._operate(mini_batch)
77+
7578
for dataset_id, documents in document_mapping.items():
7679
output_dataset_ids.append(dataset_id)
7780
dataset = Dataset.from_details(dataset_id, self.token)
@@ -81,7 +84,14 @@ def apply(self) -> None:
8184
self.operator.post_hooks(self._dataset)
8285

8386
output_datasets = self.datasets_from_ids(output_dataset_ids)
84-
self.operator.store_dataset_relationship(self.dataset, output_datasets)
87+
self.store_dataset_relationship(output_datasets)
8588

8689
def datasets_from_ids(self, dataset_ids: Sequence[str]) -> Sequence[Dataset]:
8790
return [Dataset.from_details(dataset_id, self.token) for dataset_id in dataset_ids]
91+
92+
def store_dataset_relationship(self, output_datasets: Sequence[Dataset]):
93+
self.dataset.update_metadata(
94+
{"_child_datasets_": [output_dataset.dataset_id for output_dataset in output_datasets]}
95+
)
96+
for output_dataset in output_datasets:
97+
output_dataset.update_metadata({"_parent_dataset_": self.dataset.dataset_id})

ai_transform/operator/abstract_operator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def __init__(
7575
output_fields: Optional[Union[Dict[str, str], List[str]]] = None,
7676
enable_postprocess: Optional[bool] = True,
7777
):
78-
7978
if input_fields is not None and output_fields is not None:
8079
if any(input_field in output_fields for input_field in input_fields):
8180
detected_fields = [input_field for input_field in input_fields if input_field in output_fields]
@@ -171,6 +170,12 @@ def transform_for_playground(
171170
from ai_transform.api.client import Client
172171

173172
output = self.transform(documents=documents)
173+
if hasattr(documents, "to_json"):
174+
output = output.to_json()
175+
else:
176+
for index in range(len(output)):
177+
if hasattr(output[index], "to_json"):
178+
output[index] = output[index].to_json()
174179
client = Client(authorization_token)
175180
return client.api._set_workflow_status(
176181
job_id=job_id,
@@ -180,7 +185,7 @@ def transform_for_playground(
180185
status=status,
181186
send_email=send_email,
182187
worker_number=worker_number,
183-
output=output,
188+
output={"output": output},
184189
)
185190

186191
def pre_hooks(self, dataset: Dataset):

ai_transform/operator/dense_operator.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Dict, Sequence
66

77
from ai_transform.operator.abstract_operator import AbstractOperator
8-
from ai_transform.dataset.dataset import Dataset
98
from ai_transform.utils.document import Document
109
from ai_transform.utils.document_list import DocumentList
1110

@@ -24,18 +23,13 @@
2423
class DenseOperator(AbstractOperator):
2524
def __call__(self, old_documents: DocumentList) -> DenseOperatorOutput:
2625
datum = self.transform(old_documents)
27-
assert isinstance(datum, dict), BAD_OPERATOR_MESSAGE
26+
if not isinstance(datum, dict):
27+
raise ValueError(BAD_OPERATOR_MESSAGE)
2828
for _, documents in datum.items():
29-
assert isinstance(documents, Sequence)
29+
if not isinstance(documents, Sequence):
30+
raise ValueError(BAD_OPERATOR_MESSAGE)
3031
return datum
3132

3233
@abstractmethod
3334
def transform(self, documents: DocumentList) -> DenseOperatorOutput:
3435
raise NotImplementedError
35-
36-
def store_dataset_relationship(self, input_dataset: Dataset, output_datasets: Sequence[Dataset]):
37-
input_dataset.update_metadata(
38-
{"_child_datasets_": [output_dataset.dataset_id for output_dataset in output_datasets]}
39-
)
40-
for output_dataset in output_datasets:
41-
output_dataset.update_metadata({"_parent_dataset_": input_dataset.dataset_id})

ai_transform/utils/document.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,7 @@ def list_chunks(self):
141141
return [k for k in self.keys() if k.endswith("_chunk_")]
142142

143143
def get_chunk(self, chunk_field: str, field: str = None, default: str = None):
144-
"""
145-
Returns a list of values.
146-
"""
147-
# provide a recursive implementation for getting chunks
148-
from ai_transform.utils.document_list import DocumentList
149-
150-
document_list = DocumentList(self.get(chunk_field, default=default))
151-
# Get the field across chunks
152-
if field is None:
153-
return document_list
154-
return [d.get(field, default=default) for d in document_list.data]
144+
return [document.get(field, default) for document in self.get(chunk_field, default=default)]
155145

156146
def _create_chunk_documents(self, field: str, values: list, generate_id: bool = False):
157147
"""

ai_transform/utils/example_documents.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def create_id():
5252

5353

5454
def generate_random_string(string_length: int = 5) -> str:
55-
5655
"""Generate a random string of letters and numbers"""
5756
return "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(string_length))
5857

@@ -92,7 +91,16 @@ def vector_document(vector_length: int) -> Document:
9291

9392

9493
def mock_documents(n: int = 100, vector_length: int = 5) -> DocumentList:
95-
return DocumentList([vector_document(vector_length) for _ in range(n)])
94+
documents = [vector_document(vector_length) for _ in range(n)]
95+
return DocumentList(documents)
96+
97+
98+
def incomplete_documents(n: int = 100, vector_length: int = 5) -> DocumentList:
99+
documents = [vector_document(vector_length).data for _ in range(n)]
100+
for document in documents:
101+
for key in random.sample(document.keys(), 3):
102+
document.pop(key)
103+
return DocumentList(documents)
96104

97105

98106
def static_documents(n: int = 100) -> DocumentList:

examples/workflows/clustering_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def execute(token: str, logger: Callable, worker_number: int = 0, *args, **kwarg
7171
engine = InMemoryEngine(
7272
dataset=dataset,
7373
operator=operator,
74-
chunksize=16,
74+
chunksize=8,
7575
select_fields=[vector_field],
7676
filters=filters,
7777
worker_number=worker_number,

0 commit comments

Comments
 (0)