Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text format result support #961

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions postgres/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,24 @@ impl Client {
Config::new()
}

/// Return the result format of client
///
/// true indicates that the client will receive the result in binary format
/// false indicates that the client will receive the result in text format
pub fn result_format(&self) -> bool {
self.client.result_format()
}

/// Set the format of return result.
///
/// format
/// true: binary format
/// false: text format
/// default format is binary format(result_format = true)
pub fn set_result_format(&mut self, format: bool) {
self.client.set_result_format(format);
}

/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
Expand Down
33 changes: 33 additions & 0 deletions postgres/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,36 @@ fn check_send() {
is_send::<Statement>();
is_send::<Transaction<'_>>();
}

#[test]
fn query_text() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
client.set_result_format(false);

let rows = client.query("SELECT $1::TEXT", &[&"hello"]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get_text(0).unwrap(), "hello");

let rows = client.query("SELECT 2,'2022-01-01'::date", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get_text(0).unwrap(), "2");
assert_eq!(rows[0].get_text(1).unwrap(), "2022-01-01");
}

#[test]
fn transaction_text() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
client.set_result_format(false);

let mut transaction = client.transaction().unwrap();

let prepare_stmt = transaction.prepare("SELECT $1::INT8,$2::FLOAT4").unwrap();
let portal = transaction
.bind(&prepare_stmt, &[&64_i64, &3.9999_f32])
.unwrap();
let rows = transaction.query_portal(&portal, 0).unwrap();

assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get_text(0).unwrap(), "64");
assert_eq!(rows[0].get_text(1).unwrap(), "3.9999");
}
3 changes: 2 additions & 1 deletion tokio-postgres/src/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub async fn bind<P, I>(
client: &Arc<InnerClient>,
statement: Statement,
params: I,
result_format: bool,
) -> Result<Portal, Error>
where
P: BorrowToSql,
Expand All @@ -22,7 +23,7 @@ where
{
let name = format!("p{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
let buf = client.with_buf(|buf| {
query::encode_bind(&statement, params, &name, buf)?;
query::encode_bind(&statement, params, &name, buf, result_format)?;
frontend::sync(buf);
Ok(buf.split().freeze())
})?;
Expand Down
22 changes: 21 additions & 1 deletion tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ pub struct Client {
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
result_format: bool,
}

impl Client {
Expand All @@ -190,6 +191,7 @@ impl Client {
ssl_mode,
process_id,
secret_key,
result_format: true,
}
}

Expand All @@ -202,6 +204,24 @@ impl Client {
self.socket_config = Some(socket_config);
}

/// Return the result format of client
///
/// true indicates that the client will receive the result in binary format
/// false indicates that the client will receive the result in text format
pub fn result_format(&self) -> bool {
self.result_format
}

/// Set the format of return result.
///
/// format
/// true: binary format
/// false: text format
/// default format is binary format(result_format = true)
pub fn set_result_format(&mut self, format: bool) {
self.result_format = format;
}

/// Creates a new prepared statement.
///
/// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc),
Expand Down Expand Up @@ -369,7 +389,7 @@ impl Client {
I::IntoIter: ExactSizeIterator,
{
let statement = statement.__convert().into_statement(self).await?;
query::query(&self.inner, statement, params).await
query::query(&self.inner, statement, params, self.result_format).await
}

/// Executes a statement, returning the number of rows modified.
Expand Down
4 changes: 2 additions & 2 deletions tokio-postgres/src/copy_in.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::{query, slice_iter, Error, Statement};
use crate::{query, slice_iter, Error, Statement, DEFAULT_RESULT_FORMAT};
use bytes::{Buf, BufMut, BytesMut};
use futures_channel::mpsc;
use futures_util::{future, ready, Sink, SinkExt, Stream, StreamExt};
Expand Down Expand Up @@ -200,7 +200,7 @@ where
{
debug!("executing copy in statement {}", statement.name());

let buf = query::encode(client, &statement, slice_iter(&[]))?;
let buf = query::encode(client, &statement, slice_iter(&[]), DEFAULT_RESULT_FORMAT)?;

let (mut sender, receiver) = mpsc::channel(1);
let receiver = CopyInReceiver::new(receiver);
Expand Down
4 changes: 2 additions & 2 deletions tokio-postgres/src/copy_out.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::{query, slice_iter, Error, Statement};
use crate::{query, slice_iter, Error, Statement, DEFAULT_RESULT_FORMAT};
use bytes::Bytes;
use futures_util::{ready, Stream};
use log::debug;
Expand All @@ -14,7 +14,7 @@ use std::task::{Context, Poll};
pub async fn copy_out(client: &InnerClient, statement: Statement) -> Result<CopyOutStream, Error> {
debug!("executing copy out statement {}", statement.name());

let buf = query::encode(client, &statement, slice_iter(&[]))?;
let buf = query::encode(client, &statement, slice_iter(&[]), DEFAULT_RESULT_FORMAT)?;
let responses = start(client, buf).await?;
Ok(CopyOutStream {
responses,
Expand Down
3 changes: 3 additions & 0 deletions tokio-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ mod transaction;
mod transaction_builder;
pub mod types;

// Default result format : binary(true)
const DEFAULT_RESULT_FORMAT: bool = true;

/// A convenience function which parses a connection string and connects to the database.
///
/// See the documentation for [`Config`] for details on the connection string format.
Expand Down
8 changes: 4 additions & 4 deletions tokio-postgres/src/prepare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::error::SqlState;
use crate::types::{Field, Kind, Oid, Type};
use crate::{query, slice_iter};
use crate::{query, slice_iter, DEFAULT_RESULT_FORMAT};
use crate::{Column, Error, Statement};
use bytes::Bytes;
use fallible_iterator::FallibleIterator;
Expand Down Expand Up @@ -137,7 +137,7 @@ async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {

let stmt = typeinfo_statement(client).await?;

let rows = query::query(client, stmt, slice_iter(&[&oid])).await?;
let rows = query::query(client, stmt, slice_iter(&[&oid]), DEFAULT_RESULT_FORMAT).await?;
pin_mut!(rows);

let row = match rows.try_next().await? {
Expand Down Expand Up @@ -207,7 +207,7 @@ async fn typeinfo_statement(client: &Arc<InnerClient>) -> Result<Statement, Erro
async fn get_enum_variants(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<String>, Error> {
let stmt = typeinfo_enum_statement(client).await?;

query::query(client, stmt, slice_iter(&[&oid]))
query::query(client, stmt, slice_iter(&[&oid]), DEFAULT_RESULT_FORMAT)
.await?
.and_then(|row| async move { row.try_get(0) })
.try_collect()
Expand All @@ -234,7 +234,7 @@ async fn typeinfo_enum_statement(client: &Arc<InnerClient>) -> Result<Statement,
async fn get_composite_fields(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<Field>, Error> {
let stmt = typeinfo_composite_statement(client).await?;

let rows = query::query(client, stmt, slice_iter(&[&oid]))
let rows = query::query(client, stmt, slice_iter(&[&oid]), DEFAULT_RESULT_FORMAT)
.await?
.try_collect::<Vec<_>>()
.await?;
Expand Down
24 changes: 16 additions & 8 deletions tokio-postgres/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::types::{BorrowToSql, IsNull};
use crate::{Error, Portal, Row, Statement};
use crate::{Error, Portal, Row, Statement, DEFAULT_RESULT_FORMAT};
use bytes::{Bytes, BytesMut};
use futures_util::{ready, Stream};
use log::{debug, log_enabled, Level};
Expand Down Expand Up @@ -31,6 +31,7 @@ pub async fn query<P, I>(
client: &InnerClient,
statement: Statement,
params: I,
result_format: bool,
) -> Result<RowStream, Error>
where
P: BorrowToSql,
Expand All @@ -44,9 +45,9 @@ where
statement.name(),
BorrowToSqlParamsDebug(params.as_slice()),
);
encode(client, &statement, params)?
encode(client, &statement, params, result_format)?
} else {
encode(client, &statement, params)?
encode(client, &statement, params, result_format)?
};
let responses = start(client, buf).await?;
Ok(RowStream {
Expand Down Expand Up @@ -93,9 +94,9 @@ where
statement.name(),
BorrowToSqlParamsDebug(params.as_slice()),
);
encode(client, &statement, params)?
encode(client, &statement, params, DEFAULT_RESULT_FORMAT)?
} else {
encode(client, &statement, params)?
encode(client, &statement, params, DEFAULT_RESULT_FORMAT)?
};
let mut responses = start(client, buf).await?;

Expand Down Expand Up @@ -131,14 +132,19 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
Ok(responses)
}

pub fn encode<P, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
pub fn encode<P, I>(
client: &InnerClient,
statement: &Statement,
params: I,
result_format: bool,
) -> Result<Bytes, Error>
where
P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator,
{
client.with_buf(|buf| {
encode_bind(statement, params, "", buf)?;
encode_bind(statement, params, "", buf, result_format)?;
frontend::execute("", 0, buf).map_err(Error::encode)?;
frontend::sync(buf);
Ok(buf.split().freeze())
Expand All @@ -150,6 +156,7 @@ pub fn encode_bind<P, I>(
params: I,
portal: &str,
buf: &mut BytesMut,
result_format: bool,
) -> Result<(), Error>
where
P: BorrowToSql,
Expand All @@ -174,6 +181,7 @@ where
let params = params.into_iter();

let mut error_idx = 0;
let result_format = if result_format { Some(1) } else { Some(0) };
let r = frontend::bind(
portal,
statement.name(),
Expand All @@ -187,7 +195,7 @@ where
Err(e)
}
},
Some(1),
result_format,
buf,
);
match r {
Expand Down
40 changes: 40 additions & 0 deletions tokio-postgres/src/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,46 @@ impl Row {
FromSql::from_sql_nullable(ty, self.col_buffer(idx)).map_err(|e| Error::from_sql(e, idx))
}

/// Returns a value(text format) from the row.
///
/// The value can be specified either by its numeric index in the row, or by its column name.
///
/// NOTE: user should gurantee the result is text format
///
/// # Panics
///
/// Panics if the index is out of bounds or if the value cannot be converted to the TEXT type.
pub fn get_text<I>(&self, idx: I) -> Option<&str>
where
I: RowIndex + fmt::Display,
{
match self.get_text_inner(&idx) {
Ok(ok) => ok,
Err(err) => panic!("error retrieving column {}: {}", idx, err),
}
}

/// Like `Row::get_text`, but returns a `Result` rather than panicking.
pub fn try_get_text<I>(&self, idx: I) -> Result<Option<&str>, Error>
where
I: RowIndex + fmt::Display,
{
self.get_text_inner(&idx)
}

fn get_text_inner<I>(&self, idx: &I) -> Result<Option<&str>, Error>
where
I: RowIndex + fmt::Display,
{
let idx = match idx.__idx(self.columns()) {
Some(idx) => idx,
None => return Err(Error::column(idx.to_string())),
};

let buf = self.ranges[idx].clone().map(|r| &self.body.buffer()[r]);
FromSql::from_sql_nullable(&Type::TEXT, buf).map_err(|e| Error::from_sql(e, idx))
}

/// Get the raw bytes for the column at the given index.
fn col_buffer(&self, idx: usize) -> Option<&[u8]> {
let range = self.ranges[idx].to_owned()?;
Expand Down
6 changes: 5 additions & 1 deletion tokio-postgres/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub struct Transaction<'a> {
client: &'a mut Client,
savepoint: Option<Savepoint>,
done: bool,
result_format: bool,
}

/// A representation of a PostgreSQL database savepoint.
Expand Down Expand Up @@ -57,10 +58,12 @@ impl<'a> Drop for Transaction<'a> {

impl<'a> Transaction<'a> {
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
let result_format = client.result_format();
Transaction {
client,
savepoint: None,
done: false,
result_format,
}
}

Expand Down Expand Up @@ -202,7 +205,7 @@ impl<'a> Transaction<'a> {
I::IntoIter: ExactSizeIterator,
{
let statement = statement.__convert().into_statement(self.client).await?;
bind::bind(self.client.inner(), statement, params).await
bind::bind(self.client.inner(), statement, params, self.result_format).await
}

/// Continues execution of a portal, returning a stream of the resulting rows.
Expand Down Expand Up @@ -304,6 +307,7 @@ impl<'a> Transaction<'a> {
client: self.client,
savepoint: Some(Savepoint { name, depth }),
done: false,
result_format: self.result_format,
})
}

Expand Down