diff --git a/cmd/gpq/command/convert.go b/cmd/gpq/command/convert.go index 5557282..0130c9d 100644 --- a/cmd/gpq/command/convert.go +++ b/cmd/gpq/command/convert.go @@ -35,6 +35,7 @@ type ConvertCmd struct { Max int `help:"Maximum number of features to consider when building a schema." default:"100"` InputPrimaryColumn string `help:"Primary geometry column name when reading Parquet withtout metadata." default:"geometry"` Compression string `help:"Parquet compression to use. Possible values: ${enum}." enum:"uncompressed, snappy, gzip, brotli, zstd" default:"zstd"` + RowGroupLength int `help:"Maximum number of rows per group when writing Parquet."` } type FormatType string @@ -149,7 +150,12 @@ func (c *ConvertCmd) Run() error { if outputFormat != ParquetType && outputFormat != GeoParquetType { return errors.New("GeoJSON input can only be converted to GeoParquet") } - convertOptions := &geojson.ConvertOptions{MinFeatures: c.Min, MaxFeatures: c.Max, Compression: c.Compression} + convertOptions := &geojson.ConvertOptions{ + MinFeatures: c.Min, + MaxFeatures: c.Max, + Compression: c.Compression, + RowGroupLength: c.RowGroupLength, + } return geojson.ToParquet(input, output, convertOptions) } @@ -160,6 +166,7 @@ func (c *ConvertCmd) Run() error { convertOptions := &geoparquet.ConvertOptions{ InputPrimaryColumn: c.InputPrimaryColumn, Compression: c.Compression, + RowGroupLength: c.RowGroupLength, } return geoparquet.FromParquet(input, output, convertOptions) diff --git a/internal/geojson/geojson.go b/internal/geojson/geojson.go index bfeb93a..2a12f42 100644 --- a/internal/geojson/geojson.go +++ b/internal/geojson/geojson.go @@ -58,10 +58,11 @@ func FromParquet(reader parquet.ReaderAtSeeker, writer io.Writer) error { } type ConvertOptions struct { - MinFeatures int - MaxFeatures int - Compression string - Metadata string + MinFeatures int + MaxFeatures int + Compression string + RowGroupLength int + Metadata string } var defaultOptions = &ConvertOptions{ @@ -80,12 +81,19 @@ func ToParquet(input io.Reader, output io.Writer, convertOptions *ConvertOptions featuresRead := 0 var pqWriterProps *parquet.WriterProperties + var writerOptions []parquet.WriterProperty if convertOptions.Compression != "" { compression, err := pqutil.GetCompression(convertOptions.Compression) if err != nil { return err } - pqWriterProps = parquet.NewWriterProperties(parquet.WithCompression(compression)) + writerOptions = append(writerOptions, parquet.WithCompression(compression)) + } + if convertOptions.RowGroupLength > 0 { + writerOptions = append(writerOptions, parquet.WithMaxRowGroupLength(int64(convertOptions.RowGroupLength))) + } + if len(writerOptions) > 0 { + pqWriterProps = parquet.NewWriterProperties(writerOptions...) } var featureWriter *geoparquet.FeatureWriter diff --git a/internal/geojson/geojson_test.go b/internal/geojson/geojson_test.go index cc4fc82..c5f4832 100644 --- a/internal/geojson/geojson_test.go +++ b/internal/geojson/geojson_test.go @@ -105,6 +105,42 @@ func TestToParquet(t *testing.T) { assert.JSONEq(t, string(expected), geojsonBuffer.String()) } +func TestToParquetRowGroupLength3(t *testing.T) { + geojsonFile, openErr := os.Open("testdata/ten-points.geojson") + require.NoError(t, openErr) + + parquetBuffer := &bytes.Buffer{} + toParquetErr := geojson.ToParquet(geojsonFile, parquetBuffer, &geojson.ConvertOptions{ + RowGroupLength: 3, + }) + assert.NoError(t, toParquetErr) + + parquetInput := bytes.NewReader(parquetBuffer.Bytes()) + fileReader, fileErr := file.NewParquetReader(parquetInput) + require.NoError(t, fileErr) + defer fileReader.Close() + + assert.Equal(t, 4, fileReader.NumRowGroups()) +} + +func TestToParquetRowGroupLength5(t *testing.T) { + geojsonFile, openErr := os.Open("testdata/ten-points.geojson") + require.NoError(t, openErr) + + parquetBuffer := &bytes.Buffer{} + toParquetErr := geojson.ToParquet(geojsonFile, parquetBuffer, &geojson.ConvertOptions{ + RowGroupLength: 5, + }) + assert.NoError(t, toParquetErr) + + parquetInput := bytes.NewReader(parquetBuffer.Bytes()) + fileReader, fileErr := file.NewParquetReader(parquetInput) + require.NoError(t, fileErr) + defer fileReader.Close() + + assert.Equal(t, 2, fileReader.NumRowGroups()) +} + func TestToParquetMismatchedTypes(t *testing.T) { geojsonFile, openErr := os.Open("testdata/mismatched-types.geojson") require.NoError(t, openErr) diff --git a/internal/geojson/testdata/ten-points.geojson b/internal/geojson/testdata/ten-points.geojson new file mode 100644 index 0000000..9db30f2 --- /dev/null +++ b/internal/geojson/testdata/ten-points.geojson @@ -0,0 +1,105 @@ +{ + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "properties": { + "num": 0 + }, + "geometry": { + "type": "Point", + "coordinates": [0, 0] + } + }, + { + "type": "Feature", + "properties": { + "num": 1 + }, + "geometry": { + "type": "Point", + "coordinates": [1, 1] + } + }, + { + "type": "Feature", + "properties": { + "num": 2 + }, + "geometry": { + "type": "Point", + "coordinates": [2, 2] + } + }, + { + "type": "Feature", + "properties": { + "num": 3 + }, + "geometry": { + "type": "Point", + "coordinates": [3, 3] + } + }, + { + "type": "Feature", + "properties": { + "num": 4 + }, + "geometry": { + "type": "Point", + "coordinates": [4, 4] + } + }, + { + "type": "Feature", + "properties": { + "num": 5 + }, + "geometry": { + "type": "Point", + "coordinates": [5, 5] + } + }, + { + "type": "Feature", + "properties": { + "num": 6 + }, + "geometry": { + "type": "Point", + "coordinates": [6, 6] + } + }, + { + "type": "Feature", + "properties": { + "num": 7 + }, + "geometry": { + "type": "Point", + "coordinates": [7, 7] + } + }, + { + "type": "Feature", + "properties": { + "num": 8 + }, + "geometry": { + "type": "Point", + "coordinates": [8, 8] + } + }, + { + "type": "Feature", + "properties": { + "num": 9 + }, + "geometry": { + "type": "Point", + "coordinates": [9, 9] + } + } + ] +} \ No newline at end of file diff --git a/internal/geoparquet/geoparquet.go b/internal/geoparquet/geoparquet.go index 6bc9605..67eeaa2 100644 --- a/internal/geoparquet/geoparquet.go +++ b/internal/geoparquet/geoparquet.go @@ -21,6 +21,7 @@ import ( type ConvertOptions struct { InputPrimaryColumn string Compression string + RowGroupLength int } func getMetadata(fileReader *file.Reader, convertOptions *ConvertOptions) *Metadata { @@ -171,6 +172,7 @@ func FromParquet(input parquet.ReaderAtSeeker, output io.Writer, convertOptions TransformColumn: transformColumn, BeforeClose: beforeClose, Compression: compression, + RowGroupLength: convertOptions.RowGroupLength, } return pqutil.TransformByColumn(config) diff --git a/internal/pqutil/transform.go b/internal/pqutil/transform.go index 3e4e946..2061a7d 100644 --- a/internal/pqutil/transform.go +++ b/internal/pqutil/transform.go @@ -23,6 +23,7 @@ type TransformConfig struct { Reader parquet.ReaderAtSeeker Writer io.Writer Compression *compress.Compression + RowGroupLength int TransformSchema SchemaTransformer TransformColumn ColumnTransformer BeforeClose func(*file.Reader, *file.Writer) error @@ -50,6 +51,10 @@ func getWriterProperties(config *TransformConfig, fileReader *file.Reader) (*par } } + if config.RowGroupLength > 0 { + writerProperties = append(writerProperties, parquet.WithMaxRowGroupLength(int64(config.RowGroupLength))) + } + return parquet.NewWriterProperties(writerProperties...), nil } @@ -104,34 +109,85 @@ func TransformByColumn(config *TransformConfig) error { ctx := pqarrow.NewArrowWriteContext(context.Background(), nil) - numRowGroups := fileReader.NumRowGroups() - for rowGroupIndex := 0; rowGroupIndex < numRowGroups; rowGroupIndex += 1 { - rowGroupReader := arrowReader.RowGroup(rowGroupIndex) - rowGroupWriter := fileWriter.AppendRowGroup() + if config.RowGroupLength > 0 { + columnReaders := make([]*pqarrow.ColumnReader, numFields) for fieldNum := 0; fieldNum < numFields; fieldNum += 1 { - arr, readErr := rowGroupReader.Column(fieldNum).Read(ctx) - if readErr != nil { - return readErr + colReader, err := arrowReader.GetColumn(ctx, fieldNum) + if err != nil { + return err } - if config.TransformColumn != nil { - inputField := inputManifest.Fields[fieldNum].Field - outputField := outputManifest.Fields[fieldNum].Field - transformed, err := config.TransformColumn(inputField, outputField, arr) - if err != nil { - return err + columnReaders[fieldNum] = colReader + } + + numRows := fileReader.NumRows() + numRowsWritten := int64(0) + for { + rowGroupWriter := fileWriter.AppendRowGroup() + for fieldNum := 0; fieldNum < numFields; fieldNum += 1 { + colReader := columnReaders[fieldNum] + arr, readErr := colReader.NextBatch(int64(config.RowGroupLength)) + if readErr != nil { + return readErr } - if transformed.DataType() != outputField.Type { - return fmt.Errorf("transform generated an unexpected type, got %s, expected %s", transformed.DataType().Name(), outputField.Type.Name()) + if config.TransformColumn != nil { + inputField := inputManifest.Fields[fieldNum].Field + outputField := outputManifest.Fields[fieldNum].Field + transformed, err := config.TransformColumn(inputField, outputField, arr) + if err != nil { + return err + } + if transformed.DataType() != outputField.Type { + return fmt.Errorf("transform generated an unexpected type, got %s, expected %s", transformed.DataType().Name(), outputField.Type.Name()) + } + arr = transformed + } + colWriter, colWriterErr := pqarrow.NewArrowColumnWriter(arr, 0, int64(arr.Len()), outputManifest, rowGroupWriter, fieldNum) + if colWriterErr != nil { + return colWriterErr + } + if err := colWriter.Write(ctx); err != nil { + return err } - arr = transformed - } - colWriter, colWriterErr := pqarrow.NewArrowColumnWriter(arr, 0, int64(arr.Len()), outputManifest, rowGroupWriter, fieldNum) - if colWriterErr != nil { - return colWriterErr } - if err := colWriter.Write(ctx); err != nil { + numRowsInGroup, err := rowGroupWriter.NumRows() + if err != nil { return err } + numRowsWritten += int64(numRowsInGroup) + if numRowsWritten >= numRows { + break + } + } + } else { + numRowGroups := fileReader.NumRowGroups() + for rowGroupIndex := 0; rowGroupIndex < numRowGroups; rowGroupIndex += 1 { + rowGroupReader := arrowReader.RowGroup(rowGroupIndex) + rowGroupWriter := fileWriter.AppendRowGroup() + for fieldNum := 0; fieldNum < numFields; fieldNum += 1 { + arr, readErr := rowGroupReader.Column(fieldNum).Read(ctx) + if readErr != nil { + return readErr + } + if config.TransformColumn != nil { + inputField := inputManifest.Fields[fieldNum].Field + outputField := outputManifest.Fields[fieldNum].Field + transformed, err := config.TransformColumn(inputField, outputField, arr) + if err != nil { + return err + } + if transformed.DataType() != outputField.Type { + return fmt.Errorf("transform generated an unexpected type, got %s, expected %s", transformed.DataType().Name(), outputField.Type.Name()) + } + arr = transformed + } + colWriter, colWriterErr := pqarrow.NewArrowColumnWriter(arr, 0, int64(arr.Len()), outputManifest, rowGroupWriter, fieldNum) + if colWriterErr != nil { + return colWriterErr + } + if err := colWriter.Write(ctx); err != nil { + return err + } + } } } diff --git a/internal/pqutil/transform_test.go b/internal/pqutil/transform_test.go index 45a3928..5402554 100644 --- a/internal/pqutil/transform_test.go +++ b/internal/pqutil/transform_test.go @@ -2,7 +2,9 @@ package pqutil_test import ( "bytes" + "encoding/json" "fmt" + "math" "strconv" "testing" @@ -83,7 +85,7 @@ func TestTransformByColumn(t *testing.T) { for i, c := range cases { t.Run(fmt.Sprintf("%s (case %d)", c.name, i), func(t *testing.T) { - input := test.ParquetFromJSON(t, c.data) + input := test.ParquetFromJSON(t, c.data, nil) output := &bytes.Buffer{} config := c.config if config == nil { @@ -121,6 +123,97 @@ func TestTransformByColumn(t *testing.T) { } } +func TestTransformByRowGroupLength(t *testing.T) { + numRows := 100 + rows := make([]map[string]any, numRows) + for i := 0; i < numRows; i += 1 { + rows[i] = map[string]any{"num": i} + } + inputData, err := json.Marshal(rows) + require.NoError(t, err) + + cases := []struct { + name string + inputRowGroupLength int + config *pqutil.TransformConfig + }{ + { + name: "no row group length, use input", + inputRowGroupLength: 50, + }, + { + name: "read row group length 50, write 13", + inputRowGroupLength: 50, + config: &pqutil.TransformConfig{ + RowGroupLength: 13, + }, + }, + { + name: "read row group length 50, write 60", + inputRowGroupLength: 50, + config: &pqutil.TransformConfig{ + RowGroupLength: 60, + }, + }, + { + name: "read row group length 50, write 110", + inputRowGroupLength: 50, + config: &pqutil.TransformConfig{ + RowGroupLength: 110, + }, + }, + { + name: "read row group length 110, write 110", + inputRowGroupLength: 110, + config: &pqutil.TransformConfig{ + RowGroupLength: 110, + }, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("%s (case %d)", c.name, i), func(t *testing.T) { + writerProperties := parquet.NewWriterProperties(parquet.WithMaxRowGroupLength(int64(c.inputRowGroupLength))) + input := test.ParquetFromJSON(t, string(inputData), writerProperties) + output := &bytes.Buffer{} + config := c.config + if config == nil { + config = &pqutil.TransformConfig{} + } + config.Reader = input + config.Writer = output + + require.NoError(t, pqutil.TransformByColumn(config)) + + outputAsJSON := test.ParquetToJSON(t, bytes.NewReader(output.Bytes())) + assert.JSONEq(t, string(inputData), outputAsJSON) + + fileReader, err := file.NewParquetReader(bytes.NewReader(output.Bytes())) + require.NoError(t, err) + defer fileReader.Close() + + var expectedNumRowGroups int + if config.RowGroupLength > 0 { + expectedNumRowGroups = int(math.Ceil(float64(numRows) / float64(c.config.RowGroupLength))) + } else { + inputFileReader, err := file.NewParquetReader(input) + require.NoError(t, err) + defer inputFileReader.Close() + expectedNumRowGroups = inputFileReader.NumRowGroups() + } + require.Equal(t, expectedNumRowGroups, fileReader.NumRowGroups()) + + if config.RowGroupLength > 0 { + for rowGroupIndex := 0; rowGroupIndex < fileReader.NumRowGroups(); rowGroupIndex += 1 { + numRows := fileReader.MetaData().RowGroups[rowGroupIndex].NumRows + require.LessOrEqual(t, numRows, int64(config.RowGroupLength), "row group index: %d", rowGroupIndex) + } + } + }) + + } +} + func TestTransformColumn(t *testing.T) { data := `[ { @@ -200,7 +293,7 @@ func TestTransformColumn(t *testing.T) { return arrow.NewChunked(builder.Type(), transformed), nil } - input := test.ParquetFromJSON(t, data) + input := test.ParquetFromJSON(t, data, nil) output := &bytes.Buffer{} config := &pqutil.TransformConfig{ Reader: input, diff --git a/internal/test/test.go b/internal/test/test.go index 467187a..0377483 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -21,7 +21,10 @@ import ( "github.com/stretchr/testify/require" ) -func ParquetFromJSON(t *testing.T, data string) parquet.ReaderAtSeeker { +func ParquetFromJSON(t *testing.T, data string, writerProperties *parquet.WriterProperties) parquet.ReaderAtSeeker { + if writerProperties == nil { + writerProperties = parquet.NewWriterProperties() + } var rows []map[string]any require.NoError(t, json.Unmarshal([]byte(data), &rows)) @@ -41,7 +44,7 @@ func ParquetFromJSON(t *testing.T, data string) parquet.ReaderAtSeeker { output := &bytes.Buffer{} - writer, err := pqarrow.NewFileWriter(schema, output, parquet.NewWriterProperties(), pqarrow.DefaultWriterProps()) + writer, err := pqarrow.NewFileWriter(schema, output, writerProperties, pqarrow.DefaultWriterProps()) require.NoError(t, err) require.NoError(t, writer.WriteBuffered(rec))