Skip to content

Commit

Permalink
Propagate array item types from SQL to Avro schema (#931)
Browse files Browse the repository at this point in the history
* Propagate array item types from SQL to Avro schema

* Add tests

* Fix checkstyle errors

* Update e2e/e2e.sh

Co-authored-by: Luís Bianchin <[email protected]>

* Update e2e/e2e.sh

Co-authored-by: Luís Bianchin <[email protected]>

* fix e2e

* Update docs

---------

Co-authored-by: Luís Bianchin <[email protected]>
  • Loading branch information
shnapz and labianchin authored Dec 6, 2024
1 parent 9c72eac commit addacc3
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,11 @@ static SqlFunction<ResultSet, Object> computeMapping(
} else {
return resultSet -> nullableBytes(resultSet.getBytes(column));
}
case ARRAY:
return resultSet -> resultSet.getArray(column);
case BINARY:
case VARBINARY:
case LONGVARBINARY:
case ARRAY:
case BLOB:
return resultSet -> nullableBytes(resultSet.getBytes(column));
case DOUBLE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,24 +97,39 @@ public ByteBuffer convertResultSetIntoAvroBytes() throws SQLException, IOExcepti
binaryEncoder.writeNull();
} else {
binaryEncoder.writeIndex(1);
if (value instanceof String) {
binaryEncoder.writeString((String) value);
} else if (value instanceof Long) {
binaryEncoder.writeLong((Long) value);
} else if (value instanceof Integer) {
binaryEncoder.writeInt((Integer) value);
} else if (value instanceof Boolean) {
binaryEncoder.writeBoolean((Boolean) value);
} else if (value instanceof ByteBuffer) {
binaryEncoder.writeBytes((ByteBuffer) value);
} else if (value instanceof Double) {
binaryEncoder.writeDouble((Double) value);
} else if (value instanceof Float) {
binaryEncoder.writeFloat((Float) value);
}
writeValue(value, binaryEncoder);
}
}
binaryEncoder.flush();
return ByteBuffer.wrap(out.getBufffer(), 0, out.size());
}

private void writeValue(Object value, BinaryEncoder binaryEncoder)
throws SQLException, IOException {
if (value instanceof String) {
binaryEncoder.writeString((String) value);
} else if (value instanceof Long) {
binaryEncoder.writeLong((Long) value);
} else if (value instanceof Integer) {
binaryEncoder.writeInt((Integer) value);
} else if (value instanceof Boolean) {
binaryEncoder.writeBoolean((Boolean) value);
} else if (value instanceof ByteBuffer) {
binaryEncoder.writeBytes((ByteBuffer) value);
} else if (value instanceof Double) {
binaryEncoder.writeDouble((Double) value);
} else if (value instanceof Float) {
binaryEncoder.writeFloat((Float) value);
} else if (value instanceof java.sql.Array) {
binaryEncoder.writeArrayStart();
Object[] array = (Object[]) ((java.sql.Array) value).getArray();
binaryEncoder.setItemCount(array.length);
for (Object arrayItem : array) {
binaryEncoder.startItem();
writeValue(arrayItem, binaryEncoder);
}

binaryEncoder.writeArrayEnd();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ public static Schema createSchemaByReadingOneRow(
try (Statement statement = connection.createStatement()) {
final ResultSet resultSet = statement.executeQuery(queryBuilderArgs.sqlQueryWithLimitOne());

resultSet.next();

final Schema schema =
createAvroSchema(
resultSet,
Expand Down Expand Up @@ -107,7 +109,7 @@ public static Schema createAvroSchema(
.prop("tableName", tableName)
.prop("connectionUrl", connectionUrl)
.fields();
return createAvroFields(meta, builder, useLogicalTypes).endRecord();
return createAvroFields(resultSet, builder, useLogicalTypes).endRecord();
}

static String getDatabaseTableName(final ResultSetMetaData meta) throws SQLException {
Expand All @@ -123,11 +125,13 @@ static String getDatabaseTableName(final ResultSetMetaData meta) throws SQLExcep
}

private static SchemaBuilder.FieldAssembler<Schema> createAvroFields(
final ResultSetMetaData meta,
final SchemaBuilder.FieldAssembler<Schema> builder,
final ResultSet resultSet,
final SchemaBuilder.FieldAssembler<Schema> builder,
final boolean useLogicalTypes)
throws SQLException {

ResultSetMetaData meta = resultSet.getMetaData();

for (int i = 1; i <= meta.getColumnCount(); i++) {

final String columnName;
Expand All @@ -140,7 +144,8 @@ private static SchemaBuilder.FieldAssembler<Schema> createAvroFields(
final int columnType = meta.getColumnType(i);
final String typeName = JDBCType.valueOf(columnType).getName();
final String columnClassName = meta.getColumnClassName(i);
final SchemaBuilder.FieldBuilder<Schema> field =
final String columnTypeName = meta.getColumnTypeName(i);
SchemaBuilder.FieldBuilder<Schema> field =
builder
.name(normalizeForAvro(columnName))
.doc(String.format("From sqlType %d %s (%s)", columnType, typeName, columnClassName))
Expand All @@ -149,13 +154,21 @@ private static SchemaBuilder.FieldAssembler<Schema> createAvroFields(
.prop("typeName", typeName)
.prop("columnClassName", columnClassName);

if (columnTypeName != null) {
field = field.prop("columnTypeName", columnTypeName);
}

final SchemaBuilder.BaseTypeBuilder<
SchemaBuilder.UnionAccumulator<SchemaBuilder.NullDefault<Schema>>>
fieldSchemaBuilder = field.type().unionOf().nullBuilder().endNull().and();

Integer arrayItemType = resultSet.isFirst() && columnType == ARRAY
? resultSet.getArray(i).getBaseType() : null;

final SchemaBuilder.UnionAccumulator<SchemaBuilder.NullDefault<Schema>> schemaFieldAssembler =
setAvroColumnType(
columnType,
arrayItemType,
meta.getPrecision(i),
columnClassName,
useLogicalTypes,
Expand All @@ -181,6 +194,7 @@ private static SchemaBuilder.FieldAssembler<Schema> createAvroFields(
private static SchemaBuilder.UnionAccumulator<SchemaBuilder.NullDefault<Schema>>
setAvroColumnType(
final int columnType,
final Integer arrayItemType,
final int precision,
final String columnClassName,
final boolean useLogicalTypes,
Expand Down Expand Up @@ -225,10 +239,12 @@ private static SchemaBuilder.FieldAssembler<Schema> createAvroFields(
} else {
return field.bytesType();
}
case ARRAY:
return setAvroColumnType(arrayItemType, null, precision, columnClassName,
useLogicalTypes, field.array().items());
case BINARY:
case VARBINARY:
case LONGVARBINARY:
case ARRAY:
case BLOB:
return field.bytesType();
case DOUBLE:
Expand Down
53 changes: 46 additions & 7 deletions dbeam-core/src/test/java/com/spotify/dbeam/Coffee.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@

import com.google.auto.value.AutoValue;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;

// A fictitious DB model to test different SQL types
@AutoValue
Expand All @@ -42,7 +45,9 @@ public static Coffee create(
final java.sql.Timestamp created,
final Optional<java.sql.Timestamp> updated,
final UUID uid,
final Long rownum) {
final Long rownum,
final List<Integer> intArr,
final List<String> textArr) {
return new AutoValue_Coffee(
name,
supId,
Expand All @@ -55,7 +60,9 @@ public static Coffee create(
created,
updated,
uid,
rownum);
rownum,
new ArrayList<>(intArr),
new ArrayList<>(textArr));
}

public abstract String name();
Expand All @@ -82,10 +89,15 @@ public static Coffee create(

public abstract Long rownum();

public abstract List<Integer> intArr();

public abstract List<String> textArr();

public String insertStatement() {
return String.format(
Locale.ENGLISH,
"INSERT INTO COFFEES " + "VALUES ('%s', %s, '%s', %f, %f, %b, %d, %d, '%s', %s, '%s', %d)",
"INSERT INTO COFFEES " + "VALUES ('%s', %s, '%s', %f, %f, %b, %d, %d, '%s', %s, '%s', %d,"
+ " ARRAY [%s], ARRAY ['%s'])",
name(),
supId().orElse(null),
price().toString(),
Expand All @@ -97,7 +109,9 @@ public String insertStatement() {
created(),
updated().orElse(null),
uid(),
rownum());
rownum(),
String.join(",", intArr().stream().map(x -> (CharSequence) x.toString())::iterator),
String.join("','", textArr()));
}

public static String ddl() {
Expand All @@ -114,7 +128,9 @@ public static String ddl() {
+ "\"CREATED\" TIMESTAMP NOT NULL,"
+ "\"UPDATED\" TIMESTAMP,"
+ "\"UID\" UUID NOT NULL,"
+ "\"ROWNUM\" BIGINT NOT NULL);";
+ "\"ROWNUM\" BIGINT NOT NULL,"
+ "\"INT_ARR\" INTEGER ARRAY NOT NULL,"
+ "\"TEXT_ARR\" VARCHAR ARRAY NOT NULL);";
}

public static Coffee COFFEE1 =
Expand All @@ -130,7 +146,19 @@ public static String ddl() {
new java.sql.Timestamp(1488300933000L),
Optional.empty(),
UUID.fromString("123e4567-e89b-12d3-a456-426655440000"),
1L);
1L,
new ArrayList<Integer>() {{
add(5);
add(7);
add(11);
}},
new ArrayList<String>() {{
add("rock");
add("scissors");
add("paper");
}}
);

public static Coffee COFFEE2 =
create(
"colombian caffee",
Expand All @@ -144,5 +172,16 @@ public static String ddl() {
new java.sql.Timestamp(1488300723000L),
Optional.empty(),
UUID.fromString("123e4567-e89b-a456-12d3-426655440000"),
2L);
2L,
new ArrayList<Integer>() {{
add(7);
add(11);
add(23);
}},
new ArrayList<String>() {{
add("scissors");
add("paper");
add("rock");
}}
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
Expand All @@ -42,9 +45,11 @@
import org.apache.avro.file.DataFileReader;
import org.apache.avro.file.DataFileWriter;
import org.apache.avro.file.SeekableByteArrayInput;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericDatumReader;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.util.Utf8;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
Expand All @@ -62,7 +67,7 @@ public static void beforeAll() throws SQLException, ClassNotFoundException {

@Test
public void shouldCreateSchema() throws ClassNotFoundException, SQLException {
final int fieldCount = 12;
final int fieldCount = 14;
final Schema actual =
JdbcAvroSchema.createSchemaByReadingOneRow(
DbTestHelper.createConnection(CONNECTION_URL),
Expand Down Expand Up @@ -92,7 +97,9 @@ public void shouldCreateSchema() throws ClassNotFoundException, SQLException {
"CREATED",
"UPDATED",
"UID",
"ROWNUM"),
"ROWNUM",
"INT_ARR",
"TEXT_ARR"),
actual.getFields().stream().map(Schema.Field::name).collect(Collectors.toList()));
for (Schema.Field f : actual.getFields()) {
Assert.assertEquals(Schema.Type.UNION, f.schema().getType());
Expand Down Expand Up @@ -128,7 +135,7 @@ public void shouldCreateSchema() throws ClassNotFoundException, SQLException {

@Test
public void shouldCreateSchemaWithLogicalTypes() throws ClassNotFoundException, SQLException {
final int fieldCount = 12;
final int fieldCount = 14;
final Schema actual =
JdbcAvroSchema.createSchemaByReadingOneRow(
DbTestHelper.createConnection(CONNECTION_URL),
Expand Down Expand Up @@ -163,8 +170,10 @@ public void shouldEncodeResultSetToValidAvro()
throws ClassNotFoundException, SQLException, IOException {
final ResultSet rs =
DbTestHelper.createConnection(CONNECTION_URL)
.createStatement()
.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_READ_ONLY)
.executeQuery("SELECT * FROM COFFEES");

rs.first();
final Schema schema =
JdbcAvroSchema.createAvroSchema(
rs, "dbeam_generated", "connection", Optional.empty(), "doc", false);
Expand All @@ -173,6 +182,7 @@ public void shouldEncodeResultSetToValidAvro()
new DataFileWriter<>(new GenericDatumWriter<>(schema));
final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
dataFileWriter.create(schema, outputStream);
rs.previous();
// convert and write
while (rs.next()) {
dataFileWriter.appendEncoded(converter.convertResultSetIntoAvroBytes());
Expand All @@ -194,8 +204,11 @@ public void shouldEncodeResultSetToValidAvro()
.findFirst()
.orElseThrow(() -> new IllegalArgumentException("not found"));

Assert.assertEquals(12, record.getSchema().getFields().size());
Assert.assertEquals(14, record.getSchema().getFields().size());
Assert.assertEquals(schema, record.getSchema());
List<String> actualTxtArray =
((GenericData.Array<Utf8>) record.get(13)).stream().map(x -> x.toString()).collect(
Collectors.toList());
final Coffee actual =
Coffee.create(
record.get(0).toString(),
Expand All @@ -209,7 +222,9 @@ public void shouldEncodeResultSetToValidAvro()
new java.sql.Timestamp((Long) record.get(8)),
Optional.ofNullable((Long) record.get(9)).map(Timestamp::new),
TestHelper.byteBufferToUuid((ByteBuffer) record.get(10)),
(Long) record.get(11));
(Long) record.get(11),
new ArrayList<>((GenericData.Array<Integer>) record.get(12)),
actualTxtArray);
Assert.assertEquals(Coffee.COFFEE1, actual);
}

Expand Down
Loading

0 comments on commit addacc3

Please sign in to comment.