diff --git a/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java b/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java index 284ad28a..c9ddc63a 100644 --- a/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java +++ b/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java @@ -17,6 +17,7 @@ package com.mongodb.spark.sql.connector; +import static java.time.ZoneOffset.UTC; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static org.junit.jupiter.api.Assertions.assertIterableEquals; @@ -30,7 +31,7 @@ import java.sql.Timestamp; import java.time.Instant; import java.time.LocalDate; -import java.time.ZoneOffset; +import java.time.LocalDateTime; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -41,6 +42,8 @@ import org.apache.spark.sql.Encoders; import org.apache.spark.sql.SparkSession; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; public class RoundTripTest extends MongoSparkConnectorTestCase { @@ -92,37 +95,45 @@ void testBoxedBean() { assertIterableEquals(dataSetOriginal, dataSetMongo); } - @Test - void testDateTimeBean() { + @ParameterizedTest() + @ValueSource(strings = {"true", "false"}) + void testDateTimeBean(String java8DateTimeAPI) { TimeZone original = TimeZone.getDefault(); try { - TimeZone.setDefault(TimeZone.getTimeZone(ZoneOffset.UTC)); + TimeZone.setDefault(TimeZone.getTimeZone(UTC)); // Given long oneHour = TimeUnit.MILLISECONDS.convert(1, TimeUnit.HOURS); long oneDay = oneHour * 24; - List dataSetOriginal = singletonList(new DateTimeBean( - new Date(oneDay * 365), - new Timestamp(oneDay + oneHour), - LocalDate.of(2000, 1, 1), - Instant.EPOCH)); + Instant epoch = Instant.EPOCH; + List dataSetOriginal = + singletonList( + new DateTimeBean( + new Date(oneDay * 365), + new Timestamp(oneDay + oneHour), + LocalDate.of(2000, 1, 1), + epoch, + LocalDateTime.ofInstant(epoch, UTC))); // when - SparkSession spark = getOrCreateSparkSession(); + SparkSession spark = + getOrCreateSparkSession( + getSparkConf().set("spark.sql.datetime.java8API.enabled", java8DateTimeAPI)); Encoder encoder = Encoders.bean(DateTimeBean.class); Dataset dataset = spark.createDataset(dataSetOriginal, encoder); dataset.write().format("mongodb").mode("Overwrite").save(); // Then - List dataSetMongo = spark - .read() - .format("mongodb") - .schema(encoder.schema()) - .load() - .as(encoder) - .collectAsList(); + List dataSetMongo = + spark + .read() + .format("mongodb") + .schema(encoder.schema()) + .load() + .as(encoder) + .collectAsList(); assertIterableEquals(dataSetOriginal, dataSetMongo); } finally { TimeZone.setDefault(original); diff --git a/src/integrationTest/java/com/mongodb/spark/sql/connector/beans/DateTimeBean.java b/src/integrationTest/java/com/mongodb/spark/sql/connector/beans/DateTimeBean.java index 5b4ff473..7ae42abf 100644 --- a/src/integrationTest/java/com/mongodb/spark/sql/connector/beans/DateTimeBean.java +++ b/src/integrationTest/java/com/mongodb/spark/sql/connector/beans/DateTimeBean.java @@ -21,6 +21,7 @@ import java.sql.Timestamp; import java.time.Instant; import java.time.LocalDate; +import java.time.LocalDateTime; import java.util.Objects; public class DateTimeBean implements Serializable { @@ -28,6 +29,7 @@ public class DateTimeBean implements Serializable { private java.sql.Timestamp sqlTimestamp; private java.time.LocalDate localDate; private java.time.Instant instant; + private java.time.LocalDateTime localDateTime; public DateTimeBean() {} @@ -35,10 +37,13 @@ public DateTimeBean( final Date sqlDate, final Timestamp sqlTimestamp, final LocalDate localDate, - final Instant instant) { + final Instant instant, + final LocalDateTime localDateTime + ) { this.sqlDate = sqlDate; this.sqlTimestamp = sqlTimestamp; this.localDate = localDate; + this.localDateTime = localDateTime; this.instant = instant; } @@ -66,6 +71,14 @@ public void setLocalDate(final LocalDate localDate) { this.localDate = localDate; } + public LocalDateTime getLocalDateTime() { + return localDateTime; + } + + public void setLocalDateTime(final LocalDateTime localDateTime) { + this.localDateTime = localDateTime; + } + public Instant getInstant() { return instant; } @@ -76,30 +89,39 @@ public void setInstant(final Instant instant) { @Override public boolean equals(final Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; DateTimeBean that = (DateTimeBean) o; return Objects.equals(sqlDate, that.sqlDate) && Objects.equals(sqlTimestamp, that.sqlTimestamp) && Objects.equals(localDate, that.localDate) + && Objects.equals(localDateTime, that.localDateTime) && Objects.equals(instant, that.instant); } @Override public int hashCode() { - return Objects.hash(sqlDate, sqlTimestamp, localDate, instant); + return Objects.hash( + sqlDate, + sqlTimestamp, + localDate, + localDateTime, + instant); } @Override public String toString() { - return "DateTimeBean{" + "sqlDate=" - + sqlDate + ", sqlTimestamp=" - + sqlTimestamp + ", localDate=" - + localDate + ", instant=" - + instant + '}'; + return "DateTimeBean{" + + "sqlDate=" + + sqlDate + + ", sqlTimestamp=" + + sqlTimestamp + + ", localDate=" + + localDate + + ", localDateTime=" + + localDateTime + + ", instant=" + + instant + + '}'; } } diff --git a/src/main/java/com/mongodb/spark/sql/connector/schema/BsonDocumentToRowConverter.java b/src/main/java/com/mongodb/spark/sql/connector/schema/BsonDocumentToRowConverter.java index a97b720a..e0ab78eb 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/schema/BsonDocumentToRowConverter.java +++ b/src/main/java/com/mongodb/spark/sql/connector/schema/BsonDocumentToRowConverter.java @@ -44,6 +44,7 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; +import org.apache.spark.sql.internal.SqlApiConf; import org.apache.spark.sql.types.ArrayType; import org.apache.spark.sql.types.BinaryType; import org.apache.spark.sql.types.BooleanType; @@ -62,6 +63,7 @@ import org.apache.spark.sql.types.StringType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampNTZType; import org.apache.spark.sql.types.TimestampType; import org.bson.BsonArray; import org.bson.BsonBinaryWriter; @@ -76,6 +78,8 @@ import org.bson.types.Decimal128; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * The helper for conversion of BsonDocuments to GenericRowWithSchema instances. @@ -89,6 +93,7 @@ @NotNull public final class BsonDocumentToRowConverter implements Serializable { private static final long serialVersionUID = 1L; + private static final Logger log = LoggerFactory.getLogger(BsonDocumentToRowConverter.class); private final Function rowToInternalRowFunction; private final StructType schema; private final boolean outputExtendedJson; @@ -96,6 +101,7 @@ public final class BsonDocumentToRowConverter implements Serializable { private final boolean dropMalformed; private final String columnNameOfCorruptRecord; private final boolean schemaContainsCorruptRecordColumn; + private final boolean dataTimeJava8APIEnabled; private boolean corruptedRecord; @@ -114,6 +120,7 @@ public BsonDocumentToRowConverter(final StructType originalSchema, final ReadCon this.columnNameOfCorruptRecord = readConfig.getColumnNameOfCorruptRecord(); this.schemaContainsCorruptRecordColumn = !columnNameOfCorruptRecord.isEmpty() && Arrays.asList(schema.fieldNames()).contains(columnNameOfCorruptRecord); + this.dataTimeJava8APIEnabled = SqlApiConf.get().datetimeJava8ApiEnabled(); } /** @return the schema for the converter */ @@ -165,6 +172,7 @@ GenericRowWithSchema toRow(final BsonDocument bsonDocument) { @VisibleForTesting Object convertBsonValue( final String fieldName, final DataType dataType, final BsonValue bsonValue) { + log.info("converting bson to value: {} {} {}", fieldName, dataType, bsonValue); try { if (bsonValue.isNull()) { return null; @@ -179,9 +187,22 @@ Object convertBsonValue( } else if (dataType instanceof BooleanType) { return convertToBoolean(fieldName, dataType, bsonValue); } else if (dataType instanceof DateType) { - return convertToDate(fieldName, dataType, bsonValue); + Date date = convertToDate(fieldName, dataType, bsonValue); + if (dataTimeJava8APIEnabled) { + return date.toLocalDate(); + } else { + return date; + } } else if (dataType instanceof TimestampType) { - return convertToTimestamp(fieldName, dataType, bsonValue); + Timestamp timestamp = convertToTimestamp(fieldName, dataType, bsonValue); + if (dataTimeJava8APIEnabled) { + return timestamp.toInstant(); + } else { + return timestamp; + } + } else if (dataType instanceof TimestampNTZType) { + Timestamp timestamp = convertToTimestamp(fieldName, dataType, bsonValue); + return timestamp.toLocalDateTime(); } else if (dataType instanceof FloatType) { return convertToFloat(fieldName, dataType, bsonValue); } else if (dataType instanceof IntegerType) { diff --git a/src/main/java/com/mongodb/spark/sql/connector/schema/RowToBsonDocumentConverter.java b/src/main/java/com/mongodb/spark/sql/connector/schema/RowToBsonDocumentConverter.java index d1378515..53fd6e4f 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/schema/RowToBsonDocumentConverter.java +++ b/src/main/java/com/mongodb/spark/sql/connector/schema/RowToBsonDocumentConverter.java @@ -26,6 +26,10 @@ import com.mongodb.spark.sql.connector.interop.JavaScala; import java.io.Serializable; import java.math.BigDecimal; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; import java.util.Arrays; import java.util.Date; import java.util.List; @@ -139,7 +143,7 @@ public static ObjectToBsonValue createObjectToBsonValue( } catch (Exception e) { throw new DataException(format( "Cannot cast %s into a BsonValue. %s has no matching BsonValue. Error: %s", - data, dataType, e.getMessage())); + data, dataType, e.getMessage()), e); } }; } @@ -177,14 +181,33 @@ private static ObjectToBsonValue objectToBsonValue( } else if (DataTypes.StringType.acceptsType(dataType)) { return (data) -> processString((String) data, convertJson); } else if (DataTypes.DateType.acceptsType(dataType) - || DataTypes.TimestampType.acceptsType(dataType)) { + || DataTypes.TimestampType.acceptsType(dataType) + || DataTypes.TimestampNTZType.acceptsType(dataType) + ) { return (data) -> { if (data instanceof Date) { // Covers java.util.Date, java.sql.Date, java.sql.Timestamp return new BsonDateTime(((Date) data).getTime()); } - throw new MongoSparkException( - "Unsupported date type: " + data.getClass().getSimpleName()); + if (data instanceof Instant) { + return new BsonDateTime(((Instant) data).toEpochMilli()); + } + if (data instanceof LocalDateTime) { + LocalDateTime dateTime = (LocalDateTime) data; + return new BsonDateTime(Timestamp.valueOf(dateTime).getTime()); + } + if (data instanceof LocalDate) { + long epochSeconds = ((LocalDate) data).toEpochDay() * 24L * 3600L; + return new BsonDateTime(epochSeconds * 1000L); + } + + /* + NOTE 1: ZonedDateTime, OffsetDateTime, OffsetTime are not explicitly supported by Spark and causing Encoder resolver to fail due to + cyclic dependency in the ZoneOffset. Subject for review after it changes (if ever). + NOTE 2: LocalTime type is not represented neither in Bson nor in Spark + */ + + throw new MongoSparkException("Unsupported date type: " + data.getClass().getSimpleName()); }; } else if (DataTypes.NullType.acceptsType(dataType)) { return (data) -> BsonNull.VALUE; @@ -259,7 +282,7 @@ private static ObjectToBsonValue objectToBsonValue( }; } - private static BsonDocument rowToBsonDocument( + private static BsonDocument rowToBsonDocument( final Row row, final List objectToBsonElements, final boolean ignoreNulls) {