From bfdc93590d468dcadecc95a6479c4cfeb5d4a117 Mon Sep 17 00:00:00 2001 From: zhijianma Date: Thu, 23 Nov 2023 22:48:15 +0800 Subject: [PATCH] tools: add some data probe --- tools/data_probe/collector.py | 67 +++++++++++++++++++++ tools/data_probe/draw.py | 41 +++++++++++++ tools/data_probe/measure.py | 109 ++++++++++++++++++++++++++++++++++ 3 files changed, 217 insertions(+) create mode 100644 tools/data_probe/collector.py create mode 100644 tools/data_probe/draw.py create mode 100644 tools/data_probe/measure.py diff --git a/tools/data_probe/collector.py b/tools/data_probe/collector.py new file mode 100644 index 000000000..c55cde352 --- /dev/null +++ b/tools/data_probe/collector.py @@ -0,0 +1,67 @@ +from itertools import chain + +import torch +from torch.distributions import Categorical +from transformers import AutoTokenizer + +from data_juicer.format import load_formatter + + +class TextTokenDistCollector(object): + """Tokenize and collect distribution of tokens for given + dataset with a specified tokenizer. + """ + + def __init__(self, tokenizer): + """ + Initialization method. + + :param tokenizer: tokenizer name on huggingface + """ + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, + trust_remote_code=True) + self.vocab_size = len(self.tokenizer) + + def collect(self, data_path, text_key, num_proc=1) -> 'Categorical': + """ + Tokenize and collect tokens distribution of input dataset + :param data_path: path to input dataset. + :param text_key: field keys that will be considered into token counts. + :param num_proc: number of processes to count tokens. + :return: token distribution. + """ + + formatter = load_formatter(data_path) + dataset = formatter.load_dataset(num_proc=num_proc) + assert text_key in dataset.features, f'[{text_key} not find in dataset' + + def prepare_tokenizer( + tokenizer, + text_key, + ): + """ + Prepare a tokenizer function for dataset. + :param tokenizer: a tokenizer to tokenize sample. + :param text_key: field keys that will be + considered into token counts. + """ + + def _tokenize_fn(example, ): + example = tokenizer(example[text_key], + add_special_tokens=False) + return example + + return _tokenize_fn + + tokenize_proc = prepare_tokenizer(self.tokenizer, text_key) + dataset = dataset.map(tokenize_proc, + num_proc=num_proc, + desc=f'tokenize {data_path.split("/")[-1]}') + + token_count = torch.zeros(self.vocab_size, dtype=torch.int64) + token_ids = torch.tensor( + list(chain.from_iterable(dataset['input_ids']))) + indices, counts = token_ids.unique(return_counts=True) + token_count.scatter_(0, indices, counts.to(token_count.dtype)) + dist = Categorical(token_count) + return dist diff --git a/tools/data_probe/draw.py b/tools/data_probe/draw.py new file mode 100644 index 000000000..94f8c078b --- /dev/null +++ b/tools/data_probe/draw.py @@ -0,0 +1,41 @@ +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns + + +def draw_heatmap(data, xlabels, ylables=None, figsize=None, triangle=False): + """ + Draw heatmap of input data with special lables. + :param data: input data, now support + [`list`, `tuple`, `numpy array`, 'torch tensor'] + :param xlabels: x axis labels. + :param ylabels: y axis labels, if None, use xlabels. + :param figsize: figure size. + :param triangle: only display triangle. + :return: a plot figure. + """ + figsize = figsize if figsize else (8 * 2.5, 6 * 2.5) + _, ax = plt.subplots(figsize=figsize) + mask = None + if triangle: + mask = np.triu(np.ones_like(data)) + ax.tick_params( + right=True, + top=True, + labelright=True, + labeltop=True, + ) + sns.heatmap(data, + ax=ax, + cmap='Oranges', + annot=True, + mask=mask, + linewidths=.05, + square=True, + xticklabels=xlabels, + yticklabels=ylables, + annot_kws={'size': 8}) + plt.subplots_adjust(left=.1, right=0.95, bottom=0.22, top=0.95) + fig = plt.gcf() + plt.show() + return fig diff --git a/tools/data_probe/measure.py b/tools/data_probe/measure.py new file mode 100644 index 000000000..80e7c42f8 --- /dev/null +++ b/tools/data_probe/measure.py @@ -0,0 +1,109 @@ +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.distributions import Categorical + + +class Measure(object): + """Base class for Measure distribution. + """ + name = 'base' + + def measure(self, *args, **kwargs): + pass + + def __call__(self, *args, **kwargs): + return self.measure(*args, **kwargs) + + def _convert_to_tensor(self, p): + """ + Convert input data to torch tensor. + :param p: input data, now support + [`scalar`,`list`, `tuple`, `torch binary file`, and `Categorical`]. + :return: torch tensor + """ + if isinstance(p, Tensor): + return p + elif isinstance(p, Categorical): + return p.probs + elif isinstance(p, str): + return torch.load(p) + else: + return torch.tensor(p) + + def _convert_to_categorical(self, p): + """ + Convert input data to torch Categorical. + :param p: input data, now support + [`scalar`,`list`, `tuple`, `torch binary file`, and `Categorical`]. + :return: torch Categorical + """ + if isinstance(p, Categorical): + return p + elif isinstance(p, Tensor): + return Categorical(p) + elif isinstance(p, str): + return Categorical(torch.load(p)) + else: + return Categorical(torch.tensor(p)) + + +class KLDivMeasure(Measure): + """ + Measure Kullback-Leibler divergence. + """ + name = 'kl_divergence' + + def measure(self, p, q): + p = self._convert_to_categorical(p) + q = self._convert_to_categorical(q) + assert p.probs.shape == q.probs.shape, \ + 'The two inputs have different shape:' \ + f'{p.probs.shape} != {q.probs.shape} in {self.name}' + return F.kl_div(q.logits, p.probs, log_target=False, reduction='sum') + + +class JSDivMeasure(Measure): + """ + Measure Jensen-Shannon divergence. + """ + name = 'js_divergence' + + def measure(self, p, q): + p = self._convert_to_tensor(p) + q = self._convert_to_tensor(q) + assert p.shape == q.shape, \ + 'The two inputs have different shape:' \ + f'{p.shape} != {q.shape} in {self.name}' + + m = 0.5 * (p + q) + kl_p = KLDivMeasure()(p, m) + kl_q = KLDivMeasure()(q, m) + js = 0.5 * (kl_p + kl_q) + return js + + +class CrossEntropyMeasure(Measure): + """ + Measure Cross-Entropy. + """ + name = 'cross_entropy' + + def measure(self, p, q): + p = self._convert_to_categorical(p) + q = self._convert_to_categorical(q) + assert p.probs.shape == q.probs.shape, \ + 'The two inputs have different shape: '\ + f'{p.probs.shape} != {q.probs.shape} in {self.name}' + return F.cross_entropy(q.logits, p.probs, reduction='sum') + + +class EntropyMeasure(Measure): + """ + Measure Entropy. + """ + name = 'entropy' + + def measure(self, p): + p = self._convert_to_categorical(p) + return p.entropy()