Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric updates #195

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions dlio_benchmark/data_loader/synthetic_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class SyntheticDataLoader(BaseDataLoader):
@dlp.log_init
def __init__(self, format_type, dataset_type, epoch):
super().__init__(format_type, dataset_type, epoch, DataLoaderType.SYNTHETIC)
shape = self._args.resized_image.shape
self.batch = np.zeros((self.batch_size, shape[0], shape[1]))
#self.batch = 1

@dlp.log
def read(self, init=False):
Expand All @@ -48,10 +51,7 @@ def next(self):
step = 0
self.read(True)
while step < self.num_samples // self.batch_size:
batch = []
for i in range(self.batch_size):
batch.append(self._args.resized_image)
yield batch
yield self.batch
step += 1

@dlp.log
Expand Down
7 changes: 5 additions & 2 deletions dlio_benchmark/framework/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from dlio_benchmark.common.enumerations import DatasetType
from dlio_benchmark.data_loader.data_loader_factory import DataLoaderFactory
from dlio_benchmark.storage.storage_factory import StorageFactory
from dlio_benchmark.utils.utility import utcnow
from dlio_benchmark.utils.utility import utcnow, DLIOMPI
comm = DLIOMPI.get_instance().comm()

from time import sleep
import os
import logging
from multiprocessing import Process

from dlio_benchmark.utils.config import ConfigArguments

Expand Down Expand Up @@ -69,8 +71,9 @@ def stop_framework_profiler(self):
def trace_object(self, string, step, r):
pass

def model(epoch, epoch_number, step, computation_time):
def model(epoch, x, computation_time):
sleep(computation_time)
comm.barrier()

@abstractmethod
def compute(self, x, epoch_number, step, computation_time):
Expand Down
3 changes: 1 addition & 2 deletions dlio_benchmark/framework/tf_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import os
import logging
from time import time, sleep

from dlio_benchmark.common.constants import MODULE_AI_FRAMEWORK
from dlio_benchmark.data_loader.data_loader_factory import DataLoaderFactory
from dlio_benchmark.utils.utility import utcnow, DLIOMPI
Expand Down Expand Up @@ -87,7 +86,7 @@ def trace_object(self, string, step, r):

@dlp.log
def compute(self, x, epoch_number, step, computation_time):
sleep(computation_time)
return self.model(x, computation_time)
# tf.function(self.model)(epoch_number, step, computation_time)

@dlp.log
Expand Down
2 changes: 1 addition & 1 deletion dlio_benchmark/framework/torch_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def trace_object(self, string, step, r):

@dlp.log
def compute(self, x, epoch_number, step, computation_time):
torch_sleep(computation_time)
return self.model(x, computation_time)

@dlp.log
def get_loader(self, dataset_type=DatasetType.TRAIN):
Expand Down
30 changes: 12 additions & 18 deletions dlio_benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,22 +227,23 @@ def _eval(self, epoch):
step = 1
total = math.floor(self.num_samples * self.num_files_eval / self.batch_size_eval / self.comm_size)
loader = self.framework.get_loader(DatasetType.VALID)
t0 = time()
self.stats.start_loading()
for batch in dlp.iter(loader.next()):
self.stats.eval_batch_loaded(epoch, step, t0)
self.stats.eval_batch_loaded(epoch, step)
eval_time = 0.0
if self.eval_time > 0:
if self.eval_time_stdev > 0:
eval_time = random.normal(self.eval_time, self.eval_time_stdev)
else:
eval_time = self.eval_time
self.stats.start_compute()
self.framework.compute(batch, epoch, step, eval_time)
self.stats.eval_batch_processed(epoch, step, t0, eval_time)
self.stats.eval_batch_processed(epoch, step)

step += 1
if step > total:
break
t0 = time()
self.stats.start_loading()
return step - 1

@dlp.log
Expand All @@ -259,22 +260,15 @@ def _train(self, epoch):
self.stats.start_block(epoch, block)

loader = self.framework.get_loader(dataset_type=DatasetType.TRAIN)
t0 = time()
for batch in dlp.iter(loader.next()):
self.stats.batch_loaded(epoch, overall_step, block, t0)
self.stats.start_loading()
for batch in loader.next():
self.stats.batch_loaded(epoch, overall_step, block)
# Log a new block, unless it's the first one which we've already logged before the loop
if block_step == 1 and block != 1:
self.stats.start_block(epoch, block)
computation_time = self.computation_time
if self.computation_time > 0:
self.framework.trace_object("Train", overall_step, 1)
if self.computation_time_stdev > 0:
computation_time = random.normal(self.computation_time, self.computation_time_stdev)
else:
computation_time = self.computation_time
self.framework.compute(batch, epoch, block_step, computation_time)
self.stats.batch_processed(epoch, overall_step, block, t0, computation_time)
self.comm.barrier()
self.stats.start_compute()
self.framework.compute(batch, epoch, block_step, self.computation_time)
self.stats.batch_processed(epoch, overall_step, block)
if self.do_checkpoint and (
self.steps_between_checkpoints >= 0) and overall_step == self.next_checkpoint_step:
self.stats.end_block(epoch, block, block_step)
Expand All @@ -295,7 +289,7 @@ def _train(self, epoch):
self.stats.end_block(epoch, block, block_step - 1)
break
overall_step += 1
t0 = time()
self.stats.start_loading()
self.comm.barrier()
if self.do_checkpoint and (self.steps_between_checkpoints < 0) and (epoch == self.next_checkpoint_epoch):
self.stats.end_block(epoch, block, block_step)
Expand Down
3 changes: 2 additions & 1 deletion dlio_benchmark/reader/tf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _parse_image(self, serialized):
'image': tf.io.FixedLenFeature([], tf.string),
'size': tf.io.FixedLenFeature([], tf.int64)
}
parsed_example = tf.io.parse_example(serialized=serialized, features=features)
#parsed_example = tf.io.parse_example(serialized=serialized, features=features)
# Get the image as raw bytes.
#image_raw = parsed_example['image']
#dimension = tf.cast(parsed_example['size'], tf.int32).numpy()
Expand Down Expand Up @@ -109,6 +109,7 @@ def next(self):
self._dataset = self._dataset.repeat()
total = math.floor(len(self._file_list)/self._args.comm_size / self.batch_size * self._args.num_samples_per_file)
return self._dataset.take(total*self._args.epochs).prefetch(buffer_size=self._args.prefetch_size)

@dlp.log
def read_index(self, image_idx, step):
return super().read_index(image_idx, step)
Expand Down
2 changes: 0 additions & 2 deletions dlio_benchmark/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,6 @@ def LoadConfig(args, config):
args.seed_change_epoch = config['train']['seed_change_epoch']
if 'computation_time' in config['train']:
args.computation_time = config['train']['computation_time']
if 'computation_time_stdev' in config['train']:
args.computation_time_stdev = config['train']['computation_time_stdev']
if 'seed' in config['train']:
args.seed = config['train']['seed']

Expand Down
49 changes: 30 additions & 19 deletions dlio_benchmark/utils/statscounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,17 @@ def end_run(self):
metric = metric + f"[METRIC] Training Accelerator Utilization [AU] (%): {np.mean(train_au):.4f} ({np.std(train_au):.4f})\n"
metric = metric + f"[METRIC] Training Throughput (samples/second): {np.mean(train_throughput):.4f} ({np.std(train_throughput):.4f})\n"
metric = metric + f"[METRIC] Training I/O Throughput (MB/second): {np.mean(train_throughput)*self.record_size/1024/1024:.4f} ({np.std(train_throughput)*self.record_size/1024/1024:.4f})\n"
metric = metric + f"[METRIC] train_au_meet_expectation: {self.summary['metric']['train_au_meet_expectation']}\n"
metric = metric + f"[METRIC] **Expected Throughputs if compute-bound\n"
metric = metric + f"[METRIC] Training Throughput (expected) (samples/second): {np.mean(train_throughput/train_au)*100:.4f}\n"
metric = metric + f"[METRIC] Training I/O Throughput (expected) (MB/second): {np.mean(train_throughput/train_au)*100*self.record_size/1024/1024:.4f}\n"

if self.args.do_eval:
metric = metric + f"[METRIC] Eval Accelerator Utilization [AU] (%): {np.mean(eval_au):.4f} ({np.std(eval_au):.4f})\n"
metric = metric + f"[METRIC] Eval Throughput (samples/second): {np.mean(eval_throughput):.6f} ({np.std(eval_throughput):.6f})\n"
metric = metric + f"[METRIC] Eval Throughput (MB/second): {np.mean(eval_throughput)*self.record_size/1024/1024:.6f} ({np.std(eval_throughput)*self.record_size/1024/1024:.6f})\n"
metric = metric + f"[METRIC] eval_au_meet_expectation: {self.summary['metric']['eval_au_meet_expectation']}\n"
metric = metric + f"[METRIC] **Expected Throughputs if compute-bound\n"
metric = metric + f"[METRIC] Eval Throughput (expected) (samples/second): {np.mean(eval_throughput/eval_au)*100:.4f}\n"
metric = metric + f"[METRIC] Eval I/O Throughput (expected) (MB/second): {np.mean(eval_throughput/eval_au) * self.record_size/1024/1024:.4f}\n"
metric+="[METRIC] ==========================================================\n"
logging.info(metric)
def start_train(self, epoch):
Expand Down Expand Up @@ -284,6 +288,7 @@ def end_block(self, epoch, block, steps_taken):
self.per_epoch_stats[epoch][f'block{block}']['duration'] = duration
logging.info(f"{utcnow()} Epoch {epoch} - Block {block} [Training] Accelerator Utilization [AU] (%): {self.output[epoch]['au'][f'block{block}']:.4f}")
logging.info(f"{utcnow()} Epoch {epoch} - Block {block} [Training] Throughput (samples/second): {self.output[epoch]['throughput'][f'block{block}']*self.comm_size:.4f}")
logging.info(f"{utcnow()} Epoch {epoch} - Block {block} [Training] Computation time per step (second): {np.mean(self.output[epoch]['compute'][f'block{block}'][1:-1]):.4f}+/-{np.std(self.output[epoch]['compute'][f'block{block}'][1:-1]):.4f} (set value: {self.args.computation_time})")

def start_ckpt(self, epoch, block, steps_taken):
if self.my_rank == 0:
Expand All @@ -303,42 +308,47 @@ def end_ckpt(self, epoch, block):
self.per_epoch_stats[epoch][f'ckpt{block}']['end'] = ts
self.per_epoch_stats[epoch][f'ckpt{block}']['duration'] = duration

def batch_loaded(self, epoch, step, block, t0):
duration = time() - t0
def start_loading(self):
self.start_time_loading = time()
def start_compute(self):
self.start_time_compute = time()
def batch_loaded(self, epoch, step, block):
duration = time() - self.start_time_loading
key = f'block{block}'
if key in self.output[epoch]['load']:
self.output[epoch]['load'][key].append(duration)
else:
self.output[epoch]['load'][key] = [duration]
logging.debug(f"{utcnow()} Rank {self.my_rank} step {step}: loaded {self.batch_size} samples in {duration} s")


def batch_processed(self, epoch, step, block, t0, computation_time):
duration = time() - t0
def batch_processed(self, epoch, step, block):
current_time = time()
duration = current_time - self.start_time_loading
key = f'block{block}'
self.computation_time = current_time - self.start_time_compute
if key in self.output[epoch]['proc']:
self.output[epoch]['proc'][key].append(duration)
self.output[epoch]['compute'][key].append(computation_time)
self.output[epoch]['compute'][key].append(self.computation_time)
else:
self.output[epoch]['proc'] = [duration]
self.output[epoch]['compute']=[computation_time]
logging.info(f"{utcnow()} Rank {self.my_rank} step {step} processed {self.batch_size} samples in {duration} s")
self.output[epoch]['compute']=[self.computation_time]
logging.info(f"{utcnow()} Rank {self.my_rank} step {step} processed {self.batch_size} samples in {duration}s)")

def compute_metrics_train(self, epoch, block):
key = f"block{block}"
total_compute_time = np.sum(self.output[epoch]['compute'][key][1:])
total_compute_time = np.sum(self.output[epoch]['compute'][key][1:-1])
total_time = self.end_timestamp - self.start_timestamp - self.output[epoch]['proc'][key][0] - self.output[epoch]['proc'][key][-1]
if (total_compute_time==0):
au=0.0
else:
total_time = self.end_timestamp - self.start_timestamp - self.output[epoch]['proc'][key][0]
au = total_compute_time / total_time
throughput = len(self.output[epoch]['compute'][key])/(self.end_timestamp - self.start_timestamp)*self.batch_size
throughput = (len(self.output[epoch]['compute'][key]) - 2)/(total_time)*self.batch_size
self.output[epoch]['au'][key] = au*100
self.output[epoch]['throughput'][key] = throughput

def compute_metrics_eval(self, epoch):
key = 'eval'
total_compute_time = np.sum(self.output[epoch]['compute'][key][1:])
total_compute_time = np.sum(self.output[epoch]['compute'][key][1:-1])
if (total_compute_time==0):
au=0.0
else:
Expand All @@ -348,14 +358,15 @@ def compute_metrics_eval(self, epoch):
self.output[epoch]['au'][key] = au*100
self.output[epoch]['throughput'][key] = throughput

def eval_batch_loaded(self, epoch, step, t0):
duration = time() - t0
def eval_batch_loaded(self, epoch, step):
duration = time() - self.start_time_loading
self.output[epoch]['load']['eval'].append(duration)
logging.debug(f"{utcnow()} Rank {self.my_rank} step {step} loaded {self.batch_size_eval} samples in {duration} s")


def eval_batch_processed(self, epoch, step, t0, computation_time):
duration = time() - t0
def eval_batch_processed(self, epoch, step):
current_time = time()
duration = current_time - self.start_time_loading
computation_time = current_time - self.start_time_compute
self.output[epoch]['proc']['eval'].append(duration)
self.output[epoch]['compute']['eval'].append(computation_time)
logging.info(f"{utcnow()} Rank {self.my_rank} step {step} processed {self.batch_size_eval} samples in {duration} s")
Expand Down
3 changes: 0 additions & 3 deletions docs/source/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,6 @@ train
* - computation_time
- 0.0
- emulated computation time per step in second
* - computation_time_stdev
- 0.0
- standard deviation of the emulated computation time per step in second
* - total_training_steps
- -1
- number of training steps to simulate, assuming running the benchmark less than one epoch.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/testedsystems.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ Tested systems
================
So far we have tested DLIO on the following systems:
* Personal workstation, laptops including both MacOSX and Linux OS system.
* Supercomputers (Linux), such as Theta @ ALCF, Summit @ OLCF, Lassen @ LLNL (please turn to: `instructions_lassen.rst`_ for instructions)
* Supercomputers (Linux), such as Polaris @ ALCF, Summit @ OLCF, Lassen @ LLNL (please turn to: `instructions_lassen.rst`_ for instructions)
Loading