You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tried running pytest -s text_recognizer/evaluation/evaluate_paragraph_text_recognizer.py and I got the following error
> output_tokens[:, Sy : Sy + 1] = output[-1:] # Set the last output token
E RuntimeError: The expanded size of the tensor (1) must match the existing size (16) at non-singleton dimension 1. Target sizes: [16, 1]. Tensor sizes: [16]
That I fixed by changing the file /text_recognizer/models/resnet_transformer.py L176 to:
output_tokens[:, Sy : Sy + 1] = output[-1:].reshape((-1,1)) # Set the last output token
Though not quite sure of the larger ramifications of this fix.
The full traceback is:
(fsdl-text-recognizer-2021) ubuntu@pc:~/pytorch-lab/lab8_eds$ PYTHONPATH=. pytest -s text_recognizer/evaluation/evaluate_paragraph_text_recognizer.py
====================================================================================================== test session starts ======================================================================================================
platform linux -- Python 3.6.13, pytest-6.2.3, py-1.10.0, pluggy-0.13.1
rootdir: /home/ubuntu/pytorch-lab/lab8_eds, configfile: setup.cfg
plugins: anyio-2.2.0
collected 1 item
text_recognizer/evaluation/evaluate_paragraph_text_recognizer.py IAMParagraphs.setup(None): Loading IAM paragraph regions and lines...
Testing: 0it [00:00, ?it/s]F
=========================================================================================================== FAILURES ============================================================================================================
_______________________________________________________________________________________ TestEvaluateParagraphTextRecognizer.test_evaluate _______________________________________________________________________________________
self = <evaluate_paragraph_text_recognizer.TestEvaluateParagraphTextRecognizer testMethod=test_evaluate>
@torch.no_grad()
def test_evaluate(self):
dataset = IAMParagraphs(argparse.Namespace(batch_size=16, num_workers=10))
dataset.prepare_data()
dataset.setup()
text_recog = ParagraphTextRecognizer()
trainer = pl.Trainer(gpus=1)
start_time = time.time()
> metrics = trainer.test(text_recog.lit_model, datamodule=dataset)
text_recognizer/evaluation/evaluate_paragraph_text_recognizer.py:28:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../anaconda3/envs/fsdl-text-recognizer-2021/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py:910: in test
results = self.__test_given_model(model, test_dataloaders)
../../anaconda3/envs/fsdl-text-recognizer-2021/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py:970: in __test_given_model
results = self.fit(model)
../../anaconda3/envs/fsdl-text-recognizer-2021/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py:499: in fit
self.dispatch()
../../anaconda3/envs/fsdl-text-recognizer-2021/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py:540: in dispatch
self.accelerator.start_testing(self)
../../anaconda3/envs/fsdl-text-recognizer-2021/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py:76: in start_testing
self.training_type_plugin.start_testing(trainer)
../../anaconda3/envs/fsdl-text-recognizer-2021/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py:118: in start_testing
self._results = trainer.run_test()
../../anaconda3/envs/fsdl-text-recognizer-2021/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py:786: in run_test
eval_loop_results, _ = self.run_evaluation()
../../anaconda3/envs/fsdl-text-recognizer-2021/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py:725: in run_evaluation
output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)
../../anaconda3/envs/fsdl-text-recognizer-2021/lib/python3.6/site-packages/pytorch_lightning/trainer/evaluation_loop.py:162: in evaluation_step
output = self.trainer.accelerator.test_step(args)
../../anaconda3/envs/fsdl-text-recognizer-2021/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py:195: in test_step
return self.training_type_plugin.test_step(*args)
../../anaconda3/envs/fsdl-text-recognizer-2021/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py:134: in test_step
return self.lightning_module.test_step(*args, **kwargs)
text_recognizer/lit_models/transformer.py:63: in test_step
pred = self.model.predict(x)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = ResnetTransformer(
(resnet): Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), b... (dropout2): Dropout(p=0.4, inplace=False)
(dropout3): Dropout(p=0.4, inplace=False)
)
)
)
)
x = tensor([[[-0.9087, 1.0898, -0.4428, ..., -0.5195, -1.6015, 0.9281],
[-1.4115, 0.9965, -0.9382, ..., -0.8...49, -0.8971, 1.5430],
[-1.8523, -0.1957, -0.5435, ..., 0.7036, -0.6454, 0.3995]]],
device='cuda:0')
def predict(self, x: torch.Tensor) -> torch.Tensor:
"""
Parameters
----------
x
(B, H, W) image
Returns
-------
torch.Tensor
(B, Sy) with elements in [0, C-1] where C is num_classes
"""
B = x.shape[0]
S = self.max_output_length
x = self.encode(x) # (Sx, B, E)
output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, S)
output_tokens[:, 0] = self.start_token # Set start token
for Sy in range(1, S):
y = output_tokens[:, :Sy] # (B, Sy)
output = self.decode(x, y) # (Sy, B, C)
output = torch.argmax(output, dim=-1) # (Sy, B)
> output_tokens[:, Sy : Sy + 1] = output[-1:] # Set the last output token
E RuntimeError: The expanded size of the tensor (1) must match the existing size (16) at non-singleton dimension 1. Target sizes: [16, 1]. Tensor sizes: [16]
text_recognizer/models/resnet_transformer.py:176: RuntimeError
------------------------------------------------------------------------------------------------------- Captured log call -------------------------------------------------------------------------------------------------------
INFO pytorch_lightning.utilities.distributed:distributed.py:56 GPU available: True, used: True
INFO pytorch_lightning.utilities.distributed:distributed.py:56 TPU available: False, using: 0 TPU cores
INFO pytorch_lightning.accelerators.gpu:gpu.py:51 LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
======================================================================================================= warnings summary ========================================================================================================
../../anaconda3/envs/fsdl-text-recognizer-2021/lib/python3.6/site-packages/fsspec/__init__.py:43
/home/ubuntu/anaconda3/envs/fsdl-text-recognizer-2021/lib/python3.6/site-packages/fsspec/__init__.py:43: DeprecationWarning: SelectableGroups dict interface is deprecated. Use select.
for spec in entry_points.get("fsspec.specs", []):
text_recognizer/evaluation/evaluate_paragraph_text_recognizer.py::TestEvaluateParagraphTextRecognizer::test_evaluate
text_recognizer/evaluation/evaluate_paragraph_text_recognizer.py::TestEvaluateParagraphTextRecognizer::test_evaluate
text_recognizer/evaluation/evaluate_paragraph_text_recognizer.py::TestEvaluateParagraphTextRecognizer::test_evaluate
text_recognizer/evaluation/evaluate_paragraph_text_recognizer.py::TestEvaluateParagraphTextRecognizer::test_evaluate
text_recognizer/evaluation/evaluate_paragraph_text_recognizer.py::TestEvaluateParagraphTextRecognizer::test_evaluate
/home/ubuntu/anaconda3/envs/fsdl-text-recognizer-2021/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:52: DeprecationWarning: This `Metric` was deprecated since v1.3.0 in favor of `torchmetrics.Metric`. It will be removed in v1.5.0
warnings.warn(*args, **kwargs)
-- Docs: https://docs.pytest.org/en/stable/warnings.html
==================================================================================================== short test summary info ====================================================================================================
FAILED text_recognizer/evaluation/evaluate_paragraph_text_recognizer.py::TestEvaluateParagraphTextRecognizer::test_evaluate - RuntimeError: The expanded size of the tensor (1) must match the existing size (16) at non-singl...
================================================================================================ 1 failed, 6 warnings in 45.72s =================================================================================================
Testing: 0%| | 0/15 [00:00<?, ?it/s]
The text was updated successfully, but these errors were encountered:
I tried running
pytest -s text_recognizer/evaluation/evaluate_paragraph_text_recognizer.py
and I got the following errorThat I fixed by changing the file
/text_recognizer/models/resnet_transformer.py
L176 to:Though not quite sure of the larger ramifications of this fix.
The full traceback is:
The text was updated successfully, but these errors were encountered: