Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tools: add some data probe #152

Merged
merged 1 commit into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions tools/data_probe/collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from itertools import chain

import torch
from torch.distributions import Categorical
from transformers import AutoTokenizer

from data_juicer.format import load_formatter


class TextTokenDistCollector(object):
"""Tokenize and collect distribution of tokens for given
dataset with a specified tokenizer.
"""

def __init__(self, tokenizer):
"""
Initialization method.

:param tokenizer: tokenizer name on huggingface
"""
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer,
trust_remote_code=True)
self.vocab_size = len(self.tokenizer)

def collect(self, data_path, text_key, num_proc=1) -> 'Categorical':
"""
Tokenize and collect tokens distribution of input dataset
:param data_path: path to input dataset.
:param text_key: field keys that will be considered into token counts.
:param num_proc: number of processes to count tokens.
:return: token distribution.
"""

formatter = load_formatter(data_path)
dataset = formatter.load_dataset(num_proc=num_proc)
assert text_key in dataset.features, f'[{text_key} not find in dataset'

def prepare_tokenizer(
tokenizer,
text_key,
):
"""
Prepare a tokenizer function for dataset.
:param tokenizer: a tokenizer to tokenize sample.
:param text_key: field keys that will be
considered into token counts.
"""

def _tokenize_fn(example, ):
example = tokenizer(example[text_key],
add_special_tokens=False)
return example

return _tokenize_fn

tokenize_proc = prepare_tokenizer(self.tokenizer, text_key)
dataset = dataset.map(tokenize_proc,
num_proc=num_proc,
desc=f'tokenize {data_path.split("/")[-1]}')

token_count = torch.zeros(self.vocab_size, dtype=torch.int64)
token_ids = torch.tensor(
list(chain.from_iterable(dataset['input_ids'])))
indices, counts = token_ids.unique(return_counts=True)
token_count.scatter_(0, indices, counts.to(token_count.dtype))
dist = Categorical(token_count)
return dist
41 changes: 41 additions & 0 deletions tools/data_probe/draw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


def draw_heatmap(data, xlabels, ylables=None, figsize=None, triangle=False):
"""
Draw heatmap of input data with special lables.
:param data: input data, now support
[`list`, `tuple`, `numpy array`, 'torch tensor']
:param xlabels: x axis labels.
:param ylabels: y axis labels, if None, use xlabels.
:param figsize: figure size.
:param triangle: only display triangle.
:return: a plot figure.
"""
figsize = figsize if figsize else (8 * 2.5, 6 * 2.5)
_, ax = plt.subplots(figsize=figsize)
mask = None
if triangle:
mask = np.triu(np.ones_like(data))
ax.tick_params(
right=True,
top=True,
labelright=True,
labeltop=True,
)
sns.heatmap(data,
ax=ax,
cmap='Oranges',
annot=True,
mask=mask,
linewidths=.05,
square=True,
xticklabels=xlabels,
yticklabels=ylables,
annot_kws={'size': 8})
plt.subplots_adjust(left=.1, right=0.95, bottom=0.22, top=0.95)
fig = plt.gcf()
plt.show()
return fig
109 changes: 109 additions & 0 deletions tools/data_probe/measure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical


class Measure(object):
"""Base class for Measure distribution.
"""
name = 'base'

def measure(self, *args, **kwargs):
pass

def __call__(self, *args, **kwargs):
return self.measure(*args, **kwargs)

def _convert_to_tensor(self, p):
"""
Convert input data to torch tensor.
:param p: input data, now support
[`scalar`,`list`, `tuple`, `torch binary file`, and `Categorical`].
:return: torch tensor
"""
if isinstance(p, Tensor):
return p
elif isinstance(p, Categorical):
return p.probs
elif isinstance(p, str):
return torch.load(p)
else:
return torch.tensor(p)

def _convert_to_categorical(self, p):
"""
Convert input data to torch Categorical.
:param p: input data, now support
[`scalar`,`list`, `tuple`, `torch binary file`, and `Categorical`].
:return: torch Categorical
"""
if isinstance(p, Categorical):
return p
elif isinstance(p, Tensor):
return Categorical(p)
elif isinstance(p, str):
return Categorical(torch.load(p))
else:
return Categorical(torch.tensor(p))


class KLDivMeasure(Measure):
"""
Measure Kullback-Leibler divergence.
"""
name = 'kl_divergence'

def measure(self, p, q):
p = self._convert_to_categorical(p)
q = self._convert_to_categorical(q)
assert p.probs.shape == q.probs.shape, \
'The two inputs have different shape:' \
f'{p.probs.shape} != {q.probs.shape} in {self.name}'
return F.kl_div(q.logits, p.probs, log_target=False, reduction='sum')


class JSDivMeasure(Measure):
"""
Measure Jensen-Shannon divergence.
"""
name = 'js_divergence'

def measure(self, p, q):
p = self._convert_to_tensor(p)
q = self._convert_to_tensor(q)
assert p.shape == q.shape, \
'The two inputs have different shape:' \
f'{p.shape} != {q.shape} in {self.name}'

m = 0.5 * (p + q)
kl_p = KLDivMeasure()(p, m)
kl_q = KLDivMeasure()(q, m)
js = 0.5 * (kl_p + kl_q)
return js


class CrossEntropyMeasure(Measure):
"""
Measure Cross-Entropy.
"""
name = 'cross_entropy'

def measure(self, p, q):
p = self._convert_to_categorical(p)
q = self._convert_to_categorical(q)
assert p.probs.shape == q.probs.shape, \
'The two inputs have different shape: '\
f'{p.probs.shape} != {q.probs.shape} in {self.name}'
return F.cross_entropy(q.logits, p.probs, reduction='sum')


class EntropyMeasure(Measure):
"""
Measure Entropy.
"""
name = 'entropy'

def measure(self, p):
p = self._convert_to_categorical(p)
return p.entropy()