forked from GoogleCloudPlatform/bigquery-utils
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mannwhitneyu.sqlx
63 lines (60 loc) · 2.46 KB
/
mannwhitneyu.sqlx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
config { hasOutput: true }
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
--Computes the U statistics and the p value of the Mann–Whitney U test (also called Mann–Whitney–Wilcoxon)
--inputs: x,y (arrays of samples, both should be one-dimensional, type: ARRAY<FLOAT64> )
-- alt (Defines the alternative hypothesis. The following options are available: 'two-sided', 'less', and 'greater'
CREATE OR REPLACE FUNCTION ${self()}(x ARRAY<FLOAT64>, y ARRAY<FLOAT64>, alt STRING)
OPTIONS (
description="""Computes the U statistics and the p value of the Mann–Whitney U test (also called Mann–Whitney–Wilcoxon).
Inputs: x,y (arrays of samples, both should be one-dimensional, type: ARRAY<FLOAT64> )
alt (Defines the alternative hypothesis. The following options are available: 'two-sided', 'less', and 'greater'"""
)
AS (
(
WITH statistics as (
WITH summ as (
WITH mydata as(
SELECT true as label, xi as val
FROM UNNEST( x ) as xi
UNION ALL
SELECT false as label, yi as val
FROM UNNEST( y ) as yi
)
SELECT label, (RANK() OVER(ORDER BY val ASC))+(COUNT(*) OVER (PARTITION BY CAST(val AS STRING)) - 1)/2.0 AS r
FROM mydata
)
SELECT sum(r) - ARRAY_LENGTH(x)*(ARRAY_LENGTH(x)+ 1)/2.0 as U2,
ARRAY_LENGTH(x)*(ARRAY_LENGTH(y) + (ARRAY_LENGTH(x)+ 1)/2.0) - sum(r) as U1,
ARRAY_LENGTH(x) * ARRAY_LENGTH(y) as n1n2,
SQRT(ARRAY_LENGTH(x) * ARRAY_LENGTH(y) *(ARRAY_LENGTH(x)+ARRAY_LENGTH(y)+1)/12.0 ) as den,
FROM summ
WHERE label
),
normal_appr as (
SELECT
CASE alt
WHEN 'less' THEN ( n1n2/2.0 +0.5 - U1)/den
WHEN 'greater' THEN (n1n2/2.0 +0.5 - U2)/den
ELSE -1.0*ABS(n1n2/2.0 + 0.5 - IF(U1<U2,U2,U1))/den
END as z,
U1 as U,
IF( alt='less' OR alt='greater', 1.0, 2.0 ) as factor
FROM statistics
)
SELECT struct(U, factor* ${ref("normal_cdf")}(z,0.0,1.0) as p )
FROM normal_appr
));