diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 73b169288..b11d7985e 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -22,6 +22,7 @@ pub const DATA_ROW_TAG: u8 = b'D'; pub const ERROR_RESPONSE_TAG: u8 = b'E'; pub const COPY_IN_RESPONSE_TAG: u8 = b'G'; pub const COPY_OUT_RESPONSE_TAG: u8 = b'H'; +pub const COPY_BOTH_RESPONSE_TAG: u8 = b'W'; pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I'; pub const BACKEND_KEY_DATA_TAG: u8 = b'K'; pub const NO_DATA_TAG: u8 = b'n'; @@ -93,6 +94,7 @@ pub enum Message { CopyDone, CopyInResponse(CopyInResponseBody), CopyOutResponse(CopyOutResponseBody), + CopyBothResponse(CopyBothResponseBody), DataRow(DataRowBody), EmptyQueryResponse, ErrorResponse(ErrorResponseBody), @@ -190,6 +192,16 @@ impl Message { storage, }) } + COPY_BOTH_RESPONSE_TAG => { + let format = buf.read_u8()?; + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::CopyBothResponse(CopyBothResponseBody { + format, + len, + storage, + }) + } EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, BACKEND_KEY_DATA_TAG => { let process_id = buf.read_i32::()?; @@ -524,6 +536,27 @@ impl CopyOutResponseBody { } } +pub struct CopyBothResponseBody { + format: u8, + len: u16, + storage: Bytes, +} + +impl CopyBothResponseBody { + #[inline] + pub fn format(&self) -> u8 { + self.format + } + + #[inline] + pub fn column_formats(&self) -> ColumnFormats<'_> { + ColumnFormats { + remaining: self.len, + buf: &self.storage, + } + } +} + #[derive(Debug, Clone)] pub struct DataRowBody { storage: Bytes, diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 92eabde36..48137fa36 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -1,6 +1,7 @@ use crate::codec::BackendMessages; use crate::config::SslMode; use crate::connection::{Request, RequestMessages}; +use crate::copy_both::CopyBothStream; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; @@ -13,8 +14,9 @@ use crate::types::{Oid, ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; use crate::{ - copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error, - Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, + copy_both, copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, + CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, + TransactionBuilder, }; use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; @@ -493,6 +495,11 @@ impl Client { copy_out::copy_out(self.inner(), statement).await } + /// Executes a copy both query, returning a stream of the resulting data. + pub async fn copy_both(&self, query: &str) -> Result { + copy_both::copy_both(self.inner(), query).await + } + /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. /// /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 62b45f793..f029734a9 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -84,6 +84,15 @@ pub enum Host { Unix(PathBuf), } +/// Connection replication mode. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ReplicationMode { + /// Logical replication. + Logical, + /// Physical replication. + Physical, +} + /// Connection configuration. /// /// Configuration can be parsed from libpq-style connection strings. These strings come in two formats: @@ -209,6 +218,7 @@ pub struct Config { pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, pub(crate) load_balance_hosts: LoadBalanceHosts, + pub(crate) replication_mode: Option, } impl Default for Config { @@ -242,6 +252,7 @@ impl Config { target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, load_balance_hosts: LoadBalanceHosts::Disable, + replication_mode: None, } } @@ -524,6 +535,17 @@ impl Config { self.load_balance_hosts } + /// Sets connection replication mode. + pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { + self.replication_mode = Some(replication_mode); + self + } + + /// Gets connection replication mode. + pub fn get_replication_mode(&self) -> Option<&ReplicationMode> { + self.replication_mode.as_ref() + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -660,6 +682,21 @@ impl Config { }; self.load_balance_hosts(load_balance_hosts); } + "replication" => { + let replication_mode = match value { + "database" => Some(ReplicationMode::Logical), + "true" => Some(ReplicationMode::Physical), + "off" => None, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "replication_mode", + )))) + } + }; + if let Some(replication_mode) = replication_mode { + self.replication_mode(replication_mode); + } + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), @@ -744,6 +781,7 @@ impl fmt::Debug for Config { config_dbg .field("target_session_attrs", &self.target_session_attrs) .field("channel_binding", &self.channel_binding) + .field("replication", &self.replication_mode) .finish() } } diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 19be9eb01..f3068382c 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -133,6 +133,11 @@ where if let Some(application_name) = &config.application_name { params.push(("application_name", &**application_name)); } + match config.replication_mode { + Some(config::ReplicationMode::Logical) => params.push(("replication", "database")), + Some(config::ReplicationMode::Physical) => params.push(("replication", "true")), + _ => {} + } let mut buf = BytesMut::new(); frontend::startup_message(params, &mut buf).map_err(Error::encode)?; diff --git a/tokio-postgres/src/copy_both.rs b/tokio-postgres/src/copy_both.rs new file mode 100644 index 000000000..f8f2bbbe2 --- /dev/null +++ b/tokio-postgres/src/copy_both.rs @@ -0,0 +1,56 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::{simple_query, Error}; +use bytes::Bytes; +use futures_util::{ready, Stream}; +use log::debug; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::Message; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub async fn copy_both(client: &InnerClient, query: &str) -> Result { + debug!("executing copy out statement {query}"); + + let buf = simple_query::encode(client, query)?; + let responses = start(client, buf).await?; + Ok(CopyBothStream { + responses, + _p: PhantomPinned, + }) +} + +async fn start(client: &InnerClient, buf: Bytes) -> Result { + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::CopyBothResponse(_) => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(responses) +} + +pin_project! { + pub struct CopyBothStream { + responses: Responses, + #[pin] + _p: PhantomPinned, + } +} + +impl Stream for CopyBothStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match ready!(this.responses.poll_next(cx)?) { + Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))), + Message::CopyDone => Poll::Ready(None), + _ => Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } +} diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index ec843d511..46fa81eb3 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -160,6 +160,7 @@ mod connect_raw; mod connect_socket; mod connect_tls; mod connection; +mod copy_both; mod copy_in; mod copy_out; pub mod error; diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index 24473b896..29a60ab8d 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -63,7 +63,7 @@ pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Erro } } -fn encode(client: &InnerClient, query: &str) -> Result { +pub fn encode(client: &InnerClient, query: &str) -> Result { client.with_buf(|buf| { frontend::query(query, buf).map_err(Error::encode)?; Ok(buf.split().freeze())