Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update count_wikipedia.py for Python 3 compatibility and parallel execution #250

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 77 additions & 35 deletions data-scripts/count_wikipedia.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
import sys
import os
import re
import codecs
import operator
import datetime
import nltk
import warnings
import multiprocessing
import time
import io

import nltk
from unidecode import unidecode

def usage():
print '''
print('''
tokenize a directory of text and count unigrams.

usage:
Expand Down Expand Up @@ -48,14 +50,16 @@ def usage():
Then run:
./WikiExtractor.py -o en_sents --no-templates enwiki-20151002-pages-articles.xml.bz2

''' % sys.argv[0]
''' % sys.argv[0])

SENTENCES_PER_BATCH = 500000 # after each batch, delete all counts with count == 1 (hapax legomena)
PRE_SORT_CUTOFF = 300 # before sorting, discard all words with less than this count

SENTENCES_PER_BATCH = 500000 # after each batch, delete all counts with count == 1 (hapax legomena)
PRE_SORT_CUTOFF = 300 # before sorting, discard all words with less than this count

ALL_NON_ALPHA = re.compile(r'^[\W\d]*$', re.UNICODE)
SOME_NON_ALPHA = re.compile(r'[\W\d]', re.UNICODE)


class TopTokenCounter(object):
def __init__(self):
self.count = {}
Expand Down Expand Up @@ -110,7 +114,7 @@ def batch_prune(self):

def pre_sort_prune(self):
under_cutoff = set()
for token, count in self.count.iteritems():
for token, count in self.count.items():
if count < PRE_SORT_CUTOFF:
under_cutoff.add(token)
for token in under_cutoff:
Expand All @@ -127,43 +131,81 @@ def get_stats(self):
ts = self.get_ts()
return "%s keys(count): %d" % (ts, len(self.count))

def merge(self, other):
self.discarded |= other.discarded
self.legomena ^= other.legomena
for token, num in other.count.items():
if token in self.count:
self.count[token] += num
else:
self.count[token] = num


def count_file(path):
"""
Scan the file at given path, tokenize all lines and return the filled TopTokenCounter
and the number of processed lines.
"""
counter = TopTokenCounter()
lines = 0
for line in io.open(path, 'r', encoding='utf8'):
with warnings.catch_warnings():
# unidecode() occasionally (rarely but enough to clog terminal outout)
# complains about surrogate characters in some wikipedia sentences.
# ignore those warnings.
warnings.simplefilter('ignore')
line = unidecode(line)
tokens = nltk.word_tokenize(line)
counter.add_tokens(tokens)
lines += 1
return counter, lines


def main(input_dir_str, output_filename):
counter = TopTokenCounter()
print counter.get_ts(), 'starting...'
print(counter.get_ts(), 'starting...')
tic = time.time()
pruned_lines = 0
lines = 0
for root, dirs, files in os.walk(input_dir_str, topdown=True):
if not files:
continue
for fname in files:
path = os.path.join(root, fname)
for line in codecs.open(path, 'r', 'utf8'):
with warnings.catch_warnings():
# unidecode() occasionally (rarely but enough to clog terminal outout)
# complains about surrogate characters in some wikipedia sentences.
# ignore those warnings.
warnings.simplefilter('ignore')
line = unidecode(line)
tokens = nltk.word_tokenize(line)
counter.add_tokens(tokens)
lines += 1
if lines % SENTENCES_PER_BATCH == 0:
counter.batch_prune()
print counter.get_stats()
print 'processing: %s' % path
print counter.get_stats()
print 'deleting tokens under cutoff of', PRE_SORT_CUTOFF
files = 0
process_pool = multiprocessing.Pool()
# Some python iterator magic: Pool.imap() maps the given function over the iterable
# using the process pool. The iterable is produced by creating the full path of every
# file in every directory (thus, the nested generator expression).
for fcounter, l in process_pool.imap(
count_file, (os.path.join(root, fname)
for root, dirs, files in os.walk(input_dir_str, topdown=True)
if files
for fname in files), 4):
lines += l
files += 1
counter.merge(fcounter)
if (lines - pruned_lines) >= SENTENCES_PER_BATCH:
counter.batch_prune()
pruned_lines = lines
print(counter.get_stats())

toc = time.time()
print("Finished reading input data. Read %d files with %d lines in %.2fs."
% (files, lines, toc-tic))
print(counter.get_stats())

print('deleting tokens under cutoff of', PRE_SORT_CUTOFF)
counter.pre_sort_prune()
print 'done'
print counter.get_stats()
print counter.get_ts(), 'sorting...'
print('done')
print(counter.get_stats())

print(counter.get_ts(), 'sorting...')
sorted_pairs = counter.get_sorted_pairs()
print counter.get_ts(), 'done'
print 'writing...'
with codecs.open(output_filename, 'w', 'utf8') as f:
print(counter.get_ts(), 'done')

print('writing...')
with io.open(output_filename, 'w', encoding='utf8') as f:
for token, count in sorted_pairs:
f.write('%-18s %d\n' % (token, count))
sys.exit(0)


if __name__ == '__main__':
if len(sys.argv) != 3:
usage()
Expand Down