From a4b11eba615c87b9253f61d9b66f02839490e12b Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 1 Nov 2019 16:54:48 -0700 Subject: [PATCH] Update the SPV dialect type parser to use the methods on DialectAsmParser directly. This simplifies the implementation quite a bit, and removes the need for explicit string munging. One change is made to some of the enum elements of SPV_DimAttr to ensure that they are proper identifiers; The string form is now prefixed with 'Dim'. PiperOrigin-RevId: 278027132 --- include/mlir/Dialect/SPIRV/SPIRVBase.td | 6 +- include/mlir/IR/DialectImplementation.h | 23 +- include/mlir/IR/OpDefinition.h | 24 ++ lib/Dialect/SPIRV/SPIRVDialect.cpp | 408 +++++++++--------------- lib/Parser/Parser.cpp | 14 +- test/Dialect/SPIRV/types.mlir | 95 +++--- test/IR/invalid.mlir | 5 - 7 files changed, 260 insertions(+), 315 deletions(-) diff --git a/include/mlir/Dialect/SPIRV/SPIRVBase.td b/include/mlir/Dialect/SPIRV/SPIRVBase.td index ba8659c87c7f..990758a71084 100644 --- a/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -758,9 +758,9 @@ def SPV_DecorationAttr : let cppNamespace = "::mlir::spirv"; } -def SPV_D_1D : I32EnumAttrCase<"1D", 0>; -def SPV_D_2D : I32EnumAttrCase<"2D", 1>; -def SPV_D_3D : I32EnumAttrCase<"3D", 2>; +def SPV_D_1D : I32EnumAttrCase<"Dim1D", 0>; +def SPV_D_2D : I32EnumAttrCase<"Dim2D", 1>; +def SPV_D_3D : I32EnumAttrCase<"Dim3D", 2>; def SPV_D_Cube : I32EnumAttrCase<"Cube", 3>; def SPV_D_Rect : I32EnumAttrCase<"Rect", 4>; def SPV_D_Buffer : I32EnumAttrCase<"Buffer", 5>; diff --git a/include/mlir/IR/DialectImplementation.h b/include/mlir/IR/DialectImplementation.h index f713c6030693..effcf49df75e 100644 --- a/include/mlir/IR/DialectImplementation.h +++ b/include/mlir/IR/DialectImplementation.h @@ -144,14 +144,26 @@ class DialectAsmParser { virtual ParseResult parseFloat(double &result) = 0; /// Parse an integer value from the stream. - virtual ParseResult parseInteger(uint64_t &result) = 0; - template ParseResult parseInteger(IntT &result) { auto loc = getCurrentLocation(); + OptionalParseResult parseResult = parseOptionalInteger(result); + if (!parseResult.hasValue()) + return emitError(loc, "expected integer value"); + return *parseResult; + } + + /// Parse an optional integer value from the stream. + virtual OptionalParseResult parseOptionalInteger(uint64_t &result) = 0; + + template + OptionalParseResult parseOptionalInteger(IntT &result) { + auto loc = getCurrentLocation(); + // Parse the unsigned variant. uint64_t uintResult; - if (failed(parseInteger(uintResult))) - return failure(); + OptionalParseResult parseResult = parseOptionalInteger(uintResult); + if (!parseResult.hasValue() || failed(*parseResult)) + return parseResult; // Try to convert to the provided integer type. result = IntT(uintResult); @@ -222,6 +234,9 @@ class DialectAsmParser { /// Parse a '>' token. virtual ParseResult parseGreater() = 0; + /// Parse a `>` token if present. + virtual ParseResult parseOptionalGreater() = 0; + /// Parse a `(` token. virtual ParseResult parseLParen() = 0; diff --git a/include/mlir/IR/OpDefinition.h b/include/mlir/IR/OpDefinition.h index 65033b612df1..ebe373c71e4c 100644 --- a/include/mlir/IR/OpDefinition.h +++ b/include/mlir/IR/OpDefinition.h @@ -53,6 +53,30 @@ class ParseResult : public LogicalResult { /// Failure is true in a boolean context. explicit operator bool() const { return failed(*this); } }; +/// This class implements `Optional` functionality for ParseResult. We don't +/// directly use llvm::Optional here, because it provides an implicit conversion +/// to 'bool' which we want to avoid. This class is used to implement tri-state +/// 'parseOptional' functions that may have a failure mode when parsing that +/// shouldn't be attributed to "not present". +class OptionalParseResult { +public: + OptionalParseResult() = default; + OptionalParseResult(LogicalResult result) : impl(result) {} + OptionalParseResult(ParseResult result) : impl(result) {} + OptionalParseResult(const InFlightDiagnostic &) + : OptionalParseResult(failure()) {} + OptionalParseResult(llvm::NoneType) : impl(llvm::None) {} + + /// Returns true if we contain a valid ParseResult value. + bool hasValue() const { return impl.hasValue(); } + + /// Access the internal ParseResult value. + ParseResult getValue() const { return impl.getValue(); } + ParseResult operator*() const { return getValue(); } + +private: + Optional impl; +}; // These functions are out-of-line utilities, which avoids them being template // instantiated/duplicated. diff --git a/lib/Dialect/SPIRV/SPIRVDialect.cpp b/lib/Dialect/SPIRV/SPIRVDialect.cpp index abe47240b2fc..1460cf091eb7 100644 --- a/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -127,34 +127,15 @@ std::string SPIRVDialect::getAttributeName(Decoration decoration) { // Forward declarations. template -static Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, - StringRef spec); +static Optional parseAndVerify(SPIRVDialect const &dialect, + DialectAsmParser &parser); template <> -Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, - StringRef spec); +Optional parseAndVerify(SPIRVDialect const &dialect, + DialectAsmParser &parser); template <> Optional parseAndVerify(SPIRVDialect const &dialect, - Location loc, StringRef spec); - -// Parses " x" from the beginning of `spec`. -static bool parseNumberX(StringRef &spec, int64_t &number) { - spec = spec.ltrim(); - if (spec.empty() || !llvm::isDigit(spec.front())) - return false; - - number = 0; - do { - number = number * 10 + spec.front() - '0'; - spec = spec.drop_front(); - } while (!spec.empty() && llvm::isDigit(spec.front())); - - spec = spec.ltrim(); - if (!spec.consume_front("x")) - return false; - - return true; -} + DialectAsmParser &parser); static bool isValidSPIRVIntType(IntegerType type) { return llvm::is_contained(llvm::ArrayRef({1, 8, 16, 32, 64}), @@ -192,21 +173,12 @@ bool SPIRVDialect::isValidType(Type type) { return false; } -static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec, - Location loc) { - spec = spec.trim(); - auto *context = dialect.getContext(); - size_t numCharsRead = 0; - auto type = mlir::parseType(spec.trim(), context, numCharsRead); - if (!type) { - emitError(loc, "cannot parse type: ") << spec; +static Type parseAndVerifyType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + Type type; + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + if (parser.parseType(type)) return Type(); - } - if (numCharsRead < spec.size()) { - emitError(loc, "unexpected additional tokens '") - << spec.substr(numCharsRead) << "' after parsing type: " << type; - return Type(); - } // Allow SPIR-V dialect types if (&type.getDialect() == &dialect) @@ -215,28 +187,30 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec, // Check other allowed types if (auto t = type.dyn_cast()) { if (type.isBF16()) { - emitError(loc, "cannot use 'bf16' to compose SPIR-V types"); + parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types"); return Type(); } } else if (auto t = type.dyn_cast()) { if (!isValidSPIRVIntType(t)) { - emitError(loc, "only 1/8/16/32/64-bit integer type allowed but found ") + parser.emitError(typeLoc, + "only 1/8/16/32/64-bit integer type allowed but found ") << type; return Type(); } } else if (auto t = type.dyn_cast()) { if (t.getRank() != 1) { - emitError(loc, "only 1-D vector allowed but found ") << t; + parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; return Type(); } if (t.getNumElements() > 4) { - emitError(loc, - "vector length has to be less than or equal to 4 but found ") + parser.emitError( + typeLoc, "vector length has to be less than or equal to 4 but found ") << t.getNumElements(); return Type(); } } else { - emitError(loc, "cannot use ") << type << " to compose SPIR-V types"; + parser.emitError(typeLoc, "cannot use ") + << type << " to compose SPIR-V types"; return Type(); } @@ -250,71 +224,51 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec, // // array-type ::= `!spv.array<` integer-literal `x` element-type // (`[` integer-literal `]`)? `>` -static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec, - Location loc) { - if (!spec.consume_front("array<") || !spec.consume_back(">")) { - emitError(loc, "spv.array delimiter <...> mismatch"); +static Type parseArrayType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + if (parser.parseLess()) return Type(); - } - int64_t count = 0; - spec = spec.trim(); - if (!parseNumberX(spec, count)) { - emitError(loc, "expected array element count followed by 'x' but found '") - << spec << "'"; + SmallVector countDims; + llvm::SMLoc countLoc = parser.getCurrentLocation(); + if (parser.parseDimensionList(countDims, /*allowDynamic=*/false)) + return Type(); + if (countDims.size() != 1) { + parser.emitError(countLoc, + "expected single integer for array element count"); return Type(); } // According to the SPIR-V spec: // "Length is the number of elements in the array. It must be at least 1." - if (!count) { - emitError(loc, "expected array length greater than 0"); + int64_t count = countDims[0]; + if (count == 0) { + parser.emitError(countLoc, "expected array length greater than 0"); return Type(); } - if (spec.trim().empty()) { - emitError(loc, "expected element type"); + Type elementType = parseAndVerifyType(dialect, parser); + if (!elementType) return Type(); - } ArrayType::LayoutInfo layoutInfo = 0; - size_t lastLSquare; - - // Handle case when element type is not a trivial type - auto lastRDelimiter = spec.rfind('>'); - if (lastRDelimiter != StringRef::npos) { - lastLSquare = spec.find('[', lastRDelimiter); - } else { - lastLSquare = spec.rfind('['); - } - - if (lastLSquare != StringRef::npos) { - auto layoutSpec = spec.substr(lastLSquare); - layoutSpec = layoutSpec.trim(); - if (!layoutSpec.consume_front("[") || !layoutSpec.consume_back("]")) { - emitError(loc, "expected array stride within '[' ']' in '") - << layoutSpec << "'"; - return Type(); - } - layoutSpec = layoutSpec.trim(); - auto layout = - parseAndVerify(dialect, loc, layoutSpec); - if (!layout) { + if (succeeded(parser.parseOptionalLSquare())) { + llvm::SMLoc layoutLoc = parser.getCurrentLocation(); + auto layout = parseAndVerify(dialect, parser); + if (!layout) return Type(); - } if (!(layoutInfo = layout.getValue())) { - emitError(loc, "ArrayStride must be greater than zero"); + parser.emitError(layoutLoc, "ArrayStride must be greater than zero"); return Type(); } - spec = spec.substr(0, lastLSquare); + if (parser.parseRSquare()) + return Type(); } - Type elementType = parseAndVerifyType(dialect, spec, loc); - if (!elementType) + if (parser.parseGreater()) return Type(); - return ArrayType::get(elementType, count, layoutInfo); } @@ -327,104 +281,86 @@ static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec, // | // // pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>` -static Type parsePointerType(SPIRVDialect const &dialect, StringRef spec, - Location loc) { - if (!spec.consume_front("ptr<") || !spec.consume_back(">")) { - emitError(loc, "spv.ptr delimiter <...> mismatch"); +static Type parsePointerType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + if (parser.parseLess()) return Type(); - } - // Split into pointee type and storage class - StringRef scSpec, ptSpec; - std::tie(ptSpec, scSpec) = spec.rsplit(','); - if (scSpec.empty()) { - emitError(loc, - "expected comma to separate pointee type and storage class in '") - << spec << "'"; + auto pointeeType = parseAndVerifyType(dialect, parser); + if (!pointeeType) return Type(); - } - scSpec = scSpec.trim(); - auto storageClass = symbolizeStorageClass(scSpec); - if (!storageClass) { - emitError(loc, "unknown storage class: ") << scSpec; + StringRef storageClassSpec; + llvm::SMLoc storageClassLoc = parser.getCurrentLocation(); + if (parser.parseComma() || parser.parseKeyword(&storageClassSpec)) return Type(); - } - if (ptSpec.trim().empty()) { - emitError(loc, "expected pointee type"); + auto storageClass = symbolizeStorageClass(storageClassSpec); + if (!storageClass) { + parser.emitError(storageClassLoc, "unknown storage class: ") + << storageClassSpec; return Type(); } - - auto pointeeType = parseAndVerifyType(dialect, ptSpec, loc); - if (!pointeeType) + if (parser.parseGreater()) return Type(); - return PointerType::get(pointeeType, *storageClass); } // runtime-array-type ::= `!spv.rtarray<` element-type `>` -static Type parseRuntimeArrayType(SPIRVDialect const &dialect, StringRef spec, - Location loc) { - if (!spec.consume_front("rtarray<") || !spec.consume_back(">")) { - emitError(loc, "spv.rtarray delimiter <...> mismatch"); +static Type parseRuntimeArrayType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + if (parser.parseLess()) return Type(); - } - if (spec.trim().empty()) { - emitError(loc, "expected element type"); - return Type(); - } - - Type elementType = parseAndVerifyType(dialect, spec, loc); + Type elementType = parseAndVerifyType(dialect, parser); if (!elementType) return Type(); + if (parser.parseGreater()) + return Type(); return RuntimeArrayType::get(elementType); } // Specialize this function to parse each of the parameters that define an // ImageType. By default it assumes this is an enum type. template -static Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, - StringRef spec) { - auto val = spirv::symbolizeEnum()(spec); - if (!val) { - emitError(loc, "unknown attribute: '") << spec << "'"; +static Optional parseAndVerify(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + StringRef enumSpec; + llvm::SMLoc enumLoc = parser.getCurrentLocation(); + if (parser.parseKeyword(&enumSpec)) { + return llvm::None; } + + auto val = spirv::symbolizeEnum()(enumSpec); + if (!val) + parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'"; return val; } template <> -Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, - StringRef spec) { +Optional parseAndVerify(SPIRVDialect const &dialect, + DialectAsmParser &parser) { // TODO(ravishankarm): Further verify that the element type can be sampled - auto ty = parseAndVerifyType(dialect, spec, loc); - if (!ty) { + auto ty = parseAndVerifyType(dialect, parser); + if (!ty) return llvm::None; - } return ty; } template static Optional parseAndVerifyInteger(SPIRVDialect const &dialect, - Location loc, StringRef spec) { + DialectAsmParser &parser) { IntTy offsetVal = std::numeric_limits::max(); - spec = spec.trim(); - if (spec.consumeInteger(10, offsetVal)) { - return llvm::None; - } - spec = spec.trim(); - if (!spec.empty()) { + if (parser.parseInteger(offsetVal)) return llvm::None; - } return offsetVal; } template <> Optional parseAndVerify(SPIRVDialect const &dialect, - Location loc, StringRef spec) { - return parseAndVerifyInteger(dialect, loc, spec); + DialectAsmParser &parser) { + return parseAndVerifyInteger(dialect, parser); } // Functor object to parse a comma separated list of specs. The function @@ -433,28 +369,17 @@ Optional parseAndVerify(SPIRVDialect const &dialect, // (termination condition) needs partial specialization. template struct parseCommaSeparatedList { Optional> - operator()(SPIRVDialect const &dialect, Location loc, StringRef spec) const { - auto numArgs = std::tuple_size>::value; - StringRef parseSpec, restSpec; - std::tie(parseSpec, restSpec) = spec.split(','); - - parseSpec = parseSpec.trim(); - if (numArgs != 0 && restSpec.empty()) { - emitError(loc, "expected more parameters for image type '") - << parseSpec << "'"; + operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const { + auto parseVal = parseAndVerify(dialect, parser); + if (!parseVal) return llvm::None; - } - auto parseVal = parseAndVerify(dialect, loc, parseSpec); - if (!parseVal) { + auto numArgs = std::tuple_size>::value; + if (numArgs != 0 && failed(parser.parseComma())) return llvm::None; - } - - auto remainingValues = - parseCommaSeparatedList{}(dialect, loc, restSpec); - if (!remainingValues) { + auto remainingValues = parseCommaSeparatedList{}(dialect, parser); + if (!remainingValues) return llvm::None; - } return std::tuple_cat(std::tuple(parseVal.getValue()), remainingValues.getValue()); } @@ -463,14 +388,11 @@ template struct parseCommaSeparatedList { // Partial specialization of the function to parse a comma separated list of // specs to parse the last element of the list. template struct parseCommaSeparatedList { - Optional> - operator()(SPIRVDialect const &dialect, Location loc, StringRef spec) const { - spec = spec.trim(); - auto value = parseAndVerify(dialect, loc, spec); - if (!value) { - return llvm::None; - } - return std::tuple(value.getValue()); + Optional> operator()(SPIRVDialect const &dialect, + DialectAsmParser &parser) const { + if (auto value = parseAndVerify(dialect, parser)) + return std::tuple(value.getValue()); + return llvm::None; } }; @@ -489,118 +411,103 @@ template struct parseCommaSeparatedList { // image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,` // arrayed-info `,` sampling-info `,` // sampler-use-info `,` format `>` -static Type parseImageType(SPIRVDialect const &dialect, StringRef spec, - Location loc) { - if (!spec.consume_front("image<") || !spec.consume_back(">")) { - emitError(loc, "spv.image delimiter <...> mismatch"); +static Type parseImageType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + if (parser.parseLess()) return Type(); - } auto value = parseCommaSeparatedList{}(dialect, loc, spec); - if (!value) { + ImageFormat>{}(dialect, parser); + if (!value) return Type(); - } + if (parser.parseGreater()) + return Type(); return ImageType::get(value.getValue()); } // Parse decorations associated with a member. static ParseResult parseStructMemberDecorations( - SPIRVDialect const &dialect, Location loc, StringRef spec, + SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef memberTypes, SmallVectorImpl &layoutInfo, SmallVectorImpl &memberDecorationInfo) { - spec = spec.trim(); - auto memberInfo = spec.split(','); + // Check if the first element is offset. - auto layout = - parseAndVerify(dialect, loc, memberInfo.first); - if (layout) { + llvm::SMLoc layoutLoc = parser.getCurrentLocation(); + StructType::LayoutInfo layout = 0; + OptionalParseResult layoutParseResult = parser.parseOptionalInteger(layout); + if (layoutParseResult.hasValue()) { + if (failed(*layoutParseResult)) + return failure(); + if (layoutInfo.size() != memberTypes.size() - 1) { - return emitError(loc, - "layout specification must be given for all members"); + return parser.emitError( + layoutLoc, "layout specification must be given for all members"); } - layoutInfo.push_back(layout.getValue()); - spec = memberInfo.second.trim(); + layoutInfo.push_back(layout); } + // Check for no spirv::Decorations. + if (succeeded(parser.parseOptionalRSquare())) + return success(); + + // If there was a layout, make sure to parse the comma. + if (layoutParseResult.hasValue() && parser.parseComma()) + return failure(); + // Check for spirv::Decorations. - while (!spec.empty()) { - memberInfo = spec.split(','); - auto memberDecoration = - parseAndVerify(dialect, loc, memberInfo.first); - if (!memberDecoration) { + do { + auto memberDecoration = parseAndVerify(dialect, parser); + if (!memberDecoration) return failure(); - } + memberDecorationInfo.emplace_back( static_cast(memberTypes.size() - 1), memberDecoration.getValue()); - spec = memberInfo.second.trim(); - } - return success(); + } while (succeeded(parser.parseOptionalComma())); + + return parser.parseRSquare(); } // struct-member-decoration ::= integer-literal? spirv-decoration* // struct-type ::= `!spv.struct<` spirv-type (`[` struct-member-decoration `]`)? -// (`, ` spirv-type (`[` struct-member-decoration `]`)? -static Type parseStructType(SPIRVDialect const &dialect, StringRef spec, - Location loc) { - if (!spec.consume_front("struct<") || !spec.consume_back(">")) { - emitError(loc, "spv.struct delimiter <...> mismatch"); +// (`, ` spirv-type (`[` struct-member-decoration `]`)? `>` +static Type parseStructType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + if (parser.parseLess()) return Type(); - } + + if (succeeded(parser.parseOptionalGreater())) + return StructType::getEmpty(dialect.getContext()); SmallVector memberTypes; SmallVector layoutInfo; SmallVector memberDecorationInfo; - auto *context = dialect.getContext(); - while (!spec.empty()) { - spec = spec.trim(); - size_t pos = 0; - auto memberType = mlir::parseType(spec, context, pos); - if (!memberType) { - emitError(loc, "cannot parse type from '") << spec << "'"; - } + do { + Type memberType; + if (parser.parseType(memberType)) + return Type(); memberTypes.push_back(memberType); - spec = spec.substr(pos).trim(); - if (spec.consume_front("[")) { - auto rSquare = spec.find(']'); - if (rSquare == StringRef::npos) { - emitError(loc, "missing matching ']' in ") << spec; - return Type(); - } - if (parseStructMemberDecorations(dialect, loc, spec.substr(0, rSquare), - memberTypes, layoutInfo, + if (succeeded(parser.parseOptionalLSquare())) { + if (parseStructMemberDecorations(dialect, parser, memberTypes, layoutInfo, memberDecorationInfo)) { return Type(); } - spec = spec.substr(rSquare + 1).trim(); } + } while (succeeded(parser.parseOptionalComma())); - // Handle comma. - if (!spec.consume_front(",")) { - // End of decorations list. - break; - } - } - spec = spec.trim(); - if (!spec.empty()) { - emitError(loc, "unexpected substring '") - << spec << "' while parsing StructType"; - return Type(); - } if (!layoutInfo.empty() && memberTypes.size() != layoutInfo.size()) { - emitError(loc, "layout specification must be given for all members"); + parser.emitError(parser.getNameLoc(), + "layout specification must be given for all members"); return Type(); } - if (memberTypes.empty()) { - return StructType::getEmpty(dialect.getContext()); - } + if (parser.parseGreater()) + return Type(); return StructType::get(memberTypes, layoutInfo, memberDecorationInfo); } @@ -611,21 +518,22 @@ static Type parseStructType(SPIRVDialect const &dialect, StringRef spec, // | runtime-array-type // | struct-type Type SPIRVDialect::parseType(DialectAsmParser &parser) const { - StringRef spec = parser.getFullSymbolSpec(); - Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); - - if (spec.startswith("array")) - return parseArrayType(*this, spec, loc); - if (spec.startswith("image")) - return parseImageType(*this, spec, loc); - if (spec.startswith("ptr")) - return parsePointerType(*this, spec, loc); - if (spec.startswith("rtarray")) - return parseRuntimeArrayType(*this, spec, loc); - if (spec.startswith("struct")) - return parseStructType(*this, spec, loc); - - emitError(loc, "unknown SPIR-V type: ") << spec; + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return Type(); + + if (keyword == "array") + return parseArrayType(*this, parser); + if (keyword == "image") + return parseImageType(*this, parser); + if (keyword == "ptr") + return parsePointerType(*this, parser); + if (keyword == "rtarray") + return parseRuntimeArrayType(*this, parser); + if (keyword == "struct") + return parseStructType(*this, parser); + + parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword; return Type(); } diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp index 870e33d3b2e1..bc0070231ff9 100644 --- a/lib/Parser/Parser.cpp +++ b/lib/Parser/Parser.cpp @@ -464,10 +464,13 @@ class CustomDialectAsmParser : public DialectAsmParser { return emitError(getCurrentLocation(), "expected floating point literal"); } - /// Parse an integer value from the stream. - ParseResult parseInteger(uint64_t &result) override { - bool negative = parser.consumeIf(Token::minus); + /// Parse an optional integer value from the stream. + OptionalParseResult parseOptionalInteger(uint64_t &result) override { + Token curToken = parser.getToken(); + if (curToken.isNot(Token::integer, Token::minus)) + return llvm::None; + bool negative = parser.consumeIf(Token::minus); Token curTok = parser.getToken(); if (parser.parseToken(Token::integer, "expected integer value")) return failure(); @@ -548,6 +551,11 @@ class CustomDialectAsmParser : public DialectAsmParser { return parser.parseToken(Token::greater, "expected '>'"); } + /// Parse a `>` token if present. + ParseResult parseOptionalGreater() override { + return success(parser.consumeIf(Token::greater)); + } + /// Parse a `(` token. ParseResult parseLParen() override { return parser.parseToken(Token::l_paren, "expected '('"); diff --git a/test/Dialect/SPIRV/types.mlir b/test/Dialect/SPIRV/types.mlir index f784caf93181..266790d4d954 100644 --- a/test/Dialect/SPIRV/types.mlir +++ b/test/Dialect/SPIRV/types.mlir @@ -17,32 +17,32 @@ func @array_type_stride(!spv.array< 4 x !spv.array<4 x f32 [4]> [128]>) -> () // ----- -// expected-error @+1 {{spv.array delimiter <...> mismatch}} +// expected-error @+1 {{expected '<'}} func @missing_left_angle_bracket(!spv.array 4xf32>) -> () // ----- -// expected-error @+1 {{expected array element count followed by 'x' but found 'f32'}} +// expected-error @+1 {{expected single integer for array element count}} func @missing_count(!spv.array) -> () // ----- -// expected-error @+1 {{expected array element count followed by 'x' but found 'f32'}} +// expected-error @+1 {{expected 'x' in dimension list}} func @missing_x(!spv.array<4 f32>) -> () // ----- -// expected-error @+1 {{expected element type}} +// expected-error @+1 {{expected non-function type}} func @missing_element_type(!spv.array<4x>) -> () // ----- -// expected-error @+1 {{cannot parse type: blabla}} +// expected-error @+1 {{expected non-function type}} func @cannot_parse_type(!spv.array<4xblabla>) -> () // ----- -// expected-error @+1 {{cannot parse type: 3xf32}} +// expected-error @+1 {{expected single integer for array element count}} func @more_than_one_dim(!spv.array<4x3xf32>) -> () // ----- @@ -102,17 +102,17 @@ func @vector_ptr_type(!spv.ptr,PushConstant>) -> () // ----- -// expected-error @+1 {{spv.ptr delimiter <...> mismatch}} +// expected-error @+1 {{expected '<'}} func @missing_left_angle_bracket(!spv.ptr f32, Uniform>) -> () // ----- -// expected-error @+1 {{expected comma to separate pointee type and storage class in 'f32 Uniform'}} +// expected-error @+1 {{expected ','}} func @missing_comma(!spv.ptr) -> () // ----- -// expected-error @+1 {{expected pointee type}} +// expected-error @+1 {{expected non-function type}} func @missing_pointee_type(!spv.ptr<, Uniform>) -> () // ----- @@ -134,17 +134,17 @@ func @vector_runtime_array_type(!spv.rtarray< vector<4xf32> >) -> () // ----- -// expected-error @+1 {{spv.rtarray delimiter <...> mismatch}} +// expected-error @+1 {{expected '<'}} func @missing_left_angle_bracket(!spv.rtarray f32>) -> () // ----- -// expected-error @+1 {{expected element type}} +// expected-error @+1 {{expected non-function type}} func @missing_element_type(!spv.rtarray<>) -> () // ----- -// expected-error @+1 {{cannot parse type: 4xf32}} +// expected-error @+1 {{expected non-function type}} func @redundant_count(!spv.rtarray<4xf32>) -> () // ----- @@ -153,68 +153,68 @@ func @redundant_count(!spv.rtarray<4xf32>) -> () // ImageType //===----------------------------------------------------------------------===// -// CHECK: func @image_parameters_1D(!spv.image) -func @image_parameters_1D(!spv.image) -> () +// CHECK: func @image_parameters_1D(!spv.image) +func @image_parameters_1D(!spv.image) -> () // ----- -// expected-error @+1 {{expected more parameters for image type 'f32'}} +// expected-error @+1 {{expected ','}} func @image_parameters_one_element(!spv.image) -> () // ----- -// expected-error @+1 {{expected more parameters for image type '1D'}} -func @image_parameters_two_elements(!spv.image) -> () +// expected-error @+1 {{expected ','}} +func @image_parameters_two_elements(!spv.image) -> () // ----- -// expected-error @+1 {{expected more parameters for image type 'NoDepth'}} -func @image_parameters_three_elements(!spv.image) -> () +// expected-error @+1 {{expected ','}} +func @image_parameters_three_elements(!spv.image) -> () // ----- -// expected-error @+1 {{expected more parameters for image type 'NonArrayed'}} -func @image_parameters_four_elements(!spv.image) -> () +// expected-error @+1 {{expected ','}} +func @image_parameters_four_elements(!spv.image) -> () // ----- -// expected-error @+1 {{expected more parameters for image type 'SingleSampled'}} -func @image_parameters_five_elements(!spv.image) -> () +// expected-error @+1 {{expected ','}} +func @image_parameters_five_elements(!spv.image) -> () // ----- -// expected-error @+1 {{expected more parameters for image type 'SamplerUnknown'}} -func @image_parameters_six_elements(!spv.image) -> () +// expected-error @+1 {{expected ','}} +func @image_parameters_six_elements(!spv.image) -> () // ----- -// expected-error @+1 {{spv.image delimiter <...> mismatch}} -func @image_parameters_delimiter(!spv.image f32, 1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>) -> () +// expected-error @+1 {{expected '<'}} +func @image_parameters_delimiter(!spv.image f32, Dim1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>) -> () // ----- -// expected-error @+1 {{unknown attribute: '1D NoDepth'}} -func @image_parameters_nocomma_1(!spv.image) -> () +// expected-error @+1 {{expected ','}} +func @image_parameters_nocomma_1(!spv.image) -> () // ----- -// expected-error @+1 {{unknown attribute: 'NoDepth NonArrayed'}} -func @image_parameters_nocomma_2(!spv.image) -> () +// expected-error @+1 {{expected ','}} +func @image_parameters_nocomma_2(!spv.image) -> () // ----- -// expected-error @+1 {{unknown attribute: 'NonArrayed SingleSampled'}} -func @image_parameters_nocomma_3(!spv.image) -> () +// expected-error @+1 {{expected ','}} +func @image_parameters_nocomma_3(!spv.image) -> () // ----- -// expected-error @+1 {{unknown attribute: 'SingleSampled SamplerUnknown'}} -func @image_parameters_nocomma_4(!spv.image) -> () +// expected-error @+1 {{expected ','}} +func @image_parameters_nocomma_4(!spv.image) -> () // ----- -// expected-error @+1 {{expected more parameters for image type 'SamplerUnknown Unknown'}} -func @image_parameters_nocomma_5(!spv.image) -> () +// expected-error @+1 {{expected ','}} +func @image_parameters_nocomma_5(!spv.image) -> () // ----- @@ -228,8 +228,8 @@ func @struct_type(!spv.struct) -> () // CHECK: func @struct_type2(!spv.struct) func @struct_type2(!spv.struct) -> () -// CHECK: func @struct_type_simple(!spv.struct>) -func @struct_type_simple(!spv.struct>) -> () +// CHECK: func @struct_type_simple(!spv.struct>) +func @struct_type_simple(!spv.struct>) -> () // CHECK: func @struct_type_with_offset(!spv.struct) func @struct_type_with_offset(!spv.struct) -> () @@ -282,21 +282,16 @@ func @struct_type_missing_offset2(!spv.struct) -> () // ----- -// expected-error @+1 {{unexpected substring 'i32' while parsing StructType}} +// expected-error @+1 {{expected '>'}} func @struct_type_missing_comma1(!spv.struct) -> () // ----- -// expected-error @+1 {{unexpected substring 'i32' while parsing StructType}} +// expected-error @+1 {{expected '>'}} func @struct_type_missing_comma2(!spv.struct) -> () // ----- -// expected-error @+1 {{unknown attribute: '-1'}} -func @struct_type_neg_offset(!spv.struct) -> () - -// ----- - // expected-error @+1 {{unbalanced '>' character in pretty dialect name}} func @struct_type_neg_offset(!spv.struct) -> () @@ -307,20 +302,20 @@ func @struct_type_neg_offset(!spv.struct) -> () // ----- -// expected-error @+1 {{unknown attribute: 'NonWritable 0'}} +// expected-error @+1 {{expected ']'}} func @struct_type_neg_offset(!spv.struct) -> () // ----- -// expected-error @+1 {{unknown attribute: '0'}} +// expected-error @+1 {{expected valid keyword}} func @struct_type_neg_offset(!spv.struct) -> () // ----- -// expected-error @+1 {{unknown attribute: '0 NonWritable'}} +// expected-error @+1 {{expected ','}} func @struct_type_missing_comma(!spv.struct) // ----- -// expected-error @+1 {{unknown attribute: 'NonWritable NonReadable'}} +// expected-error @+1 {{expected ']'}} func @struct_type_missing_comma(!spv.struct) diff --git a/test/IR/invalid.mlir b/test/IR/invalid.mlir index 2f8dcc96b6d1..ed650290a908 100644 --- a/test/IR/invalid.mlir +++ b/test/IR/invalid.mlir @@ -1081,11 +1081,6 @@ func @bad_complex(complex) - -// ----- - func @invalid_region_dominance() { "foo.region"() ({ // expected-error @+1 {{operand #0 does not dominate this use}}