Skip to content

Commit

Permalink
Refactor .predicate macro methods
Browse files Browse the repository at this point in the history
  • Loading branch information
grouzen committed Aug 31, 2024
1 parent e206c00 commit 17a5b5d
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ package object syntax extends Predicate.Syntax {
Column.Named(column.path)
}

def predicate[A](predicate: Predicate[A]): CompiledPredicate = macro SanitizeOptionalsMacro.sanitizeImpl[A]
def filter[A](predicate: Predicate[A]): CompiledPredicate = macro SanitizeOptionalsMacro.sanitizeImpl[A]

def concat[A, B, F](
parent: Column[A],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ package object syntax extends Predicate.Syntax {
Column.Named(column.path)
}

inline def predicate[A](inline predicate: Predicate[A]): CompiledPredicate =
inline def filter[A](inline predicate: Predicate[A]): CompiledPredicate =
${ SanitizeOptionalsMacro.sanitizeImpl[A]('predicate) }

inline def concat[A, B, F](inline parent: Column[A], inline child: Column.Named[B, F])(using
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ import java.io.IOException

trait ParquetReader[+A <: Product] {

def readStream(path: Path, filter: Option[CompiledPredicate] = None): ZStream[Scope, Throwable, A]
def readStream(path: Path): ZStream[Scope, Throwable, A]

def readChunk[B](path: Path, filter: Option[CompiledPredicate] = None): Task[Chunk[A]]
def readStreamFiltered(path: Path, filter: CompiledPredicate): ZStream[Scope, Throwable, A]

def readChunk[B](path: Path): Task[Chunk[A]]

def readChunkFiltered[B](path: Path, filter: CompiledPredicate): Task[Chunk[A]]

}

Expand All @@ -29,30 +33,55 @@ final class ParquetReaderLive[A <: Product: Tag](
)(implicit decoder: ValueDecoder[A])
extends ParquetReader[A] {

override def readStream(path: Path, filter: Option[CompiledPredicate] = None): ZStream[Scope, Throwable, A] =
override def readStream(path: Path): ZStream[Scope, Throwable, A] =
for {
reader <- ZStream.fromZIO(build(path, None))
value <- readStream0(reader)
} yield value

override def readStreamFiltered(path: Path, filter: CompiledPredicate): ZStream[Scope, Throwable, A] =
for {
reader <- ZStream.fromZIO(build(path, filter))
value <- ZStream.repeatZIOOption(
ZIO
.attemptBlockingIO(reader.read())
.asSomeError
.filterOrFail(_ != null)(None)
.flatMap(decoder.decodeZIO(_).asSomeError)
)
reader <- ZStream.fromZIO(build(path, Some(filter)))
value <- readStream0(reader)
} yield value

override def readChunk[B](path: Path, filter: Option[CompiledPredicate] = None): Task[Chunk[A]] =
override def readChunk[B](path: Path): Task[Chunk[A]] =
ZIO.scoped(
for {
reader <- build(path, None)
result <- readChunk0(reader)
} yield result
)

override def readChunkFiltered[B](path: Path, filter: CompiledPredicate): Task[Chunk[A]] =
ZIO.scoped(
for {
reader <- build(path, Some(filter))
result <- readChunk0(reader)
} yield result
)

private def readStream0(reader: HadoopParquetReader[RecordValue]): ZStream[Any, Throwable, A] =
ZStream.repeatZIOOption(
ZIO
.attemptBlockingIO(reader.read())
.asSomeError
.filterOrFail(_ != null)(None)
.flatMap(decoder.decodeZIO(_).asSomeError)
)

private def readChunk0[B](reader: HadoopParquetReader[RecordValue]): Task[Chunk[A]] = {
val readNext = for {
value <- ZIO.attemptBlockingIO(reader.read())
record <- if (value != null)
decoder.decodeZIO(value)
else
ZIO.succeed(null.asInstanceOf[A])
} yield record
val builder = Chunk.newBuilder[A]

ZIO.scoped(
for {
reader <- build(path, filter)
readNext = for {
value <- ZIO.attemptBlockingIO(reader.read())
record <- if (value != null)
decoder.decodeZIO(value)
else
ZIO.succeed(null.asInstanceOf[A])
} yield record
builder = Chunk.newBuilder[A]
initial <- readNext
_ <- {
var current = initial
Expand All @@ -64,6 +93,7 @@ final class ParquetReaderLive[A <: Product: Tag](
}
} yield builder.result()
)
}

private def build[B](
path: Path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ object ExprSpec extends ZIOSpecDefault {
test("compile all operators") {
val (a, b, _, _, _) = Filter[MyRecord].columns

val result = predicate(
val result = filter(
not(
(b >= 3 `or` b <= 100 `and` a.in(Set("foo", "bar"))) `or`
(a === "foo" `and` (b === 20 `or` b.notIn(Set(1, 2, 3)))) `or`
Expand Down Expand Up @@ -64,7 +64,7 @@ object ExprSpec extends ZIOSpecDefault {
test("compile summoned") {
val (a, b) = Filter[MyRecordSummoned].columns

val result = predicate(
val result = filter(
a === 3 `and` b === "foo"
)

Expand Down Expand Up @@ -260,35 +260,35 @@ object ExprSpec extends ZIOSpecDefault {
Long.box(Value.zonedDateTime(zonedDateTimePayload).value)
)

val stringResul = predicate(string === stringPayload)
val booleanResult = predicate(boolean === booleanPayload)
val byteResult = predicate(byte === bytePayload)
val shortResult = predicate(short === shortPayload)
val intResult = predicate(int === intPayload)
val longResult = predicate(long === longPayload)
val floatResult = predicate(float === floatPayload)
val doubleResult = predicate(double === doublePayload)
val binaryResult = predicate(binary === binaryPayload)
val charResult = predicate(char === charPayload)
val uuidResult = predicate(uuid === uuidPayload)
val bigDecimalResult = predicate(bigDecimal === bigDecimalPayload)
val bigIntegerResult = predicate(bigInteger === bigIntegerPayload)
val dayOfWeekResult = predicate(dayOfWeek === dayOfWeekPayload)
val monthResult = predicate(month === monthPayload)
val monthDayResult = predicate(monthDay === monthDayPayload)
val periodResult = predicate(period === periodPayload)
val yearResult = predicate(year === yearPayload)
val yearMonthResult = predicate(yearMonth === yearMonthPayload)
val zoneIdResult = predicate(zoneId === zoneIdPayload)
val zoneOffsetResult = predicate(zoneOffset === zoneOffsetPayload)
val durationResult = predicate(duration === durationPayload)
val instantResult = predicate(instant === instantPayload)
val localDateResult = predicate(localDate === localDatePayload)
val localTimeResult = predicate(localTime === localTimePayload)
val localDateTimeResult = predicate(localDateTime === localDateTimePayload)
val offsetTimeResult = predicate(offsetTime === offsetTimePayload)
val offsetDateTimeResult = predicate(offsetDateTime === offsetDateTimePayload)
val zonedDateTimeResult = predicate(zonedDateTime === zonedDateTimePayload)
val stringResul = filter(string === stringPayload)
val booleanResult = filter(boolean === booleanPayload)
val byteResult = filter(byte === bytePayload)
val shortResult = filter(short === shortPayload)
val intResult = filter(int === intPayload)
val longResult = filter(long === longPayload)
val floatResult = filter(float === floatPayload)
val doubleResult = filter(double === doublePayload)
val binaryResult = filter(binary === binaryPayload)
val charResult = filter(char === charPayload)
val uuidResult = filter(uuid === uuidPayload)
val bigDecimalResult = filter(bigDecimal === bigDecimalPayload)
val bigIntegerResult = filter(bigInteger === bigIntegerPayload)
val dayOfWeekResult = filter(dayOfWeek === dayOfWeekPayload)
val monthResult = filter(month === monthPayload)
val monthDayResult = filter(monthDay === monthDayPayload)
val periodResult = filter(period === periodPayload)
val yearResult = filter(year === yearPayload)
val yearMonthResult = filter(yearMonth === yearMonthPayload)
val zoneIdResult = filter(zoneId === zoneIdPayload)
val zoneOffsetResult = filter(zoneOffset === zoneOffsetPayload)
val durationResult = filter(duration === durationPayload)
val instantResult = filter(instant === instantPayload)
val localDateResult = filter(localDate === localDatePayload)
val localTimeResult = filter(localTime === localTimePayload)
val localDateTimeResult = filter(localDateTime === localDateTimePayload)
val offsetTimeResult = filter(offsetTime === offsetTimePayload)
val offsetDateTimeResult = filter(offsetDateTime === offsetDateTimePayload)
val zonedDateTimeResult = filter(zonedDateTime === zonedDateTimePayload)

assert(stringResul)(isRight(equalTo(stringExpected))) &&
assert(booleanResult)(isRight(equalTo(booleanExpected))) &&
Expand Down Expand Up @@ -325,14 +325,14 @@ object ExprSpec extends ZIOSpecDefault {
val (_, _, _, _, opt) = Filter[MyRecord].columns

val expected = FilterApi.gt(FilterApi.intColumn("opt"), Int.box(Value.int(3).value))
val result = predicate(opt.nullable > 3)
val result = filter(opt.nullable > 3)

assert(result)(isRight(equalTo(expected)))
},
test("compile enum") {
val (_, _, _, enm, _) = Filter[MyRecord].columns

val result = predicate(enm === MyRecord.Enum.Done)
val result = filter(enm === MyRecord.Enum.Done)
val expected = FilterApi.eq(FilterApi.binaryColumn("enm"), Value.string("Done").value)

assert(result)(isRight(equalTo(expected)))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package me.mnedokushev.zio.apache.parquet.core.hadoop

import me.mnedokushev.zio.apache.parquet.core.Fixtures._
import me.mnedokushev.zio.apache.parquet.core.filter._
import me.mnedokushev.zio.apache.parquet.core.filter.Filter
import me.mnedokushev.zio.apache.parquet.core.filter.syntax._
import zio._
import zio.stream._
import zio.test.TestAspect._
import zio.test.Assertion._
import zio.test.TestAspect._
import zio.test._

import java.nio.file.Files
Expand Down Expand Up @@ -89,21 +89,19 @@ object ParquetIOSpec extends ZIOSpecDefault {
} yield assertTrue(result == projectedPayload)
} @@ after(cleanTmpFile(tmpDir)),
test("write and read with filter") {
val payload = Chunk(
val payload = Chunk(
MyRecordIO(1, "foo", None, List(1, 2), Map("first" -> 1, "second" -> 2)),
MyRecordIO(2, "foo", None, List(1, 2), Map.empty),
MyRecordIO(3, "bar", Some(3L), List.empty, Map("third" -> 3)),
MyRecordIO(4, "baz", None, List.empty, Map("fourth" -> 3))
)

val (id, name, _, _, _) = Filter[MyRecordIO].columns
val pred = predicate(id > 1 `and` name =!= "foo")

for {
writer <- ZIO.service[ParquetWriter[MyRecordIO]]
reader <- ZIO.service[ParquetReader[MyRecordIO]]
_ <- writer.writeChunk(tmpPath, payload)
result <- reader.readChunk(tmpPath, filter = Some(pred))
result <- reader.readChunkFiltered(tmpPath, filter(id > 1 `and` name =!= "foo"))
} yield assertTrue(result.size == 2) && assert(result)(equalTo(payload.drop(2)))
} @@ after(cleanTmpFile(tmpDir))
).provide(
Expand Down

0 comments on commit 17a5b5d

Please sign in to comment.