Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions core/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/dolthub/doltgresql/core/conflicts"
"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/core/typecollection"
"github.com/dolthub/doltgresql/server/plpgsql"

gmstypes "github.com/dolthub/go-mysql-server/sql/types"
Expand All @@ -35,4 +36,6 @@ func Init() {
conflicts.ClearContextValues = ClearContextValues
plpgsql.GetTypesCollectionFromContext = GetTypesCollectionFromContext
id.RegisterListener(sequenceIDListener{}, id.Section_Table)
typecollection.GetSqlTableFromContext = GetSqlTableFromContext
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is less defensible when it's not cross-codebase. Probably need to refactor core a bit to remove the necessity of this. My guess is that these utility methods aren't really part of the core package, and the core and type collection packages could both depend on them in a new home.

typecollection.GetSchemaName = GetSchemaName
}
75 changes: 71 additions & 4 deletions core/typecollection/typecollection.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/dolt/go/store/prolly"
"github.com/dolthub/dolt/go/store/prolly/tree"
"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/core/rootobject/objinterface"
Expand Down Expand Up @@ -101,8 +102,10 @@
}

// GetAllTypes returns a map containing all types in the collection, grouped by the schema they're contained in.
// Each type array is also sorted by the type name. It includes built-in types.
// Each type array is also sorted by the type name. It includes built-in types, but does not include types referring to
// a table's row type.
func (pgs *TypeCollection) GetAllTypes(ctx context.Context) (typeMap map[string][]*pgtypes.DoltgresType, schemaNames []string, totalCount int, err error) {
// TODO: this should probably get tables as well since tables create composite types matching their rows
schemaNamesMap := make(map[string]struct{})
typeMap = make(map[string][]*pgtypes.DoltgresType)
err = pgs.IterateTypes(ctx, func(t *pgtypes.DoltgresType) (stop bool, err error) {
Expand Down Expand Up @@ -158,9 +161,21 @@
}
// The initial load is from the internal map
h, err := pgs.underlyingMap.Get(ctx, string(name))
if err != nil || h.IsEmpty() {
if err != nil {
return nil, err
}
if h.IsEmpty() {
// If it's not a built-in type or created type, then check if it's a composite table row type
sqlCtx, ok := ctx.(*sql.Context)
if !ok {
return nil, nil
}
tbl, schema, err := pgs.getTable(sqlCtx, name.SchemaName(), name.TypeName())
if err != nil || tbl == nil {
return nil, err
}
return pgs.tableToType(sqlCtx, tbl, schema)
}
data, err := pgs.ns.ReadBytes(ctx, h)
if err != nil {
return nil, err
Expand All @@ -180,19 +195,26 @@
if _, ok := pgtypes.IDToBuiltInDoltgresType[name]; ok {
return true
}

// Now we'll check our created types
if _, ok := pgs.accessedMap[name]; ok {
return true
}
ok, err := pgs.underlyingMap.Has(ctx, string(name))
if err == nil && ok {
return true
}
return false
// If it's not a built-in type or created type, then check if it's a composite table row type
sqlCtx, ok := ctx.(*sql.Context)
if !ok {
return false
}
tbl, _, err := pgs.getTable(sqlCtx, name.SchemaName(), name.TypeName())
return err == nil && tbl != nil
}

// resolveName returns the fully resolved name of the given type. Returns an error if the name is ambiguous.
func (pgs *TypeCollection) resolveName(ctx context.Context, schemaName string, typeName string) (id.Type, error) {
// TODO: this should probably check table names as well since tables create composite types matching their rows
// First check for an exact match in the built-in types
inputID := id.NewType(schemaName, typeName)
if _, ok := pgtypes.IDToBuiltInDoltgresType[inputID]; ok {
Expand Down Expand Up @@ -251,6 +273,7 @@

// IterateTypes iterates over all types in the collection.
func (pgs *TypeCollection) IterateTypes(ctx context.Context, f func(typ *pgtypes.DoltgresType) (stop bool, err error)) error {
// TODO: this should probably iterate tables as well since tables create composite types matching their rows
// We can iterate the built-in types first
for _, t := range pgtypes.GetAllBuitInTypes() {
stop, err := f(t)
Expand Down Expand Up @@ -368,3 +391,47 @@
clear(pgs.accessedMap)
return nil
}

// getTable returns the SQL table that matches the given schema and table name. Returns a nil table if one is not found.
// This is intended for use with tableToType.
func (*TypeCollection) getTable(ctx *sql.Context, schema string, tblName string) (tbl sql.Table, actualSchema string, err error) {
actualSchema, err = GetSchemaName(ctx, nil, schema)

Check failure on line 398 in core/typecollection/typecollection.go

View workflow job for this annotation

GitHub Actions / Run Staticcheck

this value of err is never used (SA4006)
tbl, err = GetSqlTableFromContext(ctx, "", doltdb.TableName{
Name: tblName,
Schema: actualSchema,
})
if err != nil || tbl == nil {
return nil, "", err
}
if schTbl, ok := tbl.(sql.DatabaseSchemaTable); ok {
actualSchema = schTbl.DatabaseSchema().SchemaName()
}
return tbl, actualSchema, nil
}

// tableToType handles type creation related to a table's composite row type.
// https://www.postgresql.org/docs/15/sql-createtable.html
func (*TypeCollection) tableToType(ctx *sql.Context, tbl sql.Table, schema string) (*pgtypes.DoltgresType, error) {
tblName := tbl.Name()
tblSch := tbl.Schema()
typeID := id.NewType(schema, tblName)
relID := id.NewTable(schema, tblName).AsId()
arrayID := id.NewType(schema, "_"+tblName)
attrs := make([]pgtypes.CompositeAttribute, len(tblSch))
for i, col := range tblSch {
collation := "" // TODO: what should we use for the collation?
colType, ok := col.Type.(*pgtypes.DoltgresType)
if !ok {
// TODO: perhaps we should use a better error message stating that it uses a non-Doltgres type?
return nil, pgtypes.ErrTypeDoesNotExist.New(tblName)
}
attrs[i] = pgtypes.NewCompositeAttribute(ctx, relID, col.Name, colType.ID, int16(i+1), collation)
}
return pgtypes.NewCompositeType(ctx, relID, arrayID, typeID, attrs), nil
}

// GetSqlTableFromContext is a forward declaration to get around import cycles
var GetSqlTableFromContext func(ctx *sql.Context, databaseName string, tableName doltdb.TableName) (sql.Table, error)

// GetSchemaName is a forward declaration to get around import cycles
var GetSchemaName func(ctx *sql.Context, db sql.Database, schemaName string) (string, error)
44 changes: 44 additions & 0 deletions server/expression/explicit_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/dolthub/go-mysql-server/sql/expression"
vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/core"
"github.com/dolthub/doltgresql/server/functions/framework"
pgtypes "github.com/dolthub/doltgresql/server/types"
)
Expand Down Expand Up @@ -98,6 +99,49 @@ func (c *ExplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
if castFunction == nil {
if fromType.ID == pgtypes.Unknown.ID {
castFunction = framework.UnknownLiteralCast
} else if fromType.IsRecordType() && c.castToType.IsCompositeType() { // TODO: should this only be in explicit, or assignment and implicit too?
// Casting to a record type will always work for any composite type.
// TODO: is the above statement true for all cases?
// When casting to a composite type, then we must match the arity and have valid casts for every position.
if c.castToType.IsRecordType() {
castFunction = framework.IdentityCast
} else {
castFunction = func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
vals, ok := val.([]pgtypes.RecordValue)
if !ok {
// TODO: better error message
return nil, errors.New("casting input error from record type")
}
if len(targetType.CompositeAttrs) != len(vals) {
return nil, errors.Newf("cannot cast type %s to %s", "", targetType.Name())
}
typeCollection, err := core.GetTypesCollectionFromContext(ctx)
if err != nil {
return nil, err
}
outputVals := make([]pgtypes.RecordValue, len(vals))
for i := range vals {
valType, ok := vals[i].Type.(*pgtypes.DoltgresType)
if !ok {
// TODO: if this is a GMS type, then we should cast to a Doltgres type here
return nil, errors.New("cannot cast record containing GMS type")
}
outputVals[i].Type, err = typeCollection.GetType(ctx, targetType.CompositeAttrs[i].TypeID)
if err != nil {
return nil, err
}
innerExplicit := ExplicitCast{
sqlChild: NewUnsafeLiteral(vals[i].Value, valType),
castToType: outputVals[i].Type.(*pgtypes.DoltgresType),
}
outputVals[i].Value, err = innerExplicit.Eval(ctx, nil)
if err != nil {
return nil, err
}
}
return outputVals, nil
}
}
} else {
return nil, errors.Errorf("EXPLICIT CAST: cast from `%s` to `%s` does not exist: %s",
fromType.String(), c.castToType.String(), c.sqlChild.String())
Expand Down
10 changes: 5 additions & 5 deletions server/functions/framework/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func GetExplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresTyp
// parameters). If one of the types are a string type, then we do not use the identity, and use the I/O conversions
// below.
if fromType.ID == toType.ID && toType.TypCategory != pgtypes.TypeCategory_StringTypes && fromType.TypCategory != pgtypes.TypeCategory_StringTypes {
return identityCast
return IdentityCast
}
// All types have a built-in explicit cast from string types: https://www.postgresql.org/docs/15/sql-createcast.html
if fromType.TypCategory == pgtypes.TypeCategory_StringTypes {
Expand Down Expand Up @@ -175,7 +175,7 @@ func GetAssignmentCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresT
// We check for the identity after checking the maps, as the identity may be overridden (such as for types that have
// parameters). If the "to" type is a string type, then we do not use the identity, and use the I/O conversion below.
if fromType.ID == toType.ID && fromType.TypCategory != pgtypes.TypeCategory_StringTypes {
return identityCast
return IdentityCast
}
// All types have a built-in assignment cast to string types: https://www.postgresql.org/docs/15/sql-createcast.html
if toType.TypCategory == pgtypes.TypeCategory_StringTypes {
Expand All @@ -202,7 +202,7 @@ func GetImplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresTyp
// We check for the identity after checking the maps, as the identity may be overridden (such as for types that have
// parameters).
if fromType.ID == toType.ID {
return identityCast
return IdentityCast
}
return nil
}
Expand Down Expand Up @@ -282,8 +282,8 @@ func getCast(mutex *sync.RWMutex,
return nil
}

// identityCast returns the input value.
func identityCast(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
// IdentityCast returns the input value.
func IdentityCast(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
return val, nil
}

Expand Down
4 changes: 2 additions & 2 deletions server/functions/framework/compiled_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ func (c *CompiledFunction) resolveOperator(argTypes []*pgtypes.DoltgresType, ove
rightUnknownType := argTypes[1].ID == pgtypes.Unknown.ID
if (leftUnknownType && !rightUnknownType) || (!leftUnknownType && rightUnknownType) {
var typ *pgtypes.DoltgresType
casts := []pgtypes.TypeCastFunction{identityCast, identityCast}
casts := []pgtypes.TypeCastFunction{IdentityCast, IdentityCast}
if leftUnknownType {
casts[0] = UnknownLiteralCast
typ = argTypes[1]
Expand Down Expand Up @@ -577,7 +577,7 @@ func (c *CompiledFunction) typeCompatibleOverloads(fnOverloads []Overload, argTy
for i := range argTypes {
paramType := overload.argTypes[i]
if paramType.IsValidForPolymorphicType(argTypes[i]) {
overloadCasts[i] = identityCast
overloadCasts[i] = IdentityCast
polymorphicParameters = append(polymorphicParameters, paramType)
polymorphicTargets = append(polymorphicTargets, argTypes[i])
} else {
Expand Down
2 changes: 2 additions & 0 deletions server/functions/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ var record_send = framework.Function1{
if !ok {
return nil, fmt.Errorf("expected []RecordValue, but got %T", val)
}
// TODO: converting from a string back to the record doesn't work as we lose type information, so we need to
// figure out how to retain this information
output, err := pgtypes.RecordToString(ctx, values)
if err != nil {
return nil, err
Expand Down
5 changes: 4 additions & 1 deletion server/node/create_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ func (c *CreateType) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) {
newType = types.NewEnumType(ctx, arrayID, typeID, enumLabelMap)
// TODO: store labels somewhere
case types.TypeType_Composite:
relID := id.Null // TODO: create relation with c.AsTypes
// TODO: non-composite types have a zero oid for their relID, which for us would be a null ID.
// We need to find a way to distinguish a null ID from a composite type that does not reference a table
// (which is what relID points to if it represents a table row's composite type)
relID := id.Null
attrs := make([]types.CompositeAttribute, len(c.AsTypes))
for i, a := range c.AsTypes {
attrs[i] = types.NewCompositeAttribute(ctx, relID, a.AttrName, a.Typ.ID, int16(i+1), a.Collation)
Expand Down
14 changes: 13 additions & 1 deletion server/plpgsql/interpreter_stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
// interacted with. InterpreterVariableReference are, instead, the avenue of interaction as a variable may be an
// aggregate type (such as a record).
type interpreterVariable struct {
Record sql.Schema
Record sql.Schema // TODO: all records carry their type information alongside the value, so this is redundant
Type *pgtypes.DoltgresType
Value any
}
Expand Down Expand Up @@ -119,6 +119,18 @@ func (is *InterpreterStack) GetVariable(name string) InterpreterVariableReferenc
Type: iv.Record[fieldIdx].Type.(*pgtypes.DoltgresType),
Value: &(iv.Value.(sql.Row)[fieldIdx]),
}
} else if iv.Type.IsCompositeType() {
for fieldIdx := range iv.Type.CompositeAttrs {
if iv.Type.CompositeAttrs[fieldIdx].Name == fieldName {
vals := iv.Value.([]pgtypes.RecordValue)
return InterpreterVariableReference{
Type: vals[fieldIdx].Type.(*pgtypes.DoltgresType),
Value: &(vals[fieldIdx].Value),
}
}
}
// The field could not be found
return InterpreterVariableReference{}
} else {
// Can't access fields on an empty record
return InterpreterVariableReference{}
Expand Down
17 changes: 6 additions & 11 deletions server/plpgsql/statements.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,17 +357,12 @@ func substituteVariableReferences(expression string, stack *InterpreterStack) (n
token := scanResult.Tokens[i]
substring := expression[token.Start:token.End]
// varMap lowercases everything, so we'll lowercase our substring to enable case-insensitivity
if fieldNames, ok := varMap[strings.ToLower(substring)]; ok {
// If there's a '.', then we'll check if this is accessing a record's field (`NEW.val1` for example)
if len(fieldNames) > 0 && i+2 < len(scanResult.Tokens) && scanResult.Tokens[i+1].Token == '.' {
possibleFieldSubstring := expression[scanResult.Tokens[i+2].Start:scanResult.Tokens[i+2].End]
for _, fieldName := range fieldNames {
if fieldName == strings.ToLower(possibleFieldSubstring) {
substring += "." + fieldName
i += 2
break
}
}
if _, ok := varMap[strings.ToLower(substring)]; ok {
// If there's a '.', then we'll assume this is accessing a record's field (`NEW.val1` for example)
for i+2 < len(scanResult.Tokens) && scanResult.Tokens[i+1].Token == '.' {
nextFieldSubstring := expression[scanResult.Tokens[i+2].Start:scanResult.Tokens[i+2].End]
substring += "." + nextFieldSubstring
i += 2
}
// Variables cannot have a '(' after their name as that would classify them as functions, so we have to
// explicitly check for that. This is because variables and functions can share names, for example:
Expand Down
22 changes: 11 additions & 11 deletions server/types/composite.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ func NewCompositeType(ctx *sql.Context, relID id.Id, arrayID, typeID id.Type, at
// CompositeAttribute represents a composite type attribute.
// This is a partial pg_attribute row entry.
type CompositeAttribute struct {
relID id.Id // ID of the relation it belongs to
name string
typeID id.Type // ID of DoltgresType
num int16 // number of the column in relation
collation string
RelID id.Id // ID of the relation it belongs to
Name string
TypeID id.Type // ID of DoltgresType
Num int16 // 1-based number of the column in relation
Collation string
}

// NewCompositeAttribute creates new instance of composite type attribute.
// NewCompositeAttribute creates new instance of composite type attribute. `num` is 1-based rather than 0-based.
func NewCompositeAttribute(ctx *sql.Context, relID id.Id, name string, typeID id.Type, num int16, collation string) CompositeAttribute {
return CompositeAttribute{
relID: relID,
name: name,
typeID: typeID,
num: num,
collation: collation,
RelID: relID,
Name: name,
TypeID: typeID,
Num: num,
Collation: collation,
}
}
20 changes: 10 additions & 10 deletions server/types/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@ func DeserializeType(serializedType []byte) (sql.ExtendedType, error) {
num := reader.Int16()
collation := reader.String()
typ.CompositeAttrs[k] = CompositeAttribute{
relID: relID,
name: name,
typeID: id.Type(typeID),
num: num,
collation: collation,
RelID: relID,
Name: name,
TypeID: id.Type(typeID),
Num: num,
Collation: collation,
}
}
}
Expand Down Expand Up @@ -196,11 +196,11 @@ func (t *DoltgresType) Serialize() []byte {
writer.VariableUint(uint64(len(t.CompositeAttrs)))
if len(t.CompositeAttrs) > 0 {
for _, l := range t.CompositeAttrs {
writer.Id(l.relID)
writer.String(l.name)
writer.Id(l.typeID.AsId())
writer.Int16(l.num)
writer.String(l.collation)
writer.Id(l.RelID)
writer.String(l.Name)
writer.Id(l.TypeID.AsId())
writer.Int16(l.Num)
writer.String(l.Collation)
}
}
writer.String(t.InternalName)
Expand Down
Loading
Loading