Skip to content

Commit

Permalink
Merge pull request #600 from nvanbenschoten/nvanbenschoten/savepoints…
Browse files Browse the repository at this point in the history
…-2-the-revival-album

Re-add savepoint method to Transaction
  • Loading branch information
sfackler authored May 1, 2020
2 parents e3d3c6d + 64d6e97 commit c6a6686
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 16 deletions.
51 changes: 51 additions & 0 deletions postgres/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,57 @@ fn nested_transactions() {
assert_eq!(rows[2].get::<_, i32>(0), 4);
}

#[test]
fn savepoints() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();

client
.batch_execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)")
.unwrap();

let mut transaction = client.transaction().unwrap();

transaction
.execute("INSERT INTO foo (id) VALUES (1)", &[])
.unwrap();

let mut savepoint1 = transaction.savepoint("savepoint1").unwrap();

savepoint1
.execute("INSERT INTO foo (id) VALUES (2)", &[])
.unwrap();

savepoint1.rollback().unwrap();

let rows = transaction
.query("SELECT id FROM foo ORDER BY id", &[])
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, i32>(0), 1);

let mut savepoint2 = transaction.savepoint("savepoint2").unwrap();

savepoint2
.execute("INSERT INTO foo (id) VALUES(3)", &[])
.unwrap();

let mut savepoint3 = savepoint2.savepoint("savepoint3").unwrap();

savepoint3
.execute("INSERT INTO foo (id) VALUES(4)", &[])
.unwrap();

savepoint3.commit().unwrap();
savepoint2.commit().unwrap();
transaction.commit().unwrap();

let rows = client.query("SELECT id FROM foo ORDER BY id", &[]).unwrap();
assert_eq!(rows.len(), 3);
assert_eq!(rows[0].get::<_, i32>(0), 1);
assert_eq!(rows[1].get::<_, i32>(0), 3);
assert_eq!(rows[2].get::<_, i32>(0), 4);
}

#[test]
fn copy_in() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
Expand Down
13 changes: 12 additions & 1 deletion postgres/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,23 @@ impl<'a> Transaction<'a> {
CancelToken::new(self.transaction.cancel_token())
}

/// Like `Client::transaction`.
/// Like `Client::transaction`, but creates a nested transaction via a savepoint.
pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
let transaction = self.connection.block_on(self.transaction.transaction())?;
Ok(Transaction {
connection: self.connection.as_ref(),
transaction,
})
}
/// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name.
pub fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
where
I: Into<String>,
{
let transaction = self.connection.block_on(self.transaction.savepoint(name))?;
Ok(Transaction {
connection: self.connection.as_ref(),
transaction,
})
}
}
49 changes: 34 additions & 15 deletions tokio-postgres/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,26 @@ use tokio::io::{AsyncRead, AsyncWrite};
/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
pub struct Transaction<'a> {
client: &'a mut Client,
depth: u32,
savepoint: Option<Savepoint>,
done: bool,
}

/// A representation of a PostgreSQL database savepoint.
struct Savepoint {
name: String,
depth: u32,
}

impl<'a> Drop for Transaction<'a> {
fn drop(&mut self) {
if self.done {
return;
}

let query = if self.depth == 0 {
"ROLLBACK".to_string()
let query = if let Some(sp) = self.savepoint.as_ref() {
format!("ROLLBACK TO {}", sp.name)
} else {
format!("ROLLBACK TO sp{}", self.depth)
"ROLLBACK".to_string()
};
let buf = self.client.inner().with_buf(|buf| {
frontend::query(&query, buf).unwrap();
Expand All @@ -53,18 +59,18 @@ impl<'a> Transaction<'a> {
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
Transaction {
client,
depth: 0,
savepoint: None,
done: false,
}
}

/// Consumes the transaction, committing all changes made within it.
pub async fn commit(mut self) -> Result<(), Error> {
self.done = true;
let query = if self.depth == 0 {
"COMMIT".to_string()
let query = if let Some(sp) = self.savepoint.as_ref() {
format!("RELEASE {}", sp.name)
} else {
format!("RELEASE sp{}", self.depth)
"COMMIT".to_string()
};
self.client.batch_execute(&query).await
}
Expand All @@ -74,10 +80,10 @@ impl<'a> Transaction<'a> {
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
pub async fn rollback(mut self) -> Result<(), Error> {
self.done = true;
let query = if self.depth == 0 {
"ROLLBACK".to_string()
let query = if let Some(sp) = self.savepoint.as_ref() {
format!("ROLLBACK TO {}", sp.name)
} else {
format!("ROLLBACK TO sp{}", self.depth)
"ROLLBACK".to_string()
};
self.client.batch_execute(&query).await
}
Expand Down Expand Up @@ -272,15 +278,28 @@ impl<'a> Transaction<'a> {
self.client.cancel_query_raw(stream, tls).await
}

/// Like `Client::transaction`.
/// Like `Client::transaction`, but creates a nested transaction via a savepoint.
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
let depth = self.depth + 1;
let query = format!("SAVEPOINT sp{}", depth);
self._savepoint(None).await
}

/// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name.
pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
where
I: Into<String>,
{
self._savepoint(Some(name.into())).await
}

async fn _savepoint(&mut self, name: Option<String>) -> Result<Transaction<'_>, Error> {
let depth = self.savepoint.as_ref().map_or(0, |sp| sp.depth) + 1;
let name = name.unwrap_or_else(|| format!("sp_{}", depth));
let query = format!("SAVEPOINT {}", name);
self.batch_execute(&query).await?;

Ok(Transaction {
client: self.client,
depth,
savepoint: Some(Savepoint { name, depth }),
done: false,
})
}
Expand Down

0 comments on commit c6a6686

Please sign in to comment.