From 212cf60f7cf4b418345e822fa75aab0dd2c79542 Mon Sep 17 00:00:00 2001 From: Haowen Qiu Date: Fri, 18 Dec 2020 11:25:50 +0800 Subject: [PATCH] wrap rm_epsilon_iterative_tropical to python (#523) --- k2/csrc/rm_epsilon.cu | 16 ++++-- k2/csrc/rm_epsilon.h | 10 ++-- k2/csrc/rm_epsilon_test.cu | 70 +++++++++++++++++--------- k2/python/csrc/torch/fsa_algo.cu | 44 ++++++++-------- k2/python/k2/__init__.py | 2 + k2/python/k2/fsa_algo.py | 41 +++++++++++++-- k2/python/tests/remove_epsilon_test.py | 23 +++++++++ 7 files changed, 151 insertions(+), 55 deletions(-) diff --git a/k2/csrc/rm_epsilon.cu b/k2/csrc/rm_epsilon.cu index 6939eb7be..e78f52118 100644 --- a/k2/csrc/rm_epsilon.cu +++ b/k2/csrc/rm_epsilon.cu @@ -1004,13 +1004,23 @@ void ComputeEpsilonClosureOneIter(FsaVec &epsilon_fsa, FsaVec *closure_fsa, *arc_map = Index(expand_arc_map, arc_map_indexes); } -void RemoveEpsilonsIterativeTropical(FsaVec &src_fsa, FsaVec *dest_fsa, +void RemoveEpsilonsIterativeTropical(FsaOrVec &src_fsa, FsaOrVec *dest_fsa, Ragged *arc_map_out) { NVTX_RANGE(K2_FUNC); K2_CHECK(dest_fsa != nullptr && arc_map_out != nullptr); - K2_CHECK_EQ(src_fsa.NumAxes(), 3); - ContextPtr &c = src_fsa.Context(); + K2_CHECK_GE(src_fsa.NumAxes(), 2); + K2_CHECK_LE(src_fsa.NumAxes(), 3); + if (src_fsa.NumAxes() == 2) { + // Turn single Fsa into FsaVec. + Fsa *srcs = &src_fsa; + FsaVec src_vec = CreateFsaVec(1, &srcs), dest_vec; + // Recurse.. + RemoveEpsilonsIterativeTropical(src_vec, &dest_vec, arc_map_out); + *dest_fsa = GetFsaVecElement(dest_vec, 0); + return; + } + ContextPtr &c = src_fsa.Context(); Array1 epsilons_state_map, epsilons_arc_map; FsaVec epsilon_fsa; ComputeEpsilonSubset(src_fsa, &epsilon_fsa, &epsilons_state_map, diff --git a/k2/csrc/rm_epsilon.h b/k2/csrc/rm_epsilon.h index 5b7b1aea0..2a6dd4425 100644 --- a/k2/csrc/rm_epsilon.h +++ b/k2/csrc/rm_epsilon.h @@ -141,12 +141,12 @@ void ComputeEpsilonClosureOneIter(FsaVec &epsilon_fsa, FsaVec *closure_fsa, Ragged *arc_map); /* - Remove epsilons from FsaVec in `src_fsa`, producing an FsaVec `dest_fsa` which - is equivalent (in tropical semiring). Uses an iterative algorithm which tries - to minimize the number of arcs in the resulting FSA (epsilons are combined - with either preceding or following arcs). + Remove epsilons from FsaOrVec in `src_fsa`, producing an FsaOrVec `dest_fsa` + which is equivalent (in tropical semiring). Uses an iterative algorithm which + tries to minimize the number of arcs in the resulting FSA (epsilons are + combined with either preceding or following arcs). */ -void RemoveEpsilonsIterativeTropical(FsaVec &src_fsa, FsaVec *dest_fsa, +void RemoveEpsilonsIterativeTropical(FsaOrVec &src_fsa, FsaOrVec *dest_fsa, Ragged *arc_map); } // namespace k2 diff --git a/k2/csrc/rm_epsilon_test.cu b/k2/csrc/rm_epsilon_test.cu index f9a28dcd6..c9ff27c21 100644 --- a/k2/csrc/rm_epsilon_test.cu +++ b/k2/csrc/rm_epsilon_test.cu @@ -367,30 +367,54 @@ TEST(RmEpsilon, RemoveEpsilonsIterativeTropicalSimple) { 4 5 -1 1 5 )"; - Fsa fsa1 = FsaFromString(s1); - Fsa fsa2 = FsaFromString(s2); - Fsa *fsa_array[] = {&fsa1, &fsa2}; - FsaVec fsa_vec = CreateFsaVec(2, &fsa_array[0]); - fsa_vec = fsa_vec.To(context); + { + // test with single Fsa + Fsa fsa = FsaFromString(s2); + fsa = fsa.To(context); - FsaVec dest; - Ragged arc_map; - RemoveEpsilonsIterativeTropical(fsa_vec, &dest, &arc_map); - EXPECT_EQ(dest.NumAxes(), 3); - EXPECT_EQ(arc_map.NumAxes(), 2); - K2_LOG(INFO) << dest; - K2_LOG(INFO) << arc_map; - Array1 properties; - int32_t p; - GetFsaVecBasicProperties(dest, &properties, &p); - EXPECT_EQ(p & kFsaPropertiesEpsilonFree, kFsaPropertiesEpsilonFree); - bool log_semiring = false; - float beam = std::numeric_limits::infinity(); - fsa_vec = fsa_vec.To(GetCpuContext()); - dest = dest.To(GetCpuContext()); - EXPECT_TRUE( - IsRandEquivalent(fsa_vec, dest, log_semiring, beam, true, 0.001)); - CheckArcMap(fsa_vec, dest, arc_map); + FsaVec dest; + Ragged arc_map; + RemoveEpsilonsIterativeTropical(fsa, &dest, &arc_map); + EXPECT_EQ(dest.NumAxes(), 2); + EXPECT_EQ(arc_map.NumAxes(), 2); + K2_LOG(INFO) << dest; + K2_LOG(INFO) << arc_map; + int32_t p = GetFsaBasicProperties(dest); + EXPECT_EQ(p & kFsaPropertiesEpsilonFree, kFsaPropertiesEpsilonFree); + bool log_semiring = false; + float beam = std::numeric_limits::infinity(); + fsa = fsa.To(GetCpuContext()); + dest = dest.To(GetCpuContext()); + EXPECT_TRUE(IsRandEquivalent(fsa, dest, log_semiring, beam, true, 0.001)); + CheckArcMap(fsa, dest, arc_map); + } + { + // test with FsaVec + Fsa fsa1 = FsaFromString(s1); + Fsa fsa2 = FsaFromString(s2); + Fsa *fsa_array[] = {&fsa1, &fsa2}; + FsaVec fsa_vec = CreateFsaVec(2, &fsa_array[0]); + fsa_vec = fsa_vec.To(context); + + FsaVec dest; + Ragged arc_map; + RemoveEpsilonsIterativeTropical(fsa_vec, &dest, &arc_map); + EXPECT_EQ(dest.NumAxes(), 3); + EXPECT_EQ(arc_map.NumAxes(), 2); + K2_LOG(INFO) << dest; + K2_LOG(INFO) << arc_map; + Array1 properties; + int32_t p; + GetFsaVecBasicProperties(dest, &properties, &p); + EXPECT_EQ(p & kFsaPropertiesEpsilonFree, kFsaPropertiesEpsilonFree); + bool log_semiring = false; + float beam = std::numeric_limits::infinity(); + fsa_vec = fsa_vec.To(GetCpuContext()); + dest = dest.To(GetCpuContext()); + EXPECT_TRUE( + IsRandEquivalent(fsa_vec, dest, log_semiring, beam, true, 0.001)); + CheckArcMap(fsa_vec, dest, arc_map); + } } } diff --git a/k2/python/csrc/torch/fsa_algo.cu b/k2/python/csrc/torch/fsa_algo.cu index e84b46194..6dcb4e130 100644 --- a/k2/python/csrc/torch/fsa_algo.cu +++ b/k2/python/csrc/torch/fsa_algo.cu @@ -17,6 +17,7 @@ #include "k2/csrc/fsa_algo.h" #include "k2/csrc/fsa_utils.h" #include "k2/csrc/host_shim.h" +#include "k2/csrc/rm_epsilon.h" #include "k2/python/csrc/torch/fsa_algo.h" #include "k2/python/csrc/torch/torch_util.h" @@ -99,9 +100,8 @@ static void PybindLinearFsa(py::module &m) { static void PybindIntersect(py::module &m) { m.def( "intersect", // works only on CPU - [](FsaOrVec &a_fsas, int32_t properties_a, - FsaOrVec &b_fsas, int32_t properties_b, - bool treat_epsilons_specially = true, + [](FsaOrVec &a_fsas, int32_t properties_a, FsaOrVec &b_fsas, + int32_t properties_b, bool treat_epsilons_specially = true, bool need_arc_map = true) -> std::tuple, torch::optional> { @@ -125,9 +125,8 @@ static void PybindIntersect(py::module &m) { } return std::make_tuple(ans, a_tensor, b_tensor); }, - py::arg("a_fsas"), py::arg("properties_a"), - py::arg("b_fsas"), py::arg("properties_b"), - py::arg("treat_epsilons_specially") = true, + py::arg("a_fsas"), py::arg("properties_a"), py::arg("b_fsas"), + py::arg("properties_b"), py::arg("treat_epsilons_specially") = true, py::arg("need_arc_map") = true, R"( If treat_epsilons_specially it will treat epsilons as epsilons; otherwise @@ -161,26 +160,22 @@ static void PybindIntersectDensePruned(py::module &m) { py::arg("max_active_states")); } - static void PybindIntersectDense(py::module &m) { m.def( "intersect_dense", - [](FsaVec &a_fsas, DenseFsaVec &b_fsas, - float output_beam) + [](FsaVec &a_fsas, DenseFsaVec &b_fsas, float output_beam) -> std::tuple { Array1 arc_map_a; Array1 arc_map_b; FsaVec out; - IntersectDense(a_fsas, b_fsas, output_beam, &out, - &arc_map_a, &arc_map_b); + IntersectDense(a_fsas, b_fsas, output_beam, &out, &arc_map_a, + &arc_map_b); return std::make_tuple(out, ToTensor(arc_map_a), ToTensor(arc_map_b)); }, - py::arg("a_fsas"), py::arg("b_fsas"), - py::arg("output_beam")); + py::arg("a_fsas"), py::arg("b_fsas"), py::arg("output_beam")); } - static void PybindConnect(py::module &m) { m.def( "connect", @@ -275,9 +270,18 @@ static void PybindRemoveEpsilon(py::module &m) { "remove_epsilon", [](FsaOrVec &src) -> std::pair> { FsaOrVec dest; - Ragged arc_derivs; - RemoveEpsilon(src, &dest, &arc_derivs); - return std::make_pair(dest, arc_derivs); + Ragged arc_map; + RemoveEpsilon(src, &dest, &arc_map); + return std::make_pair(dest, arc_map); + }, + py::arg("src")); + m.def( + "remove_epsilons_iterative_tropical", + [](FsaOrVec &src) -> std::pair> { + FsaOrVec dest; + Ragged arc_map; + RemoveEpsilonsIterativeTropical(src, &dest, &arc_map); + return std::make_pair(dest, arc_map); }, py::arg("src")); } @@ -287,9 +291,9 @@ static void PybindDeterminize(py::module &m) { "determinize", [](FsaOrVec &src) -> std::pair> { FsaOrVec dest; - Ragged arc_derivs; - Determinize(src, &dest, &arc_derivs); - return std::make_pair(dest, arc_derivs); + Ragged arc_map; + Determinize(src, &dest, &arc_map); + return std::make_pair(dest, arc_map); }, py::arg("src")); } diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index c431ac89b..8b04cee44 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -16,6 +16,7 @@ from .fsa_algo import intersect from .fsa_algo import linear_fsa from .fsa_algo import remove_epsilon +from .fsa_algo import remove_epsilons_iterative_tropical from .fsa_algo import shortest_path from .fsa_algo import top_sort from .fsa_properties import to_str as properties_to_str @@ -67,6 +68,7 @@ 'properties_to_str', 'random_ragged_shape', 'remove_epsilon', + 'remove_epsilons_iterative_tropical', 'shortest_path', 'simple_ragged_index_select', 'to_dot', diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index 73487b920..4a10261d8 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -282,10 +282,43 @@ def remove_epsilon(fsa: Fsa) -> Fsa: if properties is not None and properties & fsa_properties.EPSILON_FREE != 0: return fsa - ragged_arc, arc_derivs = _k2.remove_epsilon(fsa.arcs) + ragged_arc, arc_map = _k2.remove_epsilon(fsa.arcs) aux_labels = None if hasattr(fsa, 'aux_labels'): - aux_labels = index_attr(fsa.aux_labels, arc_derivs) + aux_labels = index_attr(fsa.aux_labels, arc_map) + out_fsa = Fsa(ragged_arc, aux_labels) + + for name, value in fsa.named_non_tensor_attr(): + setattr(out_fsa, name, value) + + return out_fsa + + +def remove_epsilons_iterative_tropical(fsa: Fsa) -> Fsa: + '''Remove epsilons (symbol zero) in the input Fsa. + + Caution: + It doesn't support autograd for now. + + Args: + fsa: + The input FSA. It can be either a single FSA or an FsaVec. + It can be either top-sorted or non-top-sorted. + Returns: + The result Fsa, it's equivalent to the input `fsa` under + tropical semiring but will be epsilon-free. + It will be the same as the input `fsa` if the input + `fsa` is epsilon-free. Otherwise, a new epsilon-free fsa + is returned and the input `fsa` is NOT modified. + ''' + properties = getattr(fsa, 'properties', None) + if properties is not None and properties & fsa_properties.EPSILON_FREE != 0: + return fsa + + ragged_arc, arc_map = _k2.remove_epsilons_iterative_tropical(fsa.arcs) + aux_labels = None + if hasattr(fsa, 'aux_labels'): + aux_labels = index_attr(fsa.aux_labels, arc_map) out_fsa = Fsa(ragged_arc, aux_labels) for name, value in fsa.named_non_tensor_attr(): @@ -320,10 +353,10 @@ def determinize(fsa: Fsa) -> Fsa: and properties & fsa_properties.ARC_SORTED_AND_DETERMINISTIC != 0: # noqa return fsa - ragged_arc, arc_derivs = _k2.determinize(fsa.arcs) + ragged_arc, arc_map = _k2.determinize(fsa.arcs) aux_labels = None if hasattr(fsa, 'aux_labels'): - aux_labels = index_attr(fsa.aux_labels, arc_derivs) + aux_labels = index_attr(fsa.aux_labels, arc_map) out_fsa = Fsa(ragged_arc, aux_labels) for name, value in fsa.named_non_tensor_attr(): diff --git a/k2/python/tests/remove_epsilon_test.py b/k2/python/tests/remove_epsilon_test.py index eb882dc1a..d28bd4f05 100644 --- a/k2/python/tests/remove_epsilon_test.py +++ b/k2/python/tests/remove_epsilon_test.py @@ -44,5 +44,28 @@ def test1(self): self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring)) +class TestRemoveEpsilonsIterativeTropical(unittest.TestCase): + + def test1(self): + s = ''' + 0 1 0 1 1 + 1 2 0 2 1 + 2 3 0 3 1 + 3 4 4 4 1 + 3 5 -1 5 1 + 4 5 -1 6 1 + 5 + ''' + fsa = k2.Fsa.from_str(s) + print(fsa.aux_labels) + prop = fsa.properties + self.assertFalse(prop & k2.fsa_properties.EPSILON_FREE) + dest = k2.remove_epsilons_iterative_tropical(fsa) + prop = dest.properties + self.assertTrue(prop & k2.fsa_properties.EPSILON_FREE) + log_semiring = False + self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring)) + + if __name__ == '__main__': unittest.main()