diff --git a/edgedb-protocol/src/client_message.rs b/edgedb-protocol/src/client_message.rs index f5e06c09..b4ded6bc 100644 --- a/edgedb-protocol/src/client_message.rs +++ b/edgedb-protocol/src/client_message.rs @@ -106,6 +106,7 @@ pub struct Parse { pub expected_cardinality: Cardinality, pub command_text: String, pub state: State, + pub input_language: InputLanguage, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -135,6 +136,7 @@ pub struct Execute1 { pub input_typedesc_id: Uuid, pub output_typedesc_id: Uuid, pub arguments: Bytes, + pub input_language: InputLanguage, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -170,6 +172,12 @@ pub enum DescribeAspect { DataDescription = 0x54, } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum InputLanguage { + EdgeQL = 0x45, + SQL = 0x53, +} + #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum IoFormat { Binary = 0x62, @@ -629,6 +637,9 @@ impl Encode for Execute1 { buf.put_u64(self.allowed_capabilities.bits()); buf.put_u64(self.compilation_flags.bits()); buf.put_u64(self.implicit_limit.unwrap_or(0)); + if buf.proto().is_multilingual() { + buf.put_u8(self.input_language as u8); + } buf.put_u8(self.output_format as u8); buf.put_u8(self.expected_cardinality as u8); self.command_text.encode(buf)?; @@ -663,6 +674,11 @@ impl Decode for Execute1 { 0 => None, val => Some(val), }; + let input_language = if buf.proto().is_multilingual() { + TryFrom::try_from(buf.get_u8())? + } else { + InputLanguage::EdgeQL + }; let output_format = match buf.get_u8() { 0x62 => IoFormat::Binary, 0x6a => IoFormat::Json, @@ -690,6 +706,7 @@ impl Decode for Execute1 { input_typedesc_id, output_typedesc_id, arguments, + input_language, }) } } @@ -790,6 +807,7 @@ impl Parse { expected_cardinality: opts.expected_cardinality, command_text: query.into(), state, + input_language: opts.input_language, } } } @@ -851,6 +869,11 @@ impl Decode for Parse { 0 => None, val => Some(val), }; + let input_language = if buf.proto().is_multilingual() { + TryFrom::try_from(buf.get_u8())? + } else { + InputLanguage::EdgeQL + }; let output_format = match buf.get_u8() { 0x62 => IoFormat::Binary, 0x6a => IoFormat::Json, @@ -872,6 +895,7 @@ impl Decode for Parse { expected_cardinality, command_text, state, + input_language, }) } } @@ -894,6 +918,9 @@ impl Encode for Parse { buf.put_u64(self.allowed_capabilities.bits()); buf.put_u64(self.compilation_flags.bits()); buf.put_u64(self.implicit_limit.unwrap_or(0)); + if buf.proto().is_multilingual() { + buf.put_u8(self.input_language as u8); + } buf.put_u8(self.output_format as u8); buf.put_u8(self.expected_cardinality as u8); self.command_text.encode(buf)?; diff --git a/edgedb-protocol/src/common.rs b/edgedb-protocol/src/common.rs index 270d9cb4..b73bff57 100644 --- a/edgedb-protocol/src/common.rs +++ b/edgedb-protocol/src/common.rs @@ -11,7 +11,7 @@ use crate::encoding::Input; use crate::errors::DecodeError; use crate::features::ProtocolVersion; -pub use crate::client_message::IoFormat; +pub use crate::client_message::{InputLanguage, IoFormat}; #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum Cardinality { @@ -52,6 +52,7 @@ pub struct CompilationOptions { pub explicit_objectids: bool, pub io_format: IoFormat, pub expected_cardinality: Cardinality, + pub input_language: InputLanguage, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -108,6 +109,17 @@ impl Cardinality { } } +impl std::convert::TryFrom for InputLanguage { + type Error = errors::DecodeError; + fn try_from(input_language: u8) -> Result { + match input_language { + 0x45 => Ok(InputLanguage::EdgeQL), + 0x53 => Ok(InputLanguage::SQL), + _ => Err(errors::InvalidInputLanguage { input_language }.build()), + } + } +} + impl State { pub fn empty() -> State { State { diff --git a/edgedb-protocol/src/errors.rs b/edgedb-protocol/src/errors.rs index a81d8f12..67324626 100644 --- a/edgedb-protocol/src/errors.rs +++ b/edgedb-protocol/src/errors.rs @@ -35,6 +35,11 @@ pub enum DecodeError { backtrace: Backtrace, cardinality: u8, }, + #[snafu(display("unsupported input language: {:x}", input_language))] + InvalidInputLanguage { + backtrace: Backtrace, + input_language: u8, + }, #[snafu(display("unsupported capability: {:b}", capabilities))] InvalidCapabilities { backtrace: Backtrace, diff --git a/edgedb-protocol/src/features.rs b/edgedb-protocol/src/features.rs index c23a2b31..a71af0a3 100644 --- a/edgedb-protocol/src/features.rs +++ b/edgedb-protocol/src/features.rs @@ -7,7 +7,7 @@ pub struct ProtocolVersion { impl ProtocolVersion { pub fn current() -> ProtocolVersion { ProtocolVersion { - major_ver: 2, + major_ver: 3, minor_ver: 0, } } @@ -26,6 +26,9 @@ impl ProtocolVersion { pub fn is_2(&self) -> bool { self.major_ver >= 2 } + pub fn is_3(&self) -> bool { + self.major_ver >= 3 + } pub fn supports_inline_typenames(&self) -> bool { self.version_tuple() >= (0, 9) } @@ -39,6 +42,9 @@ impl ProtocolVersion { // portocols. !self.is_1() } + pub fn is_multilingual(&self) -> bool { + self.is_at_least(3, 0) + } pub fn is_at_least(&self, major_ver: u16, minor_ver: u16) -> bool { self.major_ver > major_ver || self.major_ver == major_ver && self.minor_ver >= minor_ver } diff --git a/edgedb-protocol/src/query_result.rs b/edgedb-protocol/src/query_result.rs index ed2cf5fd..5f163e64 100644 --- a/edgedb-protocol/src/query_result.rs +++ b/edgedb-protocol/src/query_result.rs @@ -52,8 +52,16 @@ impl QueryResult for Value { ctx.build_codec(root_pos) } fn decode(codec: &mut Arc, msg: &Bytes) -> Result { - codec - .decode(msg) - .map_err(ProtocolEncodingError::with_source) + let res = codec.decode(msg); + + match res { + Ok(v) => Ok(v), + Err(e) => { + if let Some(bt) = snafu::ErrorCompat::backtrace(&e) { + eprintln!("{bt}"); + } + Err(ProtocolEncodingError::with_source(e)) + } + } } } diff --git a/edgedb-protocol/tests/client_messages.rs b/edgedb-protocol/tests/client_messages.rs index 6cdf556b..96035c07 100644 --- a/edgedb-protocol/tests/client_messages.rs +++ b/edgedb-protocol/tests/client_messages.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::error::Error; use bytes::{Bytes, BytesMut}; +use edgedb_protocol::common::InputLanguage; use uuid::Uuid; use edgedb_protocol::client_message::OptimisticExecute; @@ -94,6 +95,7 @@ fn parse() -> Result<(), Box> { allowed_capabilities: Capabilities::MODIFICATIONS, compilation_flags: CompilationFlags::INJECT_OUTPUT_TYPE_NAMES, implicit_limit: Some(77), + input_language: InputLanguage::EdgeQL, output_format: IoFormat::Binary, expected_cardinality: Cardinality::AtMostOne, command_text: String::from("SELECT 1;"), @@ -108,6 +110,31 @@ fn parse() -> Result<(), Box> { Ok(()) } +#[test] +fn parse3() -> Result<(), Box> { + encoding_eq_ver!( + 3, + 0, + ClientMessage::Parse(Parse { + annotations: HashMap::new(), + allowed_capabilities: Capabilities::MODIFICATIONS, + compilation_flags: CompilationFlags::INJECT_OUTPUT_TYPE_NAMES, + implicit_limit: Some(77), + input_language: InputLanguage::EdgeQL, + output_format: IoFormat::Binary, + expected_cardinality: Cardinality::AtMostOne, + command_text: String::from("SELECT 1;"), + state: State { + typedesc_id: Uuid::from_u128(0), + data: Bytes::from(""), + }, + }), + b"P\0\0\0B\0\0\0\0\0\0\0\0\0\x01\0\0\0\0\0\0\0\x02\0\0\0\0\0\0\0MEbo\ + \0\0\0\tSELECT 1;\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0" + ); + Ok(()) +} + #[test] fn describe_statement() -> Result<(), Box> { encoding_eq!( @@ -146,6 +173,7 @@ fn execute1() -> Result<(), Box> { allowed_capabilities: Capabilities::MODIFICATIONS, compilation_flags: CompilationFlags::INJECT_OUTPUT_TYPE_NAMES, implicit_limit: Some(77), + input_language: InputLanguage::EdgeQL, output_format: IoFormat::Binary, expected_cardinality: Cardinality::AtMostOne, command_text: String::from("SELECT 1;"), @@ -165,6 +193,36 @@ fn execute1() -> Result<(), Box> { Ok(()) } +#[test] +fn execute3() -> Result<(), Box> { + encoding_eq_ver!( + 3, + 0, + ClientMessage::Execute1(Execute1 { + annotations: HashMap::new(), + allowed_capabilities: Capabilities::MODIFICATIONS, + compilation_flags: CompilationFlags::INJECT_OUTPUT_TYPE_NAMES, + implicit_limit: Some(77), + input_language: InputLanguage::EdgeQL, + output_format: IoFormat::Binary, + expected_cardinality: Cardinality::AtMostOne, + command_text: String::from("SELECT 1;"), + state: State { + typedesc_id: Uuid::from_u128(0), + data: Bytes::from(""), + }, + input_typedesc_id: Uuid::from_u128(123), + output_typedesc_id: Uuid::from_u128(456), + arguments: Bytes::new(), + }), + b"O\0\0\0f\0\0\0\0\0\0\0\0\0\x01\0\0\0\0\0\0\0\x02\0\0\0\0\0\0\0MEbo\ + \0\0\0\tSELECT 1;\ + \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\ + \0\0\0{\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\xc8\0\0\0\0" + ); + Ok(()) +} + #[test] fn optimistic_execute() -> Result<(), Box> { encoding_eq_ver!( diff --git a/edgedb-tokio/src/raw/queries.rs b/edgedb-tokio/src/raw/queries.rs index d44e773f..3858f075 100644 --- a/edgedb-tokio/src/raw/queries.rs +++ b/edgedb-tokio/src/raw/queries.rs @@ -9,7 +9,7 @@ use edgedb_protocol::client_message::{ClientMessage, Parse, Prepare}; use edgedb_protocol::client_message::{DescribeAspect, DescribeStatement}; use edgedb_protocol::client_message::{Execute0, Execute1}; use edgedb_protocol::common::CompilationOptions; -use edgedb_protocol::common::{Capabilities, Cardinality, IoFormat}; +use edgedb_protocol::common::{Capabilities, Cardinality, InputLanguage, IoFormat}; use edgedb_protocol::descriptors::Typedesc; use edgedb_protocol::features::ProtocolVersion; use edgedb_protocol::model::Uuid; @@ -234,6 +234,7 @@ impl Connection { allowed_capabilities: opts.allow_capabilities, compilation_flags: opts.flags(), implicit_limit: opts.implicit_limit, + input_language: opts.input_language, output_format: opts.io_format, expected_cardinality: opts.expected_cardinality, command_text: query.into(), @@ -367,6 +368,7 @@ impl Connection { allowed_capabilities: opts.allow_capabilities, compilation_flags: opts.flags(), implicit_limit: opts.implicit_limit, + input_language: opts.input_language, output_format: opts.io_format, expected_cardinality: opts.expected_cardinality, command_text: query.into(), @@ -422,6 +424,7 @@ impl Connection { allowed_capabilities: opts.allow_capabilities, compilation_flags: opts.flags(), implicit_limit: opts.implicit_limit, + input_language: opts.input_language, output_format: opts.io_format, expected_cardinality: opts.expected_cardinality, command_text: query.into(), @@ -475,6 +478,7 @@ impl Connection { allowed_capabilities: opts.allow_capabilities, compilation_flags: opts.flags(), implicit_limit: opts.implicit_limit, + input_language: opts.input_language, output_format: opts.io_format, expected_cardinality: opts.expected_cardinality, command_text: query.into(), @@ -575,6 +579,7 @@ impl Connection { explicit_objectids: true, allow_capabilities, io_format, + input_language: InputLanguage::EdgeQL, expected_cardinality: cardinality, }; let desc = self.parse(&flags, query, state).await?; @@ -631,6 +636,7 @@ impl Connection { implicit_typeids: false, explicit_objectids: true, allow_capabilities, + input_language: InputLanguage::EdgeQL, io_format: IoFormat::Binary, expected_cardinality: Cardinality::Many, }; @@ -686,6 +692,7 @@ impl PoolConnection { implicit_typeids: false, explicit_objectids: false, allow_capabilities: Capabilities::ALL, + input_language: InputLanguage::EdgeQL, io_format: IoFormat::Binary, expected_cardinality: Cardinality::Many, // no result is unsupported }; diff --git a/edgedb-tokio/src/transaction.rs b/edgedb-tokio/src/transaction.rs index 3ca645e2..eed97eaf 100644 --- a/edgedb-tokio/src/transaction.rs +++ b/edgedb-tokio/src/transaction.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use bytes::BytesMut; use edgedb_protocol::common::CompilationOptions; -use edgedb_protocol::common::{Capabilities, Cardinality, IoFormat}; +use edgedb_protocol::common::{Capabilities, Cardinality, InputLanguage, IoFormat}; use edgedb_protocol::model::Json; use edgedb_protocol::query_arg::{Encoder, QueryArgs}; use edgedb_protocol::QueryResult; @@ -385,6 +385,7 @@ impl Transaction { implicit_typeids: false, explicit_objectids: true, allow_capabilities: Capabilities::MODIFICATIONS, + input_language: InputLanguage::EdgeQL, io_format: IoFormat::Binary, expected_cardinality: Cardinality::Many, }; diff --git a/edgedb-tokio/tests/func/raw.rs b/edgedb-tokio/tests/func/raw.rs index 33468be8..da86f892 100644 --- a/edgedb-tokio/tests/func/raw.rs +++ b/edgedb-tokio/tests/func/raw.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use bytes::Bytes; use edgedb_protocol::common::Capabilities; -use edgedb_protocol::common::{Cardinality, CompilationOptions, IoFormat}; +use edgedb_protocol::common::{Cardinality, CompilationOptions, InputLanguage, IoFormat}; use edgedb_tokio::raw::{Pool, PoolState}; use crate::server::SERVER; @@ -21,6 +21,7 @@ async fn poll_connect() -> anyhow::Result<()> { implicit_typeids: false, allow_capabilities: Capabilities::empty(), explicit_objectids: true, + input_language: InputLanguage::EdgeQL, io_format: IoFormat::Binary, expected_cardinality: Cardinality::Many, };