Skip to content

Commit 00caa38

Browse files
authored
feat(ingestion/sqlglot): preserve CTEs when extracting SELECT from INSERT statements and add corresponding unit test (#14898)
1 parent 8248999 commit 00caa38

File tree

3 files changed

+93
-1
lines changed

3 files changed

+93
-1
lines changed

metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,12 @@ def _try_extract_select(
11761176
statement = sqlglot.exp.Select().select("*").from_(statement)
11771177
elif isinstance(statement, sqlglot.exp.Insert):
11781178
# TODO Need to map column renames in the expressions part of the statement.
1179-
statement = statement.expression
1179+
# Preserve CTEs when extracting the SELECT expression from INSERT
1180+
original_ctes = statement.ctes
1181+
statement = statement.expression # Get the SELECT expression from the INSERT
1182+
if isinstance(statement, sqlglot.exp.Query) and original_ctes:
1183+
for cte in original_ctes:
1184+
statement = statement.with_(alias=cte.alias, as_=cte.this)
11801185
elif isinstance(statement, sqlglot.exp.Update):
11811186
# Assumption: the output table is already captured in the modified tables list.
11821187
statement = _extract_select_from_update(statement)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
{
2+
"query_type": "INSERT",
3+
"query_type_props": {},
4+
"query_fingerprint": "195448498ded7a1b4df767cf0a5ec53e2fa4c7b011234bafe0a60ff9d7d11c1d",
5+
"in_tables": [
6+
"urn:li:dataset:(urn:li:dataPlatform:tsql,db.schema.source_table,PROD)"
7+
],
8+
"out_tables": [
9+
"urn:li:dataset:(urn:li:dataPlatform:tsql,db.schema.target_table,PROD)"
10+
],
11+
"column_lineage": [
12+
{
13+
"downstream": {
14+
"table": "urn:li:dataset:(urn:li:dataPlatform:tsql,db.schema.target_table,PROD)",
15+
"column": "id",
16+
"column_type": null,
17+
"native_column_type": null
18+
},
19+
"upstreams": [
20+
{
21+
"table": "urn:li:dataset:(urn:li:dataPlatform:tsql,db.schema.source_table,PROD)",
22+
"column": "id"
23+
}
24+
],
25+
"logic": {
26+
"is_direct_copy": true,
27+
"column_logic": "[source_table].[id] AS [id]"
28+
}
29+
},
30+
{
31+
"downstream": {
32+
"table": "urn:li:dataset:(urn:li:dataPlatform:tsql,db.schema.target_table,PROD)",
33+
"column": "name",
34+
"column_type": null,
35+
"native_column_type": null
36+
},
37+
"upstreams": [
38+
{
39+
"table": "urn:li:dataset:(urn:li:dataPlatform:tsql,db.schema.source_table,PROD)",
40+
"column": "name"
41+
}
42+
],
43+
"logic": {
44+
"is_direct_copy": true,
45+
"column_logic": "[source_table].[name] AS [name]"
46+
}
47+
},
48+
{
49+
"downstream": {
50+
"table": "urn:li:dataset:(urn:li:dataPlatform:tsql,db.schema.target_table,PROD)",
51+
"column": "value",
52+
"column_type": null,
53+
"native_column_type": null
54+
},
55+
"upstreams": [
56+
{
57+
"table": "urn:li:dataset:(urn:li:dataPlatform:tsql,db.schema.source_table,PROD)",
58+
"column": "value"
59+
}
60+
],
61+
"logic": {
62+
"is_direct_copy": true,
63+
"column_logic": "[source_table].[value] AS [value]"
64+
}
65+
}
66+
],
67+
"joins": [],
68+
"debug_info": {
69+
"confidence": 0.2,
70+
"generalized_statement": "WITH temp_cte AS (SELECT id AS id, name AS name, value AS value FROM db.schema.source_table) INSERT INTO db.schema.target_table (id, name, value) SELECT id, name, value FROM temp_cte"
71+
}
72+
}

metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,21 @@ def test_insert_with_column_list() -> None:
199199
)
200200

201201

202+
def test_insert_with_cte() -> None:
203+
assert_sql_result(
204+
"""
205+
WITH temp_cte AS (
206+
SELECT id, name, value
207+
FROM db.schema.source_table
208+
)
209+
INSERT INTO db.schema.target_table (id, name, value)
210+
SELECT id, name, value FROM temp_cte
211+
""",
212+
dialect="tsql",
213+
expected_file=RESOURCE_DIR / "test_insert_with_cte.json",
214+
)
215+
216+
202217
def test_select_with_full_col_name() -> None:
203218
# In this case, `widget` is a struct column.
204219
# This also tests the `default_db` functionality.

0 commit comments

Comments
 (0)