Skip to content

Commit

Permalink
pgsql: track transaction progress per direction
Browse files Browse the repository at this point in the history
PGSQL's current implementation tracks the transaction progress without
taking into consideration flow direction, and also has indirections
that make it harder to understand how the progress is tracked, as well
as when a request or response is actually complete.

This patch introduces tracking such progress per direction and adds
completion status per direction, too. This will help when triggering
raw stream reassembly or for unidirectional transactions, and may be
useful when we implement sub-protocols that can have multiple requests
per transaction, as well.

CancelRequests and TerminationRequests are examples of unidirectional
transactions. There won't be any responses to those requests, so we can
also mark the response side as done, and set their transactions as
completed.

Bug #7113
  • Loading branch information
jufajardini authored and victorjulien committed Sep 20, 2024
1 parent 2c7824a commit dcccbb1
Showing 1 changed file with 121 additions and 60 deletions.
181 changes: 121 additions & 60 deletions rust/src/pgsql/pgsql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,18 @@ static mut PGSQL_MAX_TX: usize = 1024;

#[repr(u8)]
#[derive(Copy, Clone, PartialOrd, PartialEq, Eq, Debug)]
pub enum PgsqlTransactionState {
Init = 0,
RequestReceived,
ResponseDone,
FlushedOut,
pub enum PgsqlTxProgress {
TxInit = 0,
TxReceived,
TxDone,
TxFlushedOut,
}

#[derive(Debug)]
pub struct PgsqlTransaction {
pub tx_id: u64,
pub tx_state: PgsqlTransactionState,
pub tx_req_state: PgsqlTxProgress,
pub tx_res_state: PgsqlTxProgress,
pub request: Option<PgsqlFEMessage>,
pub responses: Vec<PgsqlBEMessage>,

Expand All @@ -72,7 +73,8 @@ impl PgsqlTransaction {
pub fn new() -> Self {
Self {
tx_id: 0,
tx_state: PgsqlTransactionState::Init,
tx_req_state: PgsqlTxProgress::TxInit,
tx_res_state: PgsqlTxProgress::TxInit,
request: None,
responses: Vec::<PgsqlBEMessage>::new(),
data_row_cnt: 0,
Expand Down Expand Up @@ -205,8 +207,11 @@ impl PgsqlState {
let mut index = self.tx_index_completed;
for tx_old in &mut self.transactions.range_mut(self.tx_index_completed..) {
index += 1;
if tx_old.tx_state < PgsqlTransactionState::ResponseDone {
tx_old.tx_state = PgsqlTransactionState::FlushedOut;
if tx_old.tx_res_state < PgsqlTxProgress::TxDone {
// we don't check for TxReqDone for the majority of requests are basically completed
// when they're parsed, as of now
tx_old.tx_req_state = PgsqlTxProgress::TxFlushedOut;
tx_old.tx_res_state = PgsqlTxProgress::TxFlushedOut;
//TODO set event
break;
}
Expand Down Expand Up @@ -242,26 +247,6 @@ impl PgsqlState {
return self.transactions.back_mut();
}

/// Process State progress to decide if PgsqlTransaction is finished
///
/// As Pgsql transactions are bidirectional and may be comprised of several
/// responses, we must track State progress to decide on tx completion
fn is_tx_completed(&self) -> bool {
if let PgsqlStateProgress::ReadyForQueryReceived
| PgsqlStateProgress::SSLRejectedReceived
| PgsqlStateProgress::SimpleAuthenticationReceived
| PgsqlStateProgress::SASLAuthenticationReceived
| PgsqlStateProgress::SASLAuthenticationContinueReceived
| PgsqlStateProgress::SASLAuthenticationFinalReceived
| PgsqlStateProgress::ConnectionTerminated
| PgsqlStateProgress::Finished = self.state_progress
{
true
} else {
false
}
}

/// Define PgsqlState progression, based on the request received
///
/// As PostgreSQL transactions can have multiple messages, State progression
Expand Down Expand Up @@ -315,6 +300,22 @@ impl PgsqlState {
}
}

/// Process State progress to decide if request is finished
///
fn request_is_complete(state: PgsqlStateProgress) -> bool {
match state {
PgsqlStateProgress::SSLRequestReceived
| PgsqlStateProgress::StartupMessageReceived
| PgsqlStateProgress::SimpleQueryReceived
| PgsqlStateProgress::PasswordMessageReceived
| PgsqlStateProgress::SASLInitialResponseReceived
| PgsqlStateProgress::SASLResponseReceived
| PgsqlStateProgress::CancelRequestReceived
| PgsqlStateProgress::ConnectionTerminated => true,
_ => false,
}
}

fn parse_request(&mut self, flow: *const Flow, input: &[u8]) -> AppLayerResult {
// We're not interested in empty requests.
if input.is_empty() {
Expand Down Expand Up @@ -348,14 +349,33 @@ impl PgsqlState {
Direction::ToServer as i32,
);
start = rem;
if let Some(state) = PgsqlState::request_next_state(&request) {
let new_state = PgsqlState::request_next_state(&request);

if let Some(state) = new_state {
self.state_progress = state;
};
let tx_completed = self.is_tx_completed();
// PostreSQL progress states can be represented as a finite state machine
// After the connection phase, the backend/ server will be mostly waiting in a state of `ReadyForQuery`, unless
// it's processing some request.
// When the frontend wants to cancel a request, it will send a CancelRequest message over a new connection - to
// which there won't be any responses.
// If the frontend wants to terminate the connection, the backend won't send any confirmation after receiving a
// Terminate request.
// A simplified finite state machine for PostgreSQL v3 can be found at:
// https://samadhiweb.com/blog/2013.04.28.graphviz.postgresv3.html
if let Some(tx) = self.find_or_create_tx() {
tx.request = Some(request);
if tx_completed {
tx.tx_state = PgsqlTransactionState::ResponseDone;
if let Some(state) = new_state {
if Self::request_is_complete(state) {
// The request is always complete at this point
tx.tx_req_state = PgsqlTxProgress::TxDone;
if state == PgsqlStateProgress::ConnectionTerminated
|| state == PgsqlStateProgress::CancelRequestReceived
{
/* The server won't send any responses to such requests, so transaction should be over */
tx.tx_res_state = PgsqlTxProgress::TxDone;
}
}
}
} else {
// If there isn't a new transaction, we'll consider Suri should move on
Expand Down Expand Up @@ -455,6 +475,21 @@ impl PgsqlState {
}
}

/// Process State progress to decide if response is finished
///
fn response_is_complete(state: PgsqlStateProgress) -> bool {
match state {
PgsqlStateProgress::ReadyForQueryReceived
| PgsqlStateProgress::SSLRejectedReceived
| PgsqlStateProgress::SimpleAuthenticationReceived
| PgsqlStateProgress::SASLAuthenticationReceived
| PgsqlStateProgress::SASLAuthenticationContinueReceived
| PgsqlStateProgress::SASLAuthenticationFinalReceived
| PgsqlStateProgress::Finished => true,
_ => false,
}
}

fn parse_response(&mut self, flow: *const Flow, input: &[u8]) -> AppLayerResult {
// We're not interested in empty responses.
if input.is_empty() {
Expand Down Expand Up @@ -482,30 +517,36 @@ impl PgsqlState {
);
start = rem;
SCLogDebug!("Response is {:?}", &response);
if let Some(state) = self.response_process_next_state(&response, flow) {
let new_state = self.response_process_next_state(&response, flow);
if let Some(state) = new_state {
self.state_progress = state;
};
let tx_completed = self.is_tx_completed();
let curr_state = self.state_progress;
}
if let Some(tx) = self.find_or_create_tx() {
if curr_state == PgsqlStateProgress::DataRowReceived {
tx.incr_row_cnt();
} else if curr_state == PgsqlStateProgress::CommandCompletedReceived
&& tx.get_row_cnt() > 0
{
// let's summarize the info from the data_rows in one response
let dummy_resp =
PgsqlBEMessage::ConsolidatedDataRow(ConsolidatedDataRowPacket {
identifier: b'D',
row_cnt: tx.get_row_cnt(),
data_size: tx.data_size, // total byte count of all data_row messages combined
});
tx.responses.push(dummy_resp);
tx.responses.push(response);
} else {
tx.responses.push(response);
if tx_completed {
tx.tx_state = PgsqlTransactionState::ResponseDone;
if tx.tx_res_state == PgsqlTxProgress::TxInit {
tx.tx_res_state = PgsqlTxProgress::TxReceived;
}
if let Some(state) = new_state {
if state == PgsqlStateProgress::DataRowReceived {
tx.incr_row_cnt();
} else if state == PgsqlStateProgress::CommandCompletedReceived
&& tx.get_row_cnt() > 0
{
// let's summarize the info from the data_rows in one response
let dummy_resp = PgsqlBEMessage::ConsolidatedDataRow(
ConsolidatedDataRowPacket {
identifier: b'D',
row_cnt: tx.get_row_cnt(),
data_size: tx.data_size, // total byte count of all data_row messages combined
},
);
tx.responses.push(dummy_resp);
tx.responses.push(response);
} else {
tx.responses.push(response);
if Self::response_is_complete(state) {
tx.tx_req_state = PgsqlTxProgress::TxDone;
tx.tx_res_state = PgsqlTxProgress::TxDone;
}
}
}
} else {
Expand Down Expand Up @@ -557,6 +598,22 @@ fn probe_tc(input: &[u8]) -> bool {
false
}

fn pgsql_tx_get_req_state(tx: *mut std::os::raw::c_void) -> PgsqlTxProgress {
let tx_safe: &mut PgsqlTransaction;
unsafe {
tx_safe = cast_pointer!(tx, PgsqlTransaction);
}
tx_safe.tx_req_state
}

fn pgsql_tx_get_res_state(tx: *mut std::os::raw::c_void) -> PgsqlTxProgress {
let tx_safe: &mut PgsqlTransaction;
unsafe {
tx_safe = cast_pointer!(tx, PgsqlTransaction);
}
tx_safe.tx_res_state
}

// C exports.

/// C entry point for a probing parser.
Expand Down Expand Up @@ -712,10 +769,14 @@ pub extern "C" fn SCPgsqlStateGetTxCount(state: *mut std::os::raw::c_void) -> u6

#[no_mangle]
pub unsafe extern "C" fn SCPgsqlTxGetALStateProgress(
tx: *mut std::os::raw::c_void, _direction: u8,
tx: *mut std::os::raw::c_void, direction: u8,
) -> std::os::raw::c_int {
let tx = cast_pointer!(tx, PgsqlTransaction);
tx.tx_state as i32
if direction == Direction::ToServer as u8 {
return pgsql_tx_get_req_state(tx) as i32;
}

// Direction has only two possible values, so we don't need to check for the other one
pgsql_tx_get_res_state(tx) as i32
}

export_tx_data_get!(rs_pgsql_get_tx_data, PgsqlTransaction);
Expand Down Expand Up @@ -743,8 +804,8 @@ pub unsafe extern "C" fn SCRegisterPgsqlParser() {
parse_tc: SCPgsqlParseResponse,
get_tx_count: SCPgsqlStateGetTxCount,
get_tx: SCPgsqlStateGetTx,
tx_comp_st_ts: PgsqlTransactionState::RequestReceived as i32,
tx_comp_st_tc: PgsqlTransactionState::ResponseDone as i32,
tx_comp_st_ts: PgsqlTxProgress::TxDone as i32,
tx_comp_st_tc: PgsqlTxProgress::TxDone as i32,
tx_get_progress: SCPgsqlTxGetALStateProgress,
get_eventinfo: None,
get_eventinfo_byid: None,
Expand Down

0 comments on commit dcccbb1

Please sign in to comment.