Skip to content

Commit

Permalink
Merge pull request #140 from byu-dml/numeric_targets
Browse files Browse the repository at this point in the history
Update validation to allow numeric targets
  • Loading branch information
rlaboulaye authored Jan 25, 2019
2 parents 39ae6a9 + c09ad3e commit 4c276fa
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
5 changes: 4 additions & 1 deletion metalearn/metafeatures/metafeatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,10 @@ def _validate_n_folds(
raise ValueError(f"`n_folds` must be an integer, not {n_folds}")
if n_folds < 2:
raise ValueError(f"`n_folds` must be >= 2, but was {n_folds}")
if not Y is None and metafeature_ids is not None:
if (Y is not None and
column_types is not None and
column_types[Y.name] != self.NUMERIC and
metafeature_ids is not None):
# when computing landmarking metafeatures, there must be at least
# n_folds instances of each class of Y
landmarking_mfs = self.list_metafeatures(group="landmarking")
Expand Down
36 changes: 36 additions & 0 deletions test/metalearn/metafeatures/test_metafeatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,42 @@ def test_no_targets(self):

self._report_test_failures(test_failures, test_name)

def test_numeric_targets(self):
""" Test Metafeatures().compute() with numeric targets
"""
test_failures = {}
test_name = inspect.stack()[0][3]
for dataset_filename, dataset in self.datasets.items():
metafeatures = Metafeatures()
column_types = dataset["column_types"].copy()
column_types[dataset["Y"].name] = metafeatures.NUMERIC
computed_mfs = metafeatures.compute(
X=dataset["X"], Y=pd.Series(np.random.rand(dataset["Y"].shape[0]),
name=dataset["Y"].name), seed=CORRECTNESS_SEED,
column_types=column_types
)
known_mfs = dataset["known_metafeatures"]
target_dependent_metafeatures = Metafeatures.list_metafeatures(
"target_dependent"
)
for mf_name in target_dependent_metafeatures:
known_mfs[mf_name] = {
Metafeatures.VALUE_KEY: Metafeatures.NUMERIC_TARGETS,
Metafeatures.COMPUTE_TIME_KEY: 0.
}

required_checks = {
self._check_correctness: [
computed_mfs, known_mfs, dataset_filename
],
self._check_compare_metafeature_lists: [
computed_mfs, known_mfs, dataset_filename
]
}
test_failures.update(self._perform_checks(required_checks))

self._report_test_failures(test_failures, test_name)

def test_request_metafeatures(self):
SUBSET_LENGTH = 20
test_failures = {}
Expand Down

0 comments on commit 4c276fa

Please sign in to comment.