-
Notifications
You must be signed in to change notification settings - Fork 1
/
hash.py
70 lines (62 loc) · 1.58 KB
/
hash.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql import SparkSession
from pyspark import SparkConf
from pyspark.ml.feature import MinHashLSH
import numpy as np
def process(sc, input_path, num_hashes):
# create dataframe
df = (
# load parquet
sc.read.parquet(input_path)
# rename columns to model naming scheme
.select(
F.col("doi"),
F.col("index"),
F.col("features")
)
# filter empty rows
.filter(
F.udf(
lambda x: bool(np.sum(x.toArray()) != 0),
T.BooleanType()
)("features")
)
)
return (
MinHashLSH(inputCol="features", outputCol="hash", numHashTables=num_hashes)
# fit minhash model on text chunks
.fit(df)
# calculate hashes for each chunk
.transform(df)
# drop original features
.drop("features")
)
def run(sc, args):
# args
input_path = args[0]
output_path = args[1]
num_hashes = int(args[2])
# process
df = process(sc, input_path, num_hashes)
df.show()
df.write.mode("overwrite").parquet(output_path)
if __name__ == '__main__':
# args
INPUT = "stereo-vectorized.parquet/*"
OUTPUT = "stereo-hashed.parquet"
NUM_HASHES = 5
# spark session
spark = (
SparkSession
.builder
.config(conf=SparkConf())
.getOrCreate()
)
# process
(
process(spark, INPUT, NUM_HASHES)
.write
.mode('overwrite')
.parquet(OUTPUT)
)