Skip to content

Commit

Permalink
Implement ParquetReader
Browse files Browse the repository at this point in the history
  • Loading branch information
grouzen committed Dec 10, 2023
1 parent ec96f42 commit eaf9be3
Show file tree
Hide file tree
Showing 9 changed files with 316 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ object Value {

}

case class ByteArrayValue(value: Binary) extends PrimitiveValue[Binary] {
case class BinaryValue(value: Binary) extends PrimitiveValue[Binary] {

override def write(schema: Type, recordConsumer: RecordConsumer): Unit =
recordConsumer.addBinary(value)
Expand All @@ -68,11 +68,15 @@ object Value {

}

sealed trait GroupValue extends Value
sealed trait GroupValue[Self <: GroupValue[Self]] extends Value {

def put(name: String, value: Value): Self

}

object GroupValue {

case class RecordValue(values: Map[String, Value]) extends GroupValue {
case class RecordValue(values: Map[String, Value]) extends GroupValue[RecordValue] {

override def write(schema: Type, recordConsumer: RecordConsumer): Unit = {
val groupSchema = schema.asGroupType()
Expand All @@ -91,9 +95,15 @@ object Value {
recordConsumer.endGroup()
}

override def put(name: String, value: Value): RecordValue =
if (values.contains(name))
this.copy(values.updated(name, value))
else
throw new IllegalArgumentException(s"Record doesn't contain field $name")

}

case class ListValue(values: Chunk[Value]) extends GroupValue {
case class ListValue(values: Chunk[Value]) extends GroupValue[ListValue] {

override def write(schema: Type, recordConsumer: RecordConsumer): Unit = {
recordConsumer.startGroup()
Expand All @@ -102,7 +112,7 @@ object Value {
val groupSchema = schema.asGroupType()
val listSchema = groupSchema.getFields.get(0).asGroupType()
val listFieldName = listSchema.getName
val elementName = listSchema.getFields.get(0).getName
val elementName = listSchema.getFields.get(0).getName // TODO: validate, must be "element"
val listIndex = groupSchema.getFieldIndex(listFieldName)

recordConsumer.startField(listFieldName, listIndex)
Expand All @@ -117,9 +127,12 @@ object Value {
recordConsumer.endGroup()
}

override def put(name: String, value: Value): ListValue =
this.copy(values = values :+ value)

}

case class MapValue(values: Map[Value, Value]) extends GroupValue {
case class MapValue(values: Map[Value, Value]) extends GroupValue[MapValue] {

override def write(schema: Type, recordConsumer: RecordConsumer): Unit = {
recordConsumer.startGroup()
Expand All @@ -142,6 +155,8 @@ object Value {
recordConsumer.endGroup()
}

override def put(name: String, value: Value): MapValue = ???
// this.copy(values = values.updated(name, value))
}

}
Expand All @@ -150,7 +165,7 @@ object Value {
NullValue

def string(v: String) =
PrimitiveValue.ByteArrayValue(Binary.fromString(v))
PrimitiveValue.BinaryValue(Binary.fromString(v))

def boolean(v: Boolean) =
PrimitiveValue.BooleanValue(v)
Expand All @@ -171,7 +186,7 @@ object Value {
PrimitiveValue.DoubleValue(v)

def binary(v: Chunk[Byte]) =
PrimitiveValue.ByteArrayValue(Binary.fromConstantByteArray(v.toArray))
PrimitiveValue.BinaryValue(Binary.fromConstantByteArray(v.toArray))

def char(v: Char) =
PrimitiveValue.Int32Value(v.toInt)
Expand All @@ -182,7 +197,7 @@ object Value {
bb.putLong(v.getMostSignificantBits)
bb.putLong(v.getLeastSignificantBits)

PrimitiveValue.ByteArrayValue(Binary.fromConstantByteArray(bb.array()))
PrimitiveValue.BinaryValue(Binary.fromConstantByteArray(bb.array()))
}

def record(r: Map[String, Value]) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object ValueDecoderDeriver {
): ValueDecoder[A] = new ValueDecoder[A] {
override def decode(value: Value): A =
(st, value) match {
case (StandardType.StringType, PrimitiveValue.ByteArrayValue(v)) =>
case (StandardType.StringType, PrimitiveValue.BinaryValue(v)) =>
new String(v.getBytes, StandardCharsets.UTF_8)
case (StandardType.BoolType, PrimitiveValue.BooleanValue(v)) =>
v
Expand All @@ -71,11 +71,11 @@ object ValueDecoderDeriver {
v
case (StandardType.DoubleType, PrimitiveValue.DoubleValue(v)) =>
v
case (StandardType.BinaryType, PrimitiveValue.ByteArrayValue(v)) =>
case (StandardType.BinaryType, PrimitiveValue.BinaryValue(v)) =>
Chunk.fromArray(v.getBytes)
case (StandardType.CharType, PrimitiveValue.Int32Value(v)) =>
v.toChar
case (StandardType.UUIDType, PrimitiveValue.ByteArrayValue(v)) =>
case (StandardType.UUIDType, PrimitiveValue.BinaryValue(v)) =>
val bb = ByteBuffer.wrap(v.getBytes)
new UUID(bb.getLong, bb.getLong)
case (other, _) =>
Expand Down Expand Up @@ -135,6 +135,7 @@ object ValueDecoderDeriver {
fields: => Chunk[Deriver.WrappedF[ValueDecoder, _]],
summoned: => Option[ValueDecoder[B]]
): ValueDecoder[B] = ???

}.cached

def summoned: Deriver[ValueDecoder] =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package me.mnedokushev.zio.apache.parquet.core.hadoop

import me.mnedokushev.zio.apache.parquet.core.Value
import me.mnedokushev.zio.apache.parquet.core.Value.{ GroupValue, PrimitiveValue }
import org.apache.parquet.io.api.{ Binary, Converter, GroupConverter, PrimitiveConverter }
import org.apache.parquet.schema.{ GroupType, LogicalTypeAnnotation, Type }
import zio.Chunk

import scala.jdk.CollectionConverters._

abstract class GroupValueConverter[V <: GroupValue[V]](schema: GroupType) extends GroupConverter { parent =>

def get: V =
this.groupValue

def put(name: String, value: Value): Unit =
this.groupValue = this.groupValue.put(name, value)

protected var groupValue: V = _

private val converters: Chunk[Converter] =
Chunk.fromIterable(schema.getFields.asScala.toList.map(fromSchema))

private def fromSchema(schema0: Type) = {
val name = schema0.getName

schema0.getLogicalTypeAnnotation match {
case _ if schema0.isPrimitive =>
primitive(name)
case _: LogicalTypeAnnotation.ListLogicalTypeAnnotation =>
GroupValueConverter.list(schema0.asGroupType(), name, parent)
case _: LogicalTypeAnnotation.MapLogicalTypeAnnotation =>
GroupValueConverter.map(schema0.asGroupType(), name, parent)
case _ =>
GroupValueConverter.record(schema0.asGroupType(), name, parent)
}
}

override def getConverter(fieldIndex: Int): Converter =
converters(fieldIndex)

private def primitive(name: String) =
new PrimitiveConverter {

override def addBinary(value: Binary): Unit =
parent.groupValue = parent.groupValue.put(name, PrimitiveValue.BinaryValue(value))

override def addBoolean(value: Boolean): Unit =
parent.groupValue = parent.groupValue.put(name, PrimitiveValue.BooleanValue(value))

override def addDouble(value: Double): Unit =
parent.groupValue = parent.groupValue.put(name, PrimitiveValue.DoubleValue(value))

override def addFloat(value: Float): Unit =
parent.groupValue = parent.groupValue.put(name, PrimitiveValue.FloatValue(value))

override def addInt(value: Int): Unit =
parent.groupValue = parent.groupValue.put(name, PrimitiveValue.Int32Value(value))

override def addLong(value: Long): Unit =
parent.groupValue = parent.groupValue.put(name, PrimitiveValue.Int64Value(value))

}

}

object GroupValueConverter {

def root(schema: GroupType): GroupValueConverter[GroupValue.RecordValue] =
new GroupValueConverter[GroupValue.RecordValue](schema) {

override def start(): Unit =
this.groupValue = Value.record(
schema.getFields.asScala.toList.map(_.getName -> Value.nil).toMap
)

override def end(): Unit = ()
}

def record[V <: GroupValue[V]](
schema: GroupType,
name: String,
parent: GroupValueConverter[V]
): GroupValueConverter[GroupValue.RecordValue] =
new GroupValueConverter[GroupValue.RecordValue](schema) {

override def start(): Unit =
this.groupValue = Value.record(Map.empty)

override def end(): Unit =
parent.put(name, this.groupValue)

}

def list[V <: GroupValue[V]](
schema: GroupType,
name: String,
parent: GroupValueConverter[V]
): GroupValueConverter[GroupValue.ListValue] =
new GroupValueConverter[GroupValue.ListValue](schema) {

override def start(): Unit =
this.groupValue = Value.list(Chunk.empty)

override def end(): Unit =
parent.put(name, this.groupValue)
}

def map[V <: GroupValue[V]](
schema: GroupType,
name: String,
parent: GroupValueConverter[V]
): GroupValueConverter[GroupValue.MapValue] =
new GroupValueConverter[GroupValue.MapValue](schema) {

override def start(): Unit =
this.groupValue = Value.map(Map.empty)

override def end(): Unit =
parent.put(name, this.groupValue)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package me.mnedokushev.zio.apache.parquet.core.hadoop

import me.mnedokushev.zio.apache.parquet.core.Value.GroupValue.RecordValue
import me.mnedokushev.zio.apache.parquet.core.codec.ValueDecoder
import org.apache.hadoop.conf.Configuration
import org.apache.parquet.hadoop.{ ParquetReader => HadoopParquetReader }
import org.apache.parquet.hadoop.api.{ ReadSupport => HadoopReadSupport }
import org.apache.parquet.io.InputFile
import zio._
import zio.stream._

trait ParquetReader[A <: Product] {

def read(path: Path): ZStream[Scope, Throwable, A]

}

final class ParquetReaderLive[A <: Product](conf: Configuration)(implicit decoder: ValueDecoder[A])
extends ParquetReader[A] {

override def read(path: Path): ZStream[Scope, Throwable, A] =
for {
inputFile <- ZStream.fromZIO(ZIO.attemptBlockingIO(path.toInputFile(conf)))
reader <- ZStream.fromZIO(
ZIO.fromAutoCloseable(
ZIO.attemptBlockingIO(
new ParquetReader.Builder(inputFile).withConf(conf).build()
)
)
)
value <- ZStream.repeatZIOOption(
ZIO
.attemptBlockingIO(reader.read())
.asSomeError
.filterOrFail(_ != null)(None)
.flatMap(decoder.decodeZIO(_).asSomeError)
)
} yield value

}

object ParquetReader {

final class Builder(file: InputFile) extends HadoopParquetReader.Builder[RecordValue](file) {

override def getReadSupport: HadoopReadSupport[RecordValue] =
new ReadSupport

}

def configured[A <: Product: ValueDecoder: Tag](
hadoopConf: Configuration = new Configuration()
): ULayer[ParquetReader[A]] =
ZLayer.succeed(new ParquetReaderLive[A](hadoopConf))

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ trait ParquetWriter[A <: Product] {

def write(data: Chunk[A]): Task[Unit]

def close: Task[Unit]

}

final class ParquetWriterLive[A <: Product](
Expand All @@ -24,9 +26,15 @@ final class ParquetWriterLive[A <: Product](
extends ParquetWriter[A] {

override def write(data: Chunk[A]): Task[Unit] =
ZIO.attemptBlocking(
data.foreach(v => underlying.write(encoder.encode(v).asInstanceOf[RecordValue]))
)
ZIO.foreachDiscard(data) { value =>
for {
record <- encoder.encodeZIO(value)
_ <- ZIO.attemptBlockingIO(underlying.write(record.asInstanceOf[RecordValue]))
} yield ()
}

override def close: Task[Unit] =
ZIO.attemptBlockingIO(underlying.close())

}

Expand All @@ -42,7 +50,7 @@ object ParquetWriter {

}

def configured[A <: Product](
def configured[A <: Product: ValueEncoder](
path: Path,
writeMode: ParquetFileWriter.Mode = ParquetFileWriter.Mode.CREATE,
compressionCodecName: CompressionCodecName = HadoopParquetWriter.DEFAULT_COMPRESSION_CODEC_NAME,
Expand All @@ -56,11 +64,10 @@ object ParquetWriter {
)(implicit
schema: Schema[A],
schemaEncoder: SchemaEncoder[A],
encoder: ValueEncoder[A],
tag: Tag[A]
): TaskLayer[ParquetWriter[A]] = {

def castSchema(schema: Type) =
def castToMessageSchema(schema: Type) =
ZIO.attempt {
val groupSchema = schema.asGroupType()
val name = groupSchema.getName
Expand All @@ -72,7 +79,7 @@ object ParquetWriter {
ZLayer.scoped(
for {
schema <- schemaEncoder.encodeZIO(schema, tag.tag.shortName, optional = false)
messageSchema <- castSchema(schema)
messageSchema <- castToMessageSchema(schema)
hadoopFile <- ZIO.attemptBlockingIO(HadoopOutputFile.fromPath(path.toHadoop, hadoopConf))
builder = new Builder(hadoopFile, messageSchema)
.withWriteMode(writeMode)
Expand All @@ -84,7 +91,7 @@ object ParquetWriter {
.withRowGroupSize(rowGroupSize)
.withValidation(validationEnabled)
.withConf(hadoopConf)
underlying <- ZIO.fromAutoCloseable(ZIO.attemptBlocking(builder.build()))
underlying <- ZIO.fromAutoCloseable(ZIO.attemptBlockingIO(builder.build()))
writer = new ParquetWriterLive[A](underlying)
} yield writer
)
Expand Down
Loading

0 comments on commit eaf9be3

Please sign in to comment.