Skip to content

Commit

Permalink
add unit test for textsnake (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
cuhk-hbsun authored Apr 13, 2021
1 parent 344cc9a commit 5244984
Showing 1 changed file with 52 additions and 51 deletions.
103 changes: 52 additions & 51 deletions tests/test_models/test_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,54 +316,55 @@ def test_dbnet(cfg_file):
detector.show_result(img, results)


# @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
# @pytest.mark.parametrize(
# 'cfg_file', ['textdet/textsnake/'
# 'textsnake_r50_fpn_unet_1200e_ctw1500.py'])
# def test_textsnake(cfg_file):
# model = _get_detector_cfg(cfg_file)
# model['pretrained'] = None
# model['backbone']['norm_cfg']['type'] = 'BN'

# from mmocr.models import build_detector
# detector = build_detector(model)
# detector = detector.cuda()
# input_shape = (1, 3, 64, 64)
# num_kernels = 1
# mm_inputs = _demo_mm_inputs(num_kernels, input_shape)

# imgs = mm_inputs.pop('imgs')
# imgs = imgs.cuda()
# img_metas = mm_inputs.pop('img_metas')
# gt_text_mask = mm_inputs.pop('gt_text_mask')
# gt_center_region_mask = mm_inputs.pop('gt_center_region_mask')
# gt_mask = mm_inputs.pop('gt_mask')
# gt_radius_map = mm_inputs.pop('gt_radius_map')
# gt_sin_map = mm_inputs.pop('gt_sin_map')
# gt_cos_map = mm_inputs.pop('gt_cos_map')

# # Test forward train
# losses = detector.forward(
# imgs,
# img_metas,
# gt_text_mask=gt_text_mask,
# gt_center_region_mask=gt_center_region_mask,
# gt_mask=gt_mask,
# gt_radius_map=gt_radius_map,
# gt_sin_map=gt_sin_map,
# gt_cos_map=gt_cos_map)
# assert isinstance(losses, dict)

# # Test forward test
# with torch.no_grad():
# img_list = [g[None, :] for g in imgs]
# batch_results = []
# for one_img, one_meta in zip(img_list, img_metas):
# result = detector.forward([one_img], [[one_meta]],
# return_loss=False)
# batch_results.append(result)

# # Test show result
# results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
# img = np.random.rand(5, 5)
# detector.show_result(img, results)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.parametrize(
'cfg_file',
['textdet/textsnake/'
'textsnake_r50_fpn_unet_1200e_ctw1500.py'])
def test_textsnake(cfg_file):
model = _get_detector_cfg(cfg_file)
model['pretrained'] = None
model['backbone']['norm_cfg']['type'] = 'BN'

from mmocr.models import build_detector
detector = build_detector(model)
detector = detector.cuda()
input_shape = (1, 3, 64, 64)
num_kernels = 1
mm_inputs = _demo_mm_inputs(num_kernels, input_shape)

imgs = mm_inputs.pop('imgs')
imgs = imgs.cuda()
img_metas = mm_inputs.pop('img_metas')
gt_text_mask = mm_inputs.pop('gt_text_mask')
gt_center_region_mask = mm_inputs.pop('gt_center_region_mask')
gt_mask = mm_inputs.pop('gt_mask')
gt_radius_map = mm_inputs.pop('gt_radius_map')
gt_sin_map = mm_inputs.pop('gt_sin_map')
gt_cos_map = mm_inputs.pop('gt_cos_map')

# Test forward train
losses = detector.forward(
imgs,
img_metas,
gt_text_mask=gt_text_mask,
gt_center_region_mask=gt_center_region_mask,
gt_mask=gt_mask,
gt_radius_map=gt_radius_map,
gt_sin_map=gt_sin_map,
gt_cos_map=gt_cos_map)
assert isinstance(losses, dict)

# Test forward test
# with torch.no_grad():
# img_list = [g[None, :] for g in imgs]
# batch_results = []
# for one_img, one_meta in zip(img_list, img_metas):
# result = detector.forward([one_img], [[one_meta]],
# return_loss=False)
# batch_results.append(result)

# Test show result
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
img = np.random.rand(5, 5)
detector.show_result(img, results)

0 comments on commit 5244984

Please sign in to comment.