Skip to content

Commit

Permalink
tools: add some data probe (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhijianma authored Dec 25, 2023
1 parent b53d2dc commit b105ed0
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 0 deletions.
67 changes: 67 additions & 0 deletions tools/data_probe/collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from itertools import chain

import torch
from torch.distributions import Categorical
from transformers import AutoTokenizer

from data_juicer.format import load_formatter


class TextTokenDistCollector(object):
"""Tokenize and collect distribution of tokens for given
dataset with a specified tokenizer.
"""

def __init__(self, tokenizer):
"""
Initialization method.
:param tokenizer: tokenizer name on huggingface
"""
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer,
trust_remote_code=True)
self.vocab_size = len(self.tokenizer)

def collect(self, data_path, text_key, num_proc=1) -> 'Categorical':
"""
Tokenize and collect tokens distribution of input dataset
:param data_path: path to input dataset.
:param text_key: field keys that will be considered into token counts.
:param num_proc: number of processes to count tokens.
:return: token distribution.
"""

formatter = load_formatter(data_path)
dataset = formatter.load_dataset(num_proc=num_proc)
assert text_key in dataset.features, f'[{text_key} not find in dataset'

def prepare_tokenizer(
tokenizer,
text_key,
):
"""
Prepare a tokenizer function for dataset.
:param tokenizer: a tokenizer to tokenize sample.
:param text_key: field keys that will be
considered into token counts.
"""

def _tokenize_fn(example, ):
example = tokenizer(example[text_key],
add_special_tokens=False)
return example

return _tokenize_fn

tokenize_proc = prepare_tokenizer(self.tokenizer, text_key)
dataset = dataset.map(tokenize_proc,
num_proc=num_proc,
desc=f'tokenize {data_path.split("/")[-1]}')

token_count = torch.zeros(self.vocab_size, dtype=torch.int64)
token_ids = torch.tensor(
list(chain.from_iterable(dataset['input_ids'])))
indices, counts = token_ids.unique(return_counts=True)
token_count.scatter_(0, indices, counts.to(token_count.dtype))
dist = Categorical(token_count)
return dist
41 changes: 41 additions & 0 deletions tools/data_probe/draw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


def draw_heatmap(data, xlabels, ylables=None, figsize=None, triangle=False):
"""
Draw heatmap of input data with special lables.
:param data: input data, now support
[`list`, `tuple`, `numpy array`, 'torch tensor']
:param xlabels: x axis labels.
:param ylabels: y axis labels, if None, use xlabels.
:param figsize: figure size.
:param triangle: only display triangle.
:return: a plot figure.
"""
figsize = figsize if figsize else (8 * 2.5, 6 * 2.5)
_, ax = plt.subplots(figsize=figsize)
mask = None
if triangle:
mask = np.triu(np.ones_like(data))
ax.tick_params(
right=True,
top=True,
labelright=True,
labeltop=True,
)
sns.heatmap(data,
ax=ax,
cmap='Oranges',
annot=True,
mask=mask,
linewidths=.05,
square=True,
xticklabels=xlabels,
yticklabels=ylables,
annot_kws={'size': 8})
plt.subplots_adjust(left=.1, right=0.95, bottom=0.22, top=0.95)
fig = plt.gcf()
plt.show()
return fig
109 changes: 109 additions & 0 deletions tools/data_probe/measure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical


class Measure(object):
"""Base class for Measure distribution.
"""
name = 'base'

def measure(self, *args, **kwargs):
pass

def __call__(self, *args, **kwargs):
return self.measure(*args, **kwargs)

def _convert_to_tensor(self, p):
"""
Convert input data to torch tensor.
:param p: input data, now support
[`scalar`,`list`, `tuple`, `torch binary file`, and `Categorical`].
:return: torch tensor
"""
if isinstance(p, Tensor):
return p
elif isinstance(p, Categorical):
return p.probs
elif isinstance(p, str):
return torch.load(p)
else:
return torch.tensor(p)

def _convert_to_categorical(self, p):
"""
Convert input data to torch Categorical.
:param p: input data, now support
[`scalar`,`list`, `tuple`, `torch binary file`, and `Categorical`].
:return: torch Categorical
"""
if isinstance(p, Categorical):
return p
elif isinstance(p, Tensor):
return Categorical(p)
elif isinstance(p, str):
return Categorical(torch.load(p))
else:
return Categorical(torch.tensor(p))


class KLDivMeasure(Measure):
"""
Measure Kullback-Leibler divergence.
"""
name = 'kl_divergence'

def measure(self, p, q):
p = self._convert_to_categorical(p)
q = self._convert_to_categorical(q)
assert p.probs.shape == q.probs.shape, \
'The two inputs have different shape:' \
f'{p.probs.shape} != {q.probs.shape} in {self.name}'
return F.kl_div(q.logits, p.probs, log_target=False, reduction='sum')


class JSDivMeasure(Measure):
"""
Measure Jensen-Shannon divergence.
"""
name = 'js_divergence'

def measure(self, p, q):
p = self._convert_to_tensor(p)
q = self._convert_to_tensor(q)
assert p.shape == q.shape, \
'The two inputs have different shape:' \
f'{p.shape} != {q.shape} in {self.name}'

m = 0.5 * (p + q)
kl_p = KLDivMeasure()(p, m)
kl_q = KLDivMeasure()(q, m)
js = 0.5 * (kl_p + kl_q)
return js


class CrossEntropyMeasure(Measure):
"""
Measure Cross-Entropy.
"""
name = 'cross_entropy'

def measure(self, p, q):
p = self._convert_to_categorical(p)
q = self._convert_to_categorical(q)
assert p.probs.shape == q.probs.shape, \
'The two inputs have different shape: '\
f'{p.probs.shape} != {q.probs.shape} in {self.name}'
return F.cross_entropy(q.logits, p.probs, reduction='sum')


class EntropyMeasure(Measure):
"""
Measure Entropy.
"""
name = 'entropy'

def measure(self, p):
p = self._convert_to_categorical(p)
return p.entropy()

0 comments on commit b105ed0

Please sign in to comment.