Skip to content

Commit

Permalink
fixed OutlierRemoveTransformer and added srm stat (#2)
Browse files Browse the repository at this point in the history
* fixed OutlierRemoveTransformer and added srm stat
  • Loading branch information
deadsalmonbrain authored May 23, 2024
1 parent aef772d commit f463773
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 20 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ jobs:
strategy:
matrix:
include:
- scala-version: 2.11.8
spark-version: 2.3.0
- scala-version: 2.11.8
spark-version: 2.4.3
- scala-version: 2.12.11
spark-version: 3.0.0
- scala-version: 2.12.11
Expand Down
4 changes: 0 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ jobs:
strategy:
matrix:
include:
- scala-version: 2.11.8
spark-version: 2.3.0
- scala-version: 2.11.8
spark-version: 2.4.3
- scala-version: 2.12.11
spark-version: 3.0.0
- scala-version: 2.12.11
Expand Down
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
long_description=long_description,
long_description_content_type="text/markdown",
python_requires=">=3.6",
install_requires=["pyspark>=2.3.0", "numpy"],
install_requires=["pyspark>=3.0.0", "numpy"],
tests_require=["pytest"],
project_urls={
"Source code": "https://github.com/Salmon-Brain/dead-salmon-brain/tree/main/python",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@ trait BaseStatisticTransformer

override def transformSchema(schema: StructType): StructType = outputSchema

def srm(controlSize: Int, treatmentSize: Int, alpha: Double): Boolean = {
def srm(controlSize: Int, treatmentSize: Int): Double = {
val uniform = (treatmentSize + controlSize).toDouble / 2
TestUtils.chiSquareTest(
Array(uniform, uniform),
Array(controlSize, treatmentSize),
alpha
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ case class StatisticsReport(
beta: Double,
minValidSampleSize: Int,
srm: Boolean,
srmAlpha: Double,
pValueSrm: Double,
controlSize: Long,
treatmentSize: Long,
testType: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,17 @@ class MannWhitneyStatisticsTransformer(override val uid: String) extends BaseSta
beta,
useLinearApproximationForVariance
),
srm(controlSize, treatmentSize, $(srmAlpha))
srm(controlSize, treatmentSize)
)
else (getInvalidStatResult(CentralTendency.MEDIAN), false)
else (getInvalidStatResult(CentralTendency.MEDIAN), -1d)

StatisticsReport(
statResult,
alpha,
beta,
minValidSampleSize,
srmResult < $(srmAlpha),
$(srmAlpha),
srmResult,
controlSize,
treatmentSize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,34 @@ class OutlierRemoveTransformer(override val uid: String)
set(upperPercentile, value)

override def transform(dataset: Dataset[_]): DataFrame = {
assert($(lowerPercentile) > 0 && $(lowerPercentile) < 1, "lowerPercentile must be in (0, 1)")
assert($(upperPercentile) > 0 && $(upperPercentile) < 1, "upperPercentile must be in (0, 1)")
assert(
$(upperPercentile) > $(lowerPercentile),
"upperPercentile must be greater than lowerPercentile"
)

val isDisabledLower = $(lowerPercentile) <= 0

import dataset.sparkSession.implicits._

val aggFunc = Seq(
callUDF("percentile_approx", col($(valueColumn)), lit($(upperPercentile))) as "rightBound"
) ++
(if (isDisabledLower) Seq()
else
Seq(
callUDF(
"percentile_approx",
col($(valueColumn)),
lit($(lowerPercentile))
) as "leftBound"
))
val filterFunc =
if (isDisabledLower) col($(valueColumn)) < $"rightBound"
else col($(valueColumn)) > $"leftBound" && col($(valueColumn)) < $"rightBound"

val dropCols = if (isDisabledLower) Seq("rightBound") else Seq("rightBound", "leftBound")

val columns = Seq(
$(variantColumn),
$(experimentColumn),
Expand All @@ -59,14 +78,14 @@ class OutlierRemoveTransformer(override val uid: String)
val percentilesBound = dataset
.groupBy(columns.head, columns: _*)
.agg(
callUDF("percentile_approx", col($(valueColumn)), lit($(lowerPercentile))) as "leftBound",
callUDF("percentile_approx", col($(valueColumn)), lit($(upperPercentile))) as "rightBound"
aggFunc.head,
aggFunc.tail: _*
)

dataset
.join(broadcast(percentilesBound), columns)
.filter(col($(valueColumn)) > $"leftBound" && col($(valueColumn)) < $"rightBound")
.drop("leftBound", "rightBound")
.filter(filterFunc)
.drop(dropCols: _*)
}

override def copy(extra: ParamMap): Transformer = defaultCopy(extra)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,17 @@ class WelchStatisticsTransformer(override val uid: String) extends BaseStatistic
if (isEnoughData)
(
WelchTTest.welchTTest(control, treatment, alpha, beta),
srm(controlSize.toInt, treatmentSize.toInt, $(srmAlpha))
srm(controlSize.toInt, treatmentSize.toInt)
)
else (getInvalidStatResult(CentralTendency.MEAN), false)
else (getInvalidStatResult(CentralTendency.MEAN), -1d)

StatisticsReport(
statResult,
alpha,
beta,
minValidSampleSize,
srmResult < $(srmAlpha),
$(srmAlpha),
srmResult,
controlSize,
treatmentSize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,10 @@ class OutlierRemoveTransformerSpec extends AnyFlatSpec with SparkHelper with Mat
val clearData = new OutlierRemoveTransformer().transform(data)
assert(clearData.count() == 26)
}

"OutlierRemoveTransformerSpec with 0 lower percentile" should "be" in {
val data = generateDataForWelchTest()
val clearData = new OutlierRemoveTransformer().setLowerPercentile(0).transform(data)
assert(clearData.count() == 28)
}
}

0 comments on commit f463773

Please sign in to comment.