Skip to content

Commit

Permalink
Fix bug in GBQ batching
Browse files Browse the repository at this point in the history
Signed-off-by: Heinz N. Gies <[email protected]>
  • Loading branch information
Licenser committed Aug 15, 2023
1 parent bfe12a7 commit 19afe0a
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 108 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
* gcl now supports the `timestamp` metadata overwrite
* basic avro codec

### Fixes

* Fix bug in GBQ connector where batches were not send

### Breaking Changes

* All google connectors now require `token` to be either set to `{"file": "<path to json>"}` or `{"json": {...}}`
*


## [0.13.0-rc.12]

### Fixes
Expand Down
199 changes: 92 additions & 107 deletions src/connectors/impls/gbq/writer/sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use crate::connectors::{
prelude::*,
};
use futures::{stream, StreamExt};
use googapis::google::cloud::bigquery::storage::v1::append_rows_request::Rows;
use googapis::google::cloud::bigquery::storage::v1::{
append_rows_request::{self, ProtoData},
append_rows_response::{AppendResult, Response},
Expand All @@ -29,7 +28,6 @@ use googapis::google::cloud::bigquery::storage::v1::{
TableFieldSchema, WriteStream,
};
use prost::encoding::WireType;
use prost::Message;
use prost_types::{field_descriptor_proto, DescriptorProto, FieldDescriptorProto};
use std::collections::hash_map::Entry;
use std::marker::PhantomData;
Expand Down Expand Up @@ -79,6 +77,7 @@ pub(crate) struct GbqSink<
struct Field {
table_type: TableType,
tag: u32,
// mode: Mode,

// ignored if the table_type is not struct
subfields: HashMap<String, Field>,
Expand Down Expand Up @@ -167,9 +166,18 @@ fn map_field(
proto3_optional: None,
});

// fn mode_from_raw(mode: i32) -> Result<Mode> {
// match mode {
// 0 => Ok(Mode::Required),
// 1 => Ok(Mode::Nullable),
// 2 => Ok(Mode::Repeated),
// _ => Err(format("Invalid field mode: {}", mode).into()),
// }
// }
fields.insert(
raw_field.name.to_string(),
Field {
// mode: Mode::from(raw_field.mode),
table_type,
tag: u32::from(tag),
subfields,
Expand Down Expand Up @@ -297,14 +305,13 @@ impl JsonToProtobufMapping {
let mut result = Vec::with_capacity(obj.len());

for (key, val) in obj {
if let Some(field) = self.fields.get(&key.to_string()) {
let k: &str = key;
if let Some(field) = self.fields.get(k) {
encode_field(val, field, &mut result)?;
}
}

return Ok(result);
}

Err(ErrorKind::BigQueryTypeMismatch("object", value.value_type()).into())
}

Expand Down Expand Up @@ -360,7 +367,6 @@ where
"BigQuery",
"The client is not connected",
))?;

for request in request_data {
let req_timeout = Duration::from_nanos(self.config.request_timeout);
let append_response =
Expand Down Expand Up @@ -500,63 +506,34 @@ where
request_size_limit: usize,
) -> Result<Vec<AppendRowsRequest>> {
let mut request_data: Vec<AppendRowsRequest> = Vec::new();
let mut requests: HashMap<String, AppendRowsRequest> = HashMap::new();
let mut requests: HashMap<_, Vec<_>> = HashMap::new();

for (data, meta) in event.value_meta_iter() {
let write_stream = self
.get_or_create_write_stream(
&ctx.extract_meta(meta)
requests
.entry(
ctx.extract_meta(meta)
.get("table_id")
.as_str()
.map_or_else(|| self.config.table_id.clone(), ToString::to_string),
ctx,
.map(String::from),
)
.await?;

match requests.entry(write_stream.name.clone()) {
Entry::Occupied(entry) => {
let mut request = entry.remove();

let serialized_event = write_stream.mapping.map(data)?;
Self::rows_from_request(&mut request)?.push(serialized_event);

if request.encoded_len() > request_size_limit {
let rows = Self::rows_from_request(&mut request)?;
let row_count = rows.len();
let last_event = rows.pop().ok_or(ErrorKind::GbqSinkFailed(
"Failed to pop last event from request",
))?;
request_data.push(request);

let mut new_rows = Vec::with_capacity(event.len() - row_count);
new_rows.push(last_event);

requests.insert(
write_stream.name.clone(),
AppendRowsRequest {
write_stream: write_stream.name.clone(),
offset: None,
rows: Some(append_rows_request::Rows::ProtoRows(ProtoData {
writer_schema: Some(ProtoSchema {
proto_descriptor: Some(
write_stream.mapping.descriptor().clone(),
),
}),
rows: Some(ProtoRows {
serialized_rows: new_rows,
}),
})),
trace_id: String::new(),
},
);
}
}
Entry::Vacant(entry) => {
let serialized_event = write_stream.mapping.map(data)?;
let mut serialized_rows = Vec::with_capacity(event.len());
serialized_rows.push(serialized_event);
.or_default()
.push(data);
}

for (tid, data) in requests {
let tid = tid.map_or_else(|| self.config.table_id.clone(), String::from);
let write_stream = self.get_or_create_write_stream(tid, ctx).await?;

let data_len = data.len();

entry.insert(AppendRowsRequest {
let mut serialized_rows = Vec::with_capacity(data_len);
let mut size = 0;
for serialized in data.into_iter().map(|d| write_stream.mapping.map(d)) {
let serialized = serialized?;
if size + serialized.len() > request_size_limit {
let last_len = serialized_rows.len();

request_data.push(AppendRowsRequest {
write_stream: write_stream.name.clone(),
offset: None,
rows: Some(append_rows_request::Rows::ProtoRows(ProtoData {
Expand All @@ -567,12 +544,26 @@ where
})),
trace_id: String::new(),
});
size = 0;
serialized_rows = Vec::with_capacity(data_len - last_len);
}
size += serialized.len();
serialized_rows.push(serialized);
}
}

for (_, request) in requests {
request_data.push(request);
if !serialized_rows.is_empty() {
request_data.push(AppendRowsRequest {
write_stream: write_stream.name.clone(),
offset: None,
rows: Some(append_rows_request::Rows::ProtoRows(ProtoData {
writer_schema: Some(ProtoSchema {
proto_descriptor: Some(write_stream.mapping.descriptor().clone()),
}),
rows: Some(ProtoRows { serialized_rows }),
})),
trace_id: String::new(),
});
}
}

Ok(request_data)
Expand All @@ -589,15 +580,15 @@ where
{
async fn get_or_create_write_stream(
&mut self,
table_id: &str,
table_id: String,
ctx: &SinkContext,
) -> Result<&ConnectedWriteStream> {
let client = self.client.as_mut().ok_or(ErrorKind::ClientNotAvailable(
"BigQuery",
"The client is not connected",
))?;

match self.write_streams.entry(table_id.to_string()) {
match self.write_streams.entry(table_id.clone()) {
Entry::Occupied(entry) => {
// NOTE: `into_mut` is needed here, even though we just need a non-mutable reference
// This is because `get` returns reference which's lifetime is bound to the entry,
Expand All @@ -607,7 +598,7 @@ where
Entry::Vacant(entry) => {
let stream = client
.create_write_stream(CreateWriteStreamRequest {
parent: table_id.to_string(),
parent: table_id.clone(),
write_stream: Some(WriteStream {
// The stream name here will be ignored and a generated value will be set in the response
name: String::new(),
Expand All @@ -624,7 +615,7 @@ where
&stream
.table_schema
.as_ref()
.ok_or_else(|| ErrorKind::GbqSchemaNotProvided(table_id.to_string()))?
.ok_or_else(|| ErrorKind::GbqSchemaNotProvided(table_id))?
.clone()
.fields,
ctx,
Expand All @@ -639,50 +630,25 @@ where
}
}

impl<
T: TokenProvider + 'static,
TChannel: GbqChannel<TChannelError> + 'static,
TChannelError: GbqChannelError,
> GbqSink<T, TChannel, TChannelError>
where
TChannel::Future: Send,
{
fn rows_from_request(request: &mut AppendRowsRequest) -> Result<&mut Vec<Vec<u8>>> {
let rows = match request
.rows
.as_mut()
.ok_or(ErrorKind::GbqSinkFailed("No rows in request"))?
{
Rows::ProtoRows(ref mut x) => {
&mut x
.rows
.as_mut()
.ok_or(ErrorKind::GbqSinkFailed("No rows in request"))?
.serialized_rows
}
};

Ok(rows)
}
}

#[cfg(test)]
#[cfg(feature = "gcp-integration")]
mod test {
use super::*;
use crate::connectors::reconnect::ConnectionLostNotifier;
use crate::connectors::tests::ConnectorHarness;
use crate::connectors::{
google::tests::TestTokenProvider, utils::quiescence::QuiescenceBeacon,
google::{tests::TestTokenProvider, TokenSrc},
impls::gbq,
reconnect::ConnectionLostNotifier,
tests::ConnectorHarness,
utils::quiescence::QuiescenceBeacon,
};
use crate::connectors::{google::TokenSrc, impls::gbq};
use bytes::Bytes;
use futures::future::Ready;
use googapis::google::cloud::bigquery::storage::v1::table_field_schema::Mode;
use googapis::google::cloud::bigquery::storage::v1::{
append_rows_response, AppendRowsResponse, TableSchema,
use googapis::google::{
cloud::bigquery::storage::v1::{
append_rows_response, table_field_schema::Mode, AppendRowsResponse, TableSchema,
},
rpc::Status,
};
use googapis::google::rpc::Status;
use http::{HeaderMap, HeaderValue};
use prost::Message;
use std::collections::VecDeque;
Expand Down Expand Up @@ -1478,7 +1444,18 @@ mod test {
r#type: i32::from(write_stream::Type::Committed),
create_time: None,
commit_time: None,
table_schema: Some(TableSchema { fields: vec![] }),
table_schema: Some(TableSchema {
fields: vec![TableFieldSchema {
name: "newfield".to_string(),
r#type: i32::from(table_field_schema::Type::String),
mode: i32::from(Mode::Required),
fields: vec![],
description: "test".to_string(),
max_length: 10,
precision: 0,
scale: 0,
}],
}),
}
.encode(&mut buffer_write_stream)
.map_err(|_| "encode failed")?;
Expand All @@ -1488,7 +1465,7 @@ mod test {
fields: vec![TableFieldSchema {
name: "newfield".to_string(),
r#type: i32::from(table_field_schema::Type::String),
mode: i32::from(Mode::Nullable),
mode: i32::from(Mode::Required),
fields: vec![],
description: "test".to_string(),
max_length: 10,
Expand All @@ -1505,6 +1482,10 @@ mod test {
.encode(&mut buffer_append_rows_response)
.map_err(|_| "encode failed")?;

let responses = Arc::new(RwLock::new(VecDeque::from([
buffer_write_stream,
buffer_append_rows_response,
])));
let mut sink = GbqSink::<TestTokenProvider, _, _>::new(
Config {
token: TokenSrc::dummy(),
Expand All @@ -1514,10 +1495,7 @@ mod test {
request_size_limit: 10 * 1024 * 1024,
},
Box::new(MockChannelFactory {
responses: Arc::new(RwLock::new(VecDeque::from([
buffer_write_stream,
buffer_append_rows_response,
]))),
responses: responses.clone(),
}),
);

Expand All @@ -1531,10 +1509,18 @@ mod test {

sink.connect(&ctx, &Attempt::default()).await?;

let mut event = Event {
data: EventPayload::from(literal!({
"newfield": "test"
})),
..Event::default()
};
event.transactional = true;

let result = sink
.on_event(
"",
Event::default(),
event,
&ctx,
&mut EventSerializer::new(
None,
Expand All @@ -1546,7 +1532,6 @@ mod test {
0,
)
.await?;

assert_eq!(result.ack, SinkAck::Fail);
assert_eq!(result.cb, CbAction::None);
Ok(())
Expand Down

0 comments on commit 19afe0a

Please sign in to comment.