From 57ddc145d4f4da361c6cfd4896da7df4ac588ca1 Mon Sep 17 00:00:00 2001 From: Stella Biderman Date: Thu, 19 Oct 2023 16:24:26 -0400 Subject: [PATCH] Delete mup/mup directory --- mup/mup/__init__.py | 7 - mup/mup/coord_check.py | 603 --------------------------------------- mup/mup/infshape.py | 141 --------- mup/mup/init.py | 202 ------------- mup/mup/layer.py | 84 ------ mup/mup/optim.py | 143 ---------- mup/mup/shape.py | 209 -------------- mup/mup/test/__main__.py | 379 ------------------------ mup/mup/test/models.py | 146 ---------- 9 files changed, 1914 deletions(-) delete mode 100644 mup/mup/__init__.py delete mode 100644 mup/mup/coord_check.py delete mode 100644 mup/mup/infshape.py delete mode 100644 mup/mup/init.py delete mode 100644 mup/mup/layer.py delete mode 100644 mup/mup/optim.py delete mode 100644 mup/mup/shape.py delete mode 100644 mup/mup/test/__main__.py delete mode 100644 mup/mup/test/models.py diff --git a/mup/mup/__init__.py b/mup/mup/__init__.py deleted file mode 100644 index f11535b1f..000000000 --- a/mup/mup/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -name = "mup" - -from mup.shape import * -from mup.infshape import * -from mup.init import * -from mup.layer import * -from mup.optim import * \ No newline at end of file diff --git a/mup/mup/coord_check.py b/mup/mup/coord_check.py deleted file mode 100644 index 9321f759e..000000000 --- a/mup/mup/coord_check.py +++ /dev/null @@ -1,603 +0,0 @@ -# Copyright 2022 Microsoft Corporation. -''' -Helper functions for performing coord check. -''' -import os -from copy import copy -from itertools import product - -import numpy as np -import pandas as pd -import torch -import torch.nn.functional as F - - -def cov(x): - '''Treat `x` as a collection of vectors and its Gram matrix. - Input: - x: If it has shape [..., d], then it's treated as - a collection of d-dimensional vectors - Output: - cov: a matrix of size N x N where N is the product of - the non-last dimensions of `x`. - ''' - if x.nelement() == 1: - width = 1 - xx = x.reshape(1, 1) - else: - width = x.shape[-1] - xx = x.reshape(-1, x.shape[-1]) - return xx @ xx.T / width - -def covoffdiag(x): - '''Get off-diagonal entries of `cov(x)` in a vector. - Input: - x: If it has shape [..., d], then it's treated as - a collection of d-dimensional vectors - Output: - Off-diagonal entries of `cov(x)` in a vector.''' - c = cov(x) - return c[~torch.eye(c.shape[0], dtype=bool)] - -#: dict of provided functions for use in coord check -FDICT = { - 'l1': lambda x: torch.abs(x).mean(dtype=torch.float32), - 'l2': lambda x: (x**2).mean(dtype=torch.float32)**0.5, - 'mean': lambda x: x.mean(dtype=torch.float32), - 'std': lambda x: x.std(dtype=torch.float32), - 'covl1': lambda x: torch.abs(cov(x)).mean(dtype=torch.float32), - 'covl2': lambda x: (cov(x)**2).mean(dtype=torch.float32)**0.5, - 'covoffdiagl1': lambda x: torch.abs(covoffdiag(x)).mean(dtype=torch.float32), - 'covoffdiagl2': lambda x: (covoffdiag(x)**2).mean(dtype=torch.float32)**0.5 -} - -def convert_fdict(d): - '''convert a dict `d` with string values to function values. - Input: - d: a dict whose values are either strings or functions - Output: - a new dict, with the same keys as `d`, but the string values are - converted to functions using `FDICT`. - ''' - return dict([ - ((k, FDICT[v]) if isinstance(v, str) else (k, v)) - for k, v in d.items()]) - -def _record_coords(records, width, modulename, t, - output_fdict=None, input_fdict=None, param_fdict=None): - '''Returns a forward hook that records coordinate statistics. - - Returns a forward hook that records statistics regarding the output, input, - and/or parameters of a `nn.Module`. This hook is intended to run only once, - on the timestep specified by `t`. - - On forward pass, the returned hook calculates statistics specified in - `output_fdict`, `input_fdict`, and `param_fdict`, such as the normalized l1 - norm, of output, input, and/or parameters of the module. The statistics are - recorded along with the `width`, `modulename`, and `t` (the time step) as a - dict and inserted into `records` (which should be a list). More precisely, - for each output, input, and/or parameter, the inserted dict is of the form - - { - 'width': width, 'module': modified_modulename, 't': t, - # keys are keys in fdict - 'l1': 0.241, 'l2': 0.420, 'mean': 0.0, ... - } - - where `modified_modulename` is a string that combines the `modulename` with - an indicator of which output, input, or parameter tensor is the statistics - computed over. - - The `*_fdict` inputs should be dictionaries with string keys and whose - values can either be functions or strings. The string values are converted - to functions via `convert_fdict`. The default values of `*_dict` inputs are - converted to `output_fdict = dict(l1=FDICT['l1'])`, `input_fdict = {}`, - `param_fdict = {}`, i.e., only the average coordinate size (`l1`) of the - output activations are recorded. - - Inputs: - records: - list to append coordinate data to - width: - width of the model. This is used only for plotting coord check later - on, so it can be any notion of width. - modulename: - string name of the module. This is used only for plotting coord check. - t: - timestep of training. This is used only for plotting coord check. - output_fdict, input_fdict, param_fdict: - dicts with string keys and whose values can either be functions or - strings. The string values are converted to functions via - `convert_fdict` - Output: - a forward hook that records statistics regarding the output, input, - and/or parameters of a `nn.Module`, as discussed above. - ''' - if output_fdict is None: - output_fdict = dict(l1=FDICT['l1']) - else: - output_fdict = convert_fdict(output_fdict) - if input_fdict is None: - input_fdict = {} - else: - input_fdict = convert_fdict(input_fdict) - if param_fdict is None: - param_fdict = {} - else: - param_fdict = convert_fdict(param_fdict) - def f(module, input, output): - def get_stat(d, x, fdict): - if isinstance(x, (tuple, list)): - for i, _x in enumerate(x): - _d = copy(d) - _d['module'] += f'[{i}]' - get_stat(_d, _x, fdict) - elif isinstance(x, dict): - for name, _x in x.items(): - _d = copy(d) - _d['module'] += f'[{name}]' - get_stat(_d, _x, fdict) - elif isinstance(x, torch.Tensor): - _d = copy(d) - for fname, f in fdict.items(): - _d[fname] = f(x).item() - records.append(_d) - elif x is None: - pass - else: - raise NotImplementedError(f'Unexpected output type: {type(x)}') - with torch.no_grad(): - ret = { - 'width': width, - 'module': modulename, - 't': t - } - - # output stats - if isinstance(output, (tuple, list)): - for i, out in enumerate(output): - _ret = copy(ret) - _ret['module'] += f':out[{i}]' - get_stat(_ret, out, output_fdict) - elif isinstance(output, dict): - for name, out in output.items(): - _ret = copy(ret) - _ret['module'] += f':out[{name}]' - get_stat(_ret, out, output_fdict) - elif isinstance(output, torch.Tensor): - _ret = copy(ret) - for fname, f in output_fdict.items(): - _ret[fname] = f(output).item() - records.append(_ret) - else: - raise NotImplementedError(f'Unexpected output type: {type(output)}') - - # input stats - if input_fdict: - if isinstance(input, (tuple, list)): - for i, out in enumerate(input): - _ret = copy(ret) - _ret['module'] += f':in[{i}]' - get_stat(_ret, out, input_fdict) - elif isinstance(input, dict): - for name, out in input.items(): - _ret = copy(ret) - _ret['module'] += f':in[{name}]' - get_stat(_ret, out, input_fdict) - elif isinstance(input, torch.Tensor): - _ret = copy(ret) - for fname, f in input_fdict.items(): - _ret[fname] = f(input).item() - records.append(_ret) - else: - raise NotImplementedError(f'Unexpected output type: {type(input)}') - - # param stats - if param_fdict: - for name, p in module.named_parameters(): - _ret = copy(ret) - _ret['module'] += f':param[{name}]' - for fname, f in param_fdict.items(): - _ret[fname] = f(p).item() - records.append(_ret) - - return f - -def _get_coord_data(models, dataloader, optcls, nsteps=3, - dict_in_out=False, flatten_input=False, flatten_output=False, - output_name='loss', lossfn='xent', filter_module_by_name=None, - fix_data=True, cuda=True, nseeds=1, - output_fdict=None, input_fdict=None, param_fdict=None, - show_progress=True, one_hot_target=False): - '''Inner method for `get_coord_data`. - - Train the models in `models` with optimizer given by `optcls` and data from - `dataloader` for `nsteps` steps, and record coordinate statistics specified - by `output_fdict`, `input_fdict`, `param_fdict`. By default, only `l1` is - computed for output activations of each module. - - Inputs: - models: - a dict of lazy models, where the keys are numbers indicating width. - Each entry of `models` is a function that instantiates a model given - nothing. - dataloader: - an iterator whose elements are either Huggingface style dicts, if - `dict_in_out` is True, or (input, label). If `fix_data` is True - (which is the default), then only the first element of `dataloader` - is used in a loop and the rest of `dataloder` is ignored. - optcls: - a function so that `optcls(model)` gives an optimizer used to train - the model. - nsteps: - number of steps to train the model - dict_in_out: - whether the data loader contains Huggingface-style dict input and - output. Default: False - flatten_input: - if not `dict_in_out`, reshape the input to be - `input.view(input.shape[0], -1)`. Typically used for testing MLPs. - flatten_output: - if not `dict_in_out`, reshape the label to be `label.view(-1, - input.shape[-1])`. - output_name: - if `dict_in_out`, this is the key for the loss value if the output - is a dict. If the output is not a dict, then we assume the first - element of the output is the loss. - lossfn: - loss function to use if not `dict_in_out`. Can be either a string from - [`xent`, 'mse', 'nll', 'l1'] or a python `callable` such that - `lossfn(output, target)` returns the loss value. Examples of valid - `callable`s are `F.cross_entropy`, `F.mse_loss`, etc, where `F` is - `torch.nn.functional`. Default: 'xent' - filter_module_by_name: - a function that returns a bool given module names (from - `model.named_modules()`), or None. If not None, then only modules - whose name yields True will be recorded. - cuda: - whether to use cuda or not. Default: True - nseeds: - number of times to repeat the training, each with different seeds. - output_fdict, input_fdict, param_fdict: - function dicts to be used in `_record_coords`. By default, only `l1` - is computed for output activations of each module. - show_progress: - show progress using tqdm. Default: True - one_hot_target: - convert target label into a one-hot vector. This typically is only - used for `'mse'` or `'l1'` losses in classification tasks. - Default: False - Output: - a pandas DataFrame containing recorded results. The column names are - `'width', 'module', 't'` as well as names of statistics recorded, such - as `'l1'` (see `FDICT` for other premade statistics that can be - collected). - - Breaking Changes: - In v1.0.0, when `lossfn=='mse'`, the target is automatically converted - to a one hot vector before loss computation. Starting in v1.1.0, this - behavior is turned off, and the user needs to explicitly turn on this - behavior by setting `one_hot_target=True`. - - ''' - df = [] - if fix_data: - batch = next(iter(dataloader)) - dataloader = [batch] * nsteps - if show_progress: - from tqdm import tqdm - pbar = tqdm(total=nseeds * len(models)) - - for i in range(nseeds): - torch.manual_seed(i) - for width, model in models.items(): - model = model() - model = model.train() - if cuda: - model = model.cuda() - optimizer = optcls(model) - for batch_idx, batch in enumerate(dataloader, 1): - remove_hooks = [] - # add hooks - for name, module in model.named_modules(): - if filter_module_by_name and not filter_module_by_name(name): - continue - remove_hooks.append(module.register_forward_hook( - _record_coords(df, width, name, batch_idx, - output_fdict=output_fdict, - input_fdict=input_fdict, - param_fdict=param_fdict))) - if dict_in_out: - if cuda: - for k, v in batch.items(): - if isinstance(v, torch.Tensor): - batch[k] = v.cuda() - outputs = model(**batch) - loss = outputs[output_name] if isinstance(outputs, dict) else outputs[0] - else: - (data, target) = batch - if cuda: - data, target = data.cuda(), target.cuda() - if flatten_input: - data = data.view(data.size(0), -1) - output = model(data) - if flatten_output: - output = output.view(-1, output.shape[-1]) - if one_hot_target: - target = F.one_hot(target, - num_classes=output.size(-1)).float() - if lossfn == 'xent': - loss = F.cross_entropy(output, target) - elif lossfn == 'mse': - loss = F.mse_loss(output, target) - elif lossfn == 'nll': - loss = F.nll_loss(output, target) - elif lossfn == 'l1': - loss = F.l1_loss(output, target) - elif callable(lossfn): - loss = lossfn(output, target) - else: - raise NotImplementedError(f'unknown `lossfn`: {lossfn}') - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # remove hooks - for handle in remove_hooks: - handle.remove() - - if batch_idx == nsteps: break - if show_progress: - pbar.update(1) - if show_progress: - pbar.close() - return pd.DataFrame(df) - - -def get_coord_data(models, dataloader, optimizer='sgd', lr=None, mup=True, - filter_trainable_by_name=None, - **kwargs): - '''Get coord data for coord check. - - Train the models in `models` with data from `dataloader` and optimizer - specified by `optimizer` and `lr` for `nsteps` steps, and record coordinate - statistics specified by `output_fdict`, `input_fdict`, `param_fdict`. By - default, only `l1` is computed for output activations of each module. - - This function wraps around `_get_coord_data`, with the main difference being - user can specify common optimizers via a more convenient interface. - - Inputs: - models: - a dict of lazy models, where the keys are numbers indicating width. - Each entry of `models` is a function that instantiates a model given - nothing. - dataloader: - an iterator whose elements are either Huggingface style dicts, if - `dict_in_out` is True, or (input, label). If `fix_data` is True - (which is the default), then only the first element of `dataloader` - is used in a loop and the rest of `dataloder` is ignored. - optimizer: - a string in `['sgd', 'adam', 'adamw']`, with default being `'sgd'`. - lr: - learning rate. By default is 0.1 for `'sgd'` and 1e-3 for others. - mup: - If True, then use the optimizer from `mup.optim`; otherwise, use the - one from `torch.optim`. - filter_trainable_by_name: - a function that returns a bool given module names (from - `model.named_modules()`), or None. If not None, then only modules - whose name yields True will be trained. - nsteps: - number of steps to train the model - dict_in_out: - whether the data loader contains Huggingface-style dict input and - output. Default: False - flatten_input: - if not `dict_in_out`, reshape the input to be - `input.view(input.shape[0], -1)`. Typically used for testing MLPs. - flatten_output: - if not `dict_in_out`, reshape the label to be `label.view(-1, - input.shape[-1])`. - output_name: - if `dict_in_out`, this is the key for the loss value if the output - is a dict. If the output is not a dict, then we assume the first - element of the output is the loss. - lossfn: - loss function to use if not `dict_in_out`. Can be either a string from - [`xent`, 'mse', 'nll', 'l1'] or a python `callable` such that - `lossfn(output, target)` returns the loss value. Examples of valid - `callable`s are `F.cross_entropy`, `F.mse_loss`, etc, where `F` is - `torch.nn.functional`. Default: 'xent' - filter_module_by_name: - a function that returns a bool given module names (from - `model.named_modules()`), or None. If not None, then only modules - whose name yields True will be recorded. - cuda: - whether to use cuda or not. Default: True - nseeds: - number of times to repeat the training, each with different seeds. - output_fdict, input_fdict, param_fdict: - function dicts to be used in `_record_coords`. By default, only `l1` - is computed for output activations of each module. - show_progress: - show progress using tqdm. Default: True - one_hot_target: - convert target label into a one-hot vector. This typically is only - used for `'mse'` or `'l1'` losses in classification tasks. - Default: False - Output: - a pandas DataFrame containing recorded results. The column names are - `'width', 'module', 't'` as well as names of statistics recorded, such - as `'l1'` (see `FDICT` for other premade statistics that can be - collected). - - Breaking Changes: - In v1.0.0, when `lossfn=='mse'`, the target is automatically converted - to a one hot vector before loss computation. Starting in v1.1.0, this - behavior is turned off, and the user needs to explicitly turn on this - behavior by setting `one_hot_target=True`. - ''' - if lr is None: - lr = 0.1 if optimizer == 'sgd' else 1e-3 - if mup: - from mup.optim import MuAdam as Adam - from mup.optim import MuAdamW as AdamW - from mup.optim import MuSGD as SGD - else: - from torch.optim import SGD, Adam, AdamW - def get_trainable(model): - params = model.parameters() - if filter_trainable_by_name is not None: - params = [] - for name, p in model.named_parameters(): - if filter_trainable_by_name(name): - params.append(p) - return params - if optimizer == 'sgd': - optcls = lambda model: SGD(get_trainable(model), lr=lr) - elif optimizer == 'adam': - optcls = lambda model: Adam(get_trainable(model), lr=lr) - elif optimizer == 'adamw': - optcls = lambda model: AdamW(get_trainable(model), lr=lr) - elif optimizer is None: - raise ValueError('optimizer should be sgd|adam|adamw or a custom function') - - data = _get_coord_data(models, dataloader, optcls, **kwargs) - data['optimizer'] = optimizer - data['lr'] = lr - return data - - -def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='module', - legend='full', name_contains=None, name_not_contains=None, module_list=None, - loglog=True, logbase=2, face_color=None, subplot_width=5, - subplot_height=4): - '''Plot coord check data `df` obtained from `get_coord_data`. - - Input: - df: - a pandas DataFrame obtained from `get_coord_data` - y: - the column of `df` to plot on the y-axis. Default: `'l1'` - save_to: - path to save the resulting figure, or None. Default: None. - suptitle: - The title of the entire figure. - x: - the column of `df` to plot on the x-axis. Default: `'width'` - hue: - the column of `df` to represent as color. Default: `'module'` - legend: - 'auto', 'brief', 'full', or False. This is passed to `seaborn.lineplot`. - name_contains, name_not_contains: - only plot modules whose name contains `name_contains` and does not contain `name_not_contains` - module_list: - only plot modules that are given in the list, overrides `name_contains` and `name_not_contains` - loglog: - whether to use loglog scale. Default: True - logbase: - the log base, if using loglog scale. Default: 2 - face_color: - background color of the plot. Default: None (which means white) - subplot_width, subplot_height: - The width and height for each timestep's subplot. More precisely, - the figure size will be - `(subplot_width*number_of_time_steps, subplot_height)`. - Default: 5, 4 - - Output: - the `matplotlib` figure object - ''' - ### preprocessing - df = copy(df) - # nn.Sequential has name '', which duplicates the output layer - df = df[df.module != ''] - if module_list is not None: - df = df[df['module'].isin(module_list)] - else: - if name_contains is not None: - df = df[df['module'].str.contains(name_contains)] - if name_not_contains is not None: - df = df[~(df['module'].str.contains(name_not_contains))] - # for nn.Sequential, module names are numerical - try: - df['module'] = pd.to_numeric(df['module']) - except ValueError: - pass - - ts = df.t.unique() - - import matplotlib.pyplot as plt - import seaborn as sns - sns.set() - - def tight_layout(plt): - plt.tight_layout(rect=[0, 0.03, 1, 0.95]) - - ### plot - fig = plt.figure(figsize=(subplot_width * len(ts), subplot_height)) - hue_order = sorted(set(df['module'])) - if face_color is not None: - fig.patch.set_facecolor(face_color) - ymin, ymax = min(df[y]), max(df[y]) - for t in ts: - t = int(t) - plt.subplot(1, len(ts), t) - sns.lineplot(x=x, y=y, data=df[df.t == t], hue=hue, hue_order=hue_order, legend=legend if t == 1 else None) - plt.title(f't={t}') - if t != 1: - plt.ylabel('') - if loglog: - plt.loglog(base=logbase) - ax = plt.gca() - ax.set_ylim([ymin, ymax]) - if suptitle: - plt.suptitle(suptitle) - tight_layout(plt) - if save_to is not None: - plt.savefig(save_to) - print(f'coord check plot saved to {save_to}') - - return fig - -# example of how to plot coord check results -# for the CNN and MLP models in mup.test -def example_plot_coord_check( - arch='mlp', optimizer='sgd', lr=None, widths=None, mup=True, - nsteps=3, nseeds=10, plotdir='', batchnorm=False, batch_size=1, - init='kaiming_fan_in_normal', download_cifar=True, legend='full', - dict_in_out=False, name_contains=None, name_not_contains=None): - - from mup.test.models import get_lazy_models, get_train_loader - if batchnorm: - batch_size = 5 - train_loader = get_train_loader(batch_size=batch_size, download=download_cifar) - - if widths is None: - widths = 2**np.arange(7, 14) if arch == 'mlp' else 2**np.arange(3, 10) - models = get_lazy_models(arch, widths, mup=mup, batchnorm=batchnorm, init=init, readout_zero_init=True) - df = get_coord_data(models, train_loader, mup=mup, lr=lr, optimizer=optimizer, flatten_input=arch == 'mlp', nseeds=nseeds, nsteps=nsteps, dict_in_out=dict_in_out) - - prm = 'μP' if mup else 'SP' - bn = 'on' if batchnorm else 'off' - if lr is None: - lr = 0.1 if optimizer == 'sgd' else 1e-3 - return plot_coord_data(df, legend=legend, - name_contains=name_contains, name_not_contains=name_not_contains, - save_to=os.path.join(plotdir, f'{prm.lower()}_{arch}_{optimizer}_lr{lr}_nseeds{nseeds}_bn{int(batchnorm)}_coord.png'), - suptitle=f'{prm} {arch.upper()} {optimizer} lr={lr} bn={bn} nseeds={nseeds}', - face_color='xkcd:light grey' if not mup else None) - - -if __name__ == '__main__': - import os - os.makedirs('coord_checks', exist_ok=True) - plotdir = 'coord_checks' - - nseeds = 5 - - for arch, opt, bn, mup in product(['mlp', 'cnn'], ['sgd', 'adam'], [False, True], [False, True]): - example_plot_coord_check(arch, opt, batchnorm=bn, mup=mup, nseeds=nseeds, download_cifar=True, legend=None, plotdir=plotdir) - - diff --git a/mup/mup/infshape.py b/mup/mup/infshape.py deleted file mode 100644 index d8fa335ac..000000000 --- a/mup/mup/infshape.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2022 Microsoft Corporation. - -from copy import copy - - -class InfDim: - '''A dimension with a base dimension, used for calculating μP scaling. - - An `InfDim` object is made up of 2 numbers: a dimension and a base - dimension. If the base dimension is None, then this object represents a - "finite", or "non-width" dimension. Otherwise, it represents an "infinite", - or "width" dimension. - ''' - - def __init__(self, base_dim, dim): - self.base_dim = base_dim - self.dim = dim - - def isinf(self): - return self.base_dim is not None - - def width_mult(self): - '''Width multiplier used for calculating μP scaling. - - If finite, return 1. - If infinite, return dim / base_dim. - ''' - if self.isinf(): - return self.dim / self.base_dim - return 1 - - def __repr__(self): - return f'InfDim({self.base_dim}, {self.dim})' - - def __str__(self): - if self.isinf(): - return repr(self) - return f'FinDim({self.dim})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, InfDim): - return False - return self.base_dim == other.base_dim and \ - self.dim == other.dim - - -class InfShape(tuple): - '''A tuple of `InfDim`s. - - This is intended to be attached to each parameter tensor `p` as `p.infshape`. - ''' - - def __init__(self, *args, **kwargs): - tuple.__init__(*args, **kwargs) - for dim in self: - if not isinstance(dim, InfDim): - raise ValueError('Elements of InfShape needs to be of class InfDim') - # set main to be the last dimension that is infinite - # for inf x inf this is fanin - # for inf x fin or fin x inf it's the unique inf dim - # user can set this manually if necessary - self.main_idx = self.main = None - for i, dim in list(enumerate(self))[::-1]: - if dim.isinf(): - self.main_idx = i - self.main = dim - break - - def fanin_fanout(self): - assert len(self) >= 2, 'fanin, fanout undefined for 1-dimensional weights' - return self[1], self[0] - - def fanin_fanout_mult_ratio(self): - fanin, fanout = self.fanin_fanout() - return fanin.width_mult() / fanout.width_mult() - - def ninf(self): - return sum(1 for dim in self if dim.isinf()) - - def width_mult(self): - if self.main is not None: - return self.main.width_mult() - return 1 - - def base_shape(self): - return [d.base_dim for d in self] - - def shape(self): - return [d.dim for d in self] - - def __repr__(self): - r = tuple.__repr__(self)[1:-1] - return f'InfShape([{r}])' - - def serialize(self): - d = {'base_shape': [], 'shape': []} - for infdim in self: - d['shape'].append(infdim.dim) - d['base_shape'].append(infdim.base_dim) - return d - - def __eq__(self, other: object) -> bool: - if not isinstance(other, InfShape): - return False - return all(d == dd for d, dd in zip(self, other)) - - @classmethod - def deserialize(cls, d): - infshape = [] - for base_dim, dim in zip(d['base_shape'], d['shape']): - infshape.append(InfDim(base_dim, dim)) - return InfShape(infshape) - - @classmethod - def from_base_shape(cls, bsh): - return InfShape([InfDim(bd, None) for bd in bsh]) - -def zip_infshape(base_dims, dims, fin_if_same=True): - infshape = [] - for bd, d in zip(base_dims, dims): - if isinstance(bd, InfDim): - # retain bd's base_dim but overwrite dim - infdim = copy(bd) - infdim.dim = d - infshape.append(infdim) - elif isinstance(bd, int): - if bd == d and fin_if_same: - infshape.append(InfDim(None, d)) - else: - infshape.append(InfDim(bd, d)) - else: - raise ValueError(f'unhandled base_dim type: {type(bd)}') - return InfShape(infshape) - -if __name__ == '__main__': - infshape = InfShape([InfDim(None, 100), InfDim(128, 1024), InfDim(128, 128)]) - print(infshape) - print(f'{infshape.ninf()} dims are inf') - print(f'width_mult {infshape.width_mult()}') - - print(zip_infshape([64, 128, 1024], [32, 128, 2048])) \ No newline at end of file diff --git a/mup/mup/init.py b/mup/mup/init.py deleted file mode 100644 index 637236c98..000000000 --- a/mup/mup/init.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright 2022 Microsoft Corporation. -''' -Initializer functions mirroring those of `torch.nn.init`. They serve as -drop-in replacements after the user has called `set_base_shapes` on their -model. - -All of the initializers here are designed to 1) behave exactly the same -as the torch versions when the model shapes are equal to their base shapes, -and 2) to scale with width correctly (according to μP), when the model shapes -differ from the base shapes. In general, this means deviating from the -torch version behaviors. -''' -import math -import warnings - -import torch -from torch.nn.init import (_calculate_correct_fan, - _calculate_fan_in_and_fan_out, _no_grad_fill_, - _no_grad_normal_, _no_grad_uniform_, calculate_gain) - - -def constant_std_init_(tensor, sampler_): - assert hasattr(tensor, 'infshape'), 'Please call set_base_shapes(...)' - if tensor.infshape.ninf() <= 1: - sampler_(tensor) - elif tensor.infshape.ninf() == 2: - sampler_(tensor, scale=tensor.infshape.width_mult()**-0.5) - else: - raise NotImplementedError() - return tensor - -def uniform_(tensor, a=0, b=1): - '''Drop-in replacement of `torch.nn.init.uniform_`. - Note: - - if using this function, ensure `a` and `b` do not depend on fan-in, - fan-out, or other notions of width, e.g. if a = 0, b = 1. - - `tensor` should have `infshape` attribute set by `set_base_shapes`. - ''' - assert hasattr(tensor, 'infshape'), 'Please call set_base_shapes(...)' - if a != -b: - assert tensor.infshape.ninf() == 1, 'Sampler for (inf, inf) tensors should have mean 0' - def sampler_(tensor, scale=1): - _no_grad_uniform_(tensor, a * scale, b * scale) - return constant_std_init_(tensor, sampler_) - -def normal_(tensor, mean=0, std=1): - '''Drop-in replacement of `torch.nn.init.normal_`. - Note: - - if using this function, ensure `mean` and `std` do not depend on - fan-in, fan-out, or other notions of width, e.g. if mean = 0, std = - 1. - - `tensor` should have `infshape` attribute set by `set_base_shapes`. - ''' - if mean != 0: - assert tensor.infshape.ninf() == 1, 'Sampler for (inf, inf) tensors should have mean 0' - def sampler_(tensor, scale=1): - _no_grad_normal_(tensor, mean=mean*scale, std=std*scale) - return constant_std_init_(tensor, sampler_) - -def ones_(tensor): - '''Same as `torch.nn.init.ones_`. - Note: - - `tensor` should have `infshape` attribute set by `set_base_shapes`. - ''' - assert tensor.infshape.ninf() == 1, 'Sampler for (inf, inf) tensors should have mean 0' - def sampler_(tensor, scale=1): - _no_grad_fill_(tensor, scale) - return constant_std_init_(tensor, sampler_) - -def eye_(tensor): - '''Same as `torch.nn.init.eye_`. - Note: - - `tensor` should have `infshape` attribute set by `set_base_shapes`. - ''' - assert tensor.infshape.ninf() == 1, 'Sampler for (inf, inf) tensors should have mean 0' - return torch.nn.init.eye_(tensor) - - -def _inf_fan_adjust_xavier(scale, tensor): - fan_out, fan_in = tensor.infshape[:2] - # following are needed to accomodate SP models where all infshapes are finite so base_dims are Nones - fan_out_base_dim = fan_out.base_dim or fan_out.dim - fan_in_base_dim = fan_in.base_dim or fan_in.dim - scale *= math.sqrt( - (fan_out.dim + fan_in.dim) - / (fan_out_base_dim + fan_in_base_dim)) - if tensor.infshape.ninf() <= 1: - # should have fixed scale - pass - elif tensor.infshape.ninf() == 2: - # should scale like fanin - assert fan_out.isinf() and fan_in.isinf() - scale /= math.sqrt(fan_in.width_mult()) - else: - raise NotImplementedError('can only handle 2 inf dimensions currently') - return scale - - -def xavier_uniform_(tensor, gain=1.): - '''Drop-in replacement of `torch.nn.init.xavier_uniform_`. - Note: - - if using this function, ensure `gain` does not depend on fan-in, - fan-out, or other notions of width, e.g. if gain = 1. - - `tensor` should have `infshape` attribute set by `set_base_shapes`. - ''' - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) - std = _inf_fan_adjust_xavier(std, tensor) - a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation - return _no_grad_uniform_(tensor, -a, a) - - -def xavier_normal_(tensor, gain=1.): - '''Drop-in replacement of `torch.nn.init.xavier_normal_`. - Note: - - if using this function, ensure `gain` does not depend on fan-in, - fan-out, or other notions of width, e.g. if gain = 1. - - `tensor` should have `infshape` attribute set by `set_base_shapes`. - ''' - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) - std = _inf_fan_adjust_xavier(std, tensor) - return _no_grad_normal_(tensor, 0., std) - - -def _inf_fan_adjust_kaiming(scale, tensor, mode): - fan_out, fan_in = tensor.infshape[:2] - if tensor.infshape.ninf() == 0: - return scale - elif tensor.infshape.ninf() == 1: - # should have fixed scale - if mode == 'fan_in' and fan_in.isinf(): - scale *= fan_in.width_mult()**0.5 - elif mode == 'fan_out' and fan_out.isinf(): - scale *= fan_out.width_mult()**0.5 - elif tensor.infshape.ninf() == 2: - # should scale like fanin - assert fan_out.isinf() and fan_in.isinf() - if mode == 'fan_out': - scale *= math.sqrt(fan_out.width_mult() / fan_in.width_mult()) - else: - raise NotImplementedError('can only handle <=2 inf dimensions currently') - return scale - -def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): - '''Drop-in replacement of `torch.nn.init.kaiming_normal_`. - Note: - - if using this function, ensure `a` does not depend on fan-in, - fan-out, or other notions of width, e.g. if a = 0. - - `tensor` should have `infshape` attribute set by `set_base_shapes`. - ''' - if 0 in tensor.shape: - warnings.warn("Initializing zero-element tensors is a no-op") - return tensor - fan = _calculate_correct_fan(tensor, mode) - gain = calculate_gain(nonlinearity, a) - std = _inf_fan_adjust_kaiming(gain / math.sqrt(fan), tensor, mode) - with torch.no_grad(): - return tensor.normal_(0, std) - - -def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): - '''Drop-in replacement of `torch.nn.init.kaiming_uniform_`. - Note: - - if using this function, ensure `a` does not depend on fan-in, - fan-out, or other notions of width, e.g. if a = 0. - - `tensor` should have `infshape` attribute set by `set_base_shapes`. - ''' - if 0 in tensor.shape: - warnings.warn("Initializing zero-element tensors is a no-op") - return tensor - fan = _calculate_correct_fan(tensor, mode) - gain = calculate_gain(nonlinearity, a) - std = _inf_fan_adjust_kaiming(gain / math.sqrt(fan), tensor, mode) - bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation - with torch.no_grad(): - return tensor.uniform_(-bound, bound) - - -try: - from torch.nn.init import _no_grad_trunc_normal_ - def trunc_normal_(tensor, mean=0, std=1, a=-2, b=2): - '''Drop-in replacement of `torch.nn.init.trunc_normal_`. - Note: - - if using this function, ensure `mean`, `std`, `a`, `b` do not - depend on fan-in, fan-out, or other notions of width, e.g. if - mean = 0, std = 1, a = -2, b = 2. - - `tensor` should have `infshape` attribute set by - `set_base_shapes`. - ''' - if mean != 0 or a != -b: - assert tensor.infshape.ninf() == 1, 'Sampler for (inf, inf) tensors should have mean 0' - def sampler_(tensor, scale=1): - _no_grad_trunc_normal_(tensor, mean=mean*scale, std=std*scale, a=a*scale, b=b*scale) - return constant_std_init_(tensor, sampler_) -except: - warnings.warn( - 'Failed to import _no_grad_trunc_normal_ from torch.nn.init; ' - 'you might be running an older version of torch. trunc_normal_ will not work.') - def trunc_normal_(tensor, mean=0, std=1, a=-2, b=2): - warnings.warn('Please upgrade your Pytorch version before using truncated normal.') - pass diff --git a/mup/mup/layer.py b/mup/mup/layer.py deleted file mode 100644 index 518a33bd2..000000000 --- a/mup/mup/layer.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2022 Microsoft Corporation. -from torch.nn import Linear - - -class MuReadout(Linear): - '''Drop-in replacement for all output linear layers. - - An "output" linear layer is one that maps from a width dimension (e.g., - `d_model` in a Transformer) to a non-width dimension (e.g., vocab size). - - This layer implements the version of μP with a 1/width multiplier and a - constant variance initialization for both weights and biases. - ''' - def __init__(self, *args, readout_zero_init=False, output_mult=1.0, **kwargs): - self.output_mult = output_mult - self.readout_zero_init = readout_zero_init - super().__init__(*args, **kwargs) - - def reset_parameters(self) -> None: - if self.readout_zero_init: - self.weight.data[:] = 0 - if self.bias is not None: - self.bias.data[:] = 0 - else: - super().reset_parameters() - - def width_mult(self): - assert hasattr(self.weight, 'infshape'), ( - 'Please call set_base_shapes(...). If using torch.nn.DataParallel, ' - 'switch to distributed training with ' - 'torch.nn.parallel.DistributedDataParallel instead' - ) - return self.weight.infshape.width_mult() - - def _rescale_parameters(self): - '''Rescale parameters to convert SP initialization to μP initialization. - - Warning: This method is NOT idempotent and should be called only once - unless you know what you are doing. - ''' - if hasattr(self, '_has_rescaled_params') and self._has_rescaled_params: - raise RuntimeError( - "`_rescale_parameters` has been called once before already. " - "Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n" - "If you called `set_base_shapes` on a model loaded from a checkpoint, " - "or just want to re-set the base shapes of an existing model, " - "make sure to set the flag `rescale_params=False`.\n" - "To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call.") - if self.bias is not None: - self.bias.data *= self.width_mult()**0.5 - self.weight.data *= self.width_mult()**0.5 - self._has_rescaled_params = True - - def forward(self, x): - return super().forward( - self.output_mult * x / self.width_mult()) - - -class MuSharedReadout(MuReadout): - '''`MuReadout` with weights shared with an `nn.Embedding` layer. - - Inputs: - weight: should be weight of an `nn.Embedding` layer - other inputs are fed to `MuReadout` - ''' - def __init__(self, weight, bias=True, **kwargs): - super().__init__(*weight.shape, bias=bias, **kwargs) - self.weight = weight - -def rescale_linear_bias(linear): - '''Rescale bias in nn.Linear layers to convert SP initialization to μP initialization. - - Warning: This method is NOT idempotent and should be called only once - unless you know what you are doing. - ''' - if hasattr(linear, '_has_rescaled_params') and linear._has_rescaled_params: - raise RuntimeError("`rescale_linear_bias` has been called once before already. Unless you know what you are doing, usually you should not be calling `rescale_linear_bias` more than once.\n" - "If you called `set_base_shapes` on a model loaded from a checkpoint, or just want to re-set the base shapes of an existing model, make sure to set the flag `rescale_params=False`.\n" - "To bypass this error and *still rescale biases*, set `linear._has_rescaled_params=False` before this call.") - if linear.bias is None: - return - fanin_mult = linear.weight.infshape[1].width_mult() - linear.bias.data *= fanin_mult**0.5 - linear._has_rescaled_params = True diff --git a/mup/mup/optim.py b/mup/mup/optim.py deleted file mode 100644 index bf5822478..000000000 --- a/mup/mup/optim.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2022 Microsoft Corporation. -''' -Optimizers with μP scaling. - -Here we provide 3 ready-to-go optimizers MuAdam, MuAdamW, and MuSGD. -However, the user can easily convert their own optimizer to a μP -optimizer: if your `optimizer` is "Adam-like", such as RMSProp and Adagrad, -that involves normalizing the gradient entrywise, then the following creates -the desired μP optimizer: - - def MuOptimizer(params, **kwargs): - return MuAdam(params, impl=optimizer, **kwargs) - -On the other hand, if your `optimizer` is "SGD-like", such as ASGD, then -the following creates the desired μP optimizer: - - def MuOptimizer(params, **kwargs): - return MuSGD(params, impl=optimizer, **kwargs) - -See Appendix B in our paper for discussions of other optimizers. -''' -from collections import defaultdict - -from torch.optim import SGD, Adam, AdamW - - -def process_param_groups(params, **kwargs): - param_groups = list(params) - if not isinstance(param_groups[0], dict): - param_groups = [{'params': param_groups}] - for param_group in param_groups: - if 'lr' not in param_group: - param_group['lr'] = kwargs['lr'] - if 'weight_decay' not in param_group: - param_group['weight_decay'] = kwargs.get('weight_decay', 0.) - return param_groups - -def MuAdam(params, impl=Adam, decoupled_wd=False, **kwargs): - '''Adam with μP scaling. - - Note for this to work properly, your model needs to have its base shapes set - already using `mup.set_base_shapes`. - - Inputs: - impl: the specific Adam-like optimizer implementation from torch.optim or - elsewhere - decoupled_wd: if True, skips the mup scaling for weight decay, which should - be used for optimizer implementations that decouple weight decay from - learning rate. See https://github.com/microsoft/mup/issues/1 for a use case. - Outputs: - An instance of `impl` with refined parameter groups, each of which has the correctly - scaled learning rate according to mup. - ''' - new_param_groups = [] - for param_group in process_param_groups(params, **kwargs): - # For every existing param group, we split into several new groups - def new_group(): - new_g = {k:v for k, v in param_group.items() if k != 'params'} - new_g['params'] = [] - return new_g - # The matrix-like weights might need multiple groups since weights - # might have different width multipliers - matrix_like_p = defaultdict(new_group) # key is width_mult - vector_like_p = new_group() - for p in param_group['params']: - assert hasattr(p, 'infshape'), ( - f'A parameter with shape {p.shape} does not have `infshape` attribute. ' - 'Did you forget to call `mup.set_base_shapes` on the model?') - if p.infshape.ninf() == 2: - matrix_like_p[p.infshape.width_mult()]['params'].append(p) - elif p.infshape.ninf() > 2: - raise NotImplementedError('more than 2 inf dimensions') - else: - vector_like_p['params'].append(p) - for width_mult, group in matrix_like_p.items(): - # Scale learning rate and weight decay accordingly - group['lr'] /= width_mult - group['width_mult'] = width_mult - if not decoupled_wd: - group['weight_decay'] *= width_mult - new_param_groups.extend(list(matrix_like_p.values()) + [vector_like_p]) - return impl(new_param_groups, **kwargs) - -def MuAdamW(params, **kwargs): - '''AdamW with μP scaling. - - Note for this to work properly, your model needs to have its base shapes set - already using `mup.set_base_shapes`. - ''' - return MuAdam(params, impl=AdamW, **kwargs) - -def MuSGD(params, impl=SGD, decoupled_wd=False, **kwargs): - '''SGD with μP scaling. - - Note for this to work properly, your model needs to have its base shapes set - already using `mup.set_base_shapes`. - - Inputs: - impl: the specific SGD-like optimizer implementation from torch.optim or - elsewhere - decoupled_wd: if True, skips the mup scaling for weight decay, which should - be used for optimizer implementations that decouple weight decay from - learning rate. See https://github.com/microsoft/mup/issues/1 for a use case. - Outputs: - An instance of `impl` with refined parameter groups, each of which has the correctly - scaled learning rate according to mup. - ''' - new_param_groups = [] - for param_group in process_param_groups(params, **kwargs): - # For every existing param group, we split into several new groups - def new_group(): - new_g = {k:v for k, v in param_group.items() if k != 'params'} - new_g['params'] = [] - return new_g - # The matrix-like weights might need multiple groups since weights - # might have different width multipliers - vector_like_p = defaultdict(new_group) # key is width mult - matrix_like_p = defaultdict(new_group) # key is fan_in/out ratio - fixed_p = new_group() - for p in param_group['params']: - assert hasattr(p, 'infshape'), ( - f'A parameter with shape {p.shape} does not have `infshape` attribute. ' - 'Did you forget to call `mup.set_base_shapes` on the model?') - if p.infshape.ninf() == 1: - vector_like_p[p.infshape.width_mult()]['params'].append(p) - elif p.infshape.ninf() == 2: - matrix_like_p[p.infshape.fanin_fanout_mult_ratio()]['params'].append(p) - elif p.infshape.ninf() > 2: - raise NotImplementedError('more than 2 inf dimensions') - else: - fixed_p['params'].append(p) - for width_mult, group in vector_like_p.items(): - # Scale learning rate and weight decay accordingly - group['lr'] *= width_mult - if not decoupled_wd: - group['weight_decay'] /= width_mult - for shape_ratio, group in matrix_like_p.items(): - group['lr'] /= shape_ratio - if not decoupled_wd: - group['weight_decay'] *= shape_ratio - new_param_groups.extend(list(matrix_like_p.values()) + \ - list(vector_like_p.values()) + [fixed_p]) - return impl(new_param_groups, **kwargs) diff --git a/mup/mup/shape.py b/mup/mup/shape.py deleted file mode 100644 index 6889e0bc6..000000000 --- a/mup/mup/shape.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright 2022 Microsoft Corporation. -from copy import deepcopy - -import yaml -from torch import nn -from torch.nn import Linear -from torch.nn.modules.conv import _ConvNd - -from mup.infshape import InfShape, zip_infshape -from mup.layer import MuReadout, rescale_linear_bias - -__BSH_COMMENT__ = '''\ -# This is a base shape file encoded in yaml -# - `null` indicates a dimension is "finite", i.e. a non-"width" dimension -# - a number indicates the base dimension of an "infinite" dimension, i.e. some notion of "width" -''' - -def get_shapes(model): - # If you want to implement a custom shapes function, you can use this name - if hasattr(model, "get_shapes"): - return model.get_shapes() - return {name: param.shape for name, param in model.named_parameters()} - -def get_infshapes(model): - return {name: param.infshape for name, param in model.named_parameters()} - -def save_base_shapes(model_or_shapes, file): - if isinstance(model_or_shapes, nn.Module): - sh = get_infshapes(model_or_shapes) - elif isinstance(model_or_shapes, dict): - sh = deepcopy(model_or_shapes) - else: - raise ValueError() - sh = {k: s.base_shape() for k, s in sh.items()} - s = yaml.dump(sh, None, indent=4) - s = __BSH_COMMENT__ + s - with open(file, 'w') as f: - f.write(s) - -def load_base_shapes(filename): - '''Get a dict of `InfShape` from a filename.''' - with open(filename, 'r') as f: - d = yaml.safe_load(f) - return {k: InfShape.from_base_shape(v) for k, v in d.items()} - -def _dataparallel_hack(base_shapes, shapes): - '''Fix module name discrepancy caused by (Distributed)DataParallel module. - - The parameters of a (Distributed)DataParallel module all have names that - start with 'module'. This causes a mismatch from non-DataParallel modules. - This function tries to match `base_shapes` to `shapes`: if the latter starts - with 'module', then make the former too; likewise if not. - ''' - if all(k.startswith('module.') for k in shapes) and \ - all(not k.startswith('module.') for k in base_shapes): - return {'module.' + k: v for k, v in base_shapes.items()}, shapes - if all(not k.startswith('module.') for k in shapes) and \ - all(k.startswith('module.') for k in base_shapes): - return {k.strip('module.'): v for k, v in base_shapes.items()}, shapes - return base_shapes, shapes - - -def _extract_shapes(x): - ''' - Input: - x: can be any of the following: - - `nn.Module` - - dict of shapes - - dict of `InfShape` - - str of path to a base shapes (.bsh) file - Output: - If `x` is dict of `InfShape`, then output itself. - If `x` is path, then output a dict of `InfShapes` loaded from `x`. - Else, output the shapes (not `InfShape`) associated to `x` - ''' - if isinstance(x, nn.Module): - x_shapes = get_shapes(x) - elif isinstance(x, dict): - x_shapes = deepcopy(x) - elif isinstance(x, str): - # x is file name - x_shapes = load_base_shapes(x) - else: - raise ValueError(f'unhandled x type: {type(x)}') - return x_shapes - -def _zip_infshape_dict(base_shapes, shapes): - '''make a dict of `InfShape` from two dicts of shapes. - Inputs: - base_shapes: dict of base shapes or InfShape objects - shapes: dict of shapes - Output: - dict of `InfShape` using `zip_infshape` - ''' - base_shapes, shapes = _dataparallel_hack(base_shapes, shapes) - basenames = set(base_shapes.keys()) - names = set(shapes.keys()) - assert basenames == names, ( - f'`base_shapes` has extra names {basenames - names}. ' - f'`shapes` has extra names {names - basenames}.' - ) - infshapes = {} - for name, bsh in base_shapes.items(): - infshapes[name] = zip_infshape(bsh, shapes[name]) - return infshapes - -def zip_infshapes(base, target): - '''make a dict of `InfShape` from models or dicts. - Inputs: - base: a base `nn.Module` or a dict of shapes - target: a target `nn.Module` or a dict of shapes - Output: - dict of `InfShape` using `zip_infshape` - ''' - base_shapes = _extract_shapes(base) - target_shapes = _extract_shapes(target) - return _zip_infshape_dict(base_shapes, target_shapes) - -def clear_dims(infshape_dict): - ''' - Input: - infshape_dict: dict of `InfShape` - Output: - the same dict but where all `InfDim` in all `InfShape` - have their `dim` attribute set to None - ''' - d = deepcopy(infshape_dict) - for _, v in d.items(): - for infdim in v: - infdim.dim = None - return d - -def make_base_shapes(base_shapes, delta_shapes, savefile=None): - '''Make a base shape object from a base model/shapes and a delta model/shapes. - - Inputs: - base: - a base `nn.Module` or a dict of shapes - delta: - a "delta" model or a dict of shapes, for the sole purpose of - determining which dimensions are "width" and will be scaled up and - down in the target model. - savefile: - if a string, then the resulting base shape object is serialized to - this location via yaml encoding. - Outputs: - base infshapes - ''' - bsh = clear_dims(zip_infshapes(base_shapes, delta_shapes)) - if savefile is not None: - save_base_shapes(bsh, savefile) - return bsh - - -def apply_infshapes(model, infshapes): - for name, p in model.named_parameters(): - p.infshape = infshapes[name] - -def set_base_shapes(model, base, rescale_params=True, delta=None, savefile=None, do_assert=True): - '''Sets the `p.infshape` attribute for each parameter `p` of `model`. - - Inputs: - model: nn.Module instance - base: The base model. - Can be nn.Module, a dict of shapes, a str, or None. - If None, then defaults to `model` - If str, then treated as filename for yaml encoding of a dict of base shapes. - rescale_params: - assuming the model is initialized using the default pytorch init (or - He initialization etc that scale the same way with fanin): If True - (default), rescales parameters to have the correct (μP) variances. - do_assert: - Output: - same object as `model`, after setting the `infshape` attribute of each parameter. - ''' - if base is None: - base = model - base_shapes = _extract_shapes(base) - if delta is not None: - delta_shapes = _extract_shapes(delta) - base_shapes = _zip_infshape_dict(base_shapes, delta_shapes) - shapes = get_shapes(model) - infshapes = _zip_infshape_dict(base_shapes, shapes) - if savefile is not None: - save_base_shapes(infshapes, savefile) - apply_infshapes(model, infshapes) - if do_assert: - assert_hidden_size_inf(model) - if rescale_params: - for name, module in model.named_modules(): - if isinstance(module, MuReadout): - module._rescale_parameters() - elif isinstance(module, (Linear, _ConvNd)): - rescale_linear_bias(module) - return model - -def assert_hidden_size_inf(model): - ''' - This tests for any `nn.Linear` whose output dimension is finite but input - dimension is infinite and is not of type `MuReadout`. Such `nn.Linear` - modules should not exist in a correctly parametrized models. - ''' - for name, module in model.named_modules(): - if isinstance(module, Linear) and not isinstance(module, MuReadout): - if not module.weight.infshape[0].isinf() and module.weight.infshape[1].isinf(): - assert False, ( - f'{name} has infinite fan-in and finite fan-out dimensions but is not type `MuReadout`. ' - 'To resolve this, either change the module to `MuReadout` or change the fan-out to an infinite dimension.' - ) diff --git a/mup/mup/test/__main__.py b/mup/mup/test/__main__.py deleted file mode 100644 index 5b448a8c2..000000000 --- a/mup/mup/test/__main__.py +++ /dev/null @@ -1,379 +0,0 @@ -import itertools -import unittest -from functools import partial -from itertools import cycle - -import numpy as np -import pandas as pd -import torch -import torch.nn.functional as F -from mup.coord_check import get_coord_data -from mup.optim import MuAdam, MuSGD -from mup.shape import get_infshapes, get_shapes, make_base_shapes, set_base_shapes -from mup.test.models import (generate_CNN, generate_MLP, _generate_MLP, get_lazy_models, - get_train_loader, init_methods) - -train_loader = get_train_loader(batch_size=32, num_workers=4, download=True) - -def reset_seed(): - torch.manual_seed(0) - -class SetBaseShapeCase(unittest.TestCase): - mlp_base_shapes_file = 'mlp64.bsh.test' - - def get_mlp_infshapes1(self): - base_model = _generate_MLP(64, True, True, True) - delta_model = _generate_MLP(65, True, True, True) - target_model = _generate_MLP(128, True, True, True) - set_base_shapes(target_model, base_model, delta=delta_model, savefile=self.mlp_base_shapes_file) - return get_infshapes(target_model) - - def get_mlp_infshapes1meta(self): - base_model = _generate_MLP(64, True, True, True, device='meta') - delta_model = _generate_MLP(65, True, True, True, device='meta') - target_model = _generate_MLP(128, True, True, True) - set_base_shapes(target_model, base_model, delta=delta_model, savefile=self.mlp_base_shapes_file) - return get_infshapes(target_model) - - def get_mlp_infshapes2(self): - target_model = _generate_MLP(128, True, True, True) - set_base_shapes(target_model, self.mlp_base_shapes_file) - return get_infshapes(target_model) - - def get_mlp_infshapes3(self): - base_model = _generate_MLP(64, True, True, True) - delta_model = _generate_MLP(65, True, True, True) - base_infshapes = make_base_shapes(base_model, delta_model) - target_model = _generate_MLP(128, True, True, True) - set_base_shapes(target_model, base_infshapes) - return get_infshapes(target_model) - - def get_mlp_infshapes3meta(self): - base_model = _generate_MLP(64, True, True, True, device='meta') - delta_model = _generate_MLP(65, True, True, True, device='meta') - base_infshapes = make_base_shapes(base_model, delta_model) - target_model = _generate_MLP(128, True, True, True) - set_base_shapes(target_model, base_infshapes) - return get_infshapes(target_model) - - def get_mlp_infshapes4(self): - base_model = _generate_MLP(64, True, True, True) - delta_model = _generate_MLP(65, True, True, True) - target_model = _generate_MLP(128, True, True, True) - set_base_shapes(target_model, get_shapes(base_model), delta=get_shapes(delta_model)) - return get_infshapes(target_model) - - def get_mlp_infshapes4meta(self): - base_model = _generate_MLP(64, True, True, True) - delta_model = _generate_MLP(65, True, True, True, device='meta') - target_model = _generate_MLP(128, True, True, True, device='meta') - set_base_shapes(target_model, get_shapes(base_model), delta=get_shapes(delta_model)) - return get_infshapes(target_model) - - def get_mlp_infshapes5(self): - delta_model = _generate_MLP(65, True, True, True) - target_model = _generate_MLP(128, True, True, True) - # `delta` here doesn't do anything because of base shape file - set_base_shapes(target_model, self.mlp_base_shapes_file, delta=get_shapes(delta_model)) - return get_infshapes(target_model) - - def get_mlp_infshapes5meta(self): - delta_model = _generate_MLP(65, True, True, True, device='meta') - target_model = _generate_MLP(128, True, True, True) - # `delta` here doesn't do anything because of base shape file - set_base_shapes(target_model, self.mlp_base_shapes_file, delta=get_shapes(delta_model)) - return get_infshapes(target_model) - - def get_mlp_infshapes_bad(self): - base_model = _generate_MLP(64, True, True, True) - target_model = _generate_MLP(128, True, True, True) - set_base_shapes(target_model, base_model, delta=base_model) - return get_infshapes(target_model) - - def test_set_base_shape(self): - self.assertEqual(self.get_mlp_infshapes1(), self.get_mlp_infshapes1meta()) - self.assertEqual(self.get_mlp_infshapes1(), self.get_mlp_infshapes2()) - self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes2()) - self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes4()) - self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes3meta()) - self.assertEqual(self.get_mlp_infshapes4(), self.get_mlp_infshapes4meta()) - self.assertEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes4()) - self.assertEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes5meta()) - self.assertNotEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes_bad()) - - -class BackwardCompatibleCase(unittest.TestCase): - - def gen_model(self, arch, width, batchnorm=False, mup=True): - if arch == 'mlp': - return generate_MLP(width=width, batchnorm=batchnorm, readout_zero_init=False, base_width=256, mup=mup) - elif arch == 'cnn': - return generate_CNN(width=width, batchnorm=batchnorm, readout_zero_init=False, base_width=8, mup=mup) - else: - raise ValueError() - - def test_MLP_CNN_at_base_width(self): - for arch, batchnorm in itertools.product(['mlp', 'cnn'], [False, True]): - for init_name, init in init_methods.items(): - reset_seed() - mup_model = self.gen_model('mlp', 256, mup=True, batchnorm=batchnorm) - reset_seed() - init(mup_model) - reset_seed() - SP_model = self.gen_model('mlp', 256, mup=False, batchnorm=batchnorm) - reset_seed() - init(SP_model) - for (name, mup_param), (_, SP_param) in zip( - mup_model.named_parameters(), SP_model.named_parameters()): - with self.subTest(name=f'{arch}, {name}, {init_name}, bn={batchnorm}'): - self.assertEqual((mup_param.data - SP_param.data).abs().sum().item(), 0) - - def test_MLP_at_diff_width_init(self): - for init_name, init in init_methods.items(): - reset_seed() - mup_model = self.gen_model('mlp', 128, mup=True) - reset_seed() - init(mup_model) - reset_seed() - SP_model = self.gen_model('mlp', 128, mup=False) - reset_seed() - init(SP_model) - - mup_params = dict(mup_model.named_parameters()) - SP_params = dict(SP_model.named_parameters()) - - if init_name == 'default' or 'fan_in' in init_name: - diff_names = ['2.bias', '4.bias', '4.weight'] - same_names = ['0.weight', '0.bias', '2.weight'] - elif 'fan_out' in init_name: - diff_names = ['2.bias', '4.bias', '0.weight'] - same_names = ['4.weight', '0.bias', '2.weight'] - elif 'xavier' in init_name: - diff_names = ['2.bias', '4.bias', '0.weight', '4.weight'] - same_names = ['0.bias', '2.weight'] - elif 'const' in init_name: - diff_names = ['2.bias', '4.bias', '2.weight'] - same_names = ['0.weight', '0.bias', '4.weight'] - else: - raise ValueError() - - for name in diff_names: - with self.subTest(name=f'{name}, {init_name}'): - self.assertNotEqual( - (mup_params[name] - SP_params[name]).abs().sum().item(), 0) - for name in same_names: - with self.subTest(name=f'{name}, {init_name}'): - self.assertEqual( - (mup_params[name] - SP_params[name]).abs().sum().item(), 0) - - def test_CNN_at_diff_width_init(self): - for init_name, init in init_methods.items(): - reset_seed() - mup_model = self.gen_model('cnn', 16, mup=True) - reset_seed() - init(mup_model) - reset_seed() - SP_model = self.gen_model('cnn', 16, mup=False) - reset_seed() - init(SP_model) - - mup_params = dict(mup_model.named_parameters()) - SP_params = dict(SP_model.named_parameters()) - - if init_name == 'default' or 'fan_in' in init_name: - diff_names = ['3.bias', '7.bias', '9.bias', '11.bias', '11.weight'] - same_names = ['0.bias', '0.weight', '3.weight', '7.weight', '9.weight'] - elif 'fan_out' in init_name: - diff_names = ['3.bias', '7.bias', '9.bias', '11.bias', '0.weight'] - same_names = ['0.bias', '3.weight', '7.weight', '9.weight', '11.weight'] - elif 'xavier' in init_name: - diff_names = ['3.bias', '7.bias', '9.bias', '11.bias', '0.weight', '11.weight'] - same_names = ['0.bias', '3.weight', '7.weight', '9.weight'] - elif 'const' in init_name: - diff_names = ['3.bias', '7.bias', '9.bias', '11.bias', '3.weight', '7.weight', '9.weight'] - same_names = ['0.bias', '0.weight', '11.weight'] - else: - raise ValueError() - - for name in diff_names: - with self.subTest(name=f'{name}, {init_name}'): - self.assertNotEqual( - (mup_params[name] - SP_params[name]).abs().sum().item(), 0) - for name in same_names: - with self.subTest(name=f'{name}, {init_name}'): - self.assertEqual( - (mup_params[name] - SP_params[name]).abs().sum().item(), 0) - -def train_model(model, train_loader, step=-1, optcls=MuSGD, lr=0.1, flatten_input=False, cuda=True): - model.train() - train_loss = 0 - train_losses = [] - optimizer = optcls(model.parameters(), lr=lr) - for batch_idx, (data, target) in enumerate(cycle(iter(train_loader)), 1): - if cuda: - data, target = data.cuda(), target.cuda() - optimizer.zero_grad() - if flatten_input: - data = data.view(data.size(0), -1) - output = model(data) - loss = F.cross_entropy(output, target) - loss.backward() - train_loss += loss.item() - train_losses.append(train_loss / batch_idx) - optimizer.step() - if batch_idx == step: break - # train_loss /= batch_idx - return train_losses - -train_model_MuSGD = partial(train_model, optcls=MuSGD, lr=0.1) -train_model_MuAdam = partial(train_model, optcls=MuAdam, lr=1e-3) - -class CoordCheckCase(unittest.TestCase): - - def test_MLP_CNN(self): - combos = list(itertools.product(['mlp', 'cnn'], [True], [False, True], ['sgd', 'adam'], init_methods.keys())) - # comment out the following 2 lines to do all tests - idx = np.random.choice(np.arange(len(combos)), size=10) - combos = np.array(combos)[idx] - for arch, mup, batchnorm, optimizer, init in combos: - widths = [128, 512] if arch == 'cnn' else [1000, 4000] - models = get_lazy_models(arch, widths, mup=mup, batchnorm=batchnorm, init=init) - df = get_coord_data(models, train_loader, mup=mup, optimizer=optimizer, flatten_input=arch == 'mlp') - df = df[df.module != ''] - df['module'] = pd.to_numeric(df['module']) - for t, module in itertools.product([1, 2, 3], df['module'].unique()): - with self.subTest( - name=f'{arch}, mup={mup}, bn={batchnorm}, {optimizer}, {init}, t={t}, module={module}'): - data = df[(df['module'] == module) & (df['t'] == t)] - std0 = data[data.width==widths[0]]['l1'].unique()[0] - std1 = data[data.width==widths[1]]['l1'].unique()[0] - if t == 1 and module == df['module'].max(): - self.assertTrue(std0 == std1 == 0, - f'output should be 0 due to readout_zero_init: {std0}, {std1}') - else: - tol = 1.2 - self.assertGreater(std1/std0, 1/tol, f'{std0}, {std1}') - self.assertLess(std1/std0, tol, f'{std0}, {std1}') - - -class MLPTrainCase(unittest.TestCase): - - def train_adam(self, model, step): - return train_model_MuAdam(model, train_loader, step=step, flatten_input=True) - - def train_sgd(self, model, step): - return train_model_MuSGD(model, train_loader, step=step, flatten_input=True) - - def setUp(self): - self.models = {w: generate_MLP(w, bias=True, readout_zero_init=True, base_width=256, init='kaiming_fan_in_normal', bias_zero_init=True).cuda() for w in [64, 256, 1024]} - - def test_init(self): - stds = {} - for w, model in self.models.items(): - for i, module in enumerate(list(model.modules())[1::2]): - stds[(w, i+1, 'weight')] = module.weight.data.std() - stds[(w, i+1, 'bias')] = module.bias.data.std() - - for w in [64, 256]: - self.assertLess( - torch.abs( - stds[(1024, 1, 'weight')] - stds[(w, 1, 'weight')] - ) / stds[(1024, 1, 'weight')], 3e-3) - # for l in [1, 2]: - # self.assertLess( - # torch.abs( - # stds[(1024, l, 'bias')] - stds[(w, l, 'bias')] - # ) / stds[(1024, l, 'bias')], 1e-1) - self.assertTrue( - stds[(1024, 2, 'weight')] < stds[(256, 2, 'weight')] < stds[(64, 2, 'weight')]) - for w in [64, 256, 1024]: - self.assertEqual(stds[(w, 3, 'weight')], 0) - self.assertEqual(stds[(w, 3, 'bias')], 0) - - def _test_train(self, opt): - loss = {w: getattr(self, f'train_{opt}')(model, 201) for w, model in self.models.items()} - with self.subTest(name=f'{opt}, step 1'): - self.assertTrue( - loss[64][0] == loss[256][0] == loss[1024][0], - {k: v[0] for k, v in loss.items()}) - for t in [100, 200]: - with self.subTest(name=f'{opt}, step {t+1}'): - self.assertTrue( - loss[64][t] > loss[256][t] > loss[1024][t], - {k: v[t] for k, v in loss.items()}) - - def test_sgd(self): - self._test_train('sgd') - - def test_adam(self): - self._test_train('adam') - -class CNNTrainCase(unittest.TestCase): - - def train_adam(self, model, step): - return train_model_MuAdam(model, train_loader, step=step, flatten_input=False) - - def train_sgd(self, model, step): - return train_model_MuSGD(model, train_loader, step=step, flatten_input=False) - - def setUp(self): - self.models = {w: generate_CNN(w, mup=True, bias=True, readout_zero_init=True, base_width=8, init='kaiming_fan_in_normal', bias_zero_init=False).cuda() for w in [8, 32, 128]} - - def test_init(self): - stds = {} - names = [0, 3, 7, 9, 11] - for w, model in self.models.items(): - for i, module in enumerate(model): - if i in names: - stds[(w, i, 'weight')] = module.weight.data.std() - stds[(w, i, 'bias')] = module.bias.data.std() - - for w in [8, 32]: - self.assertLess( - torch.abs( - stds[(128, 0, 'weight')] - stds[(128, 0, 'weight')] - ) / stds[(128, 0, 'weight')], 3e-3) - for name in names[:-1]: - self.assertLess( - torch.abs( - stds[(128, 0, 'bias')] - stds[(w, 0, 'bias')] - ) / stds[(128, 0, 'bias')], 2e-1) - for name in names[1:-1]: - self.assertTrue( - stds[(128, name, 'weight')] < stds[(32, name, 'weight')] < stds[(8, name, 'weight')]) - for w in [8, 32, 128]: - self.assertEqual(stds[(w, 11, 'weight')], 0) - self.assertEqual(stds[(w, 11, 'bias')], 0) - - def _test_train(self, opt): - loss = {w: getattr(self, f'train_{opt}')(model, 201) for w, model in self.models.items()} - with self.subTest(name=f'{opt}, step 1'): - self.assertTrue( - loss[8][0] == loss[32][0] == loss[128][0], - {k: v[0] for k, v in loss.items()}) - for t in [200]: - with self.subTest(name=f'{opt}, step {t+1}'): - losses = {k: v[t] for k, v in loss.items()} - # print(losses) - self.assertTrue( - loss[8][t] > loss[32][t] > loss[128][t], - losses) - - def test_sgd(self): - self._test_train('sgd') - - def test_adam(self): - self._test_train('adam') - -def suite(): - suite = unittest.TestSuite() - suite.addTests(unittest.makeSuite(BackwardCompatibleCase)) - suite.addTests(unittest.makeSuite(MLPTrainCase)) - suite.addTests(unittest.makeSuite(CNNTrainCase)) - suite.addTests(unittest.makeSuite(CoordCheckCase)) - suite.addTests(unittest.makeSuite(SetBaseShapeCase)) - return suite - -if __name__ == '__main__': - runner = unittest.TextTestRunner(failfast=False) - runner.run(suite()) diff --git a/mup/mup/test/models.py b/mup/mup/test/models.py deleted file mode 100644 index d931fbf8e..000000000 --- a/mup/mup/test/models.py +++ /dev/null @@ -1,146 +0,0 @@ - -import torch -from torchvision import transforms, datasets -from mup.shape import set_base_shapes -from torch import nn -from torch.nn import Linear -from mup.layer import MuReadout -from functools import partial -from mup.init import (kaiming_normal_, kaiming_uniform_, normal_, - trunc_normal_, uniform_, xavier_normal_, - xavier_uniform_) -from torch.nn.modules.conv import _ConvNd - -samplers = { - 'default': lambda x: x, - 'const_uniform': partial(uniform_, a=-0.1, b=0.1), - 'const_normal': partial(normal_, std=0.1), - 'const_trunc_normal': partial(trunc_normal_, std=0.1, a=-0.2, b=0.2), - 'xavier_uniform': xavier_uniform_, - 'xavier_normal': xavier_normal_, - 'kaiming_fan_in_uniform': partial(kaiming_uniform_, mode='fan_in'), - 'kaiming_fan_in_normal': partial(kaiming_normal_, mode='fan_in'), - 'kaiming_fan_out_uniform': partial(kaiming_uniform_, mode='fan_out'), - 'kaiming_fan_out_normal': partial(kaiming_normal_, mode='fan_out') -} - - -def init_model(model, sampler): - for param in model.parameters(): - if len(param.shape) >= 2: - sampler(param) - return model - -init_methods = { - k: partial(init_model, sampler=s) for k, s in samplers.items() -} - -def _generate_MLP(width, bias=True, mup=True, batchnorm=False, device='cpu'): - mods = [Linear(3072, width, bias=bias, device=device), - nn.ReLU(), - Linear(width, width, bias=bias, device=device), - nn.ReLU() - ] - if mup: - mods.append(MuReadout(width, 10, bias=bias, readout_zero_init=False, device=device)) - else: - mods.append(Linear(width, 10, bias=bias, device=device)) - if batchnorm: - mods.insert(1, nn.BatchNorm1d(width, device=device)) - mods.insert(4, nn.BatchNorm1d(width, device=device)) - model = nn.Sequential(*mods) - return model - -def generate_MLP(width, bias=True, mup=True, readout_zero_init=True, batchnorm=False, init='default', bias_zero_init=False, base_width=256): - if not mup: - model = _generate_MLP(width, bias, mup, batchnorm) - # set base shapes to model's own shapes, so we get SP - return set_base_shapes(model, None) - # it's important we make `model` first, because of random seed - model = _generate_MLP(width, bias, mup, batchnorm) - base_model = _generate_MLP(base_width, bias, mup, batchnorm, device='meta') - set_base_shapes(model, base_model) - init_methods[init](model) - if readout_zero_init: - readout = list(model.modules())[-1] - readout.weight.data.zero_() - if readout.bias is not None: - readout.bias.data.zero_() - if bias_zero_init: - for module in model.modules(): - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - return model - - -def _generate_CNN(width, bias=True, mup=True, batchnorm=False, device='cpu'): - mods = [ - nn.Conv2d(3, width, kernel_size=5, bias=bias, device=device), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(width, 2*width, kernel_size=5, bias=bias, device=device), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Flatten(), - nn.Linear(2*width*25, width*16, bias=bias, device=device), - nn.ReLU(inplace=True), - nn.Linear(width*16, width*10, bias=bias, device=device), - nn.ReLU(inplace=True), - ] - if mup: - mods.append(MuReadout(width*10, 10, bias=bias, readout_zero_init=False, device=device)) - else: - mods.append(nn.Linear(width*10, 10, bias=bias, device=device)) - if batchnorm: - mods.insert(1, nn.BatchNorm2d(width, device=device)) - mods.insert(5, nn.BatchNorm2d(2*width, device=device)) - mods.insert(10, nn.BatchNorm1d(16*width, device=device)) - mods.insert(13, nn.BatchNorm1d(10*width, device=device)) - return nn.Sequential(*mods) - -def generate_CNN(width, bias=True, mup=True, readout_zero_init=True, batchnorm=False, init='default', bias_zero_init=False, base_width=8): - if not mup: - model = _generate_CNN(width, bias, mup, batchnorm) - # set base shapes to model's own shapes, so we get SP - return set_base_shapes(model, None) - # it's important we make `model` first, because of random seed - model = _generate_CNN(width, bias, mup, batchnorm) - base_model = _generate_CNN(base_width, bias, mup, batchnorm, device='meta') - set_base_shapes(model, base_model) - init_methods[init](model) - if readout_zero_init: - readout = list(model.modules())[-1] - readout.weight.data.zero_() - if readout.bias is not None: - readout.bias.data.zero_() - if bias_zero_init: - for module in model.modules(): - if isinstance(module, (nn.Linear, _ConvNd)) and module.bias is not None: - module.bias.data.zero_() - return model - -def get_lazy_models(arch, widths, mup=True, init='kaiming_fan_in_normal', readout_zero_init=True, batchnorm=True, base_width=None): - '''if mup is False, then `init`, `readout_zero_init`, `base_width` don't matter.''' - if arch == 'mlp': - base_width = base_width or 256 - generate = generate_MLP - elif arch == 'cnn': - base_width = base_width or 8 - generate = generate_CNN - def gen(w): - def f(): - model = generate(w, mup=mup, init=init, readout_zero_init=readout_zero_init, batchnorm=batchnorm, base_width=base_width) - return model - return f - return {w: gen(w) for w in widths} - - -def get_train_loader(batch_size, num_workers=0, shuffle=False, train=True, download=False): - - transform = transforms.Compose( - [transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) - trainset = datasets.CIFAR10(root='dataset', train=train, - download=download, transform=transform) - return torch.utils.data.DataLoader(trainset, batch_size=batch_size, - shuffle=shuffle, num_workers=num_workers)