Skip to content

Commit ec67c99

Browse files
authored
schemadiff: support views with CTE (#18893)
Signed-off-by: Shlomi Noach <[email protected]>
1 parent 79af4c1 commit ec67c99

File tree

2 files changed

+93
-16
lines changed

2 files changed

+93
-16
lines changed

go/vt/schemadiff/schema.go

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -134,21 +134,31 @@ func getForeignKeyParentTableNames(createTable *sqlparser.CreateTable) (names []
134134
}
135135

136136
// getViewDependentTableNames analyzes a CREATE VIEW definition and extracts all tables/views read by this view
137-
func getViewDependentTableNames(createView *sqlparser.CreateView) (names []string) {
137+
func getViewDependentTableNames(createView *sqlparser.CreateView) (names []string, cteNames []string) {
138+
cteMap := make(map[string]bool)
138139
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
139140
switch node := node.(type) {
141+
case *sqlparser.CommonTableExpr:
142+
if !cteMap[node.ID.String()] {
143+
cteNames = append(cteNames, node.ID.String())
144+
cteMap[node.ID.String()] = true
145+
}
140146
case *sqlparser.TableName:
141-
names = append(names, node.Name.String())
147+
if _, isCte := cteMap[node.Name.String()]; !isCte {
148+
names = append(names, node.Name.String())
149+
}
142150
case *sqlparser.AliasedTableExpr:
143151
if tableName, ok := node.Expr.(sqlparser.TableName); ok {
144-
names = append(names, tableName.Name.String())
152+
if _, isCte := cteMap[tableName.Name.String()]; !isCte {
153+
names = append(names, tableName.Name.String())
154+
}
145155
}
146156
// or, this could be a more complex expression, like a derived table `(select * from v1) as derived`,
147157
// in which case further Walk-ing will eventually find the "real" table name
148158
}
149159
return true, nil
150160
}, createView)
151-
return names
161+
return names, cteNames
152162
}
153163

154164
// normalize is called as part of Schema creation process. The user may only get a hold of normalized schema.
@@ -310,7 +320,7 @@ func (s *Schema) normalize(hints *DiffHints) error {
310320
continue
311321
}
312322
// Not handled. Is this view dependent on already handled objects?
313-
dependentNames := getViewDependentTableNames(v.CreateView)
323+
dependentNames, _ := getViewDependentTableNames(v.CreateView)
314324
if allNamesFoundInLowerLevel(dependentNames, iterationLevel) {
315325
s.sorted = append(s.sorted, v)
316326
dependencyLevels[v.Name()] = iterationLevel
@@ -341,7 +351,7 @@ func (s *Schema) normalize(hints *DiffHints) error {
341351
if _, ok := dependencyLevels[v.Name()]; !ok {
342352
// We _know_ that in this iteration, at least one view is found unassigned a dependency level.
343353
// We gather all the errors.
344-
dependentNames := getViewDependentTableNames(v.CreateView)
354+
dependentNames, _ := getViewDependentTableNames(v.CreateView)
345355
missingReferencedEntities := []string{}
346356
for _, name := range dependentNames {
347357
if _, ok := dependencyLevels[name]; !ok {
@@ -974,12 +984,16 @@ func (s *Schema) SchemaDiff(other *Schema, hints *DiffHints) (*SchemaDiff, error
974984
for _, diff := range schemaDiff.UnorderedDiffs() {
975985
switch diff := diff.(type) {
976986
case *CreateViewEntityDiff:
977-
checkDependencies(diff, getViewDependentTableNames(diff.createView))
987+
dependentNames, _ := getViewDependentTableNames(diff.createView)
988+
checkDependencies(diff, dependentNames)
978989
case *AlterViewEntityDiff:
979-
checkDependencies(diff, getViewDependentTableNames(diff.from.CreateView))
980-
checkDependencies(diff, getViewDependentTableNames(diff.to.CreateView))
990+
fromDependentNames, _ := getViewDependentTableNames(diff.from.CreateView)
991+
checkDependencies(diff, fromDependentNames)
992+
toDependentNames, _ := getViewDependentTableNames(diff.to.CreateView)
993+
checkDependencies(diff, toDependentNames)
981994
case *DropViewEntityDiff:
982-
checkDependencies(diff, getViewDependentTableNames(diff.from.CreateView))
995+
dependentNames, _ := getViewDependentTableNames(diff.from.CreateView)
996+
checkDependencies(diff, dependentNames)
983997
case *CreateTableEntityDiff:
984998
checkDependencies(diff, getForeignKeyParentTableNames(diff.CreateTable()))
985999
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
@@ -1142,13 +1156,15 @@ func (s *Schema) getViewColumnNames(v *CreateViewEntity, schemaInformation *decl
11421156
for _, node := range v.Select.GetColumns() {
11431157
switch node := node.(type) {
11441158
case *sqlparser.StarExpr:
1159+
dependentNames, cteNames := getViewDependentTableNames(v.CreateView)
11451160
if tableName := node.TableName.Name.String(); tableName != "" {
1146-
for _, col := range schemaInformation.Tables[tableName].Columns {
1147-
name := sqlparser.Clone(col.Name)
1148-
columnNames = append(columnNames, &name)
1161+
if tbl, ok := schemaInformation.Tables[tableName]; ok {
1162+
for _, col := range tbl.Columns {
1163+
name := sqlparser.Clone(col.Name)
1164+
columnNames = append(columnNames, &name)
1165+
}
11491166
}
11501167
} else {
1151-
dependentNames := getViewDependentTableNames(v.CreateView)
11521168
// add all columns from all referenced tables and views
11531169
for _, entityName := range dependentNames {
11541170
if schemaInformation.Tables[entityName] != nil { // is nil for dual/DUAL
@@ -1159,7 +1175,10 @@ func (s *Schema) getViewColumnNames(v *CreateViewEntity, schemaInformation *decl
11591175
}
11601176
}
11611177
}
1162-
if len(columnNames) == 0 {
1178+
if len(columnNames) == 0 && len(cteNames) == 0 {
1179+
// *-expressions that do not resolve to any columns are invalid in views.
1180+
// For CTEs, schemadiff does not analyze the list of columns returned by the CTE (even if the CTE defines it).
1181+
// TODO(shlomi): analyze CTE columns as well.
11631182
return nil, &InvalidStarExprInViewError{View: v.Name()}
11641183
}
11651184
case *sqlparser.AliasedExpr:

go/vt/schemadiff/schema_test.go

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,52 @@ func TestNewSchemaFromQueriesViewFromDualImplicit(t *testing.T) {
176176
assert.NoError(t, err)
177177
}
178178

179+
func TestNewSchemaFromQueriesViewWithCTEFail(t *testing.T) {
180+
queries := []string{"create view v30 as with vcte as (select 1) select * from vcte2"}
181+
_, err := NewSchemaFromQueries(NewTestEnv(), queries)
182+
assert.Error(t, err)
183+
assert.EqualError(t, err, (&ViewDependencyUnresolvedError{View: "v30", MissingReferencedEntities: []string{"dual", "vcte2"}}).Error())
184+
}
185+
186+
func TestNewSchemaFromQueriesViewWithCTE(t *testing.T) {
187+
tcases := []struct {
188+
name string
189+
queries []string
190+
}{
191+
{
192+
"no table",
193+
[]string{"create view v20 as with vcte as (select 1) select * from vcte"},
194+
},
195+
{
196+
"with table",
197+
[]string{
198+
"create table orders (id int primary key, info int not null)",
199+
"create view v21 as with vcte as (select * from orders) select * from vcte",
200+
},
201+
},
202+
{
203+
"with table and column aliasing",
204+
[]string{
205+
"create table orders (id int primary key, info int not null)",
206+
"create view v22 as with vcte as (select id, info as val from orders) select * from vcte",
207+
},
208+
},
209+
{
210+
"with table and select all from cte",
211+
[]string{
212+
"create table orders (id int primary key, info int not null)",
213+
"create view v22 as with vcte as (select id, info as val from orders) select vcte.* from vcte",
214+
},
215+
},
216+
}
217+
for _, tc := range tcases {
218+
t.Run(tc.name, func(t *testing.T) {
219+
_, err := NewSchemaFromQueries(NewTestEnv(), tc.queries)
220+
assert.NoError(t, err)
221+
})
222+
}
223+
}
224+
179225
func TestNewSchemaFromQueriesLoop(t *testing.T) {
180226
// v7 and v8 depend on each other
181227
queries := append(schemaTestCreateQueries,
@@ -213,6 +259,7 @@ func TestGetViewDependentTableNames(t *testing.T) {
213259
name string
214260
view string
215261
tables []string
262+
ctes []string
216263
}{
217264
{
218265
view: "create view v6 as select * from v4",
@@ -242,6 +289,16 @@ func TestGetViewDependentTableNames(t *testing.T) {
242289
view: "create view v9 as select 1",
243290
tables: []string{"dual"},
244291
},
292+
{
293+
view: "create view v20 as with vcte as (select 1) select * from vcte",
294+
tables: []string{"dual"},
295+
ctes: []string{"vcte"},
296+
},
297+
{
298+
view: "create view v21 as with vcte as (select * from orders) select * from vcte",
299+
tables: []string{"orders"},
300+
ctes: []string{"vcte"},
301+
},
245302
}
246303
for _, ts := range tt {
247304
t.Run(ts.view, func(t *testing.T) {
@@ -250,8 +307,9 @@ func TestGetViewDependentTableNames(t *testing.T) {
250307
createView, ok := stmt.(*sqlparser.CreateView)
251308
require.True(t, ok)
252309

253-
tables := getViewDependentTableNames(createView)
310+
tables, ctes := getViewDependentTableNames(createView)
254311
assert.Equal(t, ts.tables, tables)
312+
assert.Equal(t, ts.ctes, ctes)
255313
})
256314
}
257315
}

0 commit comments

Comments
 (0)