From aa3a5e35c51ff60dfe5443bd1adf9a279a6e23b3 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Fri, 31 Mar 2023 13:02:10 -0500 Subject: [PATCH 1/3] refactored literal type casting into flyteidl Signed-off-by: Daniel Rammer --- go.mod | 6 +- go.sum | 14 +- pkg/compiler/transformers/k8s/inputs.go | 6 +- pkg/compiler/validators/bindings.go | 7 +- pkg/compiler/validators/condition.go | 3 +- pkg/compiler/validators/typing.go | 360 -------- pkg/compiler/validators/typing_test.go | 868 ------------------ pkg/compiler/validators/utils.go | 142 --- pkg/compiler/validators/utils_test.go | 44 - .../task/catalog/datacatalog/transformer.go | 7 +- 10 files changed, 22 insertions(+), 1435 deletions(-) delete mode 100644 pkg/compiler/validators/typing.go delete mode 100644 pkg/compiler/validators/typing_test.go diff --git a/go.mod b/go.mod index 7a5caddbf..46d413d82 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,6 @@ require ( github.com/spf13/cobra v1.4.0 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.7.2 - golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 google.golang.org/grpc v1.46.0 @@ -122,7 +121,8 @@ require ( github.com/stretchr/objx v0.3.0 // indirect github.com/subosito/gotenv v1.2.0 // indirect go.opencensus.io v0.23.0 // indirect - golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e // indirect + golang.org/x/crypto v0.1.0 // indirect + golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect golang.org/x/net v0.7.0 // indirect golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5 // indirect golang.org/x/sys v0.5.0 // indirect @@ -147,3 +147,5 @@ require ( ) replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d + +replace github.com/flyteorg/flyteidl => ../flyteidl diff --git a/go.sum b/go.sum index 6b4e8cba5..b594a7d5e 100644 --- a/go.sum +++ b/go.sum @@ -260,8 +260,6 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/flyteorg/flyteidl v1.3.14 h1:o5M0g/r6pXTPu5PEurbYxbQmuOu3hqqsaI2M6uvK0N8= -github.com/flyteorg/flyteidl v1.3.14/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM= github.com/flyteorg/flyteplugins v1.0.44 h1:uKizng+i0vfXslyPBlrsfecInhvy71fTB4kRg7eiifE= github.com/flyteorg/flyteplugins v1.0.44/go.mod h1:ztsonku5fKwyxcIg1k69PTiBVjRI6d3nK5DnC+iwx08= github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0= @@ -846,8 +844,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM= -golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= +golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -858,8 +856,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e h1:+WEEuIdZHnUeJJmEUjyYC2gfUMj69yZXw17EnHg/otA= -golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA= +golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug= +golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -886,7 +884,7 @@ golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= +golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I= golang.org/x/net v0.0.0-20170114055629-f2499483f923/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1163,7 +1161,7 @@ golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.10-0.20220218145154-897bd77cd717/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= -golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= +golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pkg/compiler/transformers/k8s/inputs.go b/pkg/compiler/transformers/k8s/inputs.go index 1886f4906..af86022d8 100644 --- a/pkg/compiler/transformers/k8s/inputs.go +++ b/pkg/compiler/transformers/k8s/inputs.go @@ -1,10 +1,10 @@ package k8s import ( + "github.com/flyteorg/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytepropeller/pkg/compiler/common" "github.com/flyteorg/flytepropeller/pkg/compiler/errors" - "github.com/flyteorg/flytepropeller/pkg/compiler/validators" "k8s.io/apimachinery/pkg/util/sets" ) @@ -34,8 +34,8 @@ func validateInputs(nodeID common.NodeID, iface *core.TypedInterface, inputs cor continue } - inputType := validators.LiteralTypeForLiteral(inputVal) - if !validators.AreTypesCastable(inputType, v.Type) { + inputType := coreutils.LiteralTypeForLiteral(inputVal) + if !coreutils.AreTypesCastable(inputType, v.Type) { errs.Collect(errors.NewMismatchingTypesErr(nodeID, inputVar, v.Type.String(), inputType.String())) continue } diff --git a/pkg/compiler/validators/bindings.go b/pkg/compiler/validators/bindings.go index 2942424fe..1c349d9fc 100644 --- a/pkg/compiler/validators/bindings.go +++ b/pkg/compiler/validators/bindings.go @@ -5,6 +5,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/compiler/typing" + "github.com/flyteorg/flyteidl/clients/go/coreutils" flyte "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" c "github.com/flyteorg/flytepropeller/pkg/compiler/common" "github.com/flyteorg/flytepropeller/pkg/compiler/errors" @@ -125,7 +126,7 @@ func validateBinding(w c.WorkflowBuilder, nodeID c.NodeID, nodeParam string, bin } } - if !validateParamTypes || AreTypesCastable(sourceType, expectedType) { + if !validateParamTypes || coreutils.AreTypesCastable(sourceType, expectedType) { val.Promise.NodeId = upNode.GetId() return param.GetType(), []c.NodeID{val.Promise.NodeId}, true } @@ -141,10 +142,10 @@ func validateBinding(w c.WorkflowBuilder, nodeID c.NodeID, nodeParam string, bin return nil, nil, !errs.HasErrors() } - literalType := literalTypeForScalar(val.Scalar) + literalType := coreutils.LiteralTypeForScalar(val.Scalar) if literalType == nil { errs.Collect(errors.NewUnrecognizedValueErr(nodeID, reflect.TypeOf(val.Scalar.GetValue()).String())) - } else if validateParamTypes && !AreTypesCastable(literalType, expectedType) { + } else if validateParamTypes && !coreutils.AreTypesCastable(literalType, expectedType) { errs.Collect(errors.NewMismatchingTypesErr(nodeID, nodeParam, literalType.String(), expectedType.String())) } diff --git a/pkg/compiler/validators/condition.go b/pkg/compiler/validators/condition.go index a70c5dcb2..5ef75d2b4 100644 --- a/pkg/compiler/validators/condition.go +++ b/pkg/compiler/validators/condition.go @@ -3,6 +3,7 @@ package validators import ( "fmt" + "github.com/flyteorg/flyteidl/clients/go/coreutils" flyte "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" c "github.com/flyteorg/flytepropeller/pkg/compiler/common" "github.com/flyteorg/flytepropeller/pkg/compiler/errors" @@ -14,7 +15,7 @@ func validateOperand(node c.NodeBuilder, paramName string, operand *flyte.Operan errs.Collect(errors.NewValueRequiredErr(node.GetId(), paramName)) } else if operand.GetPrimitive() != nil { // no validation - literalType = literalTypeForPrimitive(operand.GetPrimitive()) + literalType = coreutils.LiteralTypeForPrimitive(operand.GetPrimitive()) } else if len(operand.GetVar()) > 0 { if node.GetInterface() != nil { if param, paramOk := validateInputVar(node, operand.GetVar(), requireParamType, errs.NewScope()); paramOk { diff --git a/pkg/compiler/validators/typing.go b/pkg/compiler/validators/typing.go deleted file mode 100644 index 09268359d..000000000 --- a/pkg/compiler/validators/typing.go +++ /dev/null @@ -1,360 +0,0 @@ -package validators - -import ( - "strings" - - flyte "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - structpb "github.com/golang/protobuf/ptypes/struct" -) - -type typeChecker interface { - CastsFrom(*flyte.LiteralType) bool -} - -type trivialChecker struct { - literalType *flyte.LiteralType -} - -// CastsFrom is a trivial type checker merely checks if types match exactly. -func (t trivialChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { - // If upstream is an enum, it can be consumed as a string downstream - if upstreamType.GetEnumType() != nil { - if t.literalType.GetSimple() == flyte.SimpleType_STRING { - return true - } - } - // If t is an enum, it can be created from a string as Enums as just constrained String aliases - if t.literalType.GetEnumType() != nil { - if upstreamType.GetSimple() == flyte.SimpleType_STRING { - return true - } - } - - if GetTagForType(upstreamType) != "" && GetTagForType(t.literalType) != GetTagForType(upstreamType) { - return false - } - - // Ignore metadata when comparing types. - upstreamTypeCopy := *upstreamType - downstreamTypeCopy := *t.literalType - upstreamTypeCopy.Structure = &flyte.TypeStructure{} - downstreamTypeCopy.Structure = &flyte.TypeStructure{} - upstreamTypeCopy.Metadata = &structpb.Struct{} - downstreamTypeCopy.Metadata = &structpb.Struct{} - upstreamTypeCopy.Annotation = &flyte.TypeAnnotation{} - downstreamTypeCopy.Annotation = &flyte.TypeAnnotation{} - return upstreamTypeCopy.String() == downstreamTypeCopy.String() -} - -type noneTypeChecker struct{} - -// CastsFrom matches only void -func (t noneTypeChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { - return isNoneType(upstreamType) -} - -type mapTypeChecker struct { - literalType *flyte.LiteralType -} - -// CastsFrom checks that the target map type can be cast to the current map type. We need to ensure both the key types -// and value types match. -func (t mapTypeChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { - // Empty maps should match any collection. - mapLiteralType := upstreamType.GetMapValueType() - if isNoneType(mapLiteralType) { - return true - } else if mapLiteralType != nil { - return getTypeChecker(t.literalType.GetMapValueType()).CastsFrom(mapLiteralType) - } - - return false -} - -type collectionTypeChecker struct { - literalType *flyte.LiteralType -} - -// CastsFrom checks whether two collection types match. We need to ensure that the nesting is correct and the final -// subtypes match. -func (t collectionTypeChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { - // Empty collections should match any collection. - collectionType := upstreamType.GetCollectionType() - if isNoneType(upstreamType.GetCollectionType()) { - return true - } else if collectionType != nil { - return getTypeChecker(t.literalType.GetCollectionType()).CastsFrom(collectionType) - } - - return false -} - -type schemaTypeChecker struct { - literalType *flyte.LiteralType -} - -// CastsFrom handles type casting to the underlying schema type. -// Schemas are more complex types in the Flyte ecosystem. A schema is considered castable in the following -// cases. -// -// 1. The downstream schema has no column types specified. In such a case, it accepts all schema input since it is -// generic. -// -// 2. The downstream schema has a subset of the upstream columns and they match perfectly. -// -// 3. The upstream type can be Schema type or structured dataset type -func (t schemaTypeChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { - schemaType := upstreamType.GetSchema() - structuredDatasetType := upstreamType.GetStructuredDatasetType() - if structuredDatasetType == nil && schemaType == nil { - return false - } - - if schemaType != nil { - return schemaCastFromSchema(schemaType, t.literalType.GetSchema()) - } - - // Flyte Schema can only be serialized to parquet - if len(structuredDatasetType.Format) != 0 && !strings.EqualFold(structuredDatasetType.Format, "parquet") { - return false - } - - return schemaCastFromStructuredDataset(structuredDatasetType, t.literalType.GetSchema()) -} - -type structuredDatasetChecker struct { - literalType *flyte.LiteralType -} - -// CastsFrom for Structured dataset are more complex types in the Flyte ecosystem. A structured dataset is considered -// castable in the following cases: -// -// 1. The downstream structured dataset has no column types specified. In such a case, it accepts all structured dataset input since it is -// generic. -// -// 2. The downstream structured dataset has a subset of the upstream structured dataset columns and they match perfectly. -// -// 3. The upstream type can be Schema type or structured dataset type -func (t structuredDatasetChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { - // structured datasets are nullable - if isNoneType(upstreamType) { - return true - } - structuredDatasetType := upstreamType.GetStructuredDatasetType() - schemaType := upstreamType.GetSchema() - if structuredDatasetType == nil && schemaType == nil { - return false - } - if schemaType != nil { - // Flyte Schema can only be serialized to parquet - format := t.literalType.GetStructuredDatasetType().Format - if len(format) != 0 && !strings.EqualFold(format, "parquet") { - return false - } - return structuredDatasetCastFromSchema(schemaType, t.literalType.GetStructuredDatasetType()) - } - return structuredDatasetCastFromStructuredDataset(structuredDatasetType, t.literalType.GetStructuredDatasetType()) -} - -// Upstream (schema) -> downstream (schema) -func schemaCastFromSchema(upstream *flyte.SchemaType, downstream *flyte.SchemaType) bool { - if len(upstream.Columns) == 0 || len(downstream.Columns) == 0 { - return true - } - - nameToTypeMap := make(map[string]flyte.SchemaType_SchemaColumn_SchemaColumnType) - for _, column := range upstream.Columns { - nameToTypeMap[column.Name] = column.Type - } - - // Check that the downstream schema is a strict sub-set of the upstream schema. - for _, column := range downstream.Columns { - upstreamType, ok := nameToTypeMap[column.Name] - if !ok { - return false - } - if upstreamType != column.Type { - return false - } - } - return true -} - -type unionTypeChecker struct { - literalType *flyte.LiteralType -} - -func (t unionTypeChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { - unionType := t.literalType.GetUnionType() - - upstreamUnionType := upstreamType.GetUnionType() - if upstreamUnionType != nil { - // For each upstream variant we must find a compatible downstream variant - for _, u := range upstreamUnionType.GetVariants() { - found := false - for _, d := range unionType.GetVariants() { - if AreTypesCastable(u, d) { - found = true - break - } - } - if !found { - return false - } - } - - return true - } - - // Matches iff we can unambiguously select a variant - foundOne := false - for _, x := range unionType.GetVariants() { - if AreTypesCastable(upstreamType, x) { - if foundOne { - return false - } - foundOne = true - } - } - - return foundOne -} - -// Upstream (structuredDatasetType) -> downstream (structuredDatasetType) -func structuredDatasetCastFromStructuredDataset(upstream *flyte.StructuredDatasetType, downstream *flyte.StructuredDatasetType) bool { - // Skip the format check here when format is empty. https://github.com/flyteorg/flyte/issues/2864 - if len(upstream.Format) != 0 && len(downstream.Format) != 0 && !strings.EqualFold(upstream.Format, downstream.Format) { - return false - } - - if len(upstream.Columns) == 0 || len(downstream.Columns) == 0 { - return true - } - - nameToTypeMap := make(map[string]*flyte.LiteralType) - for _, column := range upstream.Columns { - nameToTypeMap[column.Name] = column.LiteralType - } - - // Check that the downstream structured dataset is a strict sub-set of the upstream structured dataset. - for _, column := range downstream.Columns { - upstreamType, ok := nameToTypeMap[column.Name] - if !ok { - return false - } - if !getTypeChecker(column.LiteralType).CastsFrom(upstreamType) { - return false - } - } - return true -} - -// Upstream (schemaType) -> downstream (structuredDatasetType) -func structuredDatasetCastFromSchema(upstream *flyte.SchemaType, downstream *flyte.StructuredDatasetType) bool { - if len(upstream.Columns) == 0 || len(downstream.Columns) == 0 { - return true - } - nameToTypeMap := make(map[string]flyte.SchemaType_SchemaColumn_SchemaColumnType) - for _, column := range upstream.Columns { - nameToTypeMap[column.Name] = column.GetType() - } - - // Check that the downstream structuredDataset is a strict sub-set of the upstream schema. - for _, column := range downstream.Columns { - upstreamType, ok := nameToTypeMap[column.Name] - if !ok { - return false - } - if !schemaTypeIsMatchStructuredDatasetType(upstreamType, column.LiteralType.GetSimple()) { - return false - } - } - return true -} - -// Upstream (structuredDatasetType) -> downstream (schemaType) -func schemaCastFromStructuredDataset(upstream *flyte.StructuredDatasetType, downstream *flyte.SchemaType) bool { - if len(upstream.Columns) == 0 || len(downstream.Columns) == 0 { - return true - } - nameToTypeMap := make(map[string]flyte.SimpleType) - for _, column := range upstream.Columns { - nameToTypeMap[column.Name] = column.LiteralType.GetSimple() - } - - // Check that the downstream schema is a strict sub-set of the upstream structuredDataset. - for _, column := range downstream.Columns { - upstreamType, ok := nameToTypeMap[column.Name] - if !ok { - return false - } - if !schemaTypeIsMatchStructuredDatasetType(column.GetType(), upstreamType) { - return false - } - } - return true -} - -func schemaTypeIsMatchStructuredDatasetType(schemaType flyte.SchemaType_SchemaColumn_SchemaColumnType, structuredDatasetType flyte.SimpleType) bool { - switch schemaType { - case flyte.SchemaType_SchemaColumn_INTEGER: - return structuredDatasetType == flyte.SimpleType_INTEGER - case flyte.SchemaType_SchemaColumn_FLOAT: - return structuredDatasetType == flyte.SimpleType_FLOAT - case flyte.SchemaType_SchemaColumn_STRING: - return structuredDatasetType == flyte.SimpleType_STRING - case flyte.SchemaType_SchemaColumn_BOOLEAN: - return structuredDatasetType == flyte.SimpleType_BOOLEAN - case flyte.SchemaType_SchemaColumn_DATETIME: - return structuredDatasetType == flyte.SimpleType_DATETIME - case flyte.SchemaType_SchemaColumn_DURATION: - return structuredDatasetType == flyte.SimpleType_DURATION - } - return false -} - -func isNoneType(t *flyte.LiteralType) bool { - switch t.GetType().(type) { - case *flyte.LiteralType_Simple: - return t.GetSimple() == flyte.SimpleType_NONE - default: - return false - } -} - -func getTypeChecker(t *flyte.LiteralType) typeChecker { - switch t.GetType().(type) { - case *flyte.LiteralType_CollectionType: - return collectionTypeChecker{ - literalType: t, - } - case *flyte.LiteralType_MapValueType: - return mapTypeChecker{ - literalType: t, - } - case *flyte.LiteralType_Schema: - return schemaTypeChecker{ - literalType: t, - } - case *flyte.LiteralType_UnionType: - return unionTypeChecker{ - literalType: t, - } - case *flyte.LiteralType_StructuredDatasetType: - return structuredDatasetChecker{ - literalType: t, - } - default: - if isNoneType(t) { - return noneTypeChecker{} - } - - return trivialChecker{ - literalType: t, - } - } -} - -func AreTypesCastable(upstreamType, downstreamType *flyte.LiteralType) bool { - return getTypeChecker(downstreamType).CastsFrom(upstreamType) -} diff --git a/pkg/compiler/validators/typing_test.go b/pkg/compiler/validators/typing_test.go deleted file mode 100644 index 8344339f0..000000000 --- a/pkg/compiler/validators/typing_test.go +++ /dev/null @@ -1,868 +0,0 @@ -package validators - -import ( - "testing" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - structpb "github.com/golang/protobuf/ptypes/struct" - "github.com/stretchr/testify/assert" -) - -func TestSimpleLiteralCasting(t *testing.T) { - t.Run("BaseCase_Integer", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - ) - assert.True(t, castable, "Integers should be castable to other integers") - }) - - t.Run("IntegerToFloat", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}, - }, - ) - assert.False(t, castable, "Integers should not be castable to floats") - }) - - t.Run("FloatToInteger", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}, - }, - &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - ) - assert.False(t, castable, "Floats should not be castable to integers") - }) - - t.Run("VoidToInteger", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}, - }, - &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - ) - assert.False(t, castable, "Non-optional types are non-nullable") - }) - - t.Run("IgnoreMetadata", func(t *testing.T) { - s := structpb.Struct{ - Fields: map[string]*structpb.Value{ - "a": {}, - }, - } - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - Metadata: &s, - }, - &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - ) - assert.True(t, castable, "Metadata should be ignored") - }) - - t.Run("EnumToString", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_EnumType{EnumType: &core.EnumType{ - Values: []string{"x", "y"}, - }}, - }, - &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, - }, - ) - assert.True(t, castable, "Enum should be castable to string") - }) - - t.Run("EnumToEnum", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_EnumType{EnumType: &core.EnumType{ - Values: []string{"x", "y"}, - }}, - }, - &core.LiteralType{ - Type: &core.LiteralType_EnumType{EnumType: &core.EnumType{ - Values: []string{"x", "y"}, - }}, - }, - ) - assert.True(t, castable, "Enum should be castable to Enums if they are identical") - }) - - t.Run("EnumToEnum", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_EnumType{EnumType: &core.EnumType{ - Values: []string{"x", "y"}, - }}, - }, - &core.LiteralType{ - Type: &core.LiteralType_EnumType{EnumType: &core.EnumType{ - Values: []string{"m", "n"}, - }}, - }, - ) - assert.False(t, castable, "Enum should not be castable to non matching enums") - }) - - t.Run("StringToEnum", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, - }, - &core.LiteralType{ - Type: &core.LiteralType_EnumType{EnumType: &core.EnumType{ - Values: []string{"x", "y"}, - }}, - }, - ) - assert.True(t, castable, "Strings should be castable to enums - may result in runtime failure") - }) -} - -func TestUnionCasting(t *testing.T) { - t.Run("StringToUnionUnambiguously", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, - }, - &core.LiteralType{ - Type: &core.LiteralType_UnionType{ - UnionType: &core.UnionType{ - Variants: []*core.LiteralType{ - { - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - Structure: &core.TypeStructure{ - Tag: "int", - }, - }, - { - Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, - Structure: &core.TypeStructure{ - Tag: "str", - }, - }, - }, - }, - }, - }, - ) - assert.True(t, castable, "Strings should be castable to (str | int)") - }) - - t.Run("StringToUnionAmbiguously", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, - }, - &core.LiteralType{ - Type: &core.LiteralType_UnionType{ - UnionType: &core.UnionType{ - Variants: []*core.LiteralType{ - { - Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, - Structure: &core.TypeStructure{ - Tag: "str1", - }, - }, - { - Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, - Structure: &core.TypeStructure{ - Tag: "str2", - }, - }, - }, - }, - }, - }, - ) - assert.False(t, castable, "Raw string literals should not be ambiguously castable to (str | str)") - }) - - t.Run("UnionToUnionSuperset", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_UnionType{ - UnionType: &core.UnionType{ - Variants: []*core.LiteralType{ - { - Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, - Structure: &core.TypeStructure{ - Tag: "str1", - }, - }, - { - Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, - Structure: &core.TypeStructure{ - Tag: "str2", - }, - }, - }, - }, - }, - }, - &core.LiteralType{ - Type: &core.LiteralType_UnionType{ - UnionType: &core.UnionType{ - Variants: []*core.LiteralType{ - { - Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, - Structure: &core.TypeStructure{ - Tag: "str1", - }, - }, - { - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - Structure: &core.TypeStructure{ - Tag: "int1", - }, - }, - { - Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, - Structure: &core.TypeStructure{ - Tag: "str2", - }, - }, - }, - }, - }, - }, - ) - assert.True(t, castable, "Union types can be cast to a union that contains a superset of variants") - }) - - t.Run("UnionToUnionTagMismatch", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_UnionType{ - UnionType: &core.UnionType{ - Variants: []*core.LiteralType{ - { - Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, - Structure: &core.TypeStructure{ - Tag: "str1", - }, - }, - { - Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, - Structure: &core.TypeStructure{ - Tag: "str2", - }, - }, - }, - }, - }, - }, - &core.LiteralType{ - Type: &core.LiteralType_UnionType{ - UnionType: &core.UnionType{ - Variants: []*core.LiteralType{ - { - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - Structure: &core.TypeStructure{ - Tag: "str2", - }, - }, - { - Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, - Structure: &core.TypeStructure{ - Tag: "str3", - }, - }, - }, - }, - }, - }, - ) - assert.False(t, castable, "Union types can only be cast to a union that contains a superset of variants") - }) - - t.Run("UnionToUnionTypeMismatch", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_UnionType{ - UnionType: &core.UnionType{ - Variants: []*core.LiteralType{ - { - Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, - Structure: &core.TypeStructure{ - Tag: "test", - }, - }, - }, - }, - }, - }, - &core.LiteralType{ - Type: &core.LiteralType_UnionType{ - UnionType: &core.UnionType{ - Variants: []*core.LiteralType{ - { - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - Structure: &core.TypeStructure{ - Tag: "test", - }, - }, - }, - }, - }, - }, - ) - assert.False(t, castable, "Union types can only be cast to a union that contains a superset of variants") - }) -} - -func TestCollectionCasting(t *testing.T) { - t.Run("BaseCase_SingleIntegerCollection", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - }, - }, - &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - }, - }, - ) - assert.True(t, castable, "[Integer] should be castable to [Integer].") - }) - - t.Run("Empty collection", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}, - }, - }, - }, - &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - }, - }, - ) - assert.True(t, castable, "[] should be castable to [Integer].") - }) - - t.Run("SingleIntegerCollectionToSingleFloatCollection", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - }, - }, - &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}, - }, - }, - }, - ) - assert.False(t, castable, "[Integer] should not be castable to [Float]") - }) - - t.Run("MismatchedNestLevels_Scalar", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - }, - }, - &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - ) - assert.False(t, castable, "[Integer] should not be castable to Integer") - }) - - t.Run("MismatchedNestLevels_Collections", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - }, - }, - &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - }, - }, - }, - }, - ) - assert.False(t, castable, "[Integer] should not be castable to [[Integer]]") - }) - - t.Run("Nullable_Collections", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_Simple{ - Simple: core.SimpleType_NONE, - }, - }, - &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - }, - }, - }, - }, - ) - assert.False(t, castable, "Non-optional collections are not nullable") - }) -} - -func TestMapCasting(t *testing.T) { - t.Run("BaseCase_SingleIntegerMap", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_MapValueType{ - MapValueType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - }, - }, - &core.LiteralType{ - Type: &core.LiteralType_MapValueType{ - MapValueType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - }, - }, - ) - assert.True(t, castable, "{k: Integer} should be castable to {k: Integer}.") - }) - - t.Run("ScalarIntegerMapToScalarFloatMap", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_MapValueType{ - MapValueType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - }, - }, - &core.LiteralType{ - Type: &core.LiteralType_MapValueType{ - MapValueType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}, - }, - }, - }, - ) - assert.False(t, castable, "{k: Integer} should not be castable to {k: Float}") - }) - - t.Run("ScalarIntegerMapToScalarFloatMap", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_MapValueType{ - MapValueType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}, - }, - }, - }, - &core.LiteralType{ - Type: &core.LiteralType_MapValueType{ - MapValueType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}, - }, - }, - }, - ) - - assert.True(t, castable, "{k: None} should be castable to {k: Float}") - }) - - t.Run("ScalarStructToStruct", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_Simple{ - Simple: core.SimpleType_STRUCT, - }, - }, - &core.LiteralType{ - Type: &core.LiteralType_Simple{ - Simple: core.SimpleType_STRUCT, - }, - }, - ) - assert.True(t, castable, "castable from Struct to struct") - }) - - t.Run("MismatchedMapNestLevels_Scalar", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_MapValueType{ - MapValueType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - }, - }, - &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - ) - assert.False(t, castable, "{k: Integer} should not be castable to Integer") - }) - - t.Run("MismatchedMapNestLevels_Maps", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_MapValueType{ - MapValueType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - }, - }, - &core.LiteralType{ - Type: &core.LiteralType_MapValueType{ - MapValueType: &core.LiteralType{ - Type: &core.LiteralType_MapValueType{ - MapValueType: &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, - }, - }, - }, - }, - }, - ) - assert.False(t, castable, "{k: Integer} should not be castable to {k: {k: Integer}}") - }) -} - -func TestSchemaCasting(t *testing.T) { - genericSchema := &core.LiteralType{ - Type: &core.LiteralType_Schema{ - Schema: &core.SchemaType{ - Columns: []*core.SchemaType_SchemaColumn{}, - }, - }, - } - genericStructuredDataset := &core.LiteralType{ - Type: &core.LiteralType_StructuredDatasetType{ - StructuredDatasetType: &core.StructuredDatasetType{ - Columns: []*core.StructuredDatasetType_DatasetColumn{}, - Format: "", - }, - }, - } - subsetIntegerSchema := &core.LiteralType{ - Type: &core.LiteralType_Schema{ - Schema: &core.SchemaType{ - Columns: []*core.SchemaType_SchemaColumn{ - { - Name: "a", - Type: core.SchemaType_SchemaColumn_INTEGER, - }, - }, - }, - }, - } - supersetIntegerAndFloatSchema := &core.LiteralType{ - Type: &core.LiteralType_Schema{ - Schema: &core.SchemaType{ - Columns: []*core.SchemaType_SchemaColumn{ - { - Name: "a", - Type: core.SchemaType_SchemaColumn_INTEGER, - }, - { - Name: "b", - Type: core.SchemaType_SchemaColumn_FLOAT, - }, - }, - }, - }, - } - supersetStructuredDataset := &core.LiteralType{ - Type: &core.LiteralType_StructuredDatasetType{ - StructuredDatasetType: &core.StructuredDatasetType{ - Columns: []*core.StructuredDatasetType_DatasetColumn{ - { - Name: "a", - LiteralType: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, - }, - { - Name: "b", - LiteralType: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}}, - }, - }, - Format: "parquet", - }, - }, - } - mismatchedSubsetSchema := &core.LiteralType{ - Type: &core.LiteralType_Schema{ - Schema: &core.SchemaType{ - Columns: []*core.SchemaType_SchemaColumn{ - { - Name: "a", - Type: core.SchemaType_SchemaColumn_FLOAT, - }, - }, - }, - }, - } - - t.Run("BaseCase_GenericSchema", func(t *testing.T) { - castable := AreTypesCastable(genericSchema, genericSchema) - assert.True(t, castable, "Schema() should be castable to Schema()") - }) - - t.Run("GenericSchemaToNonGeneric", func(t *testing.T) { - castable := AreTypesCastable(genericSchema, subsetIntegerSchema) - assert.True(t, castable, "Schema() should be castable to Schema(a=Integer)") - }) - - t.Run("NonGenericSchemaToGeneric", func(t *testing.T) { - castable := AreTypesCastable(subsetIntegerSchema, genericSchema) - assert.True(t, castable, "Schema(a=Integer) should be castable to Schema()") - }) - - t.Run("SupersetToSubsetTypedSchema", func(t *testing.T) { - castable := AreTypesCastable(supersetIntegerAndFloatSchema, subsetIntegerSchema) - assert.True(t, castable, "Schema(a=Integer, b=Float) should be castable to Schema(a=Integer)") - }) - - t.Run("GenericToSubsetTypedSchema", func(t *testing.T) { - castable := AreTypesCastable(genericStructuredDataset, subsetIntegerSchema) - assert.True(t, castable, "StructuredDataset() with generic format should be castable to Schema(a=Integer)") - }) - - t.Run("SubsetTypedSchemaToGeneric", func(t *testing.T) { - castable := AreTypesCastable(subsetIntegerSchema, genericStructuredDataset) - assert.True(t, castable, "Schema(a=Integer) should be castable to StructuredDataset() with generic format") - }) - - t.Run("SupersetStructuredToSubsetTypedSchema", func(t *testing.T) { - castable := AreTypesCastable(supersetStructuredDataset, subsetIntegerSchema) - assert.True(t, castable, "StructuredDataset(a=Integer, b=Float) should be castable to Schema(a=Integer)") - }) - - t.Run("SubsetToSupersetSchema", func(t *testing.T) { - castable := AreTypesCastable(subsetIntegerSchema, supersetIntegerAndFloatSchema) - assert.False(t, castable, "Schema(a=Integer) should not be castable to Schema(a=Integer, b=Float)") - }) - - t.Run("MismatchedColumns", func(t *testing.T) { - castable := AreTypesCastable(subsetIntegerSchema, mismatchedSubsetSchema) - assert.False(t, castable, "Schema(a=Integer) should not be castable to Schema(a=Float)") - }) - - t.Run("MismatchedColumnsFlipped", func(t *testing.T) { - castable := AreTypesCastable(mismatchedSubsetSchema, subsetIntegerSchema) - assert.False(t, castable, "Schema(a=Float) should not be castable to Schema(a=Integer)") - }) - - t.Run("SchemasAreNullable", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_Simple{ - Simple: core.SimpleType_NONE, - }, - }, - subsetIntegerSchema) - assert.False(t, castable, "Non-optional schemas are not nullable") - }) -} - -func TestStructuredDatasetCasting(t *testing.T) { - emptyStructuredDataset := &core.LiteralType{ - Type: &core.LiteralType_StructuredDatasetType{ - StructuredDatasetType: &core.StructuredDatasetType{ - Columns: []*core.StructuredDatasetType_DatasetColumn{}, - Format: "", - }, - }, - } - genericStructuredDataset := &core.LiteralType{ - Type: &core.LiteralType_StructuredDatasetType{ - StructuredDatasetType: &core.StructuredDatasetType{ - Columns: []*core.StructuredDatasetType_DatasetColumn{}, - Format: "parquet", - }, - }, - } - subsetStructuredDataset := &core.LiteralType{ - Type: &core.LiteralType_StructuredDatasetType{ - StructuredDatasetType: &core.StructuredDatasetType{ - Columns: []*core.StructuredDatasetType_DatasetColumn{ - { - Name: "a", - LiteralType: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, - }, - { - Name: "b", - LiteralType: &core.LiteralType{Type: &core.LiteralType_CollectionType{CollectionType: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}}}, - }, - }, - Format: "parquet", - }, - }, - } - supersetStructuredDataset := &core.LiteralType{ - Type: &core.LiteralType_StructuredDatasetType{ - StructuredDatasetType: &core.StructuredDatasetType{ - Columns: []*core.StructuredDatasetType_DatasetColumn{ - { - Name: "a", - LiteralType: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, - }, - { - Name: "b", - LiteralType: &core.LiteralType{Type: &core.LiteralType_CollectionType{CollectionType: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}}}, - }, - { - Name: "c", - LiteralType: &core.LiteralType{Type: &core.LiteralType_MapValueType{MapValueType: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}}}, - }, - }, - Format: "parquet", - }, - }, - } - integerSchema := &core.LiteralType{ - Type: &core.LiteralType_Schema{ - Schema: &core.SchemaType{ - Columns: []*core.SchemaType_SchemaColumn{ - { - Name: "a", - Type: core.SchemaType_SchemaColumn_INTEGER, - }, - }, - }, - }, - } - integerStructuredDataset := &core.LiteralType{ - Type: &core.LiteralType_StructuredDatasetType{ - StructuredDatasetType: &core.StructuredDatasetType{ - Columns: []*core.StructuredDatasetType_DatasetColumn{ - { - Name: "a", - LiteralType: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, - }, - }, - Format: "parquet", - }, - }, - } - mismatchedSubsetStructuredDataset := &core.LiteralType{ - Type: &core.LiteralType_StructuredDatasetType{ - StructuredDatasetType: &core.StructuredDatasetType{ - Columns: []*core.StructuredDatasetType_DatasetColumn{ - { - Name: "a", - LiteralType: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}}, - }, - }, - }, - }, - } - - t.Run("BaseCase_GenericStructuredDataset", func(t *testing.T) { - castable := AreTypesCastable(genericStructuredDataset, genericStructuredDataset) - assert.True(t, castable, "StructuredDataset() should be castable to StructuredDataset()") - }) - - t.Run("GenericStructuredDatasetToNonGeneric", func(t *testing.T) { - castable := AreTypesCastable(genericStructuredDataset, subsetStructuredDataset) - assert.True(t, castable, "StructuredDataset() should be castable to StructuredDataset(a=Integer, b=Collection)") - }) - - t.Run("NonGenericStructuredDatasetToGeneric", func(t *testing.T) { - castable := AreTypesCastable(subsetStructuredDataset, genericStructuredDataset) - assert.True(t, castable, "StructuredDataset(a=Integer, b=Collection) should be castable to StructuredDataset()") - }) - - t.Run("SupersetToSubsetTypedStructuredDataset", func(t *testing.T) { - castable := AreTypesCastable(supersetStructuredDataset, subsetStructuredDataset) - assert.True(t, castable, "StructuredDataset(a=Integer, b=Collection, c=Map) should be castable to StructuredDataset(a=Integer, b=Collection)") - }) - - t.Run("SubsetToSupersetStructuredDataset", func(t *testing.T) { - castable := AreTypesCastable(subsetStructuredDataset, supersetStructuredDataset) - assert.False(t, castable, "StructuredDataset(a=Integer, b=Collection) should not be castable to StructuredDataset(a=Integer, b=Collection, c=Map)") - }) - - t.Run("SchemaToStructuredDataset", func(t *testing.T) { - castable := AreTypesCastable(integerSchema, integerStructuredDataset) - assert.True(t, castable, "Schema(a=Integer) should be castable to StructuredDataset(a=Integer)") - }) - - t.Run("MismatchedSchemaColumns", func(t *testing.T) { - castable := AreTypesCastable(integerSchema, mismatchedSubsetStructuredDataset) - assert.False(t, castable, "Schema(a=Integer) should not be castable to StructuredDataset(a=Float)") - }) - - t.Run("MismatchedColumns", func(t *testing.T) { - castable := AreTypesCastable(subsetStructuredDataset, mismatchedSubsetStructuredDataset) - assert.False(t, castable, "StructuredDataset(a=Integer, b=Collection) should not be castable to StructuredDataset(a=Float)") - }) - - t.Run("MismatchedColumnsFlipped", func(t *testing.T) { - castable := AreTypesCastable(mismatchedSubsetStructuredDataset, subsetStructuredDataset) - assert.False(t, castable, "StructuredDataset(a=Float) should not be castable to StructuredDataset(a=Integer, b=Collection)") - }) - - t.Run("GenericToEmptyFormat", func(t *testing.T) { - castable := AreTypesCastable(genericStructuredDataset, emptyStructuredDataset) - assert.True(t, castable, "StructuredDataset(format='Parquet') should be castable to StructuredDataset()") - }) - - t.Run("EmptyFormatToGeneric", func(t *testing.T) { - castable := AreTypesCastable(genericStructuredDataset, emptyStructuredDataset) - assert.True(t, castable, "StructuredDataset() should be castable to StructuredDataset(format='Parquet')") - }) - - t.Run("StructuredDatasetsAreNullable", func(t *testing.T) { - castable := AreTypesCastable( - &core.LiteralType{ - Type: &core.LiteralType_Simple{ - Simple: core.SimpleType_NONE, - }, - }, - subsetStructuredDataset) - assert.True(t, castable, "StructuredDataset are nullable") - }) -} diff --git a/pkg/compiler/validators/utils.go b/pkg/compiler/validators/utils.go index 03e4636ce..05cd787b1 100644 --- a/pkg/compiler/validators/utils.go +++ b/pkg/compiler/validators/utils.go @@ -3,11 +3,8 @@ package validators import ( "fmt" - "golang.org/x/exp/slices" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/golang/protobuf/proto" - "golang.org/x/exp/maps" "k8s.io/apimachinery/pkg/util/sets" ) @@ -30,82 +27,6 @@ func findVariableByName(vars *core.VariableMap, name string) (variable *core.Var return } -// Gets literal type for scalar value. This can be used to compare the underlying type of two scalars for compatibility. -func literalTypeForScalar(scalar *core.Scalar) *core.LiteralType { - // TODO: Should we just pass the type information with the value? That way we don't have to guess? - var literalType *core.LiteralType - switch v := scalar.GetValue().(type) { - case *core.Scalar_Primitive: - literalType = literalTypeForPrimitive(scalar.GetPrimitive()) - case *core.Scalar_Blob: - if scalar.GetBlob().GetMetadata() == nil { - return nil - } - - literalType = &core.LiteralType{Type: &core.LiteralType_Blob{Blob: scalar.GetBlob().GetMetadata().GetType()}} - case *core.Scalar_Binary: - literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_BINARY}} - case *core.Scalar_Schema: - literalType = &core.LiteralType{ - Type: &core.LiteralType_Schema{ - Schema: scalar.GetSchema().Type, - }, - } - case *core.Scalar_StructuredDataset: - if v.StructuredDataset == nil || v.StructuredDataset.Metadata == nil { - return &core.LiteralType{ - Type: &core.LiteralType_StructuredDatasetType{}, - } - } - - literalType = &core.LiteralType{ - Type: &core.LiteralType_StructuredDatasetType{ - StructuredDatasetType: scalar.GetStructuredDataset().GetMetadata().StructuredDatasetType, - }, - } - case *core.Scalar_NoneType: - literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}} - case *core.Scalar_Error: - literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_ERROR}} - case *core.Scalar_Generic: - literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}} - case *core.Scalar_Union: - literalType = &core.LiteralType{ - Type: &core.LiteralType_UnionType{ - UnionType: &core.UnionType{ - Variants: []*core.LiteralType{ - scalar.GetUnion().GetType(), - }, - }, - }, - } - default: - return nil - } - - return literalType -} - -func literalTypeForPrimitive(primitive *core.Primitive) *core.LiteralType { - simpleType := core.SimpleType_NONE - switch primitive.GetValue().(type) { - case *core.Primitive_Integer: - simpleType = core.SimpleType_INTEGER - case *core.Primitive_FloatValue: - simpleType = core.SimpleType_FLOAT - case *core.Primitive_StringValue: - simpleType = core.SimpleType_STRING - case *core.Primitive_Boolean: - simpleType = core.SimpleType_BOOLEAN - case *core.Primitive_Datetime: - simpleType = core.SimpleType_DATETIME - case *core.Primitive_Duration: - simpleType = core.SimpleType_DURATION - } - - return &core.LiteralType{Type: &core.LiteralType_Simple{Simple: simpleType}} -} - func buildVariablesIndex(params *core.VariableMap) (map[string]*core.Variable, sets.String) { paramMap := make(map[string]*core.Variable, len(params.Variables)) paramSet := sets.NewString() @@ -161,66 +82,3 @@ func UnionDistinctVariableMaps(m1, m2 map[string]*core.Variable) (map[string]*co return res, nil } - -func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType { - innerType := make([]*core.LiteralType, 0, 1) - innerTypeSet := sets.NewString() - for _, x := range literals { - otherType := LiteralTypeForLiteral(x) - otherTypeKey := otherType.String() - - if !innerTypeSet.Has(otherTypeKey) { - innerType = append(innerType, otherType) - innerTypeSet.Insert(otherTypeKey) - } - } - - if len(innerType) == 0 { - return &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}, - } - } else if len(innerType) == 1 { - return innerType[0] - } - - // sort inner types to ensure consistent union types are generated - slices.SortFunc(innerType, func(a, b *core.LiteralType) bool { return a.String() < b.String() }) - - return &core.LiteralType{ - Type: &core.LiteralType_UnionType{ - UnionType: &core.UnionType{ - Variants: innerType, - }, - }, - } -} - -// LiteralTypeForLiteral gets LiteralType for literal, nil if the value of literal is unknown, or type collection/map of -// type None if the literal is a non-homogeneous type. -func LiteralTypeForLiteral(l *core.Literal) *core.LiteralType { - switch l.GetValue().(type) { - case *core.Literal_Scalar: - return literalTypeForScalar(l.GetScalar()) - case *core.Literal_Collection: - return &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: literalTypeForLiterals(l.GetCollection().Literals), - }, - } - case *core.Literal_Map: - return &core.LiteralType{ - Type: &core.LiteralType_MapValueType{ - MapValueType: literalTypeForLiterals(maps.Values(l.GetMap().Literals)), - }, - } - } - - return nil -} - -func GetTagForType(x *core.LiteralType) string { - if x.GetStructure() == nil { - return "" - } - return x.GetStructure().GetTag() -} diff --git a/pkg/compiler/validators/utils_test.go b/pkg/compiler/validators/utils_test.go index 3557ba0ec..c1c73747c 100644 --- a/pkg/compiler/validators/utils_test.go +++ b/pkg/compiler/validators/utils_test.go @@ -8,50 +8,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestLiteralTypeForLiterals(t *testing.T) { - t.Run("empty", func(t *testing.T) { - lt := literalTypeForLiterals(nil) - assert.Equal(t, core.SimpleType_NONE.String(), lt.GetSimple().String()) - }) - - t.Run("homogenous", func(t *testing.T) { - lt := literalTypeForLiterals([]*core.Literal{ - coreutils.MustMakeLiteral(5), - coreutils.MustMakeLiteral(0), - coreutils.MustMakeLiteral(5), - }) - - assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetSimple().String()) - }) - - t.Run("non-homogenous", func(t *testing.T) { - lt := literalTypeForLiterals([]*core.Literal{ - coreutils.MustMakeLiteral("hello"), - coreutils.MustMakeLiteral(5), - coreutils.MustMakeLiteral("world"), - coreutils.MustMakeLiteral(0), - coreutils.MustMakeLiteral(2), - }) - - assert.Len(t, lt.GetUnionType().Variants, 2) - assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().Variants[0].GetSimple().String()) - assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().Variants[1].GetSimple().String()) - }) - - t.Run("non-homogenous ensure ordering", func(t *testing.T) { - lt := literalTypeForLiterals([]*core.Literal{ - coreutils.MustMakeLiteral(5), - coreutils.MustMakeLiteral("world"), - coreutils.MustMakeLiteral(0), - coreutils.MustMakeLiteral(2), - }) - - assert.Len(t, lt.GetUnionType().Variants, 2) - assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().Variants[0].GetSimple().String()) - assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().Variants[1].GetSimple().String()) - }) -} - func TestJoinVariableMapsUniqueKeys(t *testing.T) { intType := &core.LiteralType{ Type: &core.LiteralType_Simple{ diff --git a/pkg/controller/nodes/task/catalog/datacatalog/transformer.go b/pkg/controller/nodes/task/catalog/datacatalog/transformer.go index abf36f128..736ed6308 100644 --- a/pkg/controller/nodes/task/catalog/datacatalog/transformer.go +++ b/pkg/controller/nodes/task/catalog/datacatalog/transformer.go @@ -8,12 +8,11 @@ import ( "strconv" "strings" + "github.com/flyteorg/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/datacatalog" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" - "github.com/flyteorg/flytepropeller/pkg/compiler/validators" - "github.com/flyteorg/flytestdlib/pbhash" ) @@ -55,8 +54,8 @@ func GenerateTaskOutputsFromArtifact(id core.Identifier, taskInterface core.Type } expectedVarType := outputVariables[artifactData.Name].GetType() - inputType := validators.LiteralTypeForLiteral(artifactData.Value) - if !validators.AreTypesCastable(inputType, expectedVarType) { + inputType := coreutils.LiteralTypeForLiteral(artifactData.Value) + if !coreutils.AreTypesCastable(inputType, expectedVarType) { return nil, fmt.Errorf("unexpected artifactData: [%v] type: [%v] does not match any task output type: [%v]", artifactData.Name, inputType, expectedVarType) } From 2f5a30a4a06edbea5f8058b102fe5c06b1f959bf Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Fri, 31 Mar 2023 13:04:07 -0500 Subject: [PATCH 2/3] fixed unit tests Signed-off-by: Daniel Rammer --- pkg/compiler/validators/bindings_test.go | 20 ++++++++++---------- pkg/compiler/validators/utils_test.go | 1 - 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pkg/compiler/validators/bindings_test.go b/pkg/compiler/validators/bindings_test.go index f240cf1d8..6e581beb6 100644 --- a/pkg/compiler/validators/bindings_test.go +++ b/pkg/compiler/validators/bindings_test.go @@ -103,7 +103,7 @@ func TestValidateBindings(t *testing.T) { vars := &core.VariableMap{ Variables: map[string]*core.Variable{ "x": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(5)), + Type: coreutils.LiteralTypeForLiteral(coreutils.MustMakeLiteral(5)), }, }, } @@ -132,7 +132,7 @@ func TestValidateBindings(t *testing.T) { vars := &core.VariableMap{ Variables: map[string]*core.Variable{ "x": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral([]interface{}{5})), + Type: coreutils.LiteralTypeForLiteral(coreutils.MustMakeLiteral([]interface{}{5})), }, }, } @@ -227,7 +227,7 @@ func TestValidateBindings(t *testing.T) { vars := &core.VariableMap{ Variables: map[string]*core.Variable{ "x": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral( + Type: coreutils.LiteralTypeForLiteral(coreutils.MustMakeLiteral( map[string]interface{}{ "xy": 5, })), @@ -265,7 +265,7 @@ func TestValidateBindings(t *testing.T) { Outputs: &core.VariableMap{ Variables: map[string]*core.Variable{ "n2_out": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(2)), + Type: coreutils.LiteralTypeForLiteral(coreutils.MustMakeLiteral(2)), }, }, }, @@ -292,7 +292,7 @@ func TestValidateBindings(t *testing.T) { vars := &core.VariableMap{ Variables: map[string]*core.Variable{ "x": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(5)), + Type: coreutils.LiteralTypeForLiteral(coreutils.MustMakeLiteral(5)), }, }, } @@ -327,7 +327,7 @@ func TestValidateBindings(t *testing.T) { Outputs: &core.VariableMap{ Variables: map[string]*core.Variable{ "n2_out": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(2)), + Type: coreutils.LiteralTypeForLiteral(coreutils.MustMakeLiteral(2)), }, }, }, @@ -349,7 +349,7 @@ func TestValidateBindings(t *testing.T) { vars := &core.VariableMap{ Variables: map[string]*core.Variable{ "x": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(5)), + Type: coreutils.LiteralTypeForLiteral(coreutils.MustMakeLiteral(5)), }, }, } @@ -906,7 +906,7 @@ func TestValidateBindings(t *testing.T) { Outputs: &core.VariableMap{ Variables: map[string]*core.Variable{ "n2_out": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(2)), + Type: coreutils.LiteralTypeForLiteral(coreutils.MustMakeLiteral(2)), }, }, }, @@ -987,7 +987,7 @@ func TestValidateBindings(t *testing.T) { Outputs: &core.VariableMap{ Variables: map[string]*core.Variable{ "n2_out": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(2)), + Type: coreutils.LiteralTypeForLiteral(coreutils.MustMakeLiteral(2)), }, }, }, @@ -1072,7 +1072,7 @@ func TestValidateBindings(t *testing.T) { Outputs: &core.VariableMap{ Variables: map[string]*core.Variable{ "n2_out": { - Type: LiteralTypeForLiteral(&core.Literal{ + Type: coreutils.LiteralTypeForLiteral(&core.Literal{ Value: &core.Literal_Scalar{ Scalar: &core.Scalar{ Value: &core.Scalar_Union{ diff --git a/pkg/compiler/validators/utils_test.go b/pkg/compiler/validators/utils_test.go index c1c73747c..e610574d4 100644 --- a/pkg/compiler/validators/utils_test.go +++ b/pkg/compiler/validators/utils_test.go @@ -3,7 +3,6 @@ package validators import ( "testing" - "github.com/flyteorg/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/stretchr/testify/assert" ) From 0517c254b56eed5d4debb8baa899cb6220d1a6b3 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Fri, 31 Mar 2023 16:39:24 -0500 Subject: [PATCH 3/3] bumped flyteidl Signed-off-by: Daniel Rammer --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 46d413d82..92f898bf2 100644 --- a/go.mod +++ b/go.mod @@ -148,4 +148,4 @@ require ( replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d -replace github.com/flyteorg/flyteidl => ../flyteidl +replace github.com/flyteorg/flyteidl => github.com/flyteorg/flyteidl v1.3.16-0.20230331180644-4649cf9c7cc7 diff --git a/go.sum b/go.sum index b594a7d5e..c84fbfb6d 100644 --- a/go.sum +++ b/go.sum @@ -260,6 +260,8 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/flyteorg/flyteidl v1.3.16-0.20230331180644-4649cf9c7cc7 h1:rzSpT+JTgpbc49Le3caq8Tt9sWWxdFjxrCTgr7tUVzU= +github.com/flyteorg/flyteidl v1.3.16-0.20230331180644-4649cf9c7cc7/go.mod h1:GdhmUeGpSBVf98mndMLgBsEM8X/emAaVY/LtyIMbMBo= github.com/flyteorg/flyteplugins v1.0.44 h1:uKizng+i0vfXslyPBlrsfecInhvy71fTB4kRg7eiifE= github.com/flyteorg/flyteplugins v1.0.44/go.mod h1:ztsonku5fKwyxcIg1k69PTiBVjRI6d3nK5DnC+iwx08= github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0=