Skip to content

Commit

Permalink
fix postgres bytes decoding (encode does not support bytea[])
Browse files Browse the repository at this point in the history
  • Loading branch information
Weakky committed Jan 16, 2024
1 parent 4d8c52c commit e7af6ad
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 136 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions psl/psl-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ serde_json.workspace = true
enumflags2 = "0.7"
indoc.workspace = true
either = "1.8.1"
hex = "0.4"

# For the connector API.
lsp-types = "0.91.1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,23 @@ impl Connector for CockroachDatamodelConnector {
None => self.parse_json_datetime(str, self.default_native_type_for_scalar_type(&ScalarType::DateTime)),
}
}

fn parse_json_bytes(&self, str: &str, nt: Option<NativeTypeInstance>) -> prisma_value::PrismaValueResult<Vec<u8>> {
let native_type: Option<&CockroachType> = nt.as_ref().map(|nt| nt.downcast_ref());

match native_type {
Some(ct) => match ct {
CockroachType::Bytes => {
super::utils::postgres::parse_bytes(str).map_err(|_| prisma_value::ConversionFailure {
from: "hex".into(),
to: "bytes".into(),
})
}
_ => unreachable!(),
},
None => self.parse_json_bytes(str, self.default_native_type_for_scalar_type(&ScalarType::Bytes)),
}
}
}

/// An `@default(sequence())` function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod validations;

use chrono::FixedOffset;
pub use native_types::MySqlType;
use prisma_value::{decode_bytes, PrismaValueResult};

use super::completions;
use crate::{
Expand Down Expand Up @@ -306,4 +307,9 @@ impl Connector for MySqlDatamodelConnector {
None => self.parse_json_datetime(str, self.default_native_type_for_scalar_type(&ScalarType::DateTime)),
}
}

// On MySQL, bytes are encoded as base64 in the database directly.
fn parse_json_bytes(&self, str: &str, _nt: Option<NativeTypeInstance>) -> PrismaValueResult<Vec<u8>> {
decode_bytes(str)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,23 @@ impl Connector for PostgresDatamodelConnector {
None => self.parse_json_datetime(str, self.default_native_type_for_scalar_type(&ScalarType::DateTime)),
}
}

fn parse_json_bytes(&self, str: &str, nt: Option<NativeTypeInstance>) -> prisma_value::PrismaValueResult<Vec<u8>> {
let native_type: Option<&PostgresType> = nt.as_ref().map(|nt| nt.downcast_ref());

match native_type {
Some(ct) => match ct {
PostgresType::ByteA => {
super::utils::postgres::parse_bytes(str).map_err(|_| prisma_value::ConversionFailure {
from: "hex".into(),
to: "bytes".into(),
})
}
_ => unreachable!(),
},
None => self.parse_json_bytes(str, self.default_native_type_for_scalar_type(&ScalarType::Bytes)),
}
}
}

fn allowed_index_operator_classes(algo: IndexAlgorithm, field: walkers::ScalarFieldWalker<'_>) -> Vec<OperatorClass> {
Expand Down
4 changes: 4 additions & 0 deletions psl/psl-core/src/builtin_connectors/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ pub(crate) mod postgres {

super::common::parse_time(time_without_tz)
}

pub(crate) fn parse_bytes(str: &str) -> Result<Vec<u8>, hex::FromHexError> {
hex::decode(&str[2..])
}
}

pub(crate) mod mysql {
Expand Down
8 changes: 8 additions & 0 deletions psl/psl-core/src/datamodel_connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,14 @@ pub trait Connector: Send + Sync {
) -> chrono::ParseResult<DateTime<FixedOffset>> {
unreachable!("This method is only implemented on connectors with lateral join support.")
}

fn parse_json_bytes(
&self,
_str: &str,
_nt: Option<NativeTypeInstance>,
) -> prisma_value::PrismaValueResult<Vec<u8>> {
unreachable!("This method is only implemented on connectors with lateral join support.")
}
}

#[derive(Copy, Clone, Debug, PartialEq)]
Expand Down
7 changes: 0 additions & 7 deletions quaint/src/visitor/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,6 @@ impl<'a> Postgres<'a> {

Ok(())
}
(_, Some("BYTES") | Some("BYTEA")) => {
self.write("ENCODE(")?;
self.visit_expression(expr)?;
self.write(", 'base64')")?;

Ok(())
}
_ => self.visit_expression(expr),
},
_ => self.visit_expression(expr),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,115 +282,3 @@ mod bytes {
Ok(())
}
}

// Napi & Wasm DAs excluded because of a bytes bug
#[test_suite(
schema(schema),
only(Postgres("9", "10", "11", "12", "13", "14", "15", "pg.js", "neon.js"))
)]
mod others {
fn schema_other_types() -> String {
let schema = indoc! {
r#"
model Parent {
#id(id, Int, @id)
childId Int? @unique
child Child? @relation(fields: [childId], references: [id])
}
model Child {
#id(id, Int, @id)
bool Boolean @test.Boolean
byteA Bytes @test.ByteA
json Json @test.Json
jsonb Json @test.JsonB
parent Parent?
}"#
};

schema.to_owned()
}

// "Other Postgres native types" should "work"
#[connector_test(schema(schema_other_types))]
async fn native_other_types(runner: Runner) -> TestResult<()> {
create_row(
&runner,
r#"{
id: 1,
child: {
create: {
id: 1,
bool: true
byteA: "dGVzdA=="
json: "{}"
jsonb: "{\"a\": \"b\"}"
}
}
}"#,
)
.await?;

insta::assert_snapshot!(
run_query!(&runner, r#"{ findManyParent { id child { id bool byteA json jsonb } } }"#),
@r###"{"data":{"findManyParent":[{"id":1,"child":{"id":1,"bool":true,"byteA":"dGVzdA==","json":"{}","jsonb":"{\"a\":\"b\"}"}}]}}"###
);

Ok(())
}

fn schema_xml() -> String {
let schema = indoc! {
r#"
model Parent {
#id(id, Int, @id)
childId Int? @unique
child Child? @relation(fields: [childId], references: [id])
}
model Child {
#id(id, Int, @id)
xml String @test.Xml
parent Parent?
}"#
};

schema.to_owned()
}

#[connector_test(schema(schema_xml), only(Postgres))]
async fn native_xml(runner: Runner) -> TestResult<()> {
create_row(
&runner,
r#"{
id: 1,
child: {
create: {
id: 1,
xml: "<salad>wurst</salad>"
}
}
}"#,
)
.await?;

insta::assert_snapshot!(
run_query!(&runner, r#"{ findManyParent { id child { xml } } }"#),
@r###"{"data":{"findManyParent":[{"id":1,"child":{"xml":"<salad>wurst</salad>"}}]}}"###
);

Ok(())
}

async fn create_row(runner: &Runner, data: &str) -> TestResult<()> {
runner
.query(format!("mutation {{ createOneParent(data: {}) {{ id }} }}", data))
.await?
.assert_success();
Ok(())
}
}
1 change: 0 additions & 1 deletion query-engine/connectors/sql-query-connector/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ uuid.workspace = true
opentelemetry = { version = "0.17", features = ["tokio"] }
tracing-opentelemetry = "0.17.3"
cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" }
hex = "0.4"

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
quaint.workspace = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ pub(crate) fn coerce_json_scalar_to_pv(value: serde_json::Value, sf: &ScalarFiel
)
})?)),
TypeIdentifier::Bytes => {
let bytes = decode_bytes(&s).map_err(|err| {
let bytes = sf.parse_json_bytes(&s).map_err(|err| {
build_conversion_error_with_reason(
sf,
&format!("String({s})"),
Expand Down
29 changes: 15 additions & 14 deletions query-engine/query-structure/src/field/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,27 +156,21 @@ impl ScalarField {

pub fn native_type(&self) -> Option<NativeTypeInstance> {
let connector = self.dm.schema.connector;

let raw_nt = match self.id {
ScalarFieldId::InModel(id) => self.dm.walk(id).raw_native_type(),
ScalarFieldId::InCompositeType(id) => self.dm.walk(id).raw_native_type(),
};

let psl_nt = raw_nt
.and_then(|(_, name, args, span)| connector.parse_native_type(name, args, span, &mut Default::default()));

let nt = match self.id {
ScalarFieldId::InModel(id) => psl_nt.or_else(|| {
self.dm
.walk(id)
.scalar_type()
.and_then(|st| connector.default_native_type_for_scalar_type(&st))
}),
ScalarFieldId::InCompositeType(id) => psl_nt.or_else(|| {
self.dm
.walk(id)
.scalar_type()
.and_then(|st| connector.default_native_type_for_scalar_type(&st))
}),
}?;
let scalar_type = match self.id {
ScalarFieldId::InModel(id) => self.dm.walk(id).scalar_type(),
ScalarFieldId::InCompositeType(id) => self.dm.walk(id).scalar_type(),
};

let nt = psl_nt.or_else(|| scalar_type.and_then(|st| connector.default_native_type_for_scalar_type(&st)))?;

Some(NativeTypeInstance {
native_type: nt,
Expand All @@ -191,6 +185,13 @@ impl ScalarField {
connector.parse_json_datetime(value, nt)
}

pub fn parse_json_bytes(&self, value: &str) -> PrismaValueResult<Vec<u8>> {
let nt = self.native_type().map(|nt| nt.native_type);
let connector = self.dm.schema.connector;

connector.parse_json_bytes(value, nt)
}

pub fn is_autoincrement(&self) -> bool {
match self.id {
ScalarFieldId::InModel(id) => self.dm.walk(id).is_autoincrement(),
Expand Down

0 comments on commit e7af6ad

Please sign in to comment.