Skip to content

Commit

Permalink
Apply black
Browse files Browse the repository at this point in the history
  • Loading branch information
rly committed Apr 4, 2024
1 parent 990a0d7 commit af9cab3
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 29 deletions.
4 changes: 1 addition & 3 deletions hdmf_ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ def __get_resources():
__schema_dir = "schema"

ret = dict()
ret["namespace_path"] = str(
__location_of_this_file / __schema_dir / __core_ns_file_name
)
ret["namespace_path"] = str(__location_of_this_file / __schema_dir / __core_ns_file_name)
return ret


Expand Down
12 changes: 3 additions & 9 deletions hdmf_ml/results_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,13 @@ def __add_col(self, **kwargs):
# dim2 is a 1D array, so shape is N-D
shape = (self.n_samples, *dim2)
else:
ValueError(
f"Unrecognized type for dim2: {type(dim2)} - expected integer or 1-D array-like"
)
ValueError(f"Unrecognized type for dim2: {type(dim2)} - expected integer or 1-D array-like")

# create empty DataIO object
data = H5DataIO(shape=shape, dtype=dtype)

if name in self:
raise ValueError(
f"Column '{name}' already exists in ResultsTable '{self.name}'"
)
raise ValueError(f"Column '{name}' already exists in ResultsTable '{self.name}'")
if len(self.id) == 0:
self.id.extend(np.arange(len(data)))
elif len(self.id) != len(data):
Expand All @@ -124,9 +120,7 @@ def __add_col(self, **kwargs):
f"existings columns of length {len(self.id)}"
)

self.add_column(
data=data, name=name, description=description, col_cls=col_cls, **kwargs
)
self.add_column(data=data, name=name, description=description, col_cls=col_cls, **kwargs)

if self.__n_samples is None:
self.__n_samples = len(data)
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ exclude_lines = [

[tool.setuptools_scm]

# [tool.black]
# line-length = 120
# preview = true
# exclude = ".git|.mypy_cache|.tox|.venv|venv|.ipynb_checkpoints|_build/|dist/|__pypackages__|.ipynb"
[tool.black]
line-length = 120
preview = true
exclude = ".git|.mypy_cache|.tox|.venv|venv|.ipynb_checkpoints|_build/|dist/|__pypackages__|.ipynb"
# force-exclude = "docs/gallery"

[tool.ruff]
Expand Down
17 changes: 4 additions & 13 deletions tests/test_results_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@ def get_hdf5io(self):
def test_add_col_diff_len(self):
rt = ResultsTable(name="foo", description="a test results table")
rt.add_tvt_split([0, 1, 2, 0, 1])
msg = (
"New column true_label of length 4 is not the same length as "
"existings columns of length 5"
)
msg = "New column true_label of length 4 is not the same length as " "existings columns of length 5"
with self.assertRaisesRegex(ValueError, msg):
rt.add_true_label([0, 0, 0, 1])

Expand Down Expand Up @@ -78,9 +75,7 @@ def test_add_cv_split(self):

def test_add_cv_split_bad_splits(self):
rt = ResultsTable(name="foo", description="a test results table")
with self.assertRaisesRegex(
ValueError, "Got non-integer data for cross-validation split"
):
with self.assertRaisesRegex(ValueError, "Got non-integer data for cross-validation split"):
rt.add_cv_split([0.0, 0.1, 0.2, 0.3, 0.4])

def test_add_true_label(self):
Expand All @@ -99,9 +94,7 @@ def test_add_true_label_str(self):

def test_add_predicted_probability(self):
rt = ResultsTable(name="foo", description="a test results table")
rt.add_predicted_probability(
[[0.1, 0.9], [0.2, 0.8], [0.3, 0.7], [0.4, 0.6], [0.5, 0.5]]
)
rt.add_predicted_probability([[0.1, 0.9], [0.2, 0.8], [0.3, 0.7], [0.4, 0.6], [0.5, 0.5]])
with self.get_hdf5io() as io:
io.write(rt)

Expand Down Expand Up @@ -137,8 +130,6 @@ def test_add_topk_classes(self):

def test_add_topk_probabilities(self):
rt = ResultsTable(name="foo", description="a test results table")
rt.add_topk_probabilities(
[[0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]
)
rt.add_topk_probabilities([[0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
with self.get_hdf5io() as io:
io.write(rt)

0 comments on commit af9cab3

Please sign in to comment.