Skip to content

Commit

Permalink
Fix list and map converters
Browse files Browse the repository at this point in the history
  • Loading branch information
grouzen committed Dec 11, 2023
1 parent 12751a7 commit da4508a
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,16 @@ object Value {
}

override def put(name: String, value: Value): MapValue =
this.copy(values = Map(Value.string(name) -> value))
// this.copy(values = values.updated(name, value))
value match {
case RecordValue(values0) =>
(values0.get("key"), values0.get("value")) match {
case (Some(k), Some(v)) =>
this.copy(values = values.updated(k, v))
case _ => this
}
case mv: MapValue => mv
case _ => this
}
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ object ValueDecoderDeriver {
Unsafe.unsafe { implicit unsafe =>
record.construct(
Chunk
.fromIterable(values.values)
.fromIterable(record.fields.map(f => values(f.name)))
.zip(fields.map(_.unwrap))
.map { case (v, decoder) =>
decoder.decode(v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ abstract class GroupValueConverter[V <: GroupValue[V]](
case _: LogicalTypeAnnotation.MapLogicalTypeAnnotation =>
map(schema0.asGroupType(), name)
case _ =>
val name = schema0.getName
val repetition = schema0.getRepetition

val p = if (name == "list" && repetition == Repetition.REPEATED) Some(this) else None

record(schema0.asGroupType(), name, p)
(name, schema0.getRepetition) match {
case ("list", Repetition.REPEATED) =>
listElement(schema0.asGroupType())
case ("key_value", Repetition.REPEATED) =>
mapKeyValue(schema0.asGroupType(), name)
case _ =>
record(schema0.asGroupType(), name)
}
}
}
)
Expand Down Expand Up @@ -73,25 +75,17 @@ abstract class GroupValueConverter[V <: GroupValue[V]](

private def record(
schema: GroupType,
name: String,
parent: Option[GroupValueConverter[_]]
): GroupValueConverter[GroupValue.RecordValue] = parent match {
case Some(_) =>
new GroupValueConverter[GroupValue.RecordValue](schema, parent) {
override def start(): Unit = ()
override def end(): Unit = ()
}
case _ =>
new GroupValueConverter[GroupValue.RecordValue](schema, parent) {
name: String
): GroupValueConverter[GroupValue.RecordValue] =
new GroupValueConverter[GroupValue.RecordValue](schema, parent) {

override def start(): Unit =
this.groupValue = Value.record(Map.empty)
override def start(): Unit =
this.groupValue = Value.record(Map.empty)

override def end(): Unit =
self.put(name, this.groupValue)
override def end(): Unit =
put(name, this.groupValue)

}
}
}

private def list(
schema: GroupType,
Expand All @@ -106,6 +100,15 @@ abstract class GroupValueConverter[V <: GroupValue[V]](
self.put(name, this.groupValue)
}

private def listElement(schema: GroupType): GroupValueConverter[GroupValue.RecordValue] =
new GroupValueConverter[GroupValue.RecordValue](schema, Some(self)) {

override def start(): Unit = ()

override def end(): Unit = ()

}

private def map(
schema: GroupType,
name: String
Expand All @@ -119,6 +122,20 @@ abstract class GroupValueConverter[V <: GroupValue[V]](
self.put(name, this.groupValue)
}

private def mapKeyValue(
schema: GroupType,
name: String
): GroupValueConverter[GroupValue.RecordValue] =
new GroupValueConverter[GroupValue.RecordValue](schema) {

override def start(): Unit =
this.groupValue = Value.record(Map("key" -> Value.nil, "value" -> Value.nil))

override def end(): Unit =
self.put(name, this.groupValue)

}

}

object GroupValueConverter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ object ParquetIOSpec extends ZIOSpecDefault {
val tmpCrcPath = tmpDir / ".parquet-writer-spec.parquet.crc"
val tmpPath = tmpDir / tmpFile

case class Record(a: Int, b: String, c: Option[Long], d: List[Int])
case class Record(a: Int, b: String, c: Option[Long], d: List[Int], e: Map[String, Int])
object Record {
implicit val schema: Schema[Record] =
DeriveSchema.gen[Record]
Expand All @@ -31,8 +31,8 @@ object ParquetIOSpec extends ZIOSpecDefault {
suite("ParquetIOSpec")(
test("write and read") {
val payload = Chunk(
Record(1, "foo", None, List(1, 2)),
Record(2, "bar", Some(3L), List.empty)
Record(1, "foo", None, List(1, 2), Map("first" -> 1, "second" -> 2)),
Record(2, "bar", Some(3L), List.empty, Map("third" -> 3))
)

for {
Expand All @@ -48,7 +48,7 @@ object ParquetIOSpec extends ZIOSpecDefault {
) @@ after(cleanTmpFile(tmpDir))
)

def cleanTmpFile(path: Path) =
private def cleanTmpFile(path: Path) =
for {
_ <- ZIO.attemptBlockingIO(Files.delete(tmpCrcPath.toJava))
_ <- ZIO.attemptBlockingIO(Files.delete(tmpPath.toJava))
Expand Down

0 comments on commit da4508a

Please sign in to comment.