Skip to content

Commit

Permalink
add param for the OutlierRemoveTransformer (#3)
Browse files Browse the repository at this point in the history
* add param for the OutlierRemoveTransformer

* test improvements

* up default spark version and log level
  • Loading branch information
deadsalmonbrain authored Jun 25, 2024
1 parent a4f9c0d commit 5414ab4
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 60 deletions.
17 changes: 15 additions & 2 deletions python/ai/salmonbrain/ruleofthumb.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,13 @@ class OutlierRemoveTransformer(
typeConverter=TypeConverters.toFloat,
)

excludedMetrics = Param(
Params._dummy(),
"excludedMetrics",
"metrics excluded from filtering",
typeConverter=TypeConverters.toListString,
)

@keyword_only
def __init__(
self,
Expand All @@ -286,14 +293,16 @@ def __init__(
entityCategoryNameColumn="categoryName",
entityCategoryValueColumn="categoryValue",
lowerPercentile=0.01,
upperPercentile=0.99
upperPercentile=0.99,
excludedMetrics=[]
):
super(OutlierRemoveTransformer, self).__init__()
self._java_obj = self._new_java_obj(
"ai.salmonbrain.ruleofthumb.OutlierRemoveTransformer", self.uid
)
self._setDefault(lowerPercentile=0.01)
self._setDefault(upperPercentile=0.99)
self._setDefault(excludedMetrics=[])
kwargs = self._input_kwargs
self.setParams(**kwargs)

Expand All @@ -311,7 +320,8 @@ def setParams(
entityCategoryNameColumn="categoryName",
entityCategoryValueColumn="categoryValue",
lowerPercentile=0.01,
upperPercentile=0.99
upperPercentile=0.99,
excludedMetrics=[]
):
kwargs = self._input_kwargs
return self._set(**kwargs)
Expand All @@ -321,3 +331,6 @@ def setLowerPercentile(self, value):

def setUpperPercentile(self, value):
return self._set(upperPercentile=value)

def setExcludedMetrics(self, value):
return self._set(excludedMetrics=value)
14 changes: 11 additions & 3 deletions python/tests/test_ruleofthumb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def data_sample_for_outlier(spark: SparkSession):
[
("common", "all", "feedback", "1", "exp", "treatment", i, "shows", True)
for i in range(100)
] + [
("common", "all", "feedback", "1", "exp", "treatment", i, "clicks", True)
for i in range(100)
],
[
"categoryName",
Expand Down Expand Up @@ -158,8 +161,13 @@ def test_autoStatisticsTransformer(data_sample: DataFrame):


def test_outlierRemoveTransformer(data_sample_for_outlier: DataFrame):
outlier = OutlierRemoveTransformer(lowerPercentile=0.05, upperPercentile=0.95)
outlier = OutlierRemoveTransformer(lowerPercentile=0.05, upperPercentile=0.95, excludedMetrics=["clicks"])
result = outlier.transform(data_sample_for_outlier)

values = [i["metricValue"] for i in result.select("metricValue").collect()]
assert len(values) == 89
result.show()

countViews = result.filter("metricName = 'shows'").count()
countClicks = result.filter("metricName = 'clicks'").count()
assert countViews == 89
assert countClicks == 100

5 changes: 2 additions & 3 deletions ruleofthumb/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def scalaVersion = findProperty("scalaVersion") ?: "2.12.11"
println "Scala version: $scalaVersion"
def scalaVersionShort = "${VersionNumber.parse(scalaVersion).getMajor()}.${VersionNumber.parse(scalaVersion).getMinor()}"

def sparkVersion = findProperty("sparkVersion") ?: "3.1.2"
def sparkVersion = findProperty("sparkVersion") ?: "3.2.0"
println "Spark version: $sparkVersion"


Expand Down Expand Up @@ -49,10 +49,9 @@ dependencies {
testImplementation "org.mockito:mockito-scala_$scalaVersionShort:1.16.42"
}



test{
maxHeapSize = '1G'
maxParallelForks = 1
}

task pythonTest(type: Exec) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ai.salmonbrain.ruleofthumb

import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{ Param, ParamMap }
import org.apache.spark.ml.param.{ Param, ParamMap, StringArrayParam }
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.functions.{ broadcast, callUDF, col, lit }
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -30,6 +30,13 @@ class OutlierRemoveTransformer(override val uid: String)
)
setDefault(upperPercentile, 0.99)

val excludedMetrics: StringArrayParam = new StringArrayParam(
this,
"excludedMetrics",
"metrics excluded from filtering"
)
setDefault(excludedMetrics, Array[String]())

/** @group setParam */
def setLowerPercentile(value: Double): this.type =
set(lowerPercentile, value)
Expand All @@ -38,6 +45,10 @@ class OutlierRemoveTransformer(override val uid: String)
def setUpperPercentile(value: Double): this.type =
set(upperPercentile, value)

/** @group setParam */
def setExcludedMetrics(value: Array[String]): this.type =
set(excludedMetrics, value)

override def transform(dataset: Dataset[_]): DataFrame = {
assert($(upperPercentile) > 0 && $(upperPercentile) < 1, "upperPercentile must be in (0, 1)")
assert(
Expand Down Expand Up @@ -76,15 +87,16 @@ class OutlierRemoveTransformer(override val uid: String)
$(entityCategoryValueColumn)
)
val percentilesBound = dataset
.filter(!col($(metricNameColumn)).isin($(excludedMetrics): _*))
.groupBy(columns.head, columns: _*)
.agg(
aggFunc.head,
aggFunc.tail: _*
)

dataset
.join(broadcast(percentilesBound), columns)
.filter(filterFunc)
.join(broadcast(percentilesBound), columns, "left")
.filter(filterFunc || col($(metricNameColumn)).isin($(excludedMetrics): _*))
.drop(dropCols: _*)
}

Expand Down
3 changes: 3 additions & 0 deletions ruleofthumb/src/test/resources/log4j.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
log4j.rootLogger=ERROR, Console
log4j.appender.Console=org.apache.log4j.ConsoleAppender
log4j.appender.Console.layout=org.apache.log4j.PatternLayout
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
package ai.salmonbrain.inputs;

import ai.salmonbrain.ruleofthumb.ExpData
import helpers.SparkHelper
import helpers.SharedSparkSession
import org.apache.spark.sql.DataFrame
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

class AccessLogTransformerSpec extends AnyFlatSpec with SparkHelper with Matchers {

class AccessLogTransformerSpec extends AnyFlatSpec with SharedSparkSession with Matchers {
import spark.implicits._
"AccessLogTransformer" should "be" in {
import sqlc.implicits._
val logsDF: DataFrame = sc
.parallelize(
Seq(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package ai.salmonbrain.inputs

import helpers.SparkHelper
import helpers.SharedSparkSession
import org.apache.spark.sql.DataFrame
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers
Expand All @@ -9,12 +9,11 @@ import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
import java.time.ZonedDateTime
import java.time.format.DateTimeFormatter

class NginxRawLogTransformerSpec extends AnyFlatSpec with SparkHelper with Matchers {

class NginxRawLogTransformerSpec extends AnyFlatSpec with SharedSparkSession with Matchers {
import spark.implicits._
private val logsResource: String = getClass.getResource("/nginx_sample_1.txt").getPath

"NginxRawLogTransformer" should "be" in {
import sqlc.implicits._
val rawLogs: DataFrame = CsvHelper.readCsv(spark, Seq(logsResource))

val transformer = new NginxRawLogTransformer()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package ai.salmonbrain.ruleofthumb

import helpers.ExperimentDataGenerator.{ experimentDataGenerator, seqExpDataToDataFrame }
import helpers.SparkHelper
import helpers.ExperimentDataGenerator.{experimentDataGenerator, seqExpDataToDataFrame}
import helpers.SharedSparkSession
import org.apache.spark.ml.Pipeline
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers

class ComputingFlowSpec extends AnyFlatSpec with SparkHelper with Matchers {
class ComputingFlowSpec extends AnyFlatSpec with SharedSparkSession with Matchers {
"ComputingFlow" should "be" in {
import spark.implicits._
val metrics =
lazy val metrics =
seqExpDataToDataFrame(
experimentDataGenerator(
uplift = 0,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package ai.salmonbrain.ruleofthumb

import helpers.ExperimentDataGenerator.{ experimentDataGenerator, seqExpDataToDataFrame }
import helpers.SparkHelper
import helpers.ExperimentDataGenerator.{experimentDataGenerator, seqExpDataToDataFrame}
import helpers.SharedSparkSession
import org.apache.spark.sql.functions.first
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers

class CumulativeMetricTransformerSpec extends AnyFlatSpec with SparkHelper with Matchers {
class CumulativeMetricTransformerSpec extends AnyFlatSpec with SharedSparkSession with Matchers {
"CumulativeMetricTransformer" should "be" in {
val metrics = seqExpDataToDataFrame(experimentDataGenerator(withAggregation = false))

lazy val metrics = seqExpDataToDataFrame(experimentDataGenerator(withAggregation = false))
val cumulativeData = new CumulativeMetricTransformer()
.setNumeratorNames(Array("clicks"))
.setDenominatorNames(Array("views"))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package ai.salmonbrain.ruleofthumb

import helpers.ExperimentDataGenerator.generateDataForWelchTest
import helpers.SparkHelper
import helpers.ExperimentDataGenerator.{experimentDataGenerator, generateDataForWelchTest, seqExpDataToDataFrame}
import helpers.SharedSparkSession
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.SparkSession
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers

class OutlierRemoveTransformerSpec extends AnyFlatSpec with SparkHelper with Matchers {
class OutlierRemoveTransformerSpec extends AnyFlatSpec with SharedSparkSession with Matchers {
implicit val sparkSession: SparkSession = spark
"OutlierRemoveTransformerSpec" should "be" in {
val data = generateDataForWelchTest()
val clearData = new OutlierRemoveTransformer().transform(data)
Expand All @@ -17,4 +20,30 @@ class OutlierRemoveTransformerSpec extends AnyFlatSpec with SparkHelper with Mat
val clearData = new OutlierRemoveTransformer().setLowerPercentile(0).transform(data)
assert(clearData.count() == 28)
}

"OutlierRemoveTransformerSpec with excluded columns" should "be" in {
val data =
seqExpDataToDataFrame(
experimentDataGenerator(
uplift = 0,
controlSkew = 0.1,
treatmentSkew = 0.1,
controlSize = 100,
treatmentSize = 100,
withAggregation = false
)
)
val pipe = new Pipeline().setStages(
Array(
new CumulativeMetricTransformer(),
new OutlierRemoveTransformer()
.setLowerPercentile(0)
.setUpperPercentile(0.99)
.setExcludedMetrics(Array("clicks"))
)
)
val clearData = pipe.fit(data).transform(data)
assert(clearData.filter("metricName = 'clicks'").count() == 200)
assert(clearData.filter("metricName = 'views'").count() < 200)
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
package ai.salmonbrain.ruleofthumb

import helpers.ExperimentDataGenerator.{
experimentDataGenerator,
generateDataForWelchTest,
seqExpDataToDataFrame
}
import helpers.SparkHelper
import helpers.ExperimentDataGenerator.{experimentDataGenerator, generateDataForWelchTest, seqExpDataToDataFrame}
import helpers.SharedSparkSession
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.first
import org.scalactic.{ Equality, TolerantNumerics }
import org.scalactic.{Equality, TolerantNumerics}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers

class StatisticsTransformerSpec extends AnyFlatSpec with SparkHelper with Matchers {
class StatisticsTransformerSpec extends AnyFlatSpec with SharedSparkSession with Matchers {
import spark.implicits._
val epsilon = 1e-4f
val statWelch = new WelchStatisticsTransformer()
Expand All @@ -29,7 +25,7 @@ class StatisticsTransformerSpec extends AnyFlatSpec with SparkHelper with Matche
.setRatioNames(Array("ctr"))
.setNumBuckets(256)

private val metricsWithUplift: DataFrame = cumWithBuckets.transform(
private lazy val metricsWithUplift: DataFrame = cumWithBuckets.transform(
seqExpDataToDataFrame(
experimentDataGenerator(
uplift = 0.2,
Expand All @@ -40,7 +36,7 @@ class StatisticsTransformerSpec extends AnyFlatSpec with SparkHelper with Matche
)
)

private val metricsWithoutUplift: DataFrame = cumWithBuckets.transform(
private lazy val metricsWithoutUplift: DataFrame = cumWithBuckets.transform(
seqExpDataToDataFrame(
experimentDataGenerator(
uplift = 0.0,
Expand All @@ -54,8 +50,8 @@ class StatisticsTransformerSpec extends AnyFlatSpec with SparkHelper with Matche
seqExpDataToDataFrame(
experimentDataGenerator(
uplift = 0.0,
controlSize = 3000,
treatmentSize = 3000,
controlSize = 300,
treatmentSize = 300,
treatmentSkew = 10,
controlSkew = 10
)
Expand Down
20 changes: 8 additions & 12 deletions ruleofthumb/src/test/scala/helpers/ExperimentDataGenerator.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
package helpers

import ai.salmonbrain.ruleofthumb.ExpData
import org.apache.commons.math3.distribution.{
BetaDistribution,
BinomialDistribution,
NormalDistribution
}
import org.apache.commons.math3.distribution.{BetaDistribution, BinomialDistribution, NormalDistribution}
import org.apache.commons.math3.random.Well19937a
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.from_unixtime
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions.{from_unixtime, when}

import java.sql.Timestamp
import java.time.LocalDate
Expand All @@ -18,11 +14,10 @@ import scala.util.Random
Inspired by
https://vkteam.medium.com/practitioners-guide-to-statistical-tests-ed2d580ef04f#609f
*/
object ExperimentDataGenerator extends SparkHelper {
import spark.implicits._
object ExperimentDataGenerator {
val randomGenerator = new Well19937a(777)

def generateDataForWelchTest(): DataFrame = {
def generateDataForWelchTest()(implicit spark: SparkSession): DataFrame = {
//sigma = 1,N = 10
val controlMetricValues = Seq(19.8, 20.4, 19.6, 17.8, 18.5, 18.9, 18.3, 18.9, 19.5, 22)
//sigma = 16, N = 20
Expand Down Expand Up @@ -134,8 +129,9 @@ object ExperimentDataGenerator extends SparkHelper {
}
}

def seqExpDataToDataFrame(data: Seq[ExpData]): DataFrame = {
sc
def seqExpDataToDataFrame(data: Seq[ExpData])(implicit spark: SparkSession): DataFrame = {
import spark.implicits._
spark.sparkContext
.parallelize(data)
.toDF
.withColumn("date", from_unixtime($"timestamp" / 1000).cast("date"))
Expand Down
Loading

0 comments on commit 5414ab4

Please sign in to comment.