Skip to content

Commit

Permalink
docs: add example for custom file format with COPY TO (#11174)
Browse files Browse the repository at this point in the history
* feat: add example for copy to

* better docs plus tempdir

* build: clean examples if over 10GB

* only 1GB

* build: try clearing some disk space before running

* build: remove sudo

* build: try clean

* build: run clean

* build: only clean examples

* docs: better output for example
  • Loading branch information
tshauck authored Jul 2, 2024
1 parent 09cdb78 commit ab8761d
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 0 deletions.
1 change: 1 addition & 0 deletions ci/scripts/rust_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ do
# Skip tests that rely on external storage and flight
if [ ! -d $filename ]; then
cargo run --example $example_name
cargo clean -p datafusion-examples
fi
done
234 changes: 234 additions & 0 deletions datafusion-examples/examples/custom_file_format.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::{any::Any, sync::Arc};

use arrow::{
array::{AsArray, RecordBatch, StringArray, UInt8Array},
datatypes::UInt64Type,
};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::{
datasource::{
file_format::{
csv::CsvFormatFactory, file_compression_type::FileCompressionType,
FileFormat, FileFormatFactory,
},
physical_plan::{FileScanConfig, FileSinkConfig},
MemTable,
},
error::Result,
execution::{context::SessionState, runtime_env::RuntimeEnv},
physical_plan::ExecutionPlan,
prelude::{SessionConfig, SessionContext},
};
use datafusion_common::{GetExt, Statistics};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement};
use object_store::{ObjectMeta, ObjectStore};
use tempfile::tempdir;

/// Example of a custom file format that reads and writes TSV files.
///
/// TSVFileFormatFactory is responsible for creating instances of TSVFileFormat.
/// The former, once registered with the SessionState, will then be used
/// to facilitate SQL operations on TSV files, such as `COPY TO` shown here.

#[derive(Debug)]
/// Custom file format that reads and writes TSV files
///
/// This file format is a wrapper around the CSV file format
/// for demonstration purposes.
struct TSVFileFormat {
csv_file_format: Arc<dyn FileFormat>,
}

impl TSVFileFormat {
pub fn new(csv_file_format: Arc<dyn FileFormat>) -> Self {
Self { csv_file_format }
}
}

#[async_trait::async_trait]
impl FileFormat for TSVFileFormat {
fn as_any(&self) -> &dyn Any {
self
}

fn get_ext(&self) -> String {
"tsv".to_string()
}

fn get_ext_with_compression(
&self,
c: &FileCompressionType,
) -> datafusion::error::Result<String> {
if c == &FileCompressionType::UNCOMPRESSED {
Ok("tsv".to_string())
} else {
todo!("Compression not supported")
}
}

async fn infer_schema(
&self,
state: &SessionState,
store: &Arc<dyn ObjectStore>,
objects: &[ObjectMeta],
) -> Result<SchemaRef> {
self.csv_file_format
.infer_schema(state, store, objects)
.await
}

async fn infer_stats(
&self,
state: &SessionState,
store: &Arc<dyn ObjectStore>,
table_schema: SchemaRef,
object: &ObjectMeta,
) -> Result<Statistics> {
self.csv_file_format
.infer_stats(state, store, table_schema, object)
.await
}

async fn create_physical_plan(
&self,
state: &SessionState,
conf: FileScanConfig,
filters: Option<&Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn ExecutionPlan>> {
self.csv_file_format
.create_physical_plan(state, conf, filters)
.await
}

async fn create_writer_physical_plan(
&self,
input: Arc<dyn ExecutionPlan>,
state: &SessionState,
conf: FileSinkConfig,
order_requirements: Option<Vec<PhysicalSortRequirement>>,
) -> Result<Arc<dyn ExecutionPlan>> {
self.csv_file_format
.create_writer_physical_plan(input, state, conf, order_requirements)
.await
}
}

#[derive(Default)]
/// Factory for creating TSV file formats
///
/// This factory is a wrapper around the CSV file format factory
/// for demonstration purposes.
pub struct TSVFileFactory {
csv_file_factory: CsvFormatFactory,
}

impl TSVFileFactory {
pub fn new() -> Self {
Self {
csv_file_factory: CsvFormatFactory::new(),
}
}
}

impl FileFormatFactory for TSVFileFactory {
fn create(
&self,
state: &SessionState,
format_options: &std::collections::HashMap<String, String>,
) -> Result<std::sync::Arc<dyn FileFormat>> {
let mut new_options = format_options.clone();
new_options.insert("format.delimiter".to_string(), "\t".to_string());

let csv_file_format = self.csv_file_factory.create(state, &new_options)?;
let tsv_file_format = Arc::new(TSVFileFormat::new(csv_file_format));

Ok(tsv_file_format)
}

fn default(&self) -> std::sync::Arc<dyn FileFormat> {
todo!()
}
}

impl GetExt for TSVFileFactory {
fn get_ext(&self) -> String {
"tsv".to_string()
}
}

#[tokio::main]
async fn main() -> Result<()> {
// Create a new context with the default configuration
let config = SessionConfig::new();
let runtime = RuntimeEnv::default();
let mut state = SessionState::new_with_config_rt(config, Arc::new(runtime));

// Register the custom file format
let file_format = Arc::new(TSVFileFactory::new());
state.register_file_format(file_format, true).unwrap();

// Create a new context with the custom file format
let ctx = SessionContext::new_with_state(state);

let mem_table = create_mem_table();
ctx.register_table("mem_table", mem_table).unwrap();

let temp_dir = tempdir().unwrap();
let table_save_path = temp_dir.path().join("mem_table.tsv");

let d = ctx
.sql(&format!(
"COPY mem_table TO '{}' STORED AS TSV;",
table_save_path.display(),
))
.await?;

let results = d.collect().await?;
println!(
"Number of inserted rows: {:?}",
(results[0]
.column_by_name("count")
.unwrap()
.as_primitive::<UInt64Type>()
.value(0))
);

Ok(())
}

// create a simple mem table
fn create_mem_table() -> Arc<MemTable> {
let fields = vec![
Field::new("id", DataType::UInt8, false),
Field::new("data", DataType::Utf8, false),
];
let schema = Arc::new(Schema::new(fields));

let partitions = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt8Array::from(vec![1, 2])),
Arc::new(StringArray::from(vec!["foo", "bar"])),
],
)
.unwrap();

Arc::new(MemTable::try_new(schema, vec![vec![partitions]]).unwrap())
}

0 comments on commit ab8761d

Please sign in to comment.