From 3ab0936b6c895f8fdf532bb7266aacc80dcc30b9 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Fri, 13 Dec 2024 23:16:36 -0800 Subject: [PATCH 1/9] Switch to using a variant vs overloaded functions, this simplifies the handlers, and gives us much better error messages for mismatched types. --- src/config.cpp | 257 +++++++++++++++++++++---------------------------- src/json.cpp | 29 +++--- src/json.h | 24 +++-- 3 files changed, 143 insertions(+), 167 deletions(-) diff --git a/src/config.cpp b/src/config.cpp index 0c6de4d69..459da0d68 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -22,8 +22,8 @@ ONNXTensorElementDataType TranslateTensorType(std::string_view value) { struct ProviderOptions_Element : JSON::Element { explicit ProviderOptions_Element(Config::ProviderOptions& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { - v_.options.emplace_back(name, value); + void OnValue(std::string_view name, JSON::Value value) override { + v_.options.emplace_back(name, JSON::Get(value)); } private: @@ -65,45 +65,35 @@ struct ProviderOptionsArray_Element : JSON::Element { struct SessionOptions_Element : JSON::Element { explicit SessionOptions_Element(Config::SessionOptions& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "log_id") - v_.log_id = value; + v_.log_id = JSON::Get(value); else if (name == "enable_profiling") - v_.enable_profiling = value; + v_.enable_profiling = JSON::Get(value); else if (name == "ep_context_embed_mode") - v_.ep_context_embed_mode = value; + v_.ep_context_embed_mode = JSON::Get(value); else if (name == "ep_context_file_path") - v_.ep_context_file_path = value; - else - throw JSON::unknown_value_error{}; - } - - void OnNumber(std::string_view name, double value) override { - if (name == "intra_op_num_threads") - v_.intra_op_num_threads = static_cast(value); + v_.ep_context_file_path = JSON::Get(value); + else if (name == "intra_op_num_threads") + v_.intra_op_num_threads = static_cast(JSON::Get(value)); else if (name == "inter_op_num_threads") - v_.inter_op_num_threads = static_cast(value); + v_.inter_op_num_threads = static_cast(JSON::Get(value)); else if (name == "log_severity_level") - v_.log_severity_level = static_cast(value); - else - throw JSON::unknown_value_error{}; - } - - void OnBool(std::string_view name, bool value) override { - if (name == "enable_cpu_mem_arena") - v_.enable_cpu_mem_arena = value; + v_.log_severity_level = static_cast(JSON::Get(value)); + else if (name == "enable_cpu_mem_arena") + v_.enable_cpu_mem_arena = JSON::Get(value); else if (name == "enable_mem_pattern") - v_.enable_mem_pattern = value; + v_.enable_mem_pattern = JSON::Get(value); else if (name == "disable_cpu_ep_fallback") - v_.disable_cpu_ep_fallback = value; + v_.disable_cpu_ep_fallback = JSON::Get(value); else if (name == "disable_quant_qdq") - v_.disable_quant_qdq = value; + v_.disable_quant_qdq = JSON::Get(value); else if (name == "enable_quant_qdq_cleanup") - v_.enable_quant_qdq_cleanup = value; + v_.enable_quant_qdq_cleanup = JSON::Get(value); else if (name == "ep_context_enable") - v_.ep_context_enable = value; + v_.ep_context_enable = JSON::Get(value); else if (name == "use_env_allocators") - v_.use_env_allocators = value; + v_.use_env_allocators = JSON::Get(value); else throw JSON::unknown_value_error{}; } @@ -122,9 +112,9 @@ struct SessionOptions_Element : JSON::Element { struct EncoderDecoderInit_Element : JSON::Element { explicit EncoderDecoderInit_Element(Config::Model::EncoderDecoderInit& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = value; + v_.filename = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -136,29 +126,29 @@ struct EncoderDecoderInit_Element : JSON::Element { struct Inputs_Element : JSON::Element { explicit Inputs_Element(Config::Model::Decoder::Inputs& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "input_ids") { - v_.input_ids = value; + v_.input_ids = JSON::Get(value); } else if (name == "inputs_embeds") { - v_.embeddings = value; + v_.embeddings = JSON::Get(value); } else if (name == "position_ids") { - v_.position_ids = value; + v_.position_ids = JSON::Get(value); } else if (name == "attention_mask") { - v_.attention_mask = value; + v_.attention_mask = JSON::Get(value); } else if (name == "past_key_names") { - v_.past_key_names = value; + v_.past_key_names = JSON::Get(value); } else if (name == "past_value_names") { - v_.past_value_names = value; + v_.past_value_names = JSON::Get(value); } else if (name == "past_names") { - v_.past_names = value; + v_.past_names = JSON::Get(value); } else if (name == "cross_past_key_names") { - v_.cross_past_key_names = value; + v_.cross_past_key_names = JSON::Get(value); } else if (name == "cross_past_value_names") { - v_.cross_past_value_names = value; + v_.cross_past_value_names = JSON::Get(value); } else if (name == "current_sequence_length") { - v_.current_sequence_length = value; + v_.current_sequence_length = JSON::Get(value); } else if (name == "past_sequence_length") { - v_.past_sequence_length = value; + v_.past_sequence_length = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -170,19 +160,19 @@ struct Inputs_Element : JSON::Element { struct Outputs_Element : JSON::Element { explicit Outputs_Element(Config::Model::Decoder::Outputs& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "logits") { - v_.logits = value; + v_.logits = JSON::Get(value); } else if (name == "present_key_names") { - v_.present_key_names = value; + v_.present_key_names = JSON::Get(value); } else if (name == "present_value_names") { - v_.present_value_names = value; + v_.present_value_names = JSON::Get(value); } else if (name == "present_names") { - v_.present_names = value; + v_.present_names = JSON::Get(value); } else if (name == "cross_present_key_names") { - v_.cross_present_key_names = value; + v_.cross_present_key_names = JSON::Get(value); } else if (name == "cross_present_value_names") { - v_.cross_present_value_names = value; + v_.cross_present_value_names = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -194,8 +184,8 @@ struct Outputs_Element : JSON::Element { struct StringArray_Element : JSON::Element { explicit StringArray_Element(std::vector& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { - v_.push_back(std::string(value)); + void OnValue(std::string_view name, JSON::Value value) override { + v_.push_back(std::string{JSON::Get(value)}); } private: @@ -205,8 +195,8 @@ struct StringArray_Element : JSON::Element { struct StringStringMap_Element : JSON::Element { explicit StringStringMap_Element(std::unordered_map& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { - v_[std::string(name)] = std::string(value); + void OnValue(std::string_view name, JSON::Value value) override { + v_[std::string(name)] = std::string(JSON::Get(value)); } private: @@ -216,18 +206,13 @@ struct StringStringMap_Element : JSON::Element { struct PipelineModel_Element : JSON::Element { explicit PipelineModel_Element(Config::Model::Decoder::PipelineModel& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = value; - } else - throw JSON::unknown_value_error{}; - } - - void OnBool(std::string_view name, bool value) override { - if (name == "run_on_prompt") { - v_.run_on_prompt = value; + v_.filename = JSON::Get(value); + } else if (name == "run_on_prompt") { + v_.run_on_prompt = JSON::Get(value); } else if (name == "run_on_token_gen") { - v_.run_on_token_gen = value; + v_.run_on_token_gen = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -289,24 +274,19 @@ struct Pipeline_Element : JSON::Element { struct Decoder_Element : JSON::Element { explicit Decoder_Element(Config::Model::Decoder& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = value; - } else - throw JSON::unknown_value_error{}; - } - - void OnNumber(std::string_view name, double value) override { - if (name == "hidden_size") { - v_.hidden_size = static_cast(value); + v_.filename = JSON::Get(value); + } else if (name == "hidden_size") { + v_.hidden_size = static_cast(JSON::Get(value)); } else if (name == "num_attention_heads") { - v_.num_attention_heads = static_cast(value); + v_.num_attention_heads = static_cast(JSON::Get(value)); } else if (name == "num_key_value_heads") { - v_.num_key_value_heads = static_cast(value); + v_.num_key_value_heads = static_cast(JSON::Get(value)); } else if (name == "num_hidden_layers") { - v_.num_hidden_layers = static_cast(value); + v_.num_hidden_layers = static_cast(JSON::Get(value)); } else if (name == "head_size") { - v_.head_size = static_cast(value); + v_.head_size = static_cast(JSON::Get(value)); } else throw JSON::unknown_value_error{}; } @@ -341,11 +321,11 @@ struct Decoder_Element : JSON::Element { struct VisionInputs_Element : JSON::Element { explicit VisionInputs_Element(Config::Model::Vision::Inputs& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "pixel_values") { - v_.pixel_values = value; + v_.pixel_values = JSON::Get(value); } else if (name == "image_sizes") { - v_.image_sizes = value; + v_.image_sizes = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -357,9 +337,9 @@ struct VisionInputs_Element : JSON::Element { struct VisionOutputs_Element : JSON::Element { explicit VisionOutputs_Element(Config::Model::Vision::Outputs& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "image_features") { - v_.image_features = value; + v_.image_features = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -371,9 +351,9 @@ struct VisionOutputs_Element : JSON::Element { struct Vision_Element : JSON::Element { explicit Vision_Element(Config::Model::Vision& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = value; + v_.filename = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -396,8 +376,8 @@ struct Vision_Element : JSON::Element { struct Eos_Array_Element : JSON::Element { explicit Eos_Array_Element(Config::Model& v) : v_{v} {} - void OnNumber(std::string_view name, double value) override { - v_.eos_token_ids.push_back(static_cast(value)); + void OnValue(std::string_view name, JSON::Value value) override { + v_.eos_token_ids.push_back(static_cast(JSON::Get(value))); } void OnComplete(bool empty) override { @@ -419,11 +399,11 @@ struct Eos_Array_Element : JSON::Element { struct EmbeddingInputs_Element : JSON::Element { explicit EmbeddingInputs_Element(Config::Model::Embedding::Inputs& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "input_ids") { - v_.input_ids = value; + v_.input_ids = JSON::Get(value); } else if (name == "image_features") { - v_.image_features = value; + v_.image_features = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -435,9 +415,9 @@ struct EmbeddingInputs_Element : JSON::Element { struct EmbeddingOutputs_Element : JSON::Element { explicit EmbeddingOutputs_Element(Config::Model::Embedding::Outputs& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "inputs_embeds") { - v_.embeddings = value; + v_.embeddings = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -449,9 +429,9 @@ struct EmbeddingOutputs_Element : JSON::Element { struct Embedding_Element : JSON::Element { explicit Embedding_Element(Config::Model::Embedding& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = value; + v_.filename = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -474,20 +454,20 @@ struct Embedding_Element : JSON::Element { struct PromptTemplates_Element : JSON::Element { explicit PromptTemplates_Element(std::optional& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { // if one of templates is given in json, then any non-specified template will be default "{Content}" if (name == "assistant") { EnsureAvailable(); - v_->assistant = value; + v_->assistant = JSON::Get(value); } else if (name == "prompt") { EnsureAvailable(); - v_->prompt = value; + v_->prompt = JSON::Get(value); } else if (name == "system") { EnsureAvailable(); - v_->system = value; + v_->system = JSON::Get(value); } else if (name == "user") { EnsureAvailable(); - v_->user = value; + v_->user = JSON::Get(value); } else { throw JSON::unknown_value_error{}; } @@ -506,28 +486,23 @@ struct PromptTemplates_Element : JSON::Element { struct Model_Element : JSON::Element { explicit Model_Element(Config::Model& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "type") { - v_.type = value; - } else - throw JSON::unknown_value_error{}; - } - - void OnNumber(std::string_view name, double value) override { - if (name == "vocab_size") { - v_.vocab_size = static_cast(value); + v_.type = JSON::Get(value); + } else if (name == "vocab_size") { + v_.vocab_size = static_cast(JSON::Get(value)); } else if (name == "context_length") { - v_.context_length = static_cast(value); + v_.context_length = static_cast(JSON::Get(value)); } else if (name == "pad_token_id") { - v_.pad_token_id = static_cast(value); + v_.pad_token_id = static_cast(JSON::Get(value)); } else if (name == "eos_token_id") { - v_.eos_token_id = static_cast(value); + v_.eos_token_id = static_cast(JSON::Get(value)); } else if (name == "bos_token_id") { - v_.bos_token_id = static_cast(value); + v_.bos_token_id = static_cast(JSON::Get(value)); } else if (name == "decoder_start_token_id") { - v_.decoder_start_token_id = static_cast(value); + v_.decoder_start_token_id = static_cast(JSON::Get(value)); } else if (name == "sep_token_id") { - v_.sep_token_id = static_cast(value); + v_.sep_token_id = static_cast(JSON::Get(value)); } else throw JSON::unknown_value_error{}; } @@ -570,50 +545,41 @@ struct Model_Element : JSON::Element { struct Search_Element : JSON::Element { explicit Search_Element(Config::Search& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { - throw JSON::unknown_value_error{}; - } - - void OnNumber(std::string_view name, double value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "min_length") { - v_.min_length = static_cast(value); + v_.min_length = static_cast(JSON::Get(value)); } else if (name == "max_length") { - v_.max_length = static_cast(value); + v_.max_length = static_cast(JSON::Get(value)); } else if (name == "batch_size") { - v_.batch_size = static_cast(value); + v_.batch_size = static_cast(JSON::Get(value)); } else if (name == "num_beams") { - v_.num_beams = static_cast(value); + v_.num_beams = static_cast(JSON::Get(value)); } else if (name == "num_return_sequences") { - v_.num_return_sequences = static_cast(value); + v_.num_return_sequences = static_cast(JSON::Get(value)); } else if (name == "top_k") { - v_.top_k = static_cast(value); + v_.top_k = static_cast(JSON::Get(value)); } else if (name == "top_p") { - v_.top_p = static_cast(value); + v_.top_p = static_cast(JSON::Get(value)); } else if (name == "temperature") { - v_.temperature = static_cast(value); + v_.temperature = static_cast(JSON::Get(value)); } else if (name == "repetition_penalty") { - v_.repetition_penalty = static_cast(value); + v_.repetition_penalty = static_cast(JSON::Get(value)); } else if (name == "length_penalty") { - v_.length_penalty = static_cast(value); + v_.length_penalty = static_cast(JSON::Get(value)); } else if (name == "no_repeat_ngram_size") { - v_.no_repeat_ngram_size = static_cast(value); + v_.no_repeat_ngram_size = static_cast(JSON::Get(value)); } else if (name == "diversity_penalty") { - v_.diversity_penalty = static_cast(value); + v_.diversity_penalty = static_cast(JSON::Get(value)); } else if (name == "length_penalty") { - v_.length_penalty = static_cast(value); + v_.length_penalty = static_cast(JSON::Get(value)); } else if (name == "random_seed") { - v_.random_seed = static_cast(value); - } else - throw JSON::unknown_value_error{}; - } - - void OnBool(std::string_view name, bool value) override { - if (name == "do_sample") { - v_.do_sample = value; + v_.random_seed = static_cast(JSON::Get(value)); + } else if (name == "do_sample") { + v_.do_sample = JSON::Get(value); } else if (name == "past_present_share_buffer") { - v_.past_present_share_buffer = value; + v_.past_present_share_buffer = JSON::Get(value); } else if (name == "early_stopping") { - v_.early_stopping = value; + v_.early_stopping = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -623,11 +589,11 @@ struct Search_Element : JSON::Element { }; void SetSearchNumber(Config::Search& search, std::string_view name, double value) { - Search_Element(search).OnNumber(name, value); + Search_Element(search).OnValue(name, value); } void SetSearchBool(Config::Search& search, std::string_view name, bool value) { - Search_Element(search).OnBool(name, value); + Search_Element(search).OnValue(name, value); } void ClearProviders(Config& config) { @@ -663,10 +629,7 @@ bool IsCudaGraphEnabled(Config::SessionOptions& session_options) { struct Root_Element : JSON::Element { explicit Root_Element(Config& config) : config_{config} {} - void OnString(std::string_view name, std::string_view value) override { - } - - void OnNumber(std::string_view name, double value) override { + void OnValue(std::string_view name, JSON::Value value) override { } Element& OnObject(std::string_view name) override { diff --git a/src/json.cpp b/src/json.cpp index 4d4d0aa91..bd2c2aefb 100644 --- a/src/json.cpp +++ b/src/json.cpp @@ -6,14 +6,8 @@ #include namespace JSON { - -Element& Element::OnArray(std::string_view /*name*/) { - throw unknown_value_error{}; -} - -Element& Element::OnObject(std::string_view /*name*/) { - throw unknown_value_error{}; -} +static constexpr const char* value_names[] = {"string", "number", "bool", "null"}; +static_assert(std::size(value_names) == std::variant_size_v); struct JSON { JSON(Element& element, std::string_view document); @@ -148,34 +142,41 @@ void JSON::Parse_Value(Element& element, std::string_view name) { Parse_Array(element_array); } break; case '"': { - element.OnString(name, Parse_String()); + element.OnValue(name, Parse_String()); } break; case 't': if (Skip("rue")) { - element.OnBool(name, true); + element.OnValue(name, true); } break; case 'f': if (Skip("alse")) { - element.OnBool(name, false); + element.OnValue(name, false); } break; case 'n': if (Skip("ull")) { - element.OnNull(name); + element.OnValue(name, nullptr); } break; default: if (c >= '0' && c <= '9' || c == '-') { --current_; - element.OnNumber(name, Parse_Number()); + element.OnValue(name, Parse_Number()); } else throw unknown_value_error{}; break; } } catch (const unknown_value_error&) { - throw std::runtime_error("Unknown value: " + std::string(name)); + throw std::runtime_error(" Unknown value \"" + std::string(name) + "\""); + } catch (const type_mismatch& e) { + throw std::runtime_error(std::string(name) + " - Expected a " + std::string(value_names[e.expected]) + " but saw a " + std::string(value_names[e.seen])); + } catch (const std::runtime_error& e) { + if (!name.empty()) + throw std::runtime_error(std::string(name) + ":" + e.what()); + throw; } + Parse_Whitespace(); } diff --git a/src/json.h b/src/json.h index 58bc16319..502e56bb9 100644 --- a/src/json.h +++ b/src/json.h @@ -9,17 +9,29 @@ // namespace JSON { struct unknown_value_error : std::exception {}; // Throw this from any Element callback to throw a std::runtime error reporting the unknown value name +struct type_mismatch { // When a file has one type, but we're expecting another type. "seen" & "expected" are indices into the Value std::variant below + size_t seen, expected; +}; + +using Value = std::variant; + +// To see descriptive errors when types don't match, use this instead of std::get +template +T& Get(Value& var) { + try { + return std::get(var); + } catch (const std::bad_variant_access&) { + throw type_mismatch{var.index(), Value{T{}}.index()}; + } +} struct Element { virtual void OnComplete(bool empty) {} // Called when parsing for this element is finished (empty is true when it's an empty element) - virtual void OnString(std::string_view name, std::string_view value) { throw unknown_value_error{}; } - virtual void OnNumber(std::string_view name, double value) { throw unknown_value_error{}; } - virtual void OnBool(std::string_view name, bool value) { throw unknown_value_error{}; } - virtual void OnNull(std::string_view name) { throw unknown_value_error{}; } + virtual void OnValue(std::string_view name, Value value) { throw unknown_value_error{}; } - virtual Element& OnArray(std::string_view name); - virtual Element& OnObject(std::string_view name); + virtual Element& OnArray(std::string_view name) { throw unknown_value_error{}; } + virtual Element& OnObject(std::string_view name) { throw unknown_value_error{}; } }; void Parse(Element& element, std::string_view document); From 8859070a893f81fa88a0cb2b88543b0cb004073a Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Mon, 16 Dec 2024 13:29:25 -0800 Subject: [PATCH 2/9] Wrap the variant in a richer type to simplify calling code. --- src/config.cpp | 164 ++++++++++++++++++++++++------------------------- src/json.cpp | 2 +- src/json.h | 26 +++++--- 3 files changed, 100 insertions(+), 92 deletions(-) diff --git a/src/config.cpp b/src/config.cpp index 459da0d68..8fc150f48 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -23,7 +23,7 @@ struct ProviderOptions_Element : JSON::Element { explicit ProviderOptions_Element(Config::ProviderOptions& v) : v_{v} {} void OnValue(std::string_view name, JSON::Value value) override { - v_.options.emplace_back(name, JSON::Get(value)); + v_.options.emplace_back(name, value); } private: @@ -67,33 +67,33 @@ struct SessionOptions_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "log_id") - v_.log_id = JSON::Get(value); + v_.log_id = value; else if (name == "enable_profiling") - v_.enable_profiling = JSON::Get(value); + v_.enable_profiling = value; else if (name == "ep_context_embed_mode") - v_.ep_context_embed_mode = JSON::Get(value); + v_.ep_context_embed_mode = value; else if (name == "ep_context_file_path") - v_.ep_context_file_path = JSON::Get(value); + v_.ep_context_file_path = value; else if (name == "intra_op_num_threads") - v_.intra_op_num_threads = static_cast(JSON::Get(value)); + v_.intra_op_num_threads = value; else if (name == "inter_op_num_threads") - v_.inter_op_num_threads = static_cast(JSON::Get(value)); + v_.inter_op_num_threads = value; else if (name == "log_severity_level") - v_.log_severity_level = static_cast(JSON::Get(value)); + v_.log_severity_level = value; else if (name == "enable_cpu_mem_arena") - v_.enable_cpu_mem_arena = JSON::Get(value); + v_.enable_cpu_mem_arena = value; else if (name == "enable_mem_pattern") - v_.enable_mem_pattern = JSON::Get(value); + v_.enable_mem_pattern = value; else if (name == "disable_cpu_ep_fallback") - v_.disable_cpu_ep_fallback = JSON::Get(value); + v_.disable_cpu_ep_fallback = value; else if (name == "disable_quant_qdq") - v_.disable_quant_qdq = JSON::Get(value); + v_.disable_quant_qdq = value; else if (name == "enable_quant_qdq_cleanup") - v_.enable_quant_qdq_cleanup = JSON::Get(value); + v_.enable_quant_qdq_cleanup = value; else if (name == "ep_context_enable") - v_.ep_context_enable = JSON::Get(value); + v_.ep_context_enable = value; else if (name == "use_env_allocators") - v_.use_env_allocators = JSON::Get(value); + v_.use_env_allocators = value; else throw JSON::unknown_value_error{}; } @@ -114,7 +114,7 @@ struct EncoderDecoderInit_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = JSON::Get(value); + v_.filename = value; } else throw JSON::unknown_value_error{}; } @@ -128,27 +128,27 @@ struct Inputs_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "input_ids") { - v_.input_ids = JSON::Get(value); + v_.input_ids = value; } else if (name == "inputs_embeds") { - v_.embeddings = JSON::Get(value); + v_.embeddings = value; } else if (name == "position_ids") { - v_.position_ids = JSON::Get(value); + v_.position_ids = value; } else if (name == "attention_mask") { - v_.attention_mask = JSON::Get(value); + v_.attention_mask = value; } else if (name == "past_key_names") { - v_.past_key_names = JSON::Get(value); + v_.past_key_names = value; } else if (name == "past_value_names") { - v_.past_value_names = JSON::Get(value); + v_.past_value_names = value; } else if (name == "past_names") { - v_.past_names = JSON::Get(value); + v_.past_names = value; } else if (name == "cross_past_key_names") { - v_.cross_past_key_names = JSON::Get(value); + v_.cross_past_key_names = value; } else if (name == "cross_past_value_names") { - v_.cross_past_value_names = JSON::Get(value); + v_.cross_past_value_names = value; } else if (name == "current_sequence_length") { - v_.current_sequence_length = JSON::Get(value); + v_.current_sequence_length = value; } else if (name == "past_sequence_length") { - v_.past_sequence_length = JSON::Get(value); + v_.past_sequence_length = value; } else throw JSON::unknown_value_error{}; } @@ -162,17 +162,17 @@ struct Outputs_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "logits") { - v_.logits = JSON::Get(value); + v_.logits = value; } else if (name == "present_key_names") { - v_.present_key_names = JSON::Get(value); + v_.present_key_names = value; } else if (name == "present_value_names") { - v_.present_value_names = JSON::Get(value); + v_.present_value_names = value; } else if (name == "present_names") { - v_.present_names = JSON::Get(value); + v_.present_names = value; } else if (name == "cross_present_key_names") { - v_.cross_present_key_names = JSON::Get(value); + v_.cross_present_key_names = value; } else if (name == "cross_present_value_names") { - v_.cross_present_value_names = JSON::Get(value); + v_.cross_present_value_names = value; } else throw JSON::unknown_value_error{}; } @@ -185,7 +185,7 @@ struct StringArray_Element : JSON::Element { explicit StringArray_Element(std::vector& v) : v_{v} {} void OnValue(std::string_view name, JSON::Value value) override { - v_.push_back(std::string{JSON::Get(value)}); + v_.push_back(value); } private: @@ -196,7 +196,7 @@ struct StringStringMap_Element : JSON::Element { explicit StringStringMap_Element(std::unordered_map& v) : v_{v} {} void OnValue(std::string_view name, JSON::Value value) override { - v_[std::string(name)] = std::string(JSON::Get(value)); + v_[std::string(name)] = value; } private: @@ -208,11 +208,11 @@ struct PipelineModel_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = JSON::Get(value); + v_.filename = value; } else if (name == "run_on_prompt") { - v_.run_on_prompt = JSON::Get(value); + v_.run_on_prompt = value; } else if (name == "run_on_token_gen") { - v_.run_on_token_gen = JSON::Get(value); + v_.run_on_token_gen = value; } else throw JSON::unknown_value_error{}; } @@ -276,17 +276,17 @@ struct Decoder_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = JSON::Get(value); + v_.filename = value; } else if (name == "hidden_size") { - v_.hidden_size = static_cast(JSON::Get(value)); + v_.hidden_size = value; } else if (name == "num_attention_heads") { - v_.num_attention_heads = static_cast(JSON::Get(value)); + v_.num_attention_heads = value; } else if (name == "num_key_value_heads") { - v_.num_key_value_heads = static_cast(JSON::Get(value)); + v_.num_key_value_heads = value; } else if (name == "num_hidden_layers") { - v_.num_hidden_layers = static_cast(JSON::Get(value)); + v_.num_hidden_layers = value; } else if (name == "head_size") { - v_.head_size = static_cast(JSON::Get(value)); + v_.head_size = value; } else throw JSON::unknown_value_error{}; } @@ -323,9 +323,9 @@ struct VisionInputs_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "pixel_values") { - v_.pixel_values = JSON::Get(value); + v_.pixel_values = value; } else if (name == "image_sizes") { - v_.image_sizes = JSON::Get(value); + v_.image_sizes = value; } else throw JSON::unknown_value_error{}; } @@ -339,7 +339,7 @@ struct VisionOutputs_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "image_features") { - v_.image_features = JSON::Get(value); + v_.image_features = value; } else throw JSON::unknown_value_error{}; } @@ -353,7 +353,7 @@ struct Vision_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = JSON::Get(value); + v_.filename = value; } else throw JSON::unknown_value_error{}; } @@ -377,7 +377,7 @@ struct Eos_Array_Element : JSON::Element { explicit Eos_Array_Element(Config::Model& v) : v_{v} {} void OnValue(std::string_view name, JSON::Value value) override { - v_.eos_token_ids.push_back(static_cast(JSON::Get(value))); + v_.eos_token_ids.push_back(value); } void OnComplete(bool empty) override { @@ -401,9 +401,9 @@ struct EmbeddingInputs_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "input_ids") { - v_.input_ids = JSON::Get(value); + v_.input_ids = value; } else if (name == "image_features") { - v_.image_features = JSON::Get(value); + v_.image_features = value; } else throw JSON::unknown_value_error{}; } @@ -417,7 +417,7 @@ struct EmbeddingOutputs_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "inputs_embeds") { - v_.embeddings = JSON::Get(value); + v_.embeddings = value; } else throw JSON::unknown_value_error{}; } @@ -431,7 +431,7 @@ struct Embedding_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = JSON::Get(value); + v_.filename = value; } else throw JSON::unknown_value_error{}; } @@ -458,16 +458,16 @@ struct PromptTemplates_Element : JSON::Element { // if one of templates is given in json, then any non-specified template will be default "{Content}" if (name == "assistant") { EnsureAvailable(); - v_->assistant = JSON::Get(value); + v_->assistant = value; } else if (name == "prompt") { EnsureAvailable(); - v_->prompt = JSON::Get(value); + v_->prompt = value; } else if (name == "system") { EnsureAvailable(); - v_->system = JSON::Get(value); + v_->system = value; } else if (name == "user") { EnsureAvailable(); - v_->user = JSON::Get(value); + v_->user = value; } else { throw JSON::unknown_value_error{}; } @@ -488,21 +488,21 @@ struct Model_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "type") { - v_.type = JSON::Get(value); + v_.type = value; } else if (name == "vocab_size") { - v_.vocab_size = static_cast(JSON::Get(value)); + v_.vocab_size = value; } else if (name == "context_length") { - v_.context_length = static_cast(JSON::Get(value)); + v_.context_length = value; } else if (name == "pad_token_id") { - v_.pad_token_id = static_cast(JSON::Get(value)); + v_.pad_token_id = value; } else if (name == "eos_token_id") { - v_.eos_token_id = static_cast(JSON::Get(value)); + v_.eos_token_id = value; } else if (name == "bos_token_id") { - v_.bos_token_id = static_cast(JSON::Get(value)); + v_.bos_token_id = value; } else if (name == "decoder_start_token_id") { - v_.decoder_start_token_id = static_cast(JSON::Get(value)); + v_.decoder_start_token_id = value; } else if (name == "sep_token_id") { - v_.sep_token_id = static_cast(JSON::Get(value)); + v_.sep_token_id = value; } else throw JSON::unknown_value_error{}; } @@ -547,39 +547,39 @@ struct Search_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "min_length") { - v_.min_length = static_cast(JSON::Get(value)); + v_.min_length = value; } else if (name == "max_length") { - v_.max_length = static_cast(JSON::Get(value)); + v_.max_length = value; } else if (name == "batch_size") { - v_.batch_size = static_cast(JSON::Get(value)); + v_.batch_size = value; } else if (name == "num_beams") { - v_.num_beams = static_cast(JSON::Get(value)); + v_.num_beams = value; } else if (name == "num_return_sequences") { - v_.num_return_sequences = static_cast(JSON::Get(value)); + v_.num_return_sequences = value; } else if (name == "top_k") { - v_.top_k = static_cast(JSON::Get(value)); + v_.top_k = value; } else if (name == "top_p") { - v_.top_p = static_cast(JSON::Get(value)); + v_.top_p = value; } else if (name == "temperature") { - v_.temperature = static_cast(JSON::Get(value)); + v_.temperature = value; } else if (name == "repetition_penalty") { - v_.repetition_penalty = static_cast(JSON::Get(value)); + v_.repetition_penalty = value; } else if (name == "length_penalty") { - v_.length_penalty = static_cast(JSON::Get(value)); + v_.length_penalty = value; } else if (name == "no_repeat_ngram_size") { - v_.no_repeat_ngram_size = static_cast(JSON::Get(value)); + v_.no_repeat_ngram_size = value; } else if (name == "diversity_penalty") { - v_.diversity_penalty = static_cast(JSON::Get(value)); + v_.diversity_penalty = value; } else if (name == "length_penalty") { - v_.length_penalty = static_cast(JSON::Get(value)); + v_.length_penalty = value; } else if (name == "random_seed") { - v_.random_seed = static_cast(JSON::Get(value)); + v_.random_seed = value; } else if (name == "do_sample") { - v_.do_sample = JSON::Get(value); + v_.do_sample = value; } else if (name == "past_present_share_buffer") { - v_.past_present_share_buffer = JSON::Get(value); + v_.past_present_share_buffer = value; } else if (name == "early_stopping") { - v_.early_stopping = JSON::Get(value); + v_.early_stopping = value; } else throw JSON::unknown_value_error{}; } diff --git a/src/json.cpp b/src/json.cpp index bd2c2aefb..48ca85612 100644 --- a/src/json.cpp +++ b/src/json.cpp @@ -7,7 +7,7 @@ namespace JSON { static constexpr const char* value_names[] = {"string", "number", "bool", "null"}; -static_assert(std::size(value_names) == std::variant_size_v); +static_assert(std::size(value_names) == std::variant_size_v); struct JSON { JSON(Element& element, std::string_view document); diff --git a/src/json.h b/src/json.h index 502e56bb9..dd84b01f5 100644 --- a/src/json.h +++ b/src/json.h @@ -13,17 +13,25 @@ struct type_mismatch { // When a file has one type, bu size_t seen, expected; }; -using Value = std::variant; +struct Value : std::variant { + using std::variant::variant; -// To see descriptive errors when types don't match, use this instead of std::get -template -T& Get(Value& var) { - try { - return std::get(var); - } catch (const std::bad_variant_access&) { - throw type_mismatch{var.index(), Value{T{}}.index()}; + // This will generate a descriptive error when the types don't match + template + T Get() { + try { + return std::get(*this); + } catch (const std::bad_variant_access&) { + throw type_mismatch{index(), Value{T{}}.index()}; + } } -} + + operator std::string() { return std::string{Get()}; } + operator double() { return Get(); } + operator float() { return static_cast(Get()); } + operator int() { return static_cast(Get()); } + operator bool() { return Get(); } +}; struct Element { virtual void OnComplete(bool empty) {} // Called when parsing for this element is finished (empty is true when it's an empty element) From 5de9f540113a7e4f4d10d194419366859c1d8ce2 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Mon, 16 Dec 2024 14:34:38 -0800 Subject: [PATCH 3/9] Try to fix gcc error --- src/json.cpp | 2 +- src/json.h | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/json.cpp b/src/json.cpp index 48ca85612..d4b4a2041 100644 --- a/src/json.cpp +++ b/src/json.cpp @@ -7,7 +7,7 @@ namespace JSON { static constexpr const char* value_names[] = {"string", "number", "bool", "null"}; -static_assert(std::size(value_names) == std::variant_size_v); +static_assert(std::size(value_names) == Value::type_count_v); struct JSON { JSON(Element& element, std::string_view document); diff --git a/src/json.h b/src/json.h index dd84b01f5..e5452c579 100644 --- a/src/json.h +++ b/src/json.h @@ -13,12 +13,13 @@ struct type_mismatch { // When a file has one type, bu size_t seen, expected; }; -struct Value : std::variant { +struct Value : private std::variant { using std::variant::variant; + static constexpr size_t type_count_v = std::variant_size_v; // This will generate a descriptive error when the types don't match template - T Get() { + T Get() const { try { return std::get(*this); } catch (const std::bad_variant_access&) { @@ -26,11 +27,11 @@ struct Value : std::variant { } } - operator std::string() { return std::string{Get()}; } - operator double() { return Get(); } - operator float() { return static_cast(Get()); } - operator int() { return static_cast(Get()); } - operator bool() { return Get(); } + operator std::string() const { return std::string{Get()}; } + operator double() const { return Get(); } + operator float() const { return static_cast(Get()); } + operator int() const { return static_cast(Get()); } + operator bool() const { return Get(); } }; struct Element { From 18323773f5a6b63b454ab8a94a42e49b9fa82fda Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Mon, 16 Dec 2024 15:47:07 -0800 Subject: [PATCH 4/9] GCC error test --- src/json.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/json.h b/src/json.h index e5452c579..c5ba308da 100644 --- a/src/json.h +++ b/src/json.h @@ -32,6 +32,7 @@ struct Value : private std::variant(Get()); } operator int() const { return static_cast(Get()); } operator bool() const { return Get(); } + operator char() const = delete; // To avoid ambiguity when converting to std::string }; struct Element { From 56b9ae0e4de2061f634c2dfee511197475cd4e19 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Mon, 16 Dec 2024 17:28:18 -0800 Subject: [PATCH 5/9] Try again --- src/json.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/json.h b/src/json.h index c5ba308da..143072e81 100644 --- a/src/json.h +++ b/src/json.h @@ -32,7 +32,7 @@ struct Value : private std::variant(Get()); } operator int() const { return static_cast(Get()); } operator bool() const { return Get(); } - operator char() const = delete; // To avoid ambiguity when converting to std::string + explicit operator char() const = delete; // To avoid ambiguity when converting to std::string }; struct Element { From 59b5c300701155bd4988ab82ff092bee4a5286d8 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Tue, 17 Dec 2024 13:53:43 -0800 Subject: [PATCH 6/9] Revert "Try again" This reverts commit 56b9ae0e4de2061f634c2dfee511197475cd4e19. --- src/json.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/json.h b/src/json.h index 143072e81..c5ba308da 100644 --- a/src/json.h +++ b/src/json.h @@ -32,7 +32,7 @@ struct Value : private std::variant(Get()); } operator int() const { return static_cast(Get()); } operator bool() const { return Get(); } - explicit operator char() const = delete; // To avoid ambiguity when converting to std::string + operator char() const = delete; // To avoid ambiguity when converting to std::string }; struct Element { From 51013003a3b25b1d8899a1640ca24aaf0804cc93 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Tue, 17 Dec 2024 13:53:49 -0800 Subject: [PATCH 7/9] Revert "GCC error test" This reverts commit 18323773f5a6b63b454ab8a94a42e49b9fa82fda. --- src/json.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/json.h b/src/json.h index c5ba308da..e5452c579 100644 --- a/src/json.h +++ b/src/json.h @@ -32,7 +32,6 @@ struct Value : private std::variant(Get()); } operator int() const { return static_cast(Get()); } operator bool() const { return Get(); } - operator char() const = delete; // To avoid ambiguity when converting to std::string }; struct Element { From ee41eb521d4f28f95ae627799a9eee3a44333abd Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Tue, 17 Dec 2024 13:53:53 -0800 Subject: [PATCH 8/9] Reapply "GCC error test" This reverts commit 51013003a3b25b1d8899a1640ca24aaf0804cc93. --- src/json.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/json.h b/src/json.h index e5452c579..c5ba308da 100644 --- a/src/json.h +++ b/src/json.h @@ -32,6 +32,7 @@ struct Value : private std::variant(Get()); } operator int() const { return static_cast(Get()); } operator bool() const { return Get(); } + operator char() const = delete; // To avoid ambiguity when converting to std::string }; struct Element { From ce9d5c6f8f738cacd69cc7450483a7d74e2befc3 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Tue, 17 Dec 2024 13:55:11 -0800 Subject: [PATCH 9/9] Revert back to JSON::Get --- src/config.cpp | 164 ++++++++++++++++++++++++------------------------- src/json.cpp | 2 +- src/json.h | 28 +++------ 3 files changed, 92 insertions(+), 102 deletions(-) diff --git a/src/config.cpp b/src/config.cpp index 8fc150f48..459da0d68 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -23,7 +23,7 @@ struct ProviderOptions_Element : JSON::Element { explicit ProviderOptions_Element(Config::ProviderOptions& v) : v_{v} {} void OnValue(std::string_view name, JSON::Value value) override { - v_.options.emplace_back(name, value); + v_.options.emplace_back(name, JSON::Get(value)); } private: @@ -67,33 +67,33 @@ struct SessionOptions_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "log_id") - v_.log_id = value; + v_.log_id = JSON::Get(value); else if (name == "enable_profiling") - v_.enable_profiling = value; + v_.enable_profiling = JSON::Get(value); else if (name == "ep_context_embed_mode") - v_.ep_context_embed_mode = value; + v_.ep_context_embed_mode = JSON::Get(value); else if (name == "ep_context_file_path") - v_.ep_context_file_path = value; + v_.ep_context_file_path = JSON::Get(value); else if (name == "intra_op_num_threads") - v_.intra_op_num_threads = value; + v_.intra_op_num_threads = static_cast(JSON::Get(value)); else if (name == "inter_op_num_threads") - v_.inter_op_num_threads = value; + v_.inter_op_num_threads = static_cast(JSON::Get(value)); else if (name == "log_severity_level") - v_.log_severity_level = value; + v_.log_severity_level = static_cast(JSON::Get(value)); else if (name == "enable_cpu_mem_arena") - v_.enable_cpu_mem_arena = value; + v_.enable_cpu_mem_arena = JSON::Get(value); else if (name == "enable_mem_pattern") - v_.enable_mem_pattern = value; + v_.enable_mem_pattern = JSON::Get(value); else if (name == "disable_cpu_ep_fallback") - v_.disable_cpu_ep_fallback = value; + v_.disable_cpu_ep_fallback = JSON::Get(value); else if (name == "disable_quant_qdq") - v_.disable_quant_qdq = value; + v_.disable_quant_qdq = JSON::Get(value); else if (name == "enable_quant_qdq_cleanup") - v_.enable_quant_qdq_cleanup = value; + v_.enable_quant_qdq_cleanup = JSON::Get(value); else if (name == "ep_context_enable") - v_.ep_context_enable = value; + v_.ep_context_enable = JSON::Get(value); else if (name == "use_env_allocators") - v_.use_env_allocators = value; + v_.use_env_allocators = JSON::Get(value); else throw JSON::unknown_value_error{}; } @@ -114,7 +114,7 @@ struct EncoderDecoderInit_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = value; + v_.filename = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -128,27 +128,27 @@ struct Inputs_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "input_ids") { - v_.input_ids = value; + v_.input_ids = JSON::Get(value); } else if (name == "inputs_embeds") { - v_.embeddings = value; + v_.embeddings = JSON::Get(value); } else if (name == "position_ids") { - v_.position_ids = value; + v_.position_ids = JSON::Get(value); } else if (name == "attention_mask") { - v_.attention_mask = value; + v_.attention_mask = JSON::Get(value); } else if (name == "past_key_names") { - v_.past_key_names = value; + v_.past_key_names = JSON::Get(value); } else if (name == "past_value_names") { - v_.past_value_names = value; + v_.past_value_names = JSON::Get(value); } else if (name == "past_names") { - v_.past_names = value; + v_.past_names = JSON::Get(value); } else if (name == "cross_past_key_names") { - v_.cross_past_key_names = value; + v_.cross_past_key_names = JSON::Get(value); } else if (name == "cross_past_value_names") { - v_.cross_past_value_names = value; + v_.cross_past_value_names = JSON::Get(value); } else if (name == "current_sequence_length") { - v_.current_sequence_length = value; + v_.current_sequence_length = JSON::Get(value); } else if (name == "past_sequence_length") { - v_.past_sequence_length = value; + v_.past_sequence_length = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -162,17 +162,17 @@ struct Outputs_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "logits") { - v_.logits = value; + v_.logits = JSON::Get(value); } else if (name == "present_key_names") { - v_.present_key_names = value; + v_.present_key_names = JSON::Get(value); } else if (name == "present_value_names") { - v_.present_value_names = value; + v_.present_value_names = JSON::Get(value); } else if (name == "present_names") { - v_.present_names = value; + v_.present_names = JSON::Get(value); } else if (name == "cross_present_key_names") { - v_.cross_present_key_names = value; + v_.cross_present_key_names = JSON::Get(value); } else if (name == "cross_present_value_names") { - v_.cross_present_value_names = value; + v_.cross_present_value_names = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -185,7 +185,7 @@ struct StringArray_Element : JSON::Element { explicit StringArray_Element(std::vector& v) : v_{v} {} void OnValue(std::string_view name, JSON::Value value) override { - v_.push_back(value); + v_.push_back(std::string{JSON::Get(value)}); } private: @@ -196,7 +196,7 @@ struct StringStringMap_Element : JSON::Element { explicit StringStringMap_Element(std::unordered_map& v) : v_{v} {} void OnValue(std::string_view name, JSON::Value value) override { - v_[std::string(name)] = value; + v_[std::string(name)] = std::string(JSON::Get(value)); } private: @@ -208,11 +208,11 @@ struct PipelineModel_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = value; + v_.filename = JSON::Get(value); } else if (name == "run_on_prompt") { - v_.run_on_prompt = value; + v_.run_on_prompt = JSON::Get(value); } else if (name == "run_on_token_gen") { - v_.run_on_token_gen = value; + v_.run_on_token_gen = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -276,17 +276,17 @@ struct Decoder_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = value; + v_.filename = JSON::Get(value); } else if (name == "hidden_size") { - v_.hidden_size = value; + v_.hidden_size = static_cast(JSON::Get(value)); } else if (name == "num_attention_heads") { - v_.num_attention_heads = value; + v_.num_attention_heads = static_cast(JSON::Get(value)); } else if (name == "num_key_value_heads") { - v_.num_key_value_heads = value; + v_.num_key_value_heads = static_cast(JSON::Get(value)); } else if (name == "num_hidden_layers") { - v_.num_hidden_layers = value; + v_.num_hidden_layers = static_cast(JSON::Get(value)); } else if (name == "head_size") { - v_.head_size = value; + v_.head_size = static_cast(JSON::Get(value)); } else throw JSON::unknown_value_error{}; } @@ -323,9 +323,9 @@ struct VisionInputs_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "pixel_values") { - v_.pixel_values = value; + v_.pixel_values = JSON::Get(value); } else if (name == "image_sizes") { - v_.image_sizes = value; + v_.image_sizes = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -339,7 +339,7 @@ struct VisionOutputs_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "image_features") { - v_.image_features = value; + v_.image_features = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -353,7 +353,7 @@ struct Vision_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = value; + v_.filename = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -377,7 +377,7 @@ struct Eos_Array_Element : JSON::Element { explicit Eos_Array_Element(Config::Model& v) : v_{v} {} void OnValue(std::string_view name, JSON::Value value) override { - v_.eos_token_ids.push_back(value); + v_.eos_token_ids.push_back(static_cast(JSON::Get(value))); } void OnComplete(bool empty) override { @@ -401,9 +401,9 @@ struct EmbeddingInputs_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "input_ids") { - v_.input_ids = value; + v_.input_ids = JSON::Get(value); } else if (name == "image_features") { - v_.image_features = value; + v_.image_features = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -417,7 +417,7 @@ struct EmbeddingOutputs_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "inputs_embeds") { - v_.embeddings = value; + v_.embeddings = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -431,7 +431,7 @@ struct Embedding_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = value; + v_.filename = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -458,16 +458,16 @@ struct PromptTemplates_Element : JSON::Element { // if one of templates is given in json, then any non-specified template will be default "{Content}" if (name == "assistant") { EnsureAvailable(); - v_->assistant = value; + v_->assistant = JSON::Get(value); } else if (name == "prompt") { EnsureAvailable(); - v_->prompt = value; + v_->prompt = JSON::Get(value); } else if (name == "system") { EnsureAvailable(); - v_->system = value; + v_->system = JSON::Get(value); } else if (name == "user") { EnsureAvailable(); - v_->user = value; + v_->user = JSON::Get(value); } else { throw JSON::unknown_value_error{}; } @@ -488,21 +488,21 @@ struct Model_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "type") { - v_.type = value; + v_.type = JSON::Get(value); } else if (name == "vocab_size") { - v_.vocab_size = value; + v_.vocab_size = static_cast(JSON::Get(value)); } else if (name == "context_length") { - v_.context_length = value; + v_.context_length = static_cast(JSON::Get(value)); } else if (name == "pad_token_id") { - v_.pad_token_id = value; + v_.pad_token_id = static_cast(JSON::Get(value)); } else if (name == "eos_token_id") { - v_.eos_token_id = value; + v_.eos_token_id = static_cast(JSON::Get(value)); } else if (name == "bos_token_id") { - v_.bos_token_id = value; + v_.bos_token_id = static_cast(JSON::Get(value)); } else if (name == "decoder_start_token_id") { - v_.decoder_start_token_id = value; + v_.decoder_start_token_id = static_cast(JSON::Get(value)); } else if (name == "sep_token_id") { - v_.sep_token_id = value; + v_.sep_token_id = static_cast(JSON::Get(value)); } else throw JSON::unknown_value_error{}; } @@ -547,39 +547,39 @@ struct Search_Element : JSON::Element { void OnValue(std::string_view name, JSON::Value value) override { if (name == "min_length") { - v_.min_length = value; + v_.min_length = static_cast(JSON::Get(value)); } else if (name == "max_length") { - v_.max_length = value; + v_.max_length = static_cast(JSON::Get(value)); } else if (name == "batch_size") { - v_.batch_size = value; + v_.batch_size = static_cast(JSON::Get(value)); } else if (name == "num_beams") { - v_.num_beams = value; + v_.num_beams = static_cast(JSON::Get(value)); } else if (name == "num_return_sequences") { - v_.num_return_sequences = value; + v_.num_return_sequences = static_cast(JSON::Get(value)); } else if (name == "top_k") { - v_.top_k = value; + v_.top_k = static_cast(JSON::Get(value)); } else if (name == "top_p") { - v_.top_p = value; + v_.top_p = static_cast(JSON::Get(value)); } else if (name == "temperature") { - v_.temperature = value; + v_.temperature = static_cast(JSON::Get(value)); } else if (name == "repetition_penalty") { - v_.repetition_penalty = value; + v_.repetition_penalty = static_cast(JSON::Get(value)); } else if (name == "length_penalty") { - v_.length_penalty = value; + v_.length_penalty = static_cast(JSON::Get(value)); } else if (name == "no_repeat_ngram_size") { - v_.no_repeat_ngram_size = value; + v_.no_repeat_ngram_size = static_cast(JSON::Get(value)); } else if (name == "diversity_penalty") { - v_.diversity_penalty = value; + v_.diversity_penalty = static_cast(JSON::Get(value)); } else if (name == "length_penalty") { - v_.length_penalty = value; + v_.length_penalty = static_cast(JSON::Get(value)); } else if (name == "random_seed") { - v_.random_seed = value; + v_.random_seed = static_cast(JSON::Get(value)); } else if (name == "do_sample") { - v_.do_sample = value; + v_.do_sample = JSON::Get(value); } else if (name == "past_present_share_buffer") { - v_.past_present_share_buffer = value; + v_.past_present_share_buffer = JSON::Get(value); } else if (name == "early_stopping") { - v_.early_stopping = value; + v_.early_stopping = JSON::Get(value); } else throw JSON::unknown_value_error{}; } diff --git a/src/json.cpp b/src/json.cpp index d4b4a2041..bd2c2aefb 100644 --- a/src/json.cpp +++ b/src/json.cpp @@ -7,7 +7,7 @@ namespace JSON { static constexpr const char* value_names[] = {"string", "number", "bool", "null"}; -static_assert(std::size(value_names) == Value::type_count_v); +static_assert(std::size(value_names) == std::variant_size_v); struct JSON { JSON(Element& element, std::string_view document); diff --git a/src/json.h b/src/json.h index c5ba308da..b489a2ad3 100644 --- a/src/json.h +++ b/src/json.h @@ -13,27 +13,17 @@ struct type_mismatch { // When a file has one type, bu size_t seen, expected; }; -struct Value : private std::variant { - using std::variant::variant; - static constexpr size_t type_count_v = std::variant_size_v; +using Value = std::variant; - // This will generate a descriptive error when the types don't match - template - T Get() const { - try { - return std::get(*this); - } catch (const std::bad_variant_access&) { - throw type_mismatch{index(), Value{T{}}.index()}; - } +// To see descriptive errors when types don't match, use this instead of std::get +template +T Get(Value& var) { + try { + return std::get(var); + } catch (const std::bad_variant_access&) { + throw type_mismatch{var.index(), Value{T{}}.index()}; } - - operator std::string() const { return std::string{Get()}; } - operator double() const { return Get(); } - operator float() const { return static_cast(Get()); } - operator int() const { return static_cast(Get()); } - operator bool() const { return Get(); } - operator char() const = delete; // To avoid ambiguity when converting to std::string -}; +} struct Element { virtual void OnComplete(bool empty) {} // Called when parsing for this element is finished (empty is true when it's an empty element)