From 9638db0bf446f111d031cee5ffe3ebfd26df906e Mon Sep 17 00:00:00 2001 From: ddl-rliu <140021987+ddl-rliu@users.noreply.github.com> Date: Wed, 17 Jul 2024 10:51:01 -0700 Subject: [PATCH] Add blob typechecker (#5519) Signed-off-by: ddl-rliu --- .../pkg/compiler/validators/typing.go | 24 ++++++++ .../pkg/compiler/validators/typing_test.go | 55 +++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/flytepropeller/pkg/compiler/validators/typing.go b/flytepropeller/pkg/compiler/validators/typing.go index 388cb32123..2bde60b47b 100644 --- a/flytepropeller/pkg/compiler/validators/typing.go +++ b/flytepropeller/pkg/compiler/validators/typing.go @@ -72,6 +72,26 @@ func (t mapTypeChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { return false } +type blobTypeChecker struct { + literalType *flyte.LiteralType +} + +// CastsFrom checks that the target blob type can be cast to the current blob type. When the blob has no format +// specified, it accepts all blob inputs since it is generic. +func (t blobTypeChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { + blobType := upstreamType.GetBlob() + if blobType == nil { + return false + } + + // Empty blobs should match any blob. + if blobType.GetFormat() == "" || t.literalType.GetBlob().GetFormat() == "" { + return true + } + + return blobType.GetFormat() == t.literalType.GetBlob().GetFormat() +} + type collectionTypeChecker struct { literalType *flyte.LiteralType } @@ -333,6 +353,10 @@ func getTypeChecker(t *flyte.LiteralType) typeChecker { return mapTypeChecker{ literalType: t, } + case *flyte.LiteralType_Blob: + return blobTypeChecker{ + literalType: t, + } case *flyte.LiteralType_Schema: return schemaTypeChecker{ literalType: t, diff --git a/flytepropeller/pkg/compiler/validators/typing_test.go b/flytepropeller/pkg/compiler/validators/typing_test.go index 17e55b1e0a..f2e407b986 100644 --- a/flytepropeller/pkg/compiler/validators/typing_test.go +++ b/flytepropeller/pkg/compiler/validators/typing_test.go @@ -893,3 +893,58 @@ func TestStructuredDatasetCasting(t *testing.T) { assert.True(t, castable, "StructuredDataset are nullable") }) } + +func TestBlobCasting(t *testing.T) { + emptyBlob := &core.LiteralType{ + Type: &core.LiteralType_Blob{ + Blob: &core.BlobType{ + Format: "", + }, + }, + } + genericBlob := &core.LiteralType{ + Type: &core.LiteralType_Blob{ + Blob: &core.BlobType{ + Format: "csv", + }, + }, + } + mismatchedFormatBlob := &core.LiteralType{ + Type: &core.LiteralType_Blob{ + Blob: &core.BlobType{ + Format: "pdf", + }, + }, + } + + t.Run("BaseCase_GenericBlob", func(t *testing.T) { + castable := AreTypesCastable(genericBlob, genericBlob) + assert.True(t, castable, "Blob() should be castable to Blob()") + }) + + t.Run("GenericToEmptyFormat", func(t *testing.T) { + castable := AreTypesCastable(genericBlob, emptyBlob) + assert.True(t, castable, "Blob(format='csv') should be castable to Blob()") + }) + + t.Run("EmptyFormatToGeneric", func(t *testing.T) { + castable := AreTypesCastable(genericBlob, emptyBlob) + assert.True(t, castable, "Blob() should be castable to Blob(format='csv')") + }) + + t.Run("MismatchedFormat", func(t *testing.T) { + castable := AreTypesCastable(genericBlob, mismatchedFormatBlob) + assert.False(t, castable, "Blob(format='csv') should not be castable to Blob(format='pdf')") + }) + + t.Run("BlobsAreNullable", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_NONE, + }, + }, + genericBlob) + assert.False(t, castable, "Blob is not nullable") + }) +}