diff --git a/pyproject.toml b/pyproject.toml index 9125259..da754b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ 'appdirs>=1.3', 'boto3>=1.28,<2', 'botocore>=1.31,<2', + 'cloudpickle>=2.1.0', 'compress-pickle>=1.2.0', 'humanfriendly>=8.2', "numpy>=1.21.0,<2.0.0;python_version<'3.10'", diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 95aedc9..0660a3f 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -7,10 +7,12 @@ import pickle import tracemalloc import warnings +from contextlib import contextmanager from datetime import datetime from pathlib import Path import boto3 +import cloudpickle import compress_pickle import numpy as np import pandas as pd @@ -318,6 +320,26 @@ def _score( return output +@contextmanager +def multiprocessing_context(): + """Override multiprocessing ForkingPickler to use cloudpickle.""" + original_dump = multiprocessing.reduction.ForkingPickler.dumps + original_load = multiprocessing.reduction.ForkingPickler.loads + original_method = multiprocessing.get_start_method() + + multiprocessing.set_start_method('spawn', force=True) + multiprocessing.reduction.ForkingPickler.dumps = cloudpickle.dumps + multiprocessing.reduction.ForkingPickler.loads = cloudpickle.loads + + try: + yield + finally: + # Restore original methods + multiprocessing.set_start_method(original_method, force=True) + multiprocessing.reduction.ForkingPickler.dumps = original_dump + multiprocessing.reduction.ForkingPickler.loads = original_load + + def _score_with_timeout( timeout, synthesizer, @@ -329,32 +351,33 @@ def _score_with_timeout( modality=None, dataset_name=None, ): - with multiprocessing.Manager() as manager: - output = manager.dict() - process = multiprocessing.Process( - target=_score, - args=( - synthesizer, - data, - metadata, - metrics, - output, - compute_quality_score, - compute_diagnostic_score, - modality, - dataset_name, - ), - ) + with multiprocessing_context(): + with multiprocessing.Manager() as manager: + output = manager.dict() + process = multiprocessing.Process( + target=_score, + args=( + synthesizer, + data, + metadata, + metrics, + output, + compute_quality_score, + compute_diagnostic_score, + modality, + dataset_name, + ), + ) - process.start() - process.join(timeout) - process.terminate() + process.start() + process.join(timeout) + process.terminate() - output = dict(output) - if output.get('timeout'): - LOGGER.error('Timeout running %s on dataset %s;', synthesizer['name'], dataset_name) + output = dict(output) + if output.get('timeout'): + LOGGER.error('Timeout running %s on dataset %s;', synthesizer['name'], dataset_name) - return output + return output def _format_output( diff --git a/sdgym/synthesizers/generate.py b/sdgym/synthesizers/generate.py index 63baced..718fe7b 100644 --- a/sdgym/synthesizers/generate.py +++ b/sdgym/synthesizers/generate.py @@ -124,7 +124,7 @@ def get_trained_synthesizer(self, data, metadata): obj: The trained synthesizer. """ - return get_trained_synthesizer_fn(data, metadata) + return self.synthesizer_fn['get_trained_synthesizer_fn'](data, metadata) def sample_from_synthesizer(self, synthesizer, num_samples): """Sample the desired number of samples from the given synthesizer. @@ -139,11 +139,22 @@ def sample_from_synthesizer(self, synthesizer, num_samples): pandas.DataFrame: The synthetic data. """ - return sample_from_synthesizer_fn(synthesizer, num_samples) - - NewSynthesizer.__name__ = f'Custom:{display_name}' - - return NewSynthesizer + return self.synthesizer_fn['sample_from_synthesizer_fn'](synthesizer, num_samples) + + CustomSynthesizer = type( + f'Custom:{display_name}', + (NewSynthesizer,), + { + 'synthesizer_fn': { + 'get_trained_synthesizer_fn': get_trained_synthesizer_fn, + 'sample_from_synthesizer_fn': sample_from_synthesizer_fn, + }, + }, + ) + CustomSynthesizer.__name__ = f'Custom:{display_name}' + CustomSynthesizer.__module__ = 'sdgym.synthesizers.generate' + globals()[f'Custom:{display_name}'] = CustomSynthesizer + return CustomSynthesizer def create_multi_table_synthesizer(