Skip to content

Commit

Permalink
FIx serialization issue
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed Feb 12, 2024
1 parent 222f761 commit 155c95a
Showing 1 changed file with 31 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import com.spotify.scio.parquet.avro.ParquetAvroIO.WriteParam._
import com.spotify.scio.parquet.read.ParquetReadConfiguration
import com.spotify.scio.parquet.{GcsConnectorUtil, ParquetConfiguration}
import com.spotify.scio.testing.TestDataManager
import com.spotify.scio.transforms._
import com.spotify.scio.transforms.DoFnWithResource.ResourceType
import com.spotify.scio.util.{FilenamePolicySupplier, ScioUtil}
import com.spotify.scio.values.SCollection
import org.apache.avro.Schema
Expand All @@ -55,6 +57,7 @@ import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

sealed trait ParquetAvroIO[T <: IndexedRecord] extends ScioIO[T] {
import ParquetAvroIO._

override type ReadP = ParquetAvroIO.ReadParam[T]
override type WriteP = ParquetAvroIO.WriteParam[T]
Expand Down Expand Up @@ -141,36 +144,29 @@ sealed trait ParquetAvroIO[T <: IndexedRecord] extends ScioIO[T] {
override protected def readTest(sc: ScioContext, params: ReadP): SCollection[T] = {
val datumFactory = Option(params.datumFactory).getOrElse(defaultDatumFactory)
implicit val coder: Coder[T] = avroCoder(datumFactory, schema)
// SpecificData.getForClass is only available for 1.9+
val recordClass = datumFactory.getType
val data = if (classOf[SpecificRecordBase].isAssignableFrom(recordClass)) {
val classModelField = recordClass.getDeclaredField("MODEL$")
classModelField.setAccessible(true)
classModelField.get(null).asInstanceOf[SpecificData]
} else {
SpecificData.get()
}

// The projection function is not part of the test input, so it must be applied directly
val projectedFields = Option(params.projection).map(_.getFields.asScala.map(_.name()).toSet)
TestDataManager
.getInput(sc.testId.get)(this)
.toSCollection(sc)
.map { record =>
projectedFields match {
case None => record
case Some(projection) =>
// beam forbids mutations. Create a new record
val copy = data.deepCopy(record.getSchema, record)
record.getSchema.getFields.asScala
.foldLeft(copy) { (c, f) =>
val names = Set(f.name()) ++ f.aliases().asScala.toSet
if (projection.intersect(names).isEmpty) {
// field is not part of the projection. user default value
c.put(f.pos(), data.getDefaultValue(f))
.mapWithResource(dataForClass(datumFactory.getType), ResourceType.PER_INSTANCE) {
case (data, record) =>
projectedFields match {
case None => record
case Some(projection) =>
// beam forbids mutations. Create a new record
val copy = data.deepCopy(record.getSchema, record)
record.getSchema.getFields.asScala
.foldLeft(copy) { (c, f) =>
val names = Set(f.name()) ++ f.aliases().asScala.toSet
if (projection.intersect(names).isEmpty) {
// field is not part of the projection. user default value
c.put(f.pos(), data.getDefaultValue(f))
}
c
}
c
}
}
}
}
}

Expand Down Expand Up @@ -235,6 +231,17 @@ sealed trait ParquetAvroIO[T <: IndexedRecord] extends ScioIO[T] {

object ParquetAvroIO {

// SpecificData.getForClass is only available for 1.9+
private def dataForClass[T](recordClass: Class[T]) = {
if (classOf[SpecificRecordBase].isAssignableFrom(recordClass)) {
val classModelField = recordClass.getDeclaredField("MODEL$")
classModelField.setAccessible(true)
classModelField.get(null).asInstanceOf[SpecificData]
} else {
SpecificData.get()
}
}

private class Identity[T](cls: Class[T])
extends SimpleFunction[T, T](SerializableFunctions.identity[T]) {
override def getInputTypeDescriptor: TypeDescriptor[T] = TypeDescriptor.of(cls)
Expand Down

0 comments on commit 155c95a

Please sign in to comment.