Skip to content

Commit ca31df6

Browse files
authored
fix!: Update AlloyDBModel based on ml extension v1.5.2 (#500)
* chore: Remove generate_headers_fn from model manager Removed the generate_headers_fn parameter from the model manager. * Update test_vectorstore_embeddings.py * Refactor embeddings_service to use synchronous methods * Corrected model manager instantiation Fix async call to create_sync for model manager * Convert embeddings_service to async fixture * Update test_embeddings.py * Update model qualified name in tests * Update test_embeddings.py * Update model_manager.py * Add batch transform functions and update version check * Fix input_batch_transform_fn assignment syntax * Fix syntax error in model_manager.py * Update google_ml_integration extension version check * Update google_ml_integration version requirement to 1.3 * Update google_ml_integration extension version to 1.5.2
1 parent c98f428 commit ca31df6

File tree

3 files changed

+33
-14
lines changed

3 files changed

+33
-14
lines changed

src/langchain_google_alloydb_pg/model_manager.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def __init__(
3636
input_transform_fn: Optional[str],
3737
output_transform_fn: Optional[str],
3838
generate_headers_fn: Optional[str] = None,
39+
input_batch_transform_fn: Optional[str] = None,
40+
output_batch_transform_fn: Optional[str] = None,
3941
**kwargs: Any,
4042
):
4143
self.model_id = model_id
@@ -49,6 +51,8 @@ def __init__(
4951
self.output_transform_fn = output_transform_fn
5052
# List models is returning column name "header_gen_fn"
5153
self.generate_headers_fn = generate_headers_fn or kwargs.get("header_gen_fn")
54+
self.input_batch_transform_fn = input_batch_transform_fn
55+
self.output_batch_transform_fn = output_batch_transform_fn
5256

5357

5458
class AlloyDBModelManager:
@@ -170,14 +174,14 @@ async def __avalidate(self) -> None:
170174
"""Private async function to validate prerequisites.
171175
172176
Raises:
173-
Exception if google_ml_integration EXTENSION is not 1.3.
177+
Exception if google_ml_integration EXTENSION is not 1.5.2.
174178
Exception if google_ml_integration.enable_model_support DB Flag not set.
175179
"""
176180
extension_version = await self.__fetch_google_ml_extension()
177181
db_flag = await self.__fetch_db_flag()
178-
if extension_version < "1.3":
182+
if extension_version < "1.5.2":
179183
raise Exception(
180-
"Please upgrade google_ml_integration EXTENSION to version 1.3 or above."
184+
"Please upgrade google_ml_integration EXTENSION to version 1.5.2 or above."
181185
)
182186
if db_flag != "on":
183187
raise Exception(
@@ -214,13 +218,15 @@ async def __aget_model(self, model_id: str) -> Optional[AlloyDBModel]:
214218
model_qualified_name VARCHAR,
215219
model_auth_type google_ml.auth_type,
216220
model_auth_id VARCHAR,
217-
generate_headers_fn VARCHAR,
221+
header_gen_fn VARCHAR,
218222
input_transform_fn VARCHAR,
219-
output_transform_fn VARCHAR)"""
223+
output_transform_fn VARCHAR,
224+
input_batch_transform_fn VARCHAR,
225+
output_batch_transform_fn VARCHAR)"""
220226

221227
try:
222228
result = await self.__query_db(query)
223-
except Exception:
229+
except Exception as e:
224230
return None
225231
data_class = self.__convert_dict_to_dataclass(result)[0]
226232
return data_class
@@ -285,13 +291,13 @@ async def __adrop_model(self, model_id: str) -> None:
285291
await conn.commit()
286292

287293
async def __fetch_google_ml_extension(self) -> str:
288-
"""Creates the Google ML Extension if it does not exist and returns the version number (Default creates version 1.3)."""
294+
"""Creates the Google ML Extension if it does not exist and returns the version number (Default creates version 1.5.2)."""
289295
create_extension_query = """
290296
DO $$
291297
BEGIN
292298
IF NOT EXISTS (
293299
SELECT 1 FROM pg_extension WHERE extname = 'google_ml_integration' )
294-
THEN CREATE EXTENSION google_ml_integration VERSION '1.3' CASCADE;
300+
THEN CREATE EXTENSION google_ml_integration VERSION '1.5.2' CASCADE;
295301
END IF;
296302
END
297303
$$;

tests/test_embeddings.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
instance_id = os.environ["INSTANCE_ID"]
3232
db_name = os.environ["DATABASE_ID"]
3333
table_name = "test-table" + str(uuid.uuid4())
34+
embedding_model = "text-embedding-005" + str(uuid.uuid4()).replace("-", "_")
3435

3536

3637
@pytest.mark.asyncio
@@ -66,7 +67,7 @@ async def sync_engine(self):
6667

6768
@pytest.fixture(scope="module")
6869
def model_id(self) -> str:
69-
return "text-embedding-005"
70+
return embedding_model
7071

7172
@pytest_asyncio.fixture
7273
async def embeddings(self, engine, model_id):
@@ -77,7 +78,7 @@ async def embeddings(self, engine, model_id):
7778
await model_manager.acreate_model(
7879
model_id=model_id,
7980
model_provider="google",
80-
model_qualified_name=model_id, # assuming model is built-in
81+
model_qualified_name="text-embedding-005", # assuming model is built-in
8182
model_type="text_embedding",
8283
)
8384
return AlloyDBEmbeddings.create_sync(engine=engine, model_id=model_id)

tests/test_vectorstore_embeddings.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_")
3333
DEFAULT_TABLE_SYNC = "test_table" + str(uuid.uuid4()).replace("-", "_")
3434
CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_")
35-
DEFAULT_EMBEDDING_MODEL = "text-embedding-005"
35+
DEFAULT_EMBEDDING_MODEL = "text-embedding-005" + str(uuid.uuid4()).replace("-", "_")
3636
VECTOR_SIZE = 768
3737

3838

@@ -114,10 +114,11 @@ async def embeddings_service(self, engine):
114114
await model_manager.acreate_model(
115115
model_id=DEFAULT_EMBEDDING_MODEL,
116116
model_provider="google",
117-
model_qualified_name=DEFAULT_EMBEDDING_MODEL, # assuming model is built-in
117+
model_qualified_name="text-embedding-005", # assuming model is built-in
118118
model_type="text_embedding",
119119
)
120-
return await AlloyDBEmbeddings.create(engine, DEFAULT_EMBEDDING_MODEL)
120+
yield await AlloyDBEmbeddings.create(engine, DEFAULT_EMBEDDING_MODEL)
121+
await model_manager.adrop_model(DEFAULT_EMBEDDING_MODEL)
121122

122123
@pytest_asyncio.fixture(scope="class")
123124
async def vs(self, engine, embeddings_service):
@@ -308,8 +309,19 @@ async def engine_sync(
308309
await engine.close()
309310

310311
@pytest_asyncio.fixture(scope="class")
311-
def embeddings_service(self, engine_sync):
312+
async def embeddings_service(self, engine_sync):
313+
model_manager = AlloyDBModelManager.create_sync(engine=engine_sync)
314+
model = await model_manager.aget_model(model_id=DEFAULT_EMBEDDING_MODEL)
315+
if not model:
316+
# create model if not exists
317+
await model_manager.acreate_model(
318+
model_id=DEFAULT_EMBEDDING_MODEL,
319+
model_provider="google",
320+
model_qualified_name="text-embedding-005", # assuming model is built-in
321+
model_type="text_embedding",
322+
)
312323
return AlloyDBEmbeddings.create_sync(engine_sync, DEFAULT_EMBEDDING_MODEL)
324+
await model_manager.adrop_model(DEFAULT_EMBEDDING_MODEL)
313325

314326
@pytest_asyncio.fixture(scope="class")
315327
async def vs_custom(self, engine_sync, embeddings_service):

0 commit comments

Comments
 (0)