Skip to content

Commit

Permalink
wrap rm_epsilon_iterative_tropical to python (#523)
Browse files Browse the repository at this point in the history
  • Loading branch information
qindazhu authored Dec 18, 2020
1 parent 029d90c commit 212cf60
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 55 deletions.
16 changes: 13 additions & 3 deletions k2/csrc/rm_epsilon.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1004,13 +1004,23 @@ void ComputeEpsilonClosureOneIter(FsaVec &epsilon_fsa, FsaVec *closure_fsa,
*arc_map = Index(expand_arc_map, arc_map_indexes);
}

void RemoveEpsilonsIterativeTropical(FsaVec &src_fsa, FsaVec *dest_fsa,
void RemoveEpsilonsIterativeTropical(FsaOrVec &src_fsa, FsaOrVec *dest_fsa,
Ragged<int32_t> *arc_map_out) {
NVTX_RANGE(K2_FUNC);
K2_CHECK(dest_fsa != nullptr && arc_map_out != nullptr);
K2_CHECK_EQ(src_fsa.NumAxes(), 3);
ContextPtr &c = src_fsa.Context();
K2_CHECK_GE(src_fsa.NumAxes(), 2);
K2_CHECK_LE(src_fsa.NumAxes(), 3);
if (src_fsa.NumAxes() == 2) {
// Turn single Fsa into FsaVec.
Fsa *srcs = &src_fsa;
FsaVec src_vec = CreateFsaVec(1, &srcs), dest_vec;
// Recurse..
RemoveEpsilonsIterativeTropical(src_vec, &dest_vec, arc_map_out);
*dest_fsa = GetFsaVecElement(dest_vec, 0);
return;
}

ContextPtr &c = src_fsa.Context();
Array1<int32_t> epsilons_state_map, epsilons_arc_map;
FsaVec epsilon_fsa;
ComputeEpsilonSubset(src_fsa, &epsilon_fsa, &epsilons_state_map,
Expand Down
10 changes: 5 additions & 5 deletions k2/csrc/rm_epsilon.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,12 @@ void ComputeEpsilonClosureOneIter(FsaVec &epsilon_fsa, FsaVec *closure_fsa,
Ragged<int32_t> *arc_map);

/*
Remove epsilons from FsaVec in `src_fsa`, producing an FsaVec `dest_fsa` which
is equivalent (in tropical semiring). Uses an iterative algorithm which tries
to minimize the number of arcs in the resulting FSA (epsilons are combined
with either preceding or following arcs).
Remove epsilons from FsaOrVec in `src_fsa`, producing an FsaOrVec `dest_fsa`
which is equivalent (in tropical semiring). Uses an iterative algorithm which
tries to minimize the number of arcs in the resulting FSA (epsilons are
combined with either preceding or following arcs).
*/
void RemoveEpsilonsIterativeTropical(FsaVec &src_fsa, FsaVec *dest_fsa,
void RemoveEpsilonsIterativeTropical(FsaOrVec &src_fsa, FsaOrVec *dest_fsa,
Ragged<int32_t> *arc_map);
} // namespace k2

Expand Down
70 changes: 47 additions & 23 deletions k2/csrc/rm_epsilon_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -367,30 +367,54 @@ TEST(RmEpsilon, RemoveEpsilonsIterativeTropicalSimple) {
4 5 -1 1
5
)";
Fsa fsa1 = FsaFromString(s1);
Fsa fsa2 = FsaFromString(s2);
Fsa *fsa_array[] = {&fsa1, &fsa2};
FsaVec fsa_vec = CreateFsaVec(2, &fsa_array[0]);
fsa_vec = fsa_vec.To(context);
{
// test with single Fsa
Fsa fsa = FsaFromString(s2);
fsa = fsa.To(context);

FsaVec dest;
Ragged<int32_t> arc_map;
RemoveEpsilonsIterativeTropical(fsa_vec, &dest, &arc_map);
EXPECT_EQ(dest.NumAxes(), 3);
EXPECT_EQ(arc_map.NumAxes(), 2);
K2_LOG(INFO) << dest;
K2_LOG(INFO) << arc_map;
Array1<int32_t> properties;
int32_t p;
GetFsaVecBasicProperties(dest, &properties, &p);
EXPECT_EQ(p & kFsaPropertiesEpsilonFree, kFsaPropertiesEpsilonFree);
bool log_semiring = false;
float beam = std::numeric_limits<float>::infinity();
fsa_vec = fsa_vec.To(GetCpuContext());
dest = dest.To(GetCpuContext());
EXPECT_TRUE(
IsRandEquivalent(fsa_vec, dest, log_semiring, beam, true, 0.001));
CheckArcMap(fsa_vec, dest, arc_map);
FsaVec dest;
Ragged<int32_t> arc_map;
RemoveEpsilonsIterativeTropical(fsa, &dest, &arc_map);
EXPECT_EQ(dest.NumAxes(), 2);
EXPECT_EQ(arc_map.NumAxes(), 2);
K2_LOG(INFO) << dest;
K2_LOG(INFO) << arc_map;
int32_t p = GetFsaBasicProperties(dest);
EXPECT_EQ(p & kFsaPropertiesEpsilonFree, kFsaPropertiesEpsilonFree);
bool log_semiring = false;
float beam = std::numeric_limits<float>::infinity();
fsa = fsa.To(GetCpuContext());
dest = dest.To(GetCpuContext());
EXPECT_TRUE(IsRandEquivalent(fsa, dest, log_semiring, beam, true, 0.001));
CheckArcMap(fsa, dest, arc_map);
}
{
// test with FsaVec
Fsa fsa1 = FsaFromString(s1);
Fsa fsa2 = FsaFromString(s2);
Fsa *fsa_array[] = {&fsa1, &fsa2};
FsaVec fsa_vec = CreateFsaVec(2, &fsa_array[0]);
fsa_vec = fsa_vec.To(context);

FsaVec dest;
Ragged<int32_t> arc_map;
RemoveEpsilonsIterativeTropical(fsa_vec, &dest, &arc_map);
EXPECT_EQ(dest.NumAxes(), 3);
EXPECT_EQ(arc_map.NumAxes(), 2);
K2_LOG(INFO) << dest;
K2_LOG(INFO) << arc_map;
Array1<int32_t> properties;
int32_t p;
GetFsaVecBasicProperties(dest, &properties, &p);
EXPECT_EQ(p & kFsaPropertiesEpsilonFree, kFsaPropertiesEpsilonFree);
bool log_semiring = false;
float beam = std::numeric_limits<float>::infinity();
fsa_vec = fsa_vec.To(GetCpuContext());
dest = dest.To(GetCpuContext());
EXPECT_TRUE(
IsRandEquivalent(fsa_vec, dest, log_semiring, beam, true, 0.001));
CheckArcMap(fsa_vec, dest, arc_map);
}
}
}

Expand Down
44 changes: 24 additions & 20 deletions k2/python/csrc/torch/fsa_algo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "k2/csrc/fsa_algo.h"
#include "k2/csrc/fsa_utils.h"
#include "k2/csrc/host_shim.h"
#include "k2/csrc/rm_epsilon.h"
#include "k2/python/csrc/torch/fsa_algo.h"
#include "k2/python/csrc/torch/torch_util.h"

Expand Down Expand Up @@ -99,9 +100,8 @@ static void PybindLinearFsa(py::module &m) {
static void PybindIntersect(py::module &m) {
m.def(
"intersect", // works only on CPU
[](FsaOrVec &a_fsas, int32_t properties_a,
FsaOrVec &b_fsas, int32_t properties_b,
bool treat_epsilons_specially = true,
[](FsaOrVec &a_fsas, int32_t properties_a, FsaOrVec &b_fsas,
int32_t properties_b, bool treat_epsilons_specially = true,
bool need_arc_map =
true) -> std::tuple<FsaOrVec, torch::optional<torch::Tensor>,
torch::optional<torch::Tensor>> {
Expand All @@ -125,9 +125,8 @@ static void PybindIntersect(py::module &m) {
}
return std::make_tuple(ans, a_tensor, b_tensor);
},
py::arg("a_fsas"), py::arg("properties_a"),
py::arg("b_fsas"), py::arg("properties_b"),
py::arg("treat_epsilons_specially") = true,
py::arg("a_fsas"), py::arg("properties_a"), py::arg("b_fsas"),
py::arg("properties_b"), py::arg("treat_epsilons_specially") = true,
py::arg("need_arc_map") = true,
R"(
If treat_epsilons_specially it will treat epsilons as epsilons; otherwise
Expand Down Expand Up @@ -161,26 +160,22 @@ static void PybindIntersectDensePruned(py::module &m) {
py::arg("max_active_states"));
}


static void PybindIntersectDense(py::module &m) {
m.def(
"intersect_dense",
[](FsaVec &a_fsas, DenseFsaVec &b_fsas,
float output_beam)
[](FsaVec &a_fsas, DenseFsaVec &b_fsas, float output_beam)
-> std::tuple<FsaVec, torch::Tensor, torch::Tensor> {
Array1<int32_t> arc_map_a;
Array1<int32_t> arc_map_b;
FsaVec out;

IntersectDense(a_fsas, b_fsas, output_beam, &out,
&arc_map_a, &arc_map_b);
IntersectDense(a_fsas, b_fsas, output_beam, &out, &arc_map_a,
&arc_map_b);
return std::make_tuple(out, ToTensor(arc_map_a), ToTensor(arc_map_b));
},
py::arg("a_fsas"), py::arg("b_fsas"),
py::arg("output_beam"));
py::arg("a_fsas"), py::arg("b_fsas"), py::arg("output_beam"));
}


static void PybindConnect(py::module &m) {
m.def(
"connect",
Expand Down Expand Up @@ -275,9 +270,18 @@ static void PybindRemoveEpsilon(py::module &m) {
"remove_epsilon",
[](FsaOrVec &src) -> std::pair<FsaOrVec, Ragged<int32_t>> {
FsaOrVec dest;
Ragged<int32_t> arc_derivs;
RemoveEpsilon(src, &dest, &arc_derivs);
return std::make_pair(dest, arc_derivs);
Ragged<int32_t> arc_map;
RemoveEpsilon(src, &dest, &arc_map);
return std::make_pair(dest, arc_map);
},
py::arg("src"));
m.def(
"remove_epsilons_iterative_tropical",
[](FsaOrVec &src) -> std::pair<FsaOrVec, Ragged<int32_t>> {
FsaOrVec dest;
Ragged<int32_t> arc_map;
RemoveEpsilonsIterativeTropical(src, &dest, &arc_map);
return std::make_pair(dest, arc_map);
},
py::arg("src"));
}
Expand All @@ -287,9 +291,9 @@ static void PybindDeterminize(py::module &m) {
"determinize",
[](FsaOrVec &src) -> std::pair<FsaOrVec, Ragged<int32_t>> {
FsaOrVec dest;
Ragged<int32_t> arc_derivs;
Determinize(src, &dest, &arc_derivs);
return std::make_pair(dest, arc_derivs);
Ragged<int32_t> arc_map;
Determinize(src, &dest, &arc_map);
return std::make_pair(dest, arc_map);
},
py::arg("src"));
}
Expand Down
2 changes: 2 additions & 0 deletions k2/python/k2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .fsa_algo import intersect
from .fsa_algo import linear_fsa
from .fsa_algo import remove_epsilon
from .fsa_algo import remove_epsilons_iterative_tropical
from .fsa_algo import shortest_path
from .fsa_algo import top_sort
from .fsa_properties import to_str as properties_to_str
Expand Down Expand Up @@ -67,6 +68,7 @@
'properties_to_str',
'random_ragged_shape',
'remove_epsilon',
'remove_epsilons_iterative_tropical',
'shortest_path',
'simple_ragged_index_select',
'to_dot',
Expand Down
41 changes: 37 additions & 4 deletions k2/python/k2/fsa_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,43 @@ def remove_epsilon(fsa: Fsa) -> Fsa:
if properties is not None and properties & fsa_properties.EPSILON_FREE != 0:
return fsa

ragged_arc, arc_derivs = _k2.remove_epsilon(fsa.arcs)
ragged_arc, arc_map = _k2.remove_epsilon(fsa.arcs)
aux_labels = None
if hasattr(fsa, 'aux_labels'):
aux_labels = index_attr(fsa.aux_labels, arc_derivs)
aux_labels = index_attr(fsa.aux_labels, arc_map)
out_fsa = Fsa(ragged_arc, aux_labels)

for name, value in fsa.named_non_tensor_attr():
setattr(out_fsa, name, value)

return out_fsa


def remove_epsilons_iterative_tropical(fsa: Fsa) -> Fsa:
'''Remove epsilons (symbol zero) in the input Fsa.
Caution:
It doesn't support autograd for now.
Args:
fsa:
The input FSA. It can be either a single FSA or an FsaVec.
It can be either top-sorted or non-top-sorted.
Returns:
The result Fsa, it's equivalent to the input `fsa` under
tropical semiring but will be epsilon-free.
It will be the same as the input `fsa` if the input
`fsa` is epsilon-free. Otherwise, a new epsilon-free fsa
is returned and the input `fsa` is NOT modified.
'''
properties = getattr(fsa, 'properties', None)
if properties is not None and properties & fsa_properties.EPSILON_FREE != 0:
return fsa

ragged_arc, arc_map = _k2.remove_epsilons_iterative_tropical(fsa.arcs)
aux_labels = None
if hasattr(fsa, 'aux_labels'):
aux_labels = index_attr(fsa.aux_labels, arc_map)
out_fsa = Fsa(ragged_arc, aux_labels)

for name, value in fsa.named_non_tensor_attr():
Expand Down Expand Up @@ -320,10 +353,10 @@ def determinize(fsa: Fsa) -> Fsa:
and properties & fsa_properties.ARC_SORTED_AND_DETERMINISTIC != 0: # noqa
return fsa

ragged_arc, arc_derivs = _k2.determinize(fsa.arcs)
ragged_arc, arc_map = _k2.determinize(fsa.arcs)
aux_labels = None
if hasattr(fsa, 'aux_labels'):
aux_labels = index_attr(fsa.aux_labels, arc_derivs)
aux_labels = index_attr(fsa.aux_labels, arc_map)
out_fsa = Fsa(ragged_arc, aux_labels)

for name, value in fsa.named_non_tensor_attr():
Expand Down
23 changes: 23 additions & 0 deletions k2/python/tests/remove_epsilon_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,28 @@ def test1(self):
self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring))


class TestRemoveEpsilonsIterativeTropical(unittest.TestCase):

def test1(self):
s = '''
0 1 0 1 1
1 2 0 2 1
2 3 0 3 1
3 4 4 4 1
3 5 -1 5 1
4 5 -1 6 1
5
'''
fsa = k2.Fsa.from_str(s)
print(fsa.aux_labels)
prop = fsa.properties
self.assertFalse(prop & k2.fsa_properties.EPSILON_FREE)
dest = k2.remove_epsilons_iterative_tropical(fsa)
prop = dest.properties
self.assertTrue(prop & k2.fsa_properties.EPSILON_FREE)
log_semiring = False
self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring))


if __name__ == '__main__':
unittest.main()

0 comments on commit 212cf60

Please sign in to comment.