From 8c33ed14cd8c571e69d558d27468d8c325bc0e8b Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Wed, 25 Sep 2024 16:51:35 -0700 Subject: [PATCH] [HLO Componentization] Create hlo/parser sub-component (Phase I). This CL takes care of 1. Migrating xla/hlo_parser* and xla/hlo_pexer* --> xla/hlo/parser 2. Setting up build aliases for xla/hlo_[parser|lexer]* ensuring external dependencies are still satisfied. Phase II will take care of migration of external projects dependencies from xla/hlo_[parser|lexer]* --> xla/hlo/parser PiperOrigin-RevId: 678897485 --- xla/hlo/parser/BUILD | 115 +++++++++ xla/{service => hlo/parser}/hlo_lexer.cc | 2 +- xla/hlo/parser/hlo_lexer.h | 218 ++++++++++++++++++ xla/{service => hlo/parser}/hlo_parser.cc | 4 +- xla/hlo/parser/hlo_parser.h | 116 ++++++++++ .../parser}/hlo_parser_test.cc | 4 +- xla/service/BUILD | 83 +------ xla/service/hlo_lexer.h | 200 +--------------- xla/service/hlo_parser.h | 97 +------- 9 files changed, 464 insertions(+), 375 deletions(-) create mode 100644 xla/hlo/parser/BUILD rename xla/{service => hlo/parser}/hlo_lexer.cc (99%) create mode 100644 xla/hlo/parser/hlo_lexer.h rename xla/{service => hlo/parser}/hlo_parser.cc (99%) create mode 100644 xla/hlo/parser/hlo_parser.h rename xla/{service => hlo/parser}/hlo_parser_test.cc (99%) diff --git a/xla/hlo/parser/BUILD b/xla/hlo/parser/BUILD new file mode 100644 index 0000000000000..5fa7d1cbffbef --- /dev/null +++ b/xla/hlo/parser/BUILD @@ -0,0 +1,115 @@ +# Description: +# XLA parser implementation. + +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load( + "//xla:xla.bzl", + "xla_cc_test", +) +load("//xla/tsl:tsl.bzl", "internal_visibility") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "hlo_parser", + srcs = ["hlo_parser.cc"], + hdrs = ["hlo_parser.h"], + deps = [ + ":hlo_lexer", + "//xla:array", + "//xla:comparison_util", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_layout", + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:tile_assignment", + "//xla/service:computation_layout", + "//xla/service:hlo_module_config", + "//xla/service:hlo_proto_cc", + "//xla/service:name_uniquer", + "//xla/service:shape_inference", + "//xla/tsl/lib/gtl:map_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:status", + ], +) + +xla_cc_test( + name = "hlo_parser_test", + size = "small", + srcs = ["hlo_parser_test.cc"], + deps = [ + ":hlo_lexer", + ":hlo_parser", + "//xla:array", + "//xla:shape_util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + ], +) + +cc_library( + name = "hlo_lexer", + srcs = ["hlo_lexer.cc"], + hdrs = [ + "hlo_lexer.h", + ], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "@com_google_absl//absl/base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:numbers", + "@tsl//tsl/platform:regexp", + ], +) diff --git a/xla/service/hlo_lexer.cc b/xla/hlo/parser/hlo_lexer.cc similarity index 99% rename from xla/service/hlo_lexer.cc rename to xla/hlo/parser/hlo_lexer.cc index 546e4989a380f..4c294b8567e32 100644 --- a/xla/service/hlo_lexer.cc +++ b/xla/hlo/parser/hlo_lexer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_lexer.h" +#include "xla/hlo/parser/hlo_lexer.h" #include #include diff --git a/xla/hlo/parser/hlo_lexer.h b/xla/hlo/parser/hlo_lexer.h new file mode 100644 index 0000000000000..f787392b39b37 --- /dev/null +++ b/xla/hlo/parser/hlo_lexer.h @@ -0,0 +1,218 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_PARSER_HLO_LEXER_H_ +#define XLA_HLO_PARSER_HLO_LEXER_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/regexp.h" + +namespace xla { + +// Defines different kinds of tokens used by the HLO lexer. +// +// You shouldn't need to use this directly unless you're using HloLexer +// directly, and you probably don't need to do that. Use hlo_parser instead. +enum class TokKind { + // Markers + kEof, + kError, + + // Tokens with no info. + kEqual, // = + kComma, // , + kColon, // : + kAsterisk, // * + kQuestionMark, // ? + kOctothorp, // # + kPlus, // + + kTilde, // ~ + kLsquare, + kRsquare, // [ ] + kLbrace, + kRbrace, // { } + kLparen, + kRparen, // ( ) + kDots, // ... + + kArrow, // -> + kLeq, // <= + + // Keywords + kw_HloModule, + kw_ENTRY, + kw_ROOT, + kw_true, + kw_false, + kw_maximal, + kw_replicated, + kw_manual, + kw_last_tile_dim_replicate, + kw_shard_as, + kw_shard_like, + kw_unknown, + kw_inf, + + kNegInf, // -inf + + // Typed tokens. + kPrimitiveType, // F32, PRED, etc. + kName, // %foo + kAttributeName, // dimensions= + kDimLabels, // [0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,} + kDxD, // [0-9]+(x[0-9]+)+ + kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* + kSparsityDesc, // ([LR]\.[0-9]+@[0-9]+:[0-9]+_?)+ + kIdent, // other identifiers + kString, // "abcd\"\n" + kInt, // 42 + kDecimal, // 4.2 +}; + +std::string TokKindToString(TokKind kind); + +// Lexer for the HloModule::ToString() format text. +// +// This class is meant to be used by hlo_parser.cc. You shouldn't need to use +// it directly. +class HloLexer { + public: + explicit HloLexer(absl::string_view buf) : buf_(buf) { + current_ptr_ = buf_.data(); + } + + TokKind Lex() { return token_state_.current_kind = LexToken(); } + + TokKind GetKind() const { return token_state_.current_kind; } + std::string GetStrVal() const { + switch (GetKind()) { + case TokKind::kName: + case TokKind::kAttributeName: + case TokKind::kDimLabels: + case TokKind::kDxD: + case TokKind::kPad: + case TokKind::kSparsityDesc: + case TokKind::kString: + case TokKind::kIdent: + return token_state_.str_val; + default: + LOG(FATAL) << "This token does not have string value"; + } + } + int64_t GetInt64Val() const { + CHECK(GetKind() == TokKind::kInt) << TokKindToString(GetKind()); + return token_state_.int64_val; + } + double GetDecimalVal() const { + CHECK(GetKind() == TokKind::kDecimal); + return token_state_.decimal_val; + } + PrimitiveType GetPrimitiveTypeVal() const { + CHECK(GetKind() == TokKind::kPrimitiveType); + return token_state_.primitive_type_val; + } + + typedef const char* LocTy; + + // Returns the location of the current token. + LocTy GetLoc() const { return token_state_.token_start; } + + // Returns the line and column of a location in the buffer. + std::pair GetLineAndColumn(LocTy location) const; + + // Returns the whole line given the location. + absl::string_view GetLine(LocTy loc) const; + + // Looks ahead one token and returns it. Lexer state is unchanged. + TokKind LookAhead(); + + // Lexes a string delimited by matching curly braces. Curlies contained + // inside double quotes don't count. + // + // Requires that you've already lexed the open curly brace. + // + // The returned string value includes the outer curlies. + // + // Returns TokKind::kString on success. + TokKind LexJsonDict(); + + private: + // Returns the current character. If it's neither the end of input buffer nor + // an invalid character, moves the pointer forward. + int GetNextChar(); + + // Returns the current character. + int PeekCurrentChar() const; + + // Creates string_view with the given begin and end. Exits if the begin > end, + // or it's out of the range of the current buffer. + absl::string_view StringViewFromPointers(const char* begin, + const char* end) const; + + // Returns true if the given ptr is dereferenceable within the range of the + // current buffer. + bool CanDereference(const char* ptr) const; + + TokKind LexToken(); + + TokKind LexIdentifier(); + TokKind LexPercent(); + TokKind LexShape(); + TokKind LexConstant(); + TokKind LexNumberOrPattern(); + TokKind LexString(); + + std::optional LexNanPayload(absl::string_view& consumable); + + absl::string_view buf_; + const char* current_ptr_; + + // Information about the current token. + struct TokenState { + const char* token_start = nullptr; + TokKind current_kind; + std::string str_val; + int64_t int64_val; + double decimal_val; + PrimitiveType primitive_type_val; + }; + TokenState token_state_; + + struct LineNoCacheTy { + const char* last_query; + unsigned line_no_of_query; + }; + // This caches the line number of the previous query. + mutable LineNoCacheTy line_no_cache_{nullptr, 0}; +}; + +// Does this string start with "{", end with "}", and contain valid-ish JSON +// in-between? If so, hlo_parser can parse e.g. backend_config={blah: "blah"} +// instead of the much uglier backend_config="{blah: \"blah\"}". +// +// (Technically we're not checking for fully-valid JSON, just something we can +// find the end of reasonably.) +bool LexesAsJsonDict(absl::string_view str); + +} // namespace xla + +#endif // XLA_HLO_PARSER_HLO_LEXER_H_ diff --git a/xla/service/hlo_parser.cc b/xla/hlo/parser/hlo_parser.cc similarity index 99% rename from xla/service/hlo_parser.cc rename to xla/hlo/parser/hlo_parser.cc index fcc3f93863ba0..49309fac356a3 100644 --- a/xla/service/hlo_parser.cc +++ b/xla/hlo/parser/hlo_parser.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include #include @@ -60,6 +60,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" #include "xla/hlo/ir/tile_assignment.h" +#include "xla/hlo/parser/hlo_lexer.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" @@ -67,7 +68,6 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/computation_layout.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_lexer.h" #include "xla/service/hlo_module_config.h" #include "xla/service/name_uniquer.h" #include "xla/service/shape_inference.h" diff --git a/xla/hlo/parser/hlo_parser.h b/xla/hlo/parser/hlo_parser.h new file mode 100644 index 0000000000000..9a59260976de9 --- /dev/null +++ b/xla/hlo/parser/hlo_parser.h @@ -0,0 +1,116 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_PARSER_HLO_PARSER_H_ +#define XLA_HLO_PARSER_HLO_PARSER_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/parser/hlo_lexer.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class HloParserOptions { + public: + // When a shape layout is not set (e.g. in the entry computation layout or + // instruction layout), set the layout to be the default (e.g. {3,2,1,0}). + HloParserOptions& set_fill_missing_layouts(bool value) { + fill_missing_layouts_ = value; + return *this; + } + + bool fill_missing_layouts() const { return fill_missing_layouts_; } + + private: + bool fill_missing_layouts_ = true; +}; + +// Given a string in the HloModule::ToString() format, parses the string and +// creates a HloModule with the given config. +// Note: Tests derived from HloTestBase should use +// ParseAndReturnVerifiedModule() instead! +absl::StatusOr> ParseAndReturnUnverifiedModule( + absl::string_view str, const HloModuleConfig& config = HloModuleConfig(), + const HloParserOptions& options = HloParserOptions()); + +// Parses sharding from str. str is supposed to contain the body of the +// sharding, i.e. just the rhs of the "sharding={...}" attribute string, e.g., +// "{replicated}". +absl::StatusOr ParseSharding(absl::string_view str); + +// Parses frontend attributes from str. str is supposed to contain the body of +// the frontend attributes , i.e. just the rhs of the +// "frontend_attributes={...}" attribute string, e.g., +// "{attr_a=a,attr_b=b}". +absl::StatusOr ParseFrontendAttributes( + absl::string_view str); + +// Parses statistics viz from str. str is supposed to contain the body of the +// statistics visualization, i.e. just the rhs of the "statistics={...}" +// attribute string, e.g., "{visualizing_index=1,nan_percent=50}". +absl::StatusOr ParseStatisticsViz(absl::string_view str); + +// Parses parameter replication from str. str is supposed to contain the body of +// the parameter replication, i.e. just the rhs of the +// "parameter_replication={...}" attribute string, e.g., "{true, false}". +absl::StatusOr> ParseParameterReplication( + absl::string_view str); + +// Parses the result of window_util::ToString(const Window&). +absl::StatusOr ParseWindow(absl::string_view str); + +// Parses the result of ConvolutionDimensionNumbersToString(), e.g. +// "b0f_0io->b0f". +absl::StatusOr ParseConvolutionDimensionNumbers( + absl::string_view str); + +// Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". +absl::StatusOr ParsePaddingConfig(absl::string_view str); + +// Parses and returns a Shape::ToString-format string. +absl::StatusOr ParseShape(absl::string_view str); + +// Parses and returns a Layout::ToString-format string. +absl::StatusOr ParseLayout(absl::string_view str); + +// Parses and returns a std::vector from str. str is supposed to +// contain a list of the replica groups, i.e. just the rhs of the +// "replica_groups={...}" attribute string, e.g., "{{0,1}, {2,3}}". +absl::StatusOr> ParseReplicaGroupsOnly( + absl::string_view str); + +class HloParser { + public: + // Runs the parser and constructs the resulting HLO in the given (empty) + // HloModule. Returns the error status in case an error occurred. + virtual absl::Status Run(HloModule* module) = 0; + virtual ~HloParser() {} + + private: + static std::unique_ptr CreateHloParserForTests( + absl::string_view str); + friend class VerifiedHloModule; +}; + +} // namespace xla + +#endif // XLA_HLO_PARSER_HLO_PARSER_H_ diff --git a/xla/service/hlo_parser_test.cc b/xla/hlo/parser/hlo_parser_test.cc similarity index 99% rename from xla/service/hlo_parser_test.cc rename to xla/hlo/parser/hlo_parser_test.cc index 40c12a92972ca..0d0b7fabff36e 100644 --- a/xla/service/hlo_parser_test.cc +++ b/xla/hlo/parser/hlo_parser_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include #include @@ -39,9 +39,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_lexer.h" #include "xla/layout.h" #include "xla/layout_util.h" -#include "xla/service/hlo_lexer.h" #include "xla/service/hlo_module_config.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" diff --git a/xla/service/BUILD b/xla/service/BUILD index 1c88a5fb887f1..e1578a3b0c79c 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -7069,96 +7069,25 @@ xla_cc_test( ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/parser:hlo_parser +# instead. cc_library( name = "hlo_parser", - srcs = ["hlo_parser.cc"], hdrs = ["hlo_parser.h"], deps = [ - ":computation_layout", - ":hlo_lexer", - ":hlo_module_config", - ":hlo_proto_cc", - ":name_uniquer", - ":shape_inference", - "//xla:array", - "//xla:comparison_util", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_layout", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:tile_assignment", - "//xla/tsl/lib/gtl:map_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@eigen_archive//:eigen3", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", - ], -) - -xla_cc_test( - name = "hlo_parser_test", - size = "small", - srcs = ["hlo_parser_test.cc"], - deps = [ - ":hlo_lexer", - ":hlo_module_config", - ":hlo_parser", - ":pattern_matcher", - ":pattern_matcher_gmock", - "//xla:array", - "//xla:shape_util", - "//xla:window_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:status_matchers", - "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:test", + "//xla/hlo/parser:hlo_parser", ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/parser:hlo_lexer +# instead. cc_library( name = "hlo_lexer", - srcs = ["hlo_lexer.cc"], hdrs = [ "hlo_lexer.h", ], deps = [ - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "@com_google_absl//absl/base", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:numbers", - "@tsl//tsl/platform:regexp", + "//xla/hlo/parser:hlo_lexer", ], ) diff --git a/xla/service/hlo_lexer.h b/xla/service/hlo_lexer.h index 7f6346da55804..aad399ed291f3 100644 --- a/xla/service/hlo_lexer.h +++ b/xla/service/hlo_lexer.h @@ -16,203 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_LEXER_H_ #define XLA_SERVICE_HLO_LEXER_H_ -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/regexp.h" - -namespace xla { - -// Defines different kinds of tokens used by the HLO lexer. -// -// You shouldn't need to use this directly unless you're using HloLexer -// directly, and you probably don't need to do that. Use hlo_parser instead. -enum class TokKind { - // Markers - kEof, - kError, - - // Tokens with no info. - kEqual, // = - kComma, // , - kColon, // : - kAsterisk, // * - kQuestionMark, // ? - kOctothorp, // # - kPlus, // + - kTilde, // ~ - kLsquare, - kRsquare, // [ ] - kLbrace, - kRbrace, // { } - kLparen, - kRparen, // ( ) - kDots, // ... - - kArrow, // -> - kLeq, // <= - - // Keywords - kw_HloModule, - kw_ENTRY, - kw_ROOT, - kw_true, - kw_false, - kw_maximal, - kw_replicated, - kw_manual, - kw_last_tile_dim_replicate, - kw_shard_as, - kw_shard_like, - kw_unknown, - kw_inf, - - kNegInf, // -inf - - // Typed tokens. - kPrimitiveType, // F32, PRED, etc. - kName, // %foo - kAttributeName, // dimensions= - kDimLabels, // [0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,} - kDxD, // [0-9]+(x[0-9]+)+ - kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* - kSparsityDesc, // ([LR]\.[0-9]+@[0-9]+:[0-9]+_?)+ - kIdent, // other identifiers - kString, // "abcd\"\n" - kInt, // 42 - kDecimal, // 4.2 -}; - -std::string TokKindToString(TokKind kind); - -// Lexer for the HloModule::ToString() format text. -// -// This class is meant to be used by hlo_parser.cc. You shouldn't need to use -// it directly. -class HloLexer { - public: - explicit HloLexer(absl::string_view buf) : buf_(buf) { - current_ptr_ = buf_.data(); - } - - TokKind Lex() { return token_state_.current_kind = LexToken(); } - - TokKind GetKind() const { return token_state_.current_kind; } - std::string GetStrVal() const { - switch (GetKind()) { - case TokKind::kName: - case TokKind::kAttributeName: - case TokKind::kDimLabels: - case TokKind::kDxD: - case TokKind::kPad: - case TokKind::kSparsityDesc: - case TokKind::kString: - case TokKind::kIdent: - return token_state_.str_val; - default: - LOG(FATAL) << "This token does not have string value"; - } - } - int64_t GetInt64Val() const { - CHECK(GetKind() == TokKind::kInt) << TokKindToString(GetKind()); - return token_state_.int64_val; - } - double GetDecimalVal() const { - CHECK(GetKind() == TokKind::kDecimal); - return token_state_.decimal_val; - } - PrimitiveType GetPrimitiveTypeVal() const { - CHECK(GetKind() == TokKind::kPrimitiveType); - return token_state_.primitive_type_val; - } - - typedef const char* LocTy; - - // Returns the location of the current token. - LocTy GetLoc() const { return token_state_.token_start; } - - // Returns the line and column of a location in the buffer. - std::pair GetLineAndColumn(LocTy location) const; - - // Returns the whole line given the location. - absl::string_view GetLine(LocTy loc) const; - - // Looks ahead one token and returns it. Lexer state is unchanged. - TokKind LookAhead(); - - // Lexes a string delimited by matching curly braces. Curlies contained - // inside double quotes don't count. - // - // Requires that you've already lexed the open curly brace. - // - // The returned string value includes the outer curlies. - // - // Returns TokKind::kString on success. - TokKind LexJsonDict(); - - private: - // Returns the current character. If it's neither the end of input buffer nor - // an invalid character, moves the pointer forward. - int GetNextChar(); - - // Returns the current character. - int PeekCurrentChar() const; - - // Creates string_view with the given begin and end. Exits if the begin > end, - // or it's out of the range of the current buffer. - absl::string_view StringViewFromPointers(const char* begin, - const char* end) const; - - // Returns true if the given ptr is dereferenceable within the range of the - // current buffer. - bool CanDereference(const char* ptr) const; - - TokKind LexToken(); - - TokKind LexIdentifier(); - TokKind LexPercent(); - TokKind LexShape(); - TokKind LexConstant(); - TokKind LexNumberOrPattern(); - TokKind LexString(); - - std::optional LexNanPayload(absl::string_view& consumable); - - absl::string_view buf_; - const char* current_ptr_; - - // Information about the current token. - struct TokenState { - const char* token_start = nullptr; - TokKind current_kind; - std::string str_val; - int64_t int64_val; - double decimal_val; - PrimitiveType primitive_type_val; - }; - TokenState token_state_; - - struct LineNoCacheTy { - const char* last_query; - unsigned line_no_of_query; - }; - // This caches the line number of the previous query. - mutable LineNoCacheTy line_no_cache_{nullptr, 0}; -}; - -// Does this string start with "{", end with "}", and contain valid-ish JSON -// in-between? If so, hlo_parser can parse e.g. backend_config={blah: "blah"} -// instead of the much uglier backend_config="{blah: \"blah\"}". -// -// (Technically we're not checking for fully-valid JSON, just something we can -// find the end of reasonably.) -bool LexesAsJsonDict(absl::string_view str); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/parser/hlo_lexer.h" #endif // XLA_SERVICE_HLO_LEXER_H_ diff --git a/xla/service/hlo_parser.h b/xla/service/hlo_parser.h index 17012f779c01a..6a9e8d8be6039 100644 --- a/xla/service/hlo_parser.h +++ b/xla/service/hlo_parser.h @@ -16,100 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_PARSER_H_ #define XLA_SERVICE_HLO_PARSER_H_ -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -class HloParserOptions { - public: - // When a shape layout is not set (e.g. in the entry computation layout or - // instruction layout), set the layout to be the default (e.g. {3,2,1,0}). - HloParserOptions& set_fill_missing_layouts(bool value) { - fill_missing_layouts_ = value; - return *this; - } - - bool fill_missing_layouts() const { return fill_missing_layouts_; } - - private: - bool fill_missing_layouts_ = true; -}; - -// Given a string in the HloModule::ToString() format, parses the string and -// creates a HloModule with the given config. -// Note: Tests derived from HloTestBase should use -// ParseAndReturnVerifiedModule() instead! -absl::StatusOr> ParseAndReturnUnverifiedModule( - absl::string_view str, const HloModuleConfig& config = HloModuleConfig(), - const HloParserOptions& options = HloParserOptions()); - -// Parses sharding from str. str is supposed to contain the body of the -// sharding, i.e. just the rhs of the "sharding={...}" attribute string, e.g., -// "{replicated}". -absl::StatusOr ParseSharding(absl::string_view str); - -// Parses frontend attributes from str. str is supposed to contain the body of -// the frontend attributes , i.e. just the rhs of the -// "frontend_attributes={...}" attribute string, e.g., -// "{attr_a=a,attr_b=b}". -absl::StatusOr ParseFrontendAttributes( - absl::string_view str); - -// Parses statistics viz from str. str is supposed to contain the body of the -// statistics visualization, i.e. just the rhs of the "statistics={...}" -// attribute string, e.g., "{visualizing_index=1,nan_percent=50}". -absl::StatusOr ParseStatisticsViz(absl::string_view str); - -// Parses parameter replication from str. str is supposed to contain the body of -// the parameter replication, i.e. just the rhs of the -// "parameter_replication={...}" attribute string, e.g., "{true, false}". -absl::StatusOr> ParseParameterReplication( - absl::string_view str); - -// Parses the result of window_util::ToString(const Window&). -absl::StatusOr ParseWindow(absl::string_view str); - -// Parses the result of ConvolutionDimensionNumbersToString(), e.g. -// "b0f_0io->b0f". -absl::StatusOr ParseConvolutionDimensionNumbers( - absl::string_view str); - -// Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". -absl::StatusOr ParsePaddingConfig(absl::string_view str); - -// Parses and returns a Shape::ToString-format string. -absl::StatusOr ParseShape(absl::string_view str); - -// Parses and returns a Layout::ToString-format string. -absl::StatusOr ParseLayout(absl::string_view str); - -// Parses and returns a std::vector from str. str is supposed to -// contain a list of the replica groups, i.e. just the rhs of the -// "replica_groups={...}" attribute string, e.g., "{{0,1}, {2,3}}". -absl::StatusOr> ParseReplicaGroupsOnly( - absl::string_view str); - -class HloParser { - public: - // Runs the parser and constructs the resulting HLO in the given (empty) - // HloModule. Returns the error status in case an error occurred. - virtual absl::Status Run(HloModule* module) = 0; - virtual ~HloParser() {} - - private: - static std::unique_ptr CreateHloParserForTests( - absl::string_view str); - friend class VerifiedHloModule; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/parser/hlo_parser.h" #endif // XLA_SERVICE_HLO_PARSER_H_