Skip to content

Commit

Permalink
Merge pull request #15 from Salmon-Brain/brokenDataFix
Browse files Browse the repository at this point in the history
Broken data fix
  • Loading branch information
deadsalmonbrain authored May 13, 2022
2 parents 3bdabb8 + eb4b9b9 commit 773f9f5
Show file tree
Hide file tree
Showing 13 changed files with 169 additions and 60 deletions.
22 changes: 22 additions & 0 deletions python/ai/salmonbrain/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,27 @@ class BasicStatInferenceParameters(Params):
typeConverter=TypeConverters.toFloat,
)

minValidSampleSize = Param(
Params._dummy(),
"minValidSampleSize",
"parameter for skip invalid groups",
typeConverter=TypeConverters.toInt,
)

useLinearApproximationForVariance = Param(
Params._dummy(),
"useLinearApproximationForVariance",
"parameter for control variance computing method for nonparametric tests",
typeConverter=TypeConverters.toBoolean,
)

def __init__(self):
super(BasicStatInferenceParameters, self).__init__()
self._setDefault(alpha=0.05)
self._setDefault(beta=0.2)
self._setDefault(srmAlpha=0.05)
self._setDefault(minValidSampleSize=10)
self._setDefault(useLinearApproximationForVariance=False)

def setAlpha(self, value):
return self._set(alpha=value)
Expand All @@ -156,3 +172,9 @@ def setBeta(self, value):

def setSrmAlpha(self, value):
return self._set(srmAlpha=value)

def setMinValidSampleSize(self, value):
return self._set(minValidSampleSize=value)

def setUseLinearApproximationForVariance(self, value):
return self._set(useLinearApproximationForVariance=value)
8 changes: 8 additions & 0 deletions python/ai/salmonbrain/ruleofthumb.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def setParams(
alpha=0.05,
beta=0.2,
srmAlpha=0.05,
minValidSampleSize=10,
useLinearApproximationForVariance=False,
metricSourceColumn="metricSource",
entityIdColumn="entityUid",
experimentColumn="experimentUid",
Expand Down Expand Up @@ -154,6 +156,8 @@ def __init__(
alpha=0.05,
beta=0.2,
srmAlpha=0.05,
minValidSampleSize=10,
useLinearApproximationForVariance=False,
metricSourceColumn="metricSource",
entityIdColumn="entityUid",
experimentColumn="experimentUid",
Expand Down Expand Up @@ -190,6 +194,8 @@ def __init__(
alpha=0.05,
beta=0.2,
srmAlpha=0.05,
minValidSampleSize=10,
useLinearApproximationForVariance=False,
metricSourceColumn="metricSource",
entityIdColumn="entityUid",
experimentColumn="experimentUid",
Expand Down Expand Up @@ -226,6 +232,8 @@ def __init__(
alpha=0.05,
beta=0.2,
srmAlpha=0.05,
minValidSampleSize=10,
useLinearApproximationForVariance=False,
metricSourceColumn="metricSource",
entityIdColumn="entityUid",
experimentColumn="experimentUid",
Expand Down
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setup(
name="dead-salmon-brain",
version="0.0.6",
version="0.0.7",
description="Dead Salmon Brain is a cluster computing system for analysing A/B experiments",
license="Apache License v2.0",
author="Dead Salmon Brain",
Expand Down
8 changes: 4 additions & 4 deletions python/tests/test_ruleofthumb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
CumulativeMetricTransformer,
WelchStatisticsTransformer,
OutlierRemoveTransformer,
AutoStatisticsTransformer
AutoStatisticsTransformer,
)


Expand Down Expand Up @@ -126,7 +126,7 @@ def test_cumulativeMetricTransformer(data_sample: DataFrame):

def test_welchStatisticsTransformer(data_sample: DataFrame):
cum = CumulativeMetricTransformer()
welch = WelchStatisticsTransformer()
welch = WelchStatisticsTransformer(minValidSampleSize=3)
result = welch.transform(cum.transform(data_sample))

p_values = [
Expand All @@ -137,7 +137,7 @@ def test_welchStatisticsTransformer(data_sample: DataFrame):

def test_mannWhitneyStatisticsTransformer(data_sample: DataFrame):
cum = CumulativeMetricTransformer()
welch = WelchStatisticsTransformer()
welch = WelchStatisticsTransformer(minValidSampleSize=3)
result = welch.transform(cum.transform(data_sample))

p_values = [
Expand All @@ -148,7 +148,7 @@ def test_mannWhitneyStatisticsTransformer(data_sample: DataFrame):

def test_autoStatisticsTransformer(data_sample: DataFrame):
cum = CumulativeMetricTransformer()
auto = AutoStatisticsTransformer()
auto = AutoStatisticsTransformer(minValidSampleSize=3)
result = auto.transform(cum.transform(data_sample))

p_values = [
Expand Down
2 changes: 1 addition & 1 deletion ruleofthumb/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ application {
mainClass = 'ai.salmonbrain.ruleofthumb.Main'
}

version '0.0.6'
version '0.0.7'

repositories {
mavenCentral()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class AutoStatisticsTransformer(override val uid: String) extends BaseStatisticT
.setAdditiveColumn($(additiveColumn))
.setAlpha($(alpha))
.setBeta($(beta))
.setMinValidSampleSize($(minValidSampleSize))
.setDataProviderColumn($(metricSourceColumn))
.setEntityIdColumn($(entityIdColumn))
.setExperimentColumn($(experimentColumn))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package ai.salmonbrain.ruleofthumb

import ai.salmonbrain.ruleofthumb.CentralTendency.CentralTendency
import org.apache.commons.math3.stat.inference.TestUtils
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.DefaultParamsWritable
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.types.{ BooleanType, StringType, StructField, StructType }
import org.apache.spark.sql.{ Dataset, Encoders }

trait BaseStatisticTransformer
extends Transformer
Expand Down Expand Up @@ -43,17 +44,19 @@ trait BaseStatisticTransformer
)
}

protected def checkVariants(dataset: Dataset[_]): Unit = {
val expectedVariants = Set($(treatmentName), $(controlName))
val observedVariants = dataset
.select($(variantColumn))
.distinct()
.collect()
.map(row => row.getAs[String]($(variantColumn)))
.toSet
assert(
expectedVariants == observedVariants,
s"Variants must be named ${$(treatmentName)} and ${$(controlName)}"
protected def getInvalidStatResult(centralTendency: CentralTendency): StatResult = {
StatResult(
Double.NaN,
Double.NaN,
-1,
Double.NaN,
Double.NaN,
Double.NaN,
Double.NaN,
Double.NaN,
Double.NaN,
centralTendency.toString,
isZeroVariance = false
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ trait BasicStatInferenceParameters extends Params {
)
setDefault(srmAlpha, 0.05)

val minValidSampleSize: Param[Int] = new Param[Int](
this,
"minValidSampleSize",
"parameter for skip invalid groups"
)
setDefault(minValidSampleSize, 10)

val useLinearApproximationForVariance: Param[Boolean] = new Param[Boolean](
this,
"useLinearApproximationForVariance",
"parameter for control variance computing method for nonparametric tests"
)
setDefault(useLinearApproximationForVariance, false)

/** @group setParam */
def setUseLinearApproximationForVariance(value: Boolean): this.type =
set(useLinearApproximationForVariance, value)

/** @group setParam */
def setAlpha(value: Double): this.type =
set(alpha, value)
Expand All @@ -36,4 +54,8 @@ trait BasicStatInferenceParameters extends Params {
def setSrmAlpha(value: Double): this.type =
set(srmAlpha, value)

/** @group setParam */
def setMinValidSampleSize(value: Int): this.type =
set(minValidSampleSize, value)

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ case class StatisticsReport(
statResult: StatResult,
alpha: Double,
beta: Double,
minValidSampleSize: Int,
srm: Boolean,
controlSize: Long,
treatmentSize: Long,
testType: String
testType: String,
isEnoughData: Boolean
)

case class CI(
Expand Down Expand Up @@ -66,7 +68,8 @@ case class StatResult(
treatmentVariance: Double,
percentageLeft: Double,
percentageRight: Double,
centralTendencyType: String = CentralTendency.MEAN.toString
centralTendencyType: String,
isZeroVariance: Boolean
)

case class Metric(metricName: String, metricValue: Double)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{ DataFrame, Dataset }

import scala.collection.mutable
import scala.collection.mutable.WrappedArray.make

/**
* Transformer to apply Mann–Whitney U test
Expand All @@ -19,7 +20,6 @@ class MannWhitneyStatisticsTransformer(override val uid: String) extends BaseSta
def this() = this(Identifiable.randomUID("mannWhitneyStatisticsTransformer"))

override def transform(dataset: Dataset[_]): DataFrame = {
checkVariants(dataset)
dataset
.groupBy(
$(experimentColumn),
Expand All @@ -35,7 +35,10 @@ class MannWhitneyStatisticsTransformer(override val uid: String) extends BaseSta
)
.withColumn(
"statisticsData",
doStatistic($(alpha), $(beta))(col($(controlName)), col($(treatmentName)))
doStatistic($(alpha), $(beta), $(minValidSampleSize), $(useLinearApproximationForVariance))(
col($(controlName)),
col($(treatmentName))
)
)
.drop("control", "treatment")
}
Expand All @@ -44,23 +47,43 @@ class MannWhitneyStatisticsTransformer(override val uid: String) extends BaseSta

override def transformSchema(schema: StructType): StructType = schema

def doStatistic(alpha: Double, beta: Double): UserDefinedFunction = udf {
def doStatistic(
alpha: Double,
beta: Double,
minValidSampleSize: Int,
useLinearApproximationForVariance: Boolean
): UserDefinedFunction = udf {
(
control: mutable.WrappedArray[Double],
treatment: mutable.WrappedArray[Double]
) =>
val statResult =
MannWhitneyTest.mannWhitneyTest(control.toArray, treatment.toArray, alpha, beta)
val controlSize = control.length
val treatmentSize = treatment.length
val controlSize = Option(control).getOrElse(make[Double](Array())).length
val treatmentSize = Option(treatment).getOrElse(make[Double](Array())).length
val isEnoughData = math.min(controlSize, treatmentSize) >= minValidSampleSize
val (statResult, srmResult) =
if (isEnoughData)
(
MannWhitneyTest.mannWhitneyTest(
control.toArray,
treatment.toArray,
alpha,
beta,
useLinearApproximationForVariance
),
srm(controlSize, treatmentSize, $(srmAlpha))
)
else (getInvalidStatResult(CentralTendency.MEDIAN), false)

StatisticsReport(
statResult,
alpha,
beta,
srm(controlSize, treatmentSize, $(srmAlpha)),
minValidSampleSize,
srmResult,
controlSize,
treatmentSize,
TestType.MANN_WHITNEY.toString
TestType.MANN_WHITNEY.toString,
isEnoughData
)
}
}
Loading

0 comments on commit 773f9f5

Please sign in to comment.