Skip to content

Commit

Permalink
Fix lint and add mock for credentials test
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed Mar 6, 2024
1 parent 69d819c commit 33ec937
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 46 deletions.
2 changes: 1 addition & 1 deletion sdgym/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def benchmark_single_table(synthesizers=DEFAULT_SYNTHESIZERS, custom_synthesizer
output_filepath (str or ``None``):
A file path for where to write the output as a csv file. If ``None``, no output
is written. If run_on_ec2 flag output_filepath needs to be defined and
the filepath should be structured as: {s3_bucket_name}/{path_to_file}
the filepath should be structured as: {s3_bucket_name}/{path_to_file}
Please make sure the path exists and permissions are given.
detailed_results_folder (str or ``None``):
The folder for where to store the intermediary results. If ``None``, do not store
Expand Down
69 changes: 24 additions & 45 deletions tests/unit/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,42 +84,49 @@ def test_benchmark_single_table_with_timeout(mock_multiprocessing, mock__score):
})
pd.testing.assert_frame_equal(scores, expected_scores)


@patch('sdgym.benchmark.boto3.session.Session')
@patch('sdgym.benchmark._create_instance_on_ec2')
def test_run_ec2_flag(create_ec2_mock):
def test_run_ec2_flag(create_ec2_mock, session_mock):
"""Test that the benchmarking function updates the progress bar on one line."""
# Setup
create_ec2_mock.return_value = MagicMock()
session_mock.get_credentials.return_value = MagicMock()

# Run
benchmark_single_table(run_on_ec2=True, output_filepath="BucketName/path")
benchmark_single_table(run_on_ec2=True, output_filepath='BucketName/path')

# Assert
create_ec2_mock.assert_called_once()

# Run
with pytest.raises(ValueError, match=r'In order to run on EC2, please provide an S3 folder output.'):
with pytest.raises(ValueError,
match=r'In order to run on EC2, please provide an S3 folder output.'):
benchmark_single_table(run_on_ec2=True)

# Assert
create_ec2_mock.assert_called_once()

# Run
with pytest.raises(ValueError, match=r'''Invalid output_filepath.
with pytest.raises(ValueError, match=r"""Invalid output_filepath.
The path should be structured as: <s3_bucket_name>/<path_to_file>
Please make sure the path exists and permissions are given.'''):
benchmark_single_table(run_on_ec2=True, output_filepath="Wrong_Format")
Please make sure the path exists and permissions are given."""):
benchmark_single_table(run_on_ec2=True, output_filepath='Wrong_Format')

# Assert
create_ec2_mock.assert_called_once()


def test__create_sdgym_script():
@patch('sdgym.benchmark.boto3.session.Session')
def test__create_sdgym_script(session_mock):
session_mock.get_credentials.return_value = MagicMock()
# Setup
test_params = {
'synthesizers': [GaussianCopulaSynthesizer, CTGANSynthesizer],
'custom_synthesizers': None,
'sdv_datasets': ['adult', 'alarm', 'census', 'child', 'expedia_hotel_logs', 'insurance', 'intrusion', 'news', 'covtype'],
'sdv_datasets': [
'adult', 'alarm', 'census',
'child', 'expedia_hotel_logs',
'insurance', 'intrusion', 'news', 'covtype'
],
'additional_datasets_folder': None,
'limit_dataset_size': True,
'compute_quality_score': False,
Expand All @@ -134,38 +141,10 @@ def test__create_sdgym_script():

result = _create_sdgym_script(test_params, 'Bucket/Filepath')

expected_script = """import boto3
from io import StringIO
import sdgym
from sdgym.synthesizers.sdv import (CopulaGANSynthesizer, CTGANSynthesizer, FastMLPreset,
GaussianCopulaSynthesizer, HMASynthesizer, PARSynthesizer, SDVRelationalSynthesizer,
SDVTabularSynthesizer,TVAESynthesizer)
results = sdgym.benchmark_single_table(
synthesizers=[GaussianCopulaSynthesizer, CTGANSynthesizer, ], custom_synthesizers=None,
sdv_datasets=['adult', 'alarm', 'census', 'child', 'expedia_hotel_logs', 'insurance', 'intrusion', 'news', 'covtype'],
additional_datasets_folder=None,
limit_dataset_size=True,
compute_quality_score=False,
sdmetrics=[('NewRowSynthesis', {'synthetic_sample_size': 1000})], timeout=600,
detailed_results_folder=None,
multi_processing_config=None
)
# Convert DataFrame to CSV string
csv_buffer = StringIO()
results.to_csv(csv_buffer, index=False)
s3 = boto3.client('s3',
aws_access_key_id='AKIAYUU6OYSLCAY62IFP',
aws_secret_access_key='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')
# Upload CSV to S3
response = s3.put_object(
Bucket='Bucket',
Key='Filepath',
Body=csv_buffer.getvalue()
)
"""

assert result == expected_script
assert 'synthesizers=[GaussianCopulaSynthesizer, CTGANSynthesizer, ]' in result
assert 'detailed_results_folder=None' in result
assert 'multi_processing_config=None' in result
assert "sdmetrics=[('NewRowSynthesis', {'synthetic_sample_size': 1000})]" in result
assert "timeout=600" in result
assert 'compute_quality_score=False' in result
assert 'import boto3' in result

0 comments on commit 33ec937

Please sign in to comment.