Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add protocol 3.0 support #360

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions edgedb-protocol/src/client_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ pub struct Parse {
pub expected_cardinality: Cardinality,
pub command_text: String,
pub state: State,
pub input_language: InputLanguage,
}

#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -135,6 +136,7 @@ pub struct Execute1 {
pub input_typedesc_id: Uuid,
pub output_typedesc_id: Uuid,
pub arguments: Bytes,
pub input_language: InputLanguage,
}

#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -170,6 +172,12 @@ pub enum DescribeAspect {
DataDescription = 0x54,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum InputLanguage {
EdgeQL = 0x45,
SQL = 0x53,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum IoFormat {
Binary = 0x62,
Expand Down Expand Up @@ -629,6 +637,9 @@ impl Encode for Execute1 {
buf.put_u64(self.allowed_capabilities.bits());
buf.put_u64(self.compilation_flags.bits());
buf.put_u64(self.implicit_limit.unwrap_or(0));
if buf.proto().is_multilingual() {
buf.put_u8(self.input_language as u8);
}
buf.put_u8(self.output_format as u8);
buf.put_u8(self.expected_cardinality as u8);
self.command_text.encode(buf)?;
Expand Down Expand Up @@ -663,6 +674,11 @@ impl Decode for Execute1 {
0 => None,
val => Some(val),
};
let input_language = if buf.proto().is_multilingual() {
TryFrom::try_from(buf.get_u8())?
} else {
InputLanguage::EdgeQL
};
let output_format = match buf.get_u8() {
0x62 => IoFormat::Binary,
0x6a => IoFormat::Json,
Expand Down Expand Up @@ -690,6 +706,7 @@ impl Decode for Execute1 {
input_typedesc_id,
output_typedesc_id,
arguments,
input_language,
})
}
}
Expand Down Expand Up @@ -790,6 +807,7 @@ impl Parse {
expected_cardinality: opts.expected_cardinality,
command_text: query.into(),
state,
input_language: opts.input_language,
}
}
}
Expand Down Expand Up @@ -851,6 +869,11 @@ impl Decode for Parse {
0 => None,
val => Some(val),
};
let input_language = if buf.proto().is_multilingual() {
TryFrom::try_from(buf.get_u8())?
} else {
InputLanguage::EdgeQL
};
let output_format = match buf.get_u8() {
0x62 => IoFormat::Binary,
0x6a => IoFormat::Json,
Expand All @@ -872,6 +895,7 @@ impl Decode for Parse {
expected_cardinality,
command_text,
state,
input_language,
})
}
}
Expand All @@ -894,6 +918,9 @@ impl Encode for Parse {
buf.put_u64(self.allowed_capabilities.bits());
buf.put_u64(self.compilation_flags.bits());
buf.put_u64(self.implicit_limit.unwrap_or(0));
if buf.proto().is_multilingual() {
buf.put_u8(self.input_language as u8);
}
buf.put_u8(self.output_format as u8);
buf.put_u8(self.expected_cardinality as u8);
self.command_text.encode(buf)?;
Expand Down
14 changes: 13 additions & 1 deletion edgedb-protocol/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::encoding::Input;
use crate::errors::DecodeError;
use crate::features::ProtocolVersion;

pub use crate::client_message::IoFormat;
pub use crate::client_message::{InputLanguage, IoFormat};

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Cardinality {
Expand Down Expand Up @@ -52,6 +52,7 @@ pub struct CompilationOptions {
pub explicit_objectids: bool,
pub io_format: IoFormat,
pub expected_cardinality: Cardinality,
pub input_language: InputLanguage,
}

#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -108,6 +109,17 @@ impl Cardinality {
}
}

impl std::convert::TryFrom<u8> for InputLanguage {
type Error = errors::DecodeError;
fn try_from(input_language: u8) -> Result<Self, errors::DecodeError> {
match input_language {
0x45 => Ok(InputLanguage::EdgeQL),
0x53 => Ok(InputLanguage::SQL),
_ => Err(errors::InvalidInputLanguage { input_language }.build()),
}
}
}

impl State {
pub fn empty() -> State {
State {
Expand Down
5 changes: 5 additions & 0 deletions edgedb-protocol/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ pub enum DecodeError {
backtrace: Backtrace,
cardinality: u8,
},
#[snafu(display("unsupported input language: {:x}", input_language))]
InvalidInputLanguage {
backtrace: Backtrace,
input_language: u8,
},
#[snafu(display("unsupported capability: {:b}", capabilities))]
InvalidCapabilities {
backtrace: Backtrace,
Expand Down
8 changes: 7 additions & 1 deletion edgedb-protocol/src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub struct ProtocolVersion {
impl ProtocolVersion {
pub fn current() -> ProtocolVersion {
ProtocolVersion {
major_ver: 2,
major_ver: 3,
minor_ver: 0,
}
}
Expand All @@ -26,6 +26,9 @@ impl ProtocolVersion {
pub fn is_2(&self) -> bool {
self.major_ver >= 2
}
pub fn is_3(&self) -> bool {
self.major_ver >= 3
}
pub fn supports_inline_typenames(&self) -> bool {
self.version_tuple() >= (0, 9)
}
Expand All @@ -39,6 +42,9 @@ impl ProtocolVersion {
// portocols.
!self.is_1()
}
pub fn is_multilingual(&self) -> bool {
self.is_at_least(3, 0)
}
pub fn is_at_least(&self, major_ver: u16, minor_ver: u16) -> bool {
self.major_ver > major_ver || self.major_ver == major_ver && self.minor_ver >= minor_ver
}
Expand Down
14 changes: 11 additions & 3 deletions edgedb-protocol/src/query_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,16 @@ impl QueryResult for Value {
ctx.build_codec(root_pos)
}
fn decode(codec: &mut Arc<dyn Codec>, msg: &Bytes) -> Result<Self, Error> {
codec
.decode(msg)
.map_err(ProtocolEncodingError::with_source)
let res = codec.decode(msg);

match res {
Ok(v) => Ok(v),
Err(e) => {
if let Some(bt) = snafu::ErrorCompat::backtrace(&e) {
eprintln!("{bt}");
}
Err(ProtocolEncodingError::with_source(e))
}
}
}
}
58 changes: 58 additions & 0 deletions edgedb-protocol/tests/client_messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::error::Error;

use bytes::{Bytes, BytesMut};
use edgedb_protocol::common::InputLanguage;
use uuid::Uuid;

use edgedb_protocol::client_message::OptimisticExecute;
Expand Down Expand Up @@ -94,6 +95,7 @@ fn parse() -> Result<(), Box<dyn Error>> {
allowed_capabilities: Capabilities::MODIFICATIONS,
compilation_flags: CompilationFlags::INJECT_OUTPUT_TYPE_NAMES,
implicit_limit: Some(77),
input_language: InputLanguage::EdgeQL,
output_format: IoFormat::Binary,
expected_cardinality: Cardinality::AtMostOne,
command_text: String::from("SELECT 1;"),
Expand All @@ -108,6 +110,31 @@ fn parse() -> Result<(), Box<dyn Error>> {
Ok(())
}

#[test]
fn parse3() -> Result<(), Box<dyn Error>> {
encoding_eq_ver!(
3,
0,
ClientMessage::Parse(Parse {
annotations: HashMap::new(),
allowed_capabilities: Capabilities::MODIFICATIONS,
compilation_flags: CompilationFlags::INJECT_OUTPUT_TYPE_NAMES,
implicit_limit: Some(77),
input_language: InputLanguage::EdgeQL,
output_format: IoFormat::Binary,
expected_cardinality: Cardinality::AtMostOne,
command_text: String::from("SELECT 1;"),
state: State {
typedesc_id: Uuid::from_u128(0),
data: Bytes::from(""),
},
}),
b"P\0\0\0B\0\0\0\0\0\0\0\0\0\x01\0\0\0\0\0\0\0\x02\0\0\0\0\0\0\0MEbo\
\0\0\0\tSELECT 1;\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"
);
Ok(())
}

#[test]
fn describe_statement() -> Result<(), Box<dyn Error>> {
encoding_eq!(
Expand Down Expand Up @@ -146,6 +173,7 @@ fn execute1() -> Result<(), Box<dyn Error>> {
allowed_capabilities: Capabilities::MODIFICATIONS,
compilation_flags: CompilationFlags::INJECT_OUTPUT_TYPE_NAMES,
implicit_limit: Some(77),
input_language: InputLanguage::EdgeQL,
output_format: IoFormat::Binary,
expected_cardinality: Cardinality::AtMostOne,
command_text: String::from("SELECT 1;"),
Expand All @@ -165,6 +193,36 @@ fn execute1() -> Result<(), Box<dyn Error>> {
Ok(())
}

#[test]
fn execute3() -> Result<(), Box<dyn Error>> {
encoding_eq_ver!(
3,
0,
ClientMessage::Execute1(Execute1 {
annotations: HashMap::new(),
allowed_capabilities: Capabilities::MODIFICATIONS,
compilation_flags: CompilationFlags::INJECT_OUTPUT_TYPE_NAMES,
implicit_limit: Some(77),
input_language: InputLanguage::EdgeQL,
output_format: IoFormat::Binary,
expected_cardinality: Cardinality::AtMostOne,
command_text: String::from("SELECT 1;"),
state: State {
typedesc_id: Uuid::from_u128(0),
data: Bytes::from(""),
},
input_typedesc_id: Uuid::from_u128(123),
output_typedesc_id: Uuid::from_u128(456),
arguments: Bytes::new(),
}),
b"O\0\0\0f\0\0\0\0\0\0\0\0\0\x01\0\0\0\0\0\0\0\x02\0\0\0\0\0\0\0MEbo\
\0\0\0\tSELECT 1;\
\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\
\0\0\0{\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\xc8\0\0\0\0"
);
Ok(())
}

#[test]
fn optimistic_execute() -> Result<(), Box<dyn Error>> {
encoding_eq_ver!(
Expand Down
9 changes: 8 additions & 1 deletion edgedb-tokio/src/raw/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use edgedb_protocol::client_message::{ClientMessage, Parse, Prepare};
use edgedb_protocol::client_message::{DescribeAspect, DescribeStatement};
use edgedb_protocol::client_message::{Execute0, Execute1};
use edgedb_protocol::common::CompilationOptions;
use edgedb_protocol::common::{Capabilities, Cardinality, IoFormat};
use edgedb_protocol::common::{Capabilities, Cardinality, InputLanguage, IoFormat};
use edgedb_protocol::descriptors::Typedesc;
use edgedb_protocol::features::ProtocolVersion;
use edgedb_protocol::model::Uuid;
Expand Down Expand Up @@ -234,6 +234,7 @@ impl Connection {
allowed_capabilities: opts.allow_capabilities,
compilation_flags: opts.flags(),
implicit_limit: opts.implicit_limit,
input_language: opts.input_language,
output_format: opts.io_format,
expected_cardinality: opts.expected_cardinality,
command_text: query.into(),
Expand Down Expand Up @@ -367,6 +368,7 @@ impl Connection {
allowed_capabilities: opts.allow_capabilities,
compilation_flags: opts.flags(),
implicit_limit: opts.implicit_limit,
input_language: opts.input_language,
output_format: opts.io_format,
expected_cardinality: opts.expected_cardinality,
command_text: query.into(),
Expand Down Expand Up @@ -422,6 +424,7 @@ impl Connection {
allowed_capabilities: opts.allow_capabilities,
compilation_flags: opts.flags(),
implicit_limit: opts.implicit_limit,
input_language: opts.input_language,
output_format: opts.io_format,
expected_cardinality: opts.expected_cardinality,
command_text: query.into(),
Expand Down Expand Up @@ -475,6 +478,7 @@ impl Connection {
allowed_capabilities: opts.allow_capabilities,
compilation_flags: opts.flags(),
implicit_limit: opts.implicit_limit,
input_language: opts.input_language,
output_format: opts.io_format,
expected_cardinality: opts.expected_cardinality,
command_text: query.into(),
Expand Down Expand Up @@ -575,6 +579,7 @@ impl Connection {
explicit_objectids: true,
allow_capabilities,
io_format,
input_language: InputLanguage::EdgeQL,
expected_cardinality: cardinality,
};
let desc = self.parse(&flags, query, state).await?;
Expand Down Expand Up @@ -631,6 +636,7 @@ impl Connection {
implicit_typeids: false,
explicit_objectids: true,
allow_capabilities,
input_language: InputLanguage::EdgeQL,
io_format: IoFormat::Binary,
expected_cardinality: Cardinality::Many,
};
Expand Down Expand Up @@ -686,6 +692,7 @@ impl PoolConnection {
implicit_typeids: false,
explicit_objectids: false,
allow_capabilities: Capabilities::ALL,
input_language: InputLanguage::EdgeQL,
io_format: IoFormat::Binary,
expected_cardinality: Cardinality::Many, // no result is unsupported
};
Expand Down
3 changes: 2 additions & 1 deletion edgedb-tokio/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;

use bytes::BytesMut;
use edgedb_protocol::common::CompilationOptions;
use edgedb_protocol::common::{Capabilities, Cardinality, IoFormat};
use edgedb_protocol::common::{Capabilities, Cardinality, InputLanguage, IoFormat};
use edgedb_protocol::model::Json;
use edgedb_protocol::query_arg::{Encoder, QueryArgs};
use edgedb_protocol::QueryResult;
Expand Down Expand Up @@ -385,6 +385,7 @@ impl Transaction {
implicit_typeids: false,
explicit_objectids: true,
allow_capabilities: Capabilities::MODIFICATIONS,
input_language: InputLanguage::EdgeQL,
io_format: IoFormat::Binary,
expected_cardinality: Cardinality::Many,
};
Expand Down
3 changes: 2 additions & 1 deletion edgedb-tokio/tests/func/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use bytes::Bytes;

use edgedb_protocol::common::Capabilities;
use edgedb_protocol::common::{Cardinality, CompilationOptions, IoFormat};
use edgedb_protocol::common::{Cardinality, CompilationOptions, InputLanguage, IoFormat};
use edgedb_tokio::raw::{Pool, PoolState};

use crate::server::SERVER;
Expand All @@ -21,6 +21,7 @@ async fn poll_connect() -> anyhow::Result<()> {
implicit_typeids: false,
allow_capabilities: Capabilities::empty(),
explicit_objectids: true,
input_language: InputLanguage::EdgeQL,
io_format: IoFormat::Binary,
expected_cardinality: Cardinality::Many,
};
Expand Down
Loading