Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin' into new-chart-prototyping
Browse files Browse the repository at this point in the history
  • Loading branch information
tenphi committed Sep 12, 2024
2 parents 6c13052 + 0f7bb3d commit 71c174c
Show file tree
Hide file tree
Showing 7 changed files with 803 additions and 300 deletions.
10 changes: 8 additions & 2 deletions rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,12 +465,14 @@ impl CubeScanWrapperNode {
node
)));
}
let mut meta_with_user = load_request_meta.as_ref().clone();
meta_with_user.set_change_user(node.options.change_user.clone());
let sql = transport
.sql(
node.span_id.clone(),
node.request.clone(),
node.auth_context,
load_request_meta.as_ref().clone(),
meta_with_user,
Some(
node.member_fields
.iter()
Expand Down Expand Up @@ -843,12 +845,16 @@ impl CubeScanWrapperNode {
}
// TODO time dimensions, filters, segments

let mut meta_with_user = load_request_meta.as_ref().clone();
meta_with_user.set_change_user(
ungrouped_scan_node.options.change_user.clone(),
);
let sql_response = transport
.sql(
ungrouped_scan_node.span_id.clone(),
load_request.clone(),
ungrouped_scan_node.auth_context.clone(),
load_request_meta.as_ref().clone(),
meta_with_user,
// TODO use aliases or push everything through names?
None,
Some(sql.values.clone()),
Expand Down
181 changes: 0 additions & 181 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,68 +288,6 @@ mod tests {
);
}

#[tokio::test]
async fn test_change_user_via_filter() {
init_testing_logger();

let query_plan = convert_select_to_query_plan(
"SELECT COUNT(*) as cnt FROM KibanaSampleDataEcommerce WHERE __user = 'gopher'"
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let cube_scan = query_plan.as_logical_plan().find_cube_scan();

assert_eq!(cube_scan.options.change_user, Some("gopher".to_string()));

assert_eq!(
cube_scan.request,
V1LoadRequestQuery {
measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string(),]),
segments: Some(vec![]),
dimensions: Some(vec![]),
time_dimensions: None,
order: None,
limit: None,
offset: None,
filters: None,
ungrouped: None,
}
)
}

#[tokio::test]
async fn test_change_user_via_in_filter() {
init_testing_logger();

let query_plan = convert_select_to_query_plan(
"SELECT COUNT(*) as cnt FROM KibanaSampleDataEcommerce WHERE __user IN ('gopher')"
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let cube_scan = query_plan.as_logical_plan().find_cube_scan();

assert_eq!(cube_scan.options.change_user, Some("gopher".to_string()));

assert_eq!(
cube_scan.request,
V1LoadRequestQuery {
measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string(),]),
segments: Some(vec![]),
dimensions: Some(vec![]),
time_dimensions: None,
order: None,
limit: None,
offset: None,
filters: None,
ungrouped: None,
}
)
}

#[tokio::test]
async fn test_starts_with() {
init_testing_logger();
Expand Down Expand Up @@ -481,92 +419,6 @@ mod tests {
assert!(sql.contains("LOWER("));
}

#[tokio::test]
async fn test_change_user_via_in_filter_thoughtspot() {
init_testing_logger();

let query_plan = convert_select_to_query_plan(
r#"SELECT COUNT(*) as cnt FROM KibanaSampleDataEcommerce "ta_1" WHERE (LOWER("ta_1"."__user") IN ('gopher')) = TRUE"#.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let expected_request = V1LoadRequestQuery {
measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string()]),
segments: Some(vec![]),
dimensions: Some(vec![]),
time_dimensions: None,
order: None,
limit: None,
offset: None,
filters: None,
ungrouped: None,
};

let cube_scan = query_plan.as_logical_plan().find_cube_scan();
assert_eq!(cube_scan.options.change_user, Some("gopher".to_string()));
assert_eq!(cube_scan.request, expected_request);

let query_plan = convert_select_to_query_plan(
r#"SELECT COUNT(*) as cnt FROM KibanaSampleDataEcommerce "ta_1" WHERE ((LOWER("ta_1"."__user") IN ('gopher') = TRUE) = TRUE)"#.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let cube_scan = query_plan.as_logical_plan().find_cube_scan();
assert_eq!(cube_scan.options.change_user, Some("gopher".to_string()));
assert_eq!(cube_scan.request, expected_request);
}

#[tokio::test]
async fn test_change_user_via_filter_and() {
let query_plan = convert_select_to_query_plan(
"SELECT COUNT(*) as cnt FROM KibanaSampleDataEcommerce WHERE __user = 'gopher' AND customer_gender = 'male'".to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let cube_scan = query_plan.as_logical_plan().find_cube_scan();

assert_eq!(cube_scan.options.change_user, Some("gopher".to_string()));

assert_eq!(
cube_scan.request,
V1LoadRequestQuery {
measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string(),]),
segments: Some(vec![]),
dimensions: Some(vec![]),
time_dimensions: None,
order: None,
limit: None,
offset: None,
filters: Some(vec![V1LoadRequestQueryFilterItem {
member: Some("KibanaSampleDataEcommerce.customer_gender".to_string()),
operator: Some("equals".to_string()),
values: Some(vec!["male".to_string()]),
or: None,
and: None,
}]),
ungrouped: None,
}
)
}

#[tokio::test]
async fn test_change_user_via_filter_or() {
// OR is not allowed for __user
let meta = get_test_tenant_ctx();
let query =
convert_sql_to_cube_query(
&"SELECT COUNT(*) as cnt FROM KibanaSampleDataEcommerce WHERE __user = 'gopher' OR customer_gender = 'male'".to_string(),
meta.clone(),
get_test_session(DatabaseProtocol::PostgreSQL, meta).await
).await;

// TODO: We need to propagate error to result, to assert message
query.unwrap_err();
}

#[tokio::test]
async fn test_order_alias_for_measure_default() {
let query_plan = convert_select_to_query_plan(
Expand Down Expand Up @@ -8806,39 +8658,6 @@ ORDER BY "source"."str0" ASC
)
}

#[tokio::test]
async fn test_user_with_join() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_testing_logger();

let logical_plan = convert_select_to_query_plan(
"SELECT aliased.count as c, aliased.user_1 as u1, aliased.user_2 as u2 FROM (SELECT \"KibanaSampleDataEcommerce\".count as count, \"KibanaSampleDataEcommerce\".__user as user_1, Logs.__user as user_2 FROM \"KibanaSampleDataEcommerce\" CROSS JOIN Logs WHERE __user = 'foo') aliased".to_string(),
DatabaseProtocol::PostgreSQL,
)
.await
.as_logical_plan();

let cube_scan = logical_plan.find_cube_scan();
assert_eq!(
cube_scan.request,
V1LoadRequestQuery {
measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string()]),
dimensions: Some(vec![]),
segments: Some(vec![]),
time_dimensions: None,
order: Some(vec![]),
limit: None,
offset: None,
filters: None,
ungrouped: Some(true),
}
);

assert_eq!(cube_scan.options.change_user, Some("foo".to_string()))
}

#[tokio::test]
async fn test_sort_relations() -> Result<(), CubeError> {
init_testing_logger();
Expand Down
51 changes: 44 additions & 7 deletions rust/cubesql/cubesql/src/compile/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ pub mod test_df_execution;
pub mod test_introspection;
#[cfg(test)]
pub mod test_udfs;
#[cfg(test)]
pub mod test_user_change;
pub mod utils;
pub use utils::*;

Expand Down Expand Up @@ -659,20 +661,34 @@ pub fn get_test_auth() -> Arc<dyn SqlAuthService> {
Arc::new(TestSqlAuth {})
}

#[derive(Clone, Debug)]
pub struct TestTransportLoadCall {
pub query: TransportLoadRequestQuery,
pub sql_query: Option<SqlQuery>,
pub ctx: AuthContextRef,
pub meta: LoadRequestMeta,
}

#[derive(Debug)]
struct TestConnectionTransport {
meta_context: Arc<MetaContext>,
load_mocks: tokio::sync::Mutex<Vec<(TransportLoadRequestQuery, TransportLoadResponse)>>,
load_calls: tokio::sync::Mutex<Vec<TestTransportLoadCall>>,
}

impl TestConnectionTransport {
pub fn new(meta_context: Arc<MetaContext>) -> Self {
Self {
meta_context,
load_mocks: tokio::sync::Mutex::new(vec![]),
load_calls: tokio::sync::Mutex::new(vec![]),
}
}

pub async fn load_calls(&self) -> Vec<TestTransportLoadCall> {
self.load_calls.lock().await.clone()
}

pub async fn add_cube_load_mock(
&self,
req: TransportLoadRequestQuery,
Expand All @@ -694,13 +710,17 @@ impl TransportService for TestConnectionTransport {
_span_id: Option<Arc<SpanId>>,
query: TransportLoadRequestQuery,
_ctx: AuthContextRef,
_meta_fields: LoadRequestMeta,
meta: LoadRequestMeta,
_member_to_alias: Option<HashMap<String, String>>,
expression_params: Option<Vec<Option<String>>>,
) -> Result<SqlResponse, CubeError> {
let inputs = serde_json::json!({
"query": query,
"meta": meta,
});
Ok(SqlResponse {
sql: SqlQuery::new(
format!("SELECT * FROM {}", serde_json::to_string(&query).unwrap()),
format!("SELECT * FROM {}", serde_json::to_string(&inputs).unwrap()),
expression_params.unwrap_or(Vec::new()),
),
})
Expand All @@ -712,16 +732,30 @@ impl TransportService for TestConnectionTransport {
_span_id: Option<Arc<SpanId>>,
query: TransportLoadRequestQuery,
sql_query: Option<SqlQuery>,
_ctx: AuthContextRef,
_meta_fields: LoadRequestMeta,
ctx: AuthContextRef,
meta: LoadRequestMeta,
) -> Result<TransportLoadResponse, CubeError> {
if sql_query.is_some() {
unimplemented!("load with sql_query");
{
let mut calls = self.load_calls.lock().await;
calls.push(TestTransportLoadCall {
query: query.clone(),
sql_query: sql_query.clone(),
ctx: ctx.clone(),
meta: meta.clone(),
});
}

if let Some(sql_query) = sql_query {
return Err(CubeError::internal(format!(
"Test transport does not support load with SQL query: {sql_query:?}"
)));
}

let mocks = self.load_mocks.lock().await;
let Some((_req, res)) = mocks.iter().find(|(req, _res)| req == &query) else {
panic!("Unexpected query: {:?}", query);
return Err(CubeError::internal(format!(
"Unexpected query in test transport: {query:?}"
)));
};
Ok(res.clone())
}
Expand Down Expand Up @@ -862,6 +896,9 @@ impl TestContext {
.or(Some(config_limit));
self.transport.add_cube_load_mock(req, res).await
}
pub async fn load_calls(&self) -> Vec<TestTransportLoadCall> {
self.transport.load_calls().await
}

pub async fn convert_sql_to_cube_query(&self, query: &str) -> CompilationResult<QueryPlan> {
// TODO push to_string() deeper
Expand Down
Loading

0 comments on commit 71c174c

Please sign in to comment.