Skip to content

Commit

Permalink
Add protocol 3.0 support (#360)
Browse files Browse the repository at this point in the history
Add support for the `input_language` property of `Parse` and `Execute`.
  • Loading branch information
elprans authored Nov 14, 2024
1 parent 4214a07 commit ac6d9cd
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 8 deletions.
27 changes: 27 additions & 0 deletions edgedb-protocol/src/client_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ pub struct Parse {
pub expected_cardinality: Cardinality,
pub command_text: String,
pub state: State,
pub input_language: InputLanguage,
}

#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -135,6 +136,7 @@ pub struct Execute1 {
pub input_typedesc_id: Uuid,
pub output_typedesc_id: Uuid,
pub arguments: Bytes,
pub input_language: InputLanguage,
}

#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -170,6 +172,12 @@ pub enum DescribeAspect {
DataDescription = 0x54,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum InputLanguage {
EdgeQL = 0x45,
SQL = 0x53,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum IoFormat {
Binary = 0x62,
Expand Down Expand Up @@ -629,6 +637,9 @@ impl Encode for Execute1 {
buf.put_u64(self.allowed_capabilities.bits());
buf.put_u64(self.compilation_flags.bits());
buf.put_u64(self.implicit_limit.unwrap_or(0));
if buf.proto().is_multilingual() {
buf.put_u8(self.input_language as u8);
}
buf.put_u8(self.output_format as u8);
buf.put_u8(self.expected_cardinality as u8);
self.command_text.encode(buf)?;
Expand Down Expand Up @@ -663,6 +674,11 @@ impl Decode for Execute1 {
0 => None,
val => Some(val),
};
let input_language = if buf.proto().is_multilingual() {
TryFrom::try_from(buf.get_u8())?
} else {
InputLanguage::EdgeQL
};
let output_format = match buf.get_u8() {
0x62 => IoFormat::Binary,
0x6a => IoFormat::Json,
Expand Down Expand Up @@ -690,6 +706,7 @@ impl Decode for Execute1 {
input_typedesc_id,
output_typedesc_id,
arguments,
input_language,
})
}
}
Expand Down Expand Up @@ -790,6 +807,7 @@ impl Parse {
expected_cardinality: opts.expected_cardinality,
command_text: query.into(),
state,
input_language: opts.input_language,
}
}
}
Expand Down Expand Up @@ -851,6 +869,11 @@ impl Decode for Parse {
0 => None,
val => Some(val),
};
let input_language = if buf.proto().is_multilingual() {
TryFrom::try_from(buf.get_u8())?
} else {
InputLanguage::EdgeQL
};
let output_format = match buf.get_u8() {
0x62 => IoFormat::Binary,
0x6a => IoFormat::Json,
Expand All @@ -872,6 +895,7 @@ impl Decode for Parse {
expected_cardinality,
command_text,
state,
input_language,
})
}
}
Expand All @@ -894,6 +918,9 @@ impl Encode for Parse {
buf.put_u64(self.allowed_capabilities.bits());
buf.put_u64(self.compilation_flags.bits());
buf.put_u64(self.implicit_limit.unwrap_or(0));
if buf.proto().is_multilingual() {
buf.put_u8(self.input_language as u8);
}
buf.put_u8(self.output_format as u8);
buf.put_u8(self.expected_cardinality as u8);
self.command_text.encode(buf)?;
Expand Down
14 changes: 13 additions & 1 deletion edgedb-protocol/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::encoding::Input;
use crate::errors::DecodeError;
use crate::features::ProtocolVersion;

pub use crate::client_message::IoFormat;
pub use crate::client_message::{InputLanguage, IoFormat};

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Cardinality {
Expand Down Expand Up @@ -52,6 +52,7 @@ pub struct CompilationOptions {
pub explicit_objectids: bool,
pub io_format: IoFormat,
pub expected_cardinality: Cardinality,
pub input_language: InputLanguage,
}

#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -108,6 +109,17 @@ impl Cardinality {
}
}

impl std::convert::TryFrom<u8> for InputLanguage {
type Error = errors::DecodeError;
fn try_from(input_language: u8) -> Result<Self, errors::DecodeError> {
match input_language {
0x45 => Ok(InputLanguage::EdgeQL),
0x53 => Ok(InputLanguage::SQL),
_ => Err(errors::InvalidInputLanguage { input_language }.build()),
}
}
}

impl State {
pub fn empty() -> State {
State {
Expand Down
5 changes: 5 additions & 0 deletions edgedb-protocol/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ pub enum DecodeError {
backtrace: Backtrace,
cardinality: u8,
},
#[snafu(display("unsupported input language: {:x}", input_language))]
InvalidInputLanguage {
backtrace: Backtrace,
input_language: u8,
},
#[snafu(display("unsupported capability: {:b}", capabilities))]
InvalidCapabilities {
backtrace: Backtrace,
Expand Down
8 changes: 7 additions & 1 deletion edgedb-protocol/src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub struct ProtocolVersion {
impl ProtocolVersion {
pub fn current() -> ProtocolVersion {
ProtocolVersion {
major_ver: 2,
major_ver: 3,
minor_ver: 0,
}
}
Expand All @@ -26,6 +26,9 @@ impl ProtocolVersion {
pub fn is_2(&self) -> bool {
self.major_ver >= 2
}
pub fn is_3(&self) -> bool {
self.major_ver >= 3
}
pub fn supports_inline_typenames(&self) -> bool {
self.version_tuple() >= (0, 9)
}
Expand All @@ -39,6 +42,9 @@ impl ProtocolVersion {
// portocols.
!self.is_1()
}
pub fn is_multilingual(&self) -> bool {
self.is_at_least(3, 0)
}
pub fn is_at_least(&self, major_ver: u16, minor_ver: u16) -> bool {
self.major_ver > major_ver || self.major_ver == major_ver && self.minor_ver >= minor_ver
}
Expand Down
14 changes: 11 additions & 3 deletions edgedb-protocol/src/query_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,16 @@ impl QueryResult for Value {
ctx.build_codec(root_pos)
}
fn decode(codec: &mut Arc<dyn Codec>, msg: &Bytes) -> Result<Self, Error> {
codec
.decode(msg)
.map_err(ProtocolEncodingError::with_source)
let res = codec.decode(msg);

match res {
Ok(v) => Ok(v),
Err(e) => {
if let Some(bt) = snafu::ErrorCompat::backtrace(&e) {
eprintln!("{bt}");
}
Err(ProtocolEncodingError::with_source(e))
}
}
}
}
58 changes: 58 additions & 0 deletions edgedb-protocol/tests/client_messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::error::Error;

use bytes::{Bytes, BytesMut};
use edgedb_protocol::common::InputLanguage;
use uuid::Uuid;

use edgedb_protocol::client_message::OptimisticExecute;
Expand Down Expand Up @@ -94,6 +95,7 @@ fn parse() -> Result<(), Box<dyn Error>> {
allowed_capabilities: Capabilities::MODIFICATIONS,
compilation_flags: CompilationFlags::INJECT_OUTPUT_TYPE_NAMES,
implicit_limit: Some(77),
input_language: InputLanguage::EdgeQL,
output_format: IoFormat::Binary,
expected_cardinality: Cardinality::AtMostOne,
command_text: String::from("SELECT 1;"),
Expand All @@ -108,6 +110,31 @@ fn parse() -> Result<(), Box<dyn Error>> {
Ok(())
}

#[test]
fn parse3() -> Result<(), Box<dyn Error>> {
encoding_eq_ver!(
3,
0,
ClientMessage::Parse(Parse {
annotations: HashMap::new(),
allowed_capabilities: Capabilities::MODIFICATIONS,
compilation_flags: CompilationFlags::INJECT_OUTPUT_TYPE_NAMES,
implicit_limit: Some(77),
input_language: InputLanguage::EdgeQL,
output_format: IoFormat::Binary,
expected_cardinality: Cardinality::AtMostOne,
command_text: String::from("SELECT 1;"),
state: State {
typedesc_id: Uuid::from_u128(0),
data: Bytes::from(""),
},
}),
b"P\0\0\0B\0\0\0\0\0\0\0\0\0\x01\0\0\0\0\0\0\0\x02\0\0\0\0\0\0\0MEbo\
\0\0\0\tSELECT 1;\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"
);
Ok(())
}

#[test]
fn describe_statement() -> Result<(), Box<dyn Error>> {
encoding_eq!(
Expand Down Expand Up @@ -146,6 +173,7 @@ fn execute1() -> Result<(), Box<dyn Error>> {
allowed_capabilities: Capabilities::MODIFICATIONS,
compilation_flags: CompilationFlags::INJECT_OUTPUT_TYPE_NAMES,
implicit_limit: Some(77),
input_language: InputLanguage::EdgeQL,
output_format: IoFormat::Binary,
expected_cardinality: Cardinality::AtMostOne,
command_text: String::from("SELECT 1;"),
Expand All @@ -165,6 +193,36 @@ fn execute1() -> Result<(), Box<dyn Error>> {
Ok(())
}

#[test]
fn execute3() -> Result<(), Box<dyn Error>> {
encoding_eq_ver!(
3,
0,
ClientMessage::Execute1(Execute1 {
annotations: HashMap::new(),
allowed_capabilities: Capabilities::MODIFICATIONS,
compilation_flags: CompilationFlags::INJECT_OUTPUT_TYPE_NAMES,
implicit_limit: Some(77),
input_language: InputLanguage::EdgeQL,
output_format: IoFormat::Binary,
expected_cardinality: Cardinality::AtMostOne,
command_text: String::from("SELECT 1;"),
state: State {
typedesc_id: Uuid::from_u128(0),
data: Bytes::from(""),
},
input_typedesc_id: Uuid::from_u128(123),
output_typedesc_id: Uuid::from_u128(456),
arguments: Bytes::new(),
}),
b"O\0\0\0f\0\0\0\0\0\0\0\0\0\x01\0\0\0\0\0\0\0\x02\0\0\0\0\0\0\0MEbo\
\0\0\0\tSELECT 1;\
\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\
\0\0\0{\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\xc8\0\0\0\0"
);
Ok(())
}

#[test]
fn optimistic_execute() -> Result<(), Box<dyn Error>> {
encoding_eq_ver!(
Expand Down
9 changes: 8 additions & 1 deletion edgedb-tokio/src/raw/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use edgedb_protocol::client_message::{ClientMessage, Parse, Prepare};
use edgedb_protocol::client_message::{DescribeAspect, DescribeStatement};
use edgedb_protocol::client_message::{Execute0, Execute1};
use edgedb_protocol::common::CompilationOptions;
use edgedb_protocol::common::{Capabilities, Cardinality, IoFormat};
use edgedb_protocol::common::{Capabilities, Cardinality, InputLanguage, IoFormat};
use edgedb_protocol::descriptors::Typedesc;
use edgedb_protocol::features::ProtocolVersion;
use edgedb_protocol::model::Uuid;
Expand Down Expand Up @@ -234,6 +234,7 @@ impl Connection {
allowed_capabilities: opts.allow_capabilities,
compilation_flags: opts.flags(),
implicit_limit: opts.implicit_limit,
input_language: opts.input_language,
output_format: opts.io_format,
expected_cardinality: opts.expected_cardinality,
command_text: query.into(),
Expand Down Expand Up @@ -367,6 +368,7 @@ impl Connection {
allowed_capabilities: opts.allow_capabilities,
compilation_flags: opts.flags(),
implicit_limit: opts.implicit_limit,
input_language: opts.input_language,
output_format: opts.io_format,
expected_cardinality: opts.expected_cardinality,
command_text: query.into(),
Expand Down Expand Up @@ -422,6 +424,7 @@ impl Connection {
allowed_capabilities: opts.allow_capabilities,
compilation_flags: opts.flags(),
implicit_limit: opts.implicit_limit,
input_language: opts.input_language,
output_format: opts.io_format,
expected_cardinality: opts.expected_cardinality,
command_text: query.into(),
Expand Down Expand Up @@ -475,6 +478,7 @@ impl Connection {
allowed_capabilities: opts.allow_capabilities,
compilation_flags: opts.flags(),
implicit_limit: opts.implicit_limit,
input_language: opts.input_language,
output_format: opts.io_format,
expected_cardinality: opts.expected_cardinality,
command_text: query.into(),
Expand Down Expand Up @@ -575,6 +579,7 @@ impl Connection {
explicit_objectids: true,
allow_capabilities,
io_format,
input_language: InputLanguage::EdgeQL,
expected_cardinality: cardinality,
};
let desc = self.parse(&flags, query, state).await?;
Expand Down Expand Up @@ -631,6 +636,7 @@ impl Connection {
implicit_typeids: false,
explicit_objectids: true,
allow_capabilities,
input_language: InputLanguage::EdgeQL,
io_format: IoFormat::Binary,
expected_cardinality: Cardinality::Many,
};
Expand Down Expand Up @@ -686,6 +692,7 @@ impl PoolConnection {
implicit_typeids: false,
explicit_objectids: false,
allow_capabilities: Capabilities::ALL,
input_language: InputLanguage::EdgeQL,
io_format: IoFormat::Binary,
expected_cardinality: Cardinality::Many, // no result is unsupported
};
Expand Down
3 changes: 2 additions & 1 deletion edgedb-tokio/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;

use bytes::BytesMut;
use edgedb_protocol::common::CompilationOptions;
use edgedb_protocol::common::{Capabilities, Cardinality, IoFormat};
use edgedb_protocol::common::{Capabilities, Cardinality, InputLanguage, IoFormat};
use edgedb_protocol::model::Json;
use edgedb_protocol::query_arg::{Encoder, QueryArgs};
use edgedb_protocol::QueryResult;
Expand Down Expand Up @@ -385,6 +385,7 @@ impl Transaction {
implicit_typeids: false,
explicit_objectids: true,
allow_capabilities: Capabilities::MODIFICATIONS,
input_language: InputLanguage::EdgeQL,
io_format: IoFormat::Binary,
expected_cardinality: Cardinality::Many,
};
Expand Down
3 changes: 2 additions & 1 deletion edgedb-tokio/tests/func/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use bytes::Bytes;

use edgedb_protocol::common::Capabilities;
use edgedb_protocol::common::{Cardinality, CompilationOptions, IoFormat};
use edgedb_protocol::common::{Cardinality, CompilationOptions, InputLanguage, IoFormat};
use edgedb_tokio::raw::{Pool, PoolState};

use crate::server::SERVER;
Expand All @@ -21,6 +21,7 @@ async fn poll_connect() -> anyhow::Result<()> {
implicit_typeids: false,
allow_capabilities: Capabilities::empty(),
explicit_objectids: true,
input_language: InputLanguage::EdgeQL,
io_format: IoFormat::Binary,
expected_cardinality: Cardinality::Many,
};
Expand Down

0 comments on commit ac6d9cd

Please sign in to comment.