Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensorflow test error as_list() is not defined on an unknow tensorshape #132

Open
sigpro opened this issue Dec 11, 2018 · 2 comments
Open

Comments

@sigpro
Copy link

sigpro commented Dec 11, 2018

 ERROR: test_multiple_batches_cpu (test_warpctc_op.WarpCTCTest)

Traceback (most recent call last):
File "/home/work/speech/libs/warp-ctc/tensorflow_binding/tests/test_warpctc_op.py", line 124, in test_m
ultiple_batches_cpu
self._test_multiple_batches(use_gpu=False)
File "/home/work/speech/libs/warp-ctc/tensorflow_binding/tests/test_warpctc_op.py", line 121, in test
multiple_batches
use_gpu=use_gpu)
File "/home/work/speech/libs/warp-ctc/tensorflow_binding/tests/test_warpctc_op.py", line 28, in _run_ct
c
self.assertShapeEqual(expected_costs, costs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/test_util.py", line 1639, in a
ssertShapeEqual
np_array.shape, tf_tensor.get_shape().as_list(), msg=msg)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/tensor_shape.py", line 903, in
as_list
raise ValueError("as_list() is not defined on an unknown TensorShape.")
ValueError: as_list() is not defined on an unknown TensorShape.

and when I print cost's shape ,it's unknown shape.And if I eval cost,the shape and value is right.So What's the ctc return?
My env is tensorflow 1.10.1 and ubuntu 16.04

@fginter
Copy link

fginter commented Jan 15, 2019

I had a similar problem with the missing shape and the latest tensorflow. The following modification to tensorflow_binding/src/warpctc_op.cc did the trick for me. You need to add the #include and then the .SetShapeFn... code and recompile. The code is copied and modified from here: https://www.tensorflow.org/guide/extend/op

#include "tensorflow/core/framework/shape_inference.h"

REGISTER_OP("WarpCTC")
    .Input("activations: float32")
    .Input("flat_labels: int32")
    .Input("label_lengths: int32")
    .Input("input_lengths: int32")
    .Attr("blank_label: int = 0")
    .Output("costs: float32")
    .Output("gradients: float32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
       c->set_output(0, c->input(3));
       c->set_output(1, c->input(0));
       return ::tensorflow::Status::OK();
     });

@NMAC427
Copy link

NMAC427 commented Apr 7, 2019

To get even better results I modified the answer by @fginter based on the original Tensorflow .SetShapeFn implementation of the CTCLoss function.

#include "tensorflow/core/framework/shape_inference.h"

using ::tensorflow::shape_inference::DimensionHandle;
using ::tensorflow::shape_inference::InferenceContext;
using ::tensorflow::shape_inference::ShapeHandle;
using ::tensorflow::Status;

REGISTER_OP("WarpCTC")
    .Input("activations: float32")
    .Input("flat_labels: int32")
    .Input("label_lengths: int32")
    .Input("input_lengths: int32")
    .Attr("blank_label: int = 0")
    .Output("costs: float32")
    .Output("gradients: float32")
    .SetShapeFn([](InferenceContext* c) {
        
        ShapeHandle activations;
        ShapeHandle flat_labels;
        ShapeHandle label_lengths;
        ShapeHandle input_lengths;
        
        TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &activations));
        TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &flat_labels));
        TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &label_lengths));
        TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &input_lengths));
        
        
        // Get batch size from inputs and sequence_length, and update inputs
        // with the merged batch_size since it is returned.
        DimensionHandle batch_size;
        TF_RETURN_IF_ERROR(
                           c->Merge(c->Dim(activations, 1), c->Dim(input_lengths, 0), &batch_size));
        TF_RETURN_IF_ERROR(c->ReplaceDim(activations, 1, batch_size, &activations));
        
        c->set_output(0, c->Vector(batch_size));
        c->set_output(1, activations);
        return Status::OK();
    });

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants