Skip to content

Commit

Permalink
Limit fixture scope to function and remove unnecessary copy.deepcopy
Browse files Browse the repository at this point in the history
  • Loading branch information
msschwartz21 committed Nov 7, 2024
1 parent d8c595a commit d328182
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions tests/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
TIMEOUT = 30


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def gt_data_2d():
path = "downloads/Fluo-N2DL-HeLa/01_GT/TRA"
return load_ctc_data(
Expand All @@ -27,7 +27,7 @@ def gt_data_2d():
)


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def gt_data_3d():
path = "downloads/Fluo-N3DH-CE/01_GT/TRA"
return load_ctc_data(
Expand All @@ -37,34 +37,34 @@ def gt_data_3d():
)


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def pred_data_2d(gt_data_2d):
# For now this is also GT data.
return copy.deepcopy(gt_data_2d)


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def pred_data_3d(gt_data_3d):
# For now this is also GT data.
return copy.deepcopy(gt_data_3d)


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def ctc_matched_2d(gt_data_2d, pred_data_2d):
return CTCMatcher().compute_mapping(gt_data_2d, pred_data_2d)


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def ctc_matched_3d(gt_data_3d, pred_data_3d):
return CTCMatcher().compute_mapping(gt_data_3d, pred_data_3d)


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def iou_matched_2d(gt_data_2d, pred_data_2d):
return IOUMatcher(iou_threshold=0.1).compute_mapping(gt_data_2d, pred_data_2d)


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def iou_matched_3d(gt_data_3d, pred_data_3d):
return IOUMatcher(iou_threshold=0.1).compute_mapping(gt_data_3d, pred_data_3d)

Expand Down Expand Up @@ -161,7 +161,7 @@ def test_ctc_metrics(benchmark, ctc_matched, request):
ctc_matched = request.getfixturevalue(ctc_matched)

def run_compute():
return CTCMetrics().compute(copy.deepcopy(ctc_matched))
return CTCMetrics().compute(ctc_matched)

benchmark.pedantic(run_compute, rounds=1, iterations=1)

Expand Down Expand Up @@ -196,6 +196,6 @@ def test_iou_div_metrics(benchmark, iou_matched, request):
iou_matched = request.getfixturevalue(iou_matched)

def run_compute():
return DivisionMetrics().compute(copy.deepcopy(iou_matched))
return DivisionMetrics().compute(iou_matched)

benchmark.pedantic(run_compute, rounds=1, iterations=1)

0 comments on commit d328182

Please sign in to comment.