Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for spark.sql.datetime.java8API.enabled #125

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package com.mongodb.spark.sql.connector;

import static java.time.ZoneOffset.UTC;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
Expand All @@ -30,7 +31,7 @@
import java.sql.Timestamp;
import java.time.Instant;
import java.time.LocalDate;
import java.time.ZoneOffset;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -41,6 +42,8 @@
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

public class RoundTripTest extends MongoSparkConnectorTestCase {

Expand Down Expand Up @@ -92,37 +95,45 @@ void testBoxedBean() {
assertIterableEquals(dataSetOriginal, dataSetMongo);
}

@Test
void testDateTimeBean() {
@ParameterizedTest()
@ValueSource(strings = {"true", "false"})
void testDateTimeBean(String java8DateTimeAPI) {
TimeZone original = TimeZone.getDefault();
try {
TimeZone.setDefault(TimeZone.getTimeZone(ZoneOffset.UTC));
TimeZone.setDefault(TimeZone.getTimeZone(UTC));

// Given
long oneHour = TimeUnit.MILLISECONDS.convert(1, TimeUnit.HOURS);
long oneDay = oneHour * 24;

List<DateTimeBean> dataSetOriginal = singletonList(new DateTimeBean(
new Date(oneDay * 365),
new Timestamp(oneDay + oneHour),
LocalDate.of(2000, 1, 1),
Instant.EPOCH));
Instant epoch = Instant.EPOCH;
List<DateTimeBean> dataSetOriginal =
singletonList(
new DateTimeBean(
new Date(oneDay * 365),
new Timestamp(oneDay + oneHour),
LocalDate.of(2000, 1, 1),
epoch,
LocalDateTime.ofInstant(epoch, UTC)));

// when
SparkSession spark = getOrCreateSparkSession();
SparkSession spark =
getOrCreateSparkSession(
getSparkConf().set("spark.sql.datetime.java8API.enabled", java8DateTimeAPI));
Encoder<DateTimeBean> encoder = Encoders.bean(DateTimeBean.class);

Dataset<DateTimeBean> dataset = spark.createDataset(dataSetOriginal, encoder);
dataset.write().format("mongodb").mode("Overwrite").save();

// Then
List<DateTimeBean> dataSetMongo = spark
.read()
.format("mongodb")
.schema(encoder.schema())
.load()
.as(encoder)
.collectAsList();
List<DateTimeBean> dataSetMongo =
spark
.read()
.format("mongodb")
.schema(encoder.schema())
.load()
.as(encoder)
.collectAsList();
assertIterableEquals(dataSetOriginal, dataSetMongo);
} finally {
TimeZone.setDefault(original);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,29 @@
import java.sql.Timestamp;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.Objects;

public class DateTimeBean implements Serializable {
private java.sql.Date sqlDate;
private java.sql.Timestamp sqlTimestamp;
private java.time.LocalDate localDate;
private java.time.Instant instant;
private java.time.LocalDateTime localDateTime;

public DateTimeBean() {}

public DateTimeBean(
final Date sqlDate,
final Timestamp sqlTimestamp,
final LocalDate localDate,
final Instant instant) {
final Instant instant,
final LocalDateTime localDateTime
) {
this.sqlDate = sqlDate;
this.sqlTimestamp = sqlTimestamp;
this.localDate = localDate;
this.localDateTime = localDateTime;
this.instant = instant;
}

Expand Down Expand Up @@ -66,6 +71,14 @@ public void setLocalDate(final LocalDate localDate) {
this.localDate = localDate;
}

public LocalDateTime getLocalDateTime() {
return localDateTime;
}

public void setLocalDateTime(final LocalDateTime localDateTime) {
this.localDateTime = localDateTime;
}

public Instant getInstant() {
return instant;
}
Expand All @@ -76,30 +89,39 @@ public void setInstant(final Instant instant) {

@Override
public boolean equals(final Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
DateTimeBean that = (DateTimeBean) o;
return Objects.equals(sqlDate, that.sqlDate)
&& Objects.equals(sqlTimestamp, that.sqlTimestamp)
&& Objects.equals(localDate, that.localDate)
&& Objects.equals(localDateTime, that.localDateTime)
&& Objects.equals(instant, that.instant);
}

@Override
public int hashCode() {
return Objects.hash(sqlDate, sqlTimestamp, localDate, instant);
return Objects.hash(
sqlDate,
sqlTimestamp,
localDate,
localDateTime,
instant);
}

@Override
public String toString() {
return "DateTimeBean{" + "sqlDate="
+ sqlDate + ", sqlTimestamp="
+ sqlTimestamp + ", localDate="
+ localDate + ", instant="
+ instant + '}';
return "DateTimeBean{"
+ "sqlDate="
+ sqlDate
+ ", sqlTimestamp="
+ sqlTimestamp
+ ", localDate="
+ localDate
+ ", localDateTime="
+ localDateTime
+ ", instant="
+ instant
+ '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.apache.spark.sql.internal.SqlApiConf;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.BinaryType;
import org.apache.spark.sql.types.BooleanType;
Expand All @@ -62,6 +63,7 @@
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.TimestampNTZType;
import org.apache.spark.sql.types.TimestampType;
import org.bson.BsonArray;
import org.bson.BsonBinaryWriter;
Expand All @@ -76,6 +78,8 @@
import org.bson.types.Decimal128;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* The helper for conversion of BsonDocuments to GenericRowWithSchema instances.
Expand All @@ -89,13 +93,15 @@
@NotNull
public final class BsonDocumentToRowConverter implements Serializable {
private static final long serialVersionUID = 1L;
private static final Logger log = LoggerFactory.getLogger(BsonDocumentToRowConverter.class);
private final Function<Row, InternalRow> rowToInternalRowFunction;
private final StructType schema;
private final boolean outputExtendedJson;
private final boolean isPermissive;
private final boolean dropMalformed;
private final String columnNameOfCorruptRecord;
private final boolean schemaContainsCorruptRecordColumn;
private final boolean dataTimeJava8APIEnabled;

private boolean corruptedRecord;

Expand All @@ -114,6 +120,7 @@ public BsonDocumentToRowConverter(final StructType originalSchema, final ReadCon
this.columnNameOfCorruptRecord = readConfig.getColumnNameOfCorruptRecord();
this.schemaContainsCorruptRecordColumn = !columnNameOfCorruptRecord.isEmpty()
&& Arrays.asList(schema.fieldNames()).contains(columnNameOfCorruptRecord);
this.dataTimeJava8APIEnabled = SqlApiConf.get().datetimeJava8ApiEnabled();
}

/** @return the schema for the converter */
Expand Down Expand Up @@ -165,6 +172,7 @@ GenericRowWithSchema toRow(final BsonDocument bsonDocument) {
@VisibleForTesting
Object convertBsonValue(
final String fieldName, final DataType dataType, final BsonValue bsonValue) {
log.info("converting bson to value: {} {} {}", fieldName, dataType, bsonValue);
try {
if (bsonValue.isNull()) {
return null;
Expand All @@ -179,9 +187,22 @@ Object convertBsonValue(
} else if (dataType instanceof BooleanType) {
return convertToBoolean(fieldName, dataType, bsonValue);
} else if (dataType instanceof DateType) {
return convertToDate(fieldName, dataType, bsonValue);
Date date = convertToDate(fieldName, dataType, bsonValue);
if (dataTimeJava8APIEnabled) {
return date.toLocalDate();
} else {
return date;
}
} else if (dataType instanceof TimestampType) {
return convertToTimestamp(fieldName, dataType, bsonValue);
Timestamp timestamp = convertToTimestamp(fieldName, dataType, bsonValue);
if (dataTimeJava8APIEnabled) {
return timestamp.toInstant();
} else {
return timestamp;
}
} else if (dataType instanceof TimestampNTZType) {
Timestamp timestamp = convertToTimestamp(fieldName, dataType, bsonValue);
return timestamp.toLocalDateTime();
} else if (dataType instanceof FloatType) {
return convertToFloat(fieldName, dataType, bsonValue);
} else if (dataType instanceof IntegerType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
import com.mongodb.spark.sql.connector.interop.JavaScala;
import java.io.Serializable;
import java.math.BigDecimal;
import java.sql.Timestamp;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
Expand Down Expand Up @@ -139,7 +143,7 @@ public static ObjectToBsonValue createObjectToBsonValue(
} catch (Exception e) {
throw new DataException(format(
"Cannot cast %s into a BsonValue. %s has no matching BsonValue. Error: %s",
data, dataType, e.getMessage()));
data, dataType, e.getMessage()), e);
}
};
}
Expand Down Expand Up @@ -177,14 +181,33 @@ private static ObjectToBsonValue objectToBsonValue(
} else if (DataTypes.StringType.acceptsType(dataType)) {
return (data) -> processString((String) data, convertJson);
} else if (DataTypes.DateType.acceptsType(dataType)
|| DataTypes.TimestampType.acceptsType(dataType)) {
|| DataTypes.TimestampType.acceptsType(dataType)
|| DataTypes.TimestampNTZType.acceptsType(dataType)
) {
return (data) -> {
if (data instanceof Date) {
// Covers java.util.Date, java.sql.Date, java.sql.Timestamp
return new BsonDateTime(((Date) data).getTime());
}
throw new MongoSparkException(
"Unsupported date type: " + data.getClass().getSimpleName());
if (data instanceof Instant) {
return new BsonDateTime(((Instant) data).toEpochMilli());
}
if (data instanceof LocalDateTime) {
LocalDateTime dateTime = (LocalDateTime) data;
return new BsonDateTime(Timestamp.valueOf(dateTime).getTime());
}
if (data instanceof LocalDate) {
long epochSeconds = ((LocalDate) data).toEpochDay() * 24L * 3600L;
return new BsonDateTime(epochSeconds * 1000L);
}

/*
NOTE 1: ZonedDateTime, OffsetDateTime, OffsetTime are not explicitly supported by Spark and causing Encoder resolver to fail due to
cyclic dependency in the ZoneOffset. Subject for review after it changes (if ever).
NOTE 2: LocalTime type is not represented neither in Bson nor in Spark
*/

throw new MongoSparkException("Unsupported date type: " + data.getClass().getSimpleName());
};
} else if (DataTypes.NullType.acceptsType(dataType)) {
return (data) -> BsonNull.VALUE;
Expand Down Expand Up @@ -259,7 +282,7 @@ private static ObjectToBsonValue objectToBsonValue(
};
}

private static BsonDocument rowToBsonDocument(
private static BsonDocument rowToBsonDocument(
final Row row,
final List<ObjectToBsonElement> objectToBsonElements,
final boolean ignoreNulls) {
Expand Down