From aa878dcbe7e3e915a561e019dc6635ad1f8896d2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Dec 2020 23:05:28 +0800 Subject: [PATCH] Pruned intersection (#517) --- k2/csrc/algorithms.cu | 18 +- k2/csrc/algorithms.h | 36 +- k2/csrc/array.h | 12 +- k2/csrc/array_inl.h | 2 +- k2/csrc/array_ops.h | 7 + k2/csrc/array_ops_inl.h | 23 + k2/csrc/array_test.cu | 12 +- k2/csrc/benchmark/benchmark.cu | 7 +- k2/csrc/context.cu | 46 + k2/csrc/context.h | 37 +- k2/csrc/fsa_algo.h | 11 +- k2/csrc/fsa_utils.cu | 27 + k2/csrc/fsa_utils.h | 14 + k2/csrc/fsa_utils_test.cu | 17 + k2/csrc/hash.h | 10 +- k2/csrc/host/fsa_equivalent.cc | 9 + k2/csrc/host/fsa_equivalent_test.cc | 12 +- k2/csrc/host/properties.h | 3 + k2/csrc/host_shim_test.cu | 13 + k2/csrc/intersect_pruned.cu | 959 ++++++++++++------ k2/csrc/intersect_test.cu | 51 +- k2/csrc/log.h | 14 +- k2/csrc/ragged.h | 10 +- k2/csrc/ragged_ops.cu | 124 ++- k2/csrc/ragged_ops.h | 104 +- k2/csrc/ragged_shape_test.cu | 93 ++ k2/csrc/ragged_tensor_ops.h | 40 + k2/csrc/semaphore.h | 55 + k2/csrc/tensor.cu | 17 +- k2/csrc/tensor.h | 12 +- k2/csrc/utils.h | 9 +- k2/python/host/tests/fsa_equivalent_test.py | 31 +- k2/python/k2/version.py | 11 +- .../tests/intersect_dense_pruned_test.py | 51 + k2/python/tests/intersect_dense_test.py | 5 +- 35 files changed, 1524 insertions(+), 378 deletions(-) create mode 100644 k2/csrc/ragged_tensor_ops.h create mode 100644 k2/csrc/semaphore.h diff --git a/k2/csrc/algorithms.cu b/k2/csrc/algorithms.cu index b856e1f9e..36dcb51c7 100644 --- a/k2/csrc/algorithms.cu +++ b/k2/csrc/algorithms.cu @@ -26,7 +26,6 @@ void Renumbering::ComputeOld2New() { num_new_elems_ = old2new_.Back(); K2_CHECK_GE(num_new_elems_, 0); K2_CHECK_LE(num_new_elems_, keep_.Dim()); - old2new_ = old2new_.Range(0, keep_.Dim()); } namespace { @@ -59,4 +58,21 @@ void Renumbering::ComputeNew2Old() { new2old_ = new2old_.Range(0, num_new_elems_); } +Renumbering::Renumbering(const Array1 &keep, + const Array1 &old2new, + const Array1 &new2old): + keep_(keep), old2new_(old2new), + num_new_elems_(new2old.Dim()), + new2old_(new2old) { } + + +Renumbering IdentityRenumbering(ContextPtr c, int32_t size) { + Array1 keep(c, size + 1); // uninitialized. + keep = keep.Arange(0, size); + Array1 range = Arange(c, 0, size + 1); + return Renumbering(keep, range, range.Arange(0, size)); +} + + + } // namespace k2 diff --git a/k2/csrc/algorithms.h b/k2/csrc/algorithms.h index 2559d14bb..923e30f6c 100644 --- a/k2/csrc/algorithms.h +++ b/k2/csrc/algorithms.h @@ -28,6 +28,9 @@ class Renumbering { Renumbering(Renumbering &&src) = default; // move assignment Renumbering &operator=(Renumbering &&) = default; + // copy assignment + Renumbering &operator=(const Renumbering &) = default; + /* This constructor will allocate memory for `keep_` array with size @@ -56,6 +59,16 @@ class Renumbering { Init(c, num_old_elems, init_keep_with_zero); } + /* + This constructor is not intended for use by users; it is used by + IdentityRenumbering(). Just sets members to the provided arrays and + num_new_elems_ to new2old.Dim(). + */ + Renumbering(const Array1 &keep, + const Array1 &old2new, + const Array1 &new2old); + + void Init(ContextPtr c, int32_t num_old_elems, bool init_keep_with_zero = false) { NVTX_RANGE(K2_FUNC); @@ -100,17 +113,20 @@ class Renumbering { /* Return a mapping from old index to new index. This is created on demand (must only be called after the Keep() array has been populated). + @param [in] extra_element If true, will return the array of size + NumOldElems() + 1, which includes one more element; + otherwise it will return an array of size NumOldElems(). + This array is just the exclusive sum of Keep(). + It gives the mapping for indexes that are kept; element + i is kept if `Old2New()[i+1] > Old2New()[i]`. + @return Returns an array mapping the old indexes to the new indexes. - Its dimension is the number of old indexes (i.e. keep_.Dim() - or NumOldElems()). It is just the exclusive sum of Keep(). - It gives the mapping for indexes that are kept; ignore the - non-kept elements of it. - Will be allocated with the same context as keep_. */ - Array1 &Old2New() { + Array1 Old2New(bool extra_element = false) { NVTX_RANGE(K2_FUNC); if (!old2new_.IsValid()) ComputeOld2New(); - return old2new_; + if (extra_element) return old2new_; + else return old2new_.Arange(0, old2new_.Dim() - 1); } private: @@ -121,12 +137,16 @@ class Renumbering { Array1 keep_; // array of elements to keep; dimension is the // `num_old_elems` provided in the constructor but it // was allocated with one extra element. - Array1 old2new_; + Array1 old2new_; // note: dimension is num-old-elems + 1. int32_t num_new_elems_; // equals last element of old2new_; set when // old2new_ is created. Array1 new2old_; }; +// returns a Renumbering object that is the identity map. Caution; its Keep() +// elements are not set up. +Renumbering IdentityRenumbering(ContextPtr c, int32_t size); + } // namespace k2 #endif // K2_CSRC_ALGORITHMS_H_ diff --git a/k2/csrc/array.h b/k2/csrc/array.h index 83541f69a..16df4f797 100644 --- a/k2/csrc/array.h +++ b/k2/csrc/array.h @@ -103,7 +103,6 @@ class Array1 { @param [in] size Number of elements to include, 0 <= size <= Dim()-start */ Array1 Range(int32_t start, int32_t size) const { - NVTX_RANGE(K2_FUNC); K2_CHECK_GE(start, 0); K2_CHECK_LE(start, Dim()); K2_CHECK_GE(size, 0); @@ -120,7 +119,6 @@ class Array1 { start <= end <= Dim(). */ Array1 Arange(int32_t start, int32_t end) const { - NVTX_RANGE(K2_FUNC); K2_CHECK_GE(start, 0); K2_CHECK_LE(start, dim_); K2_CHECK_GE(end, start); @@ -344,9 +342,9 @@ class Array1 { Array1(const Array1 &) = default; // move constructor Array1(Array1 &&) = default; - // assignment operator + // assignment operator (shallow); see Assign() for assignment of elements. Array1 &operator=(const Array1 &) = default; - // move assignment operator + // move assignment operator (shallow) Array1 &operator=(Array1 &&) = default; /* @@ -523,7 +521,7 @@ class Array2 { } // return a row (indexing on the 0th axis) - Array1 operator[](int32_t i) { + Array1 Row(int32_t i) { NVTX_RANGE(K2_FUNC); K2_CHECK_GE(i, 0); K2_CHECK_LT(i, dim0_); @@ -567,9 +565,9 @@ class Array2 { Array2(const Array2 &other) = default; // move constructor Array2(Array2 &&other) = default; - // assignment operator + // assignment operator (shallow); see Assign() for assignment of elements. Array2 &operator=(const Array2 &other) = default; - // move assignment operator + // move assignment operator (shallow); Array2 &operator=(Array2 &&other) = default; /* stride on 1st axis is 1 (in elements). */ diff --git a/k2/csrc/array_inl.h b/k2/csrc/array_inl.h index 781dc2667..26572f32e 100644 --- a/k2/csrc/array_inl.h +++ b/k2/csrc/array_inl.h @@ -72,7 +72,7 @@ std::ostream &operator<<(std::ostream &stream, const Array2 &array) { Array2 array_cpu = array.To(GetCpuContext()); int32_t num_rows = array_cpu.Dim0(); for (int32_t i = 0; i < num_rows; ++i) { - stream << ToPrintable(array_cpu[i]); + stream << ToPrintable(array_cpu.Row(i)); if (i + 1 < num_rows) stream << '\n'; } return stream << "\n]"; diff --git a/k2/csrc/array_ops.h b/k2/csrc/array_ops.h index 83b31511d..24a0dfb8e 100644 --- a/k2/csrc/array_ops.h +++ b/k2/csrc/array_ops.h @@ -628,6 +628,13 @@ void Sort(Array1 *array, Array1 *index_map = nullptr); template void Assign(Array2 &src, Array2 *dest); +/* + Assign elements from `src` to `dest`; they must have the same Dim(). + */ +template +void Assign(Array1 &src, Array1 *dest); + + /* Merge an array of Array1 with a `merge_map` which indicates which items to get from which positions (doesn't do any checking of the merge_map values!) diff --git a/k2/csrc/array_ops_inl.h b/k2/csrc/array_ops_inl.h index d90848a11..92f763229 100644 --- a/k2/csrc/array_ops_inl.h +++ b/k2/csrc/array_ops_inl.h @@ -878,6 +878,29 @@ void Assign(Array2 &src, Array2 *dest) { } } + +template +void Assign(Array1 &src, Array1 *dest) { + K2_CHECK_EQ(src.Dim(), dest->Dim()); + int32_t dim = src.Dim(); + if (std::is_same::value) { + size_t num_bytes = dim * sizeof(S); + src.Context()->CopyDataTo(num_bytes, src.Data(), dest->Context(), + dest->Data()); + } else { + if (!src.Context()->IsCompatible(*dest->Context())) { + Array1 src_new = src.To(dest->Context()); + Assign(src_new, dest); + } + const S *src_data = src.Data(); + T *dest_data = dest->Data(); + K2_EVAL(src.Context(), dim, lambda_copy_data, (int32_t i) -> void { + dest_data[i] = src_data[i]; + }); + } +} + + template Array1 MergeWithMap(const Array1 &merge_map, int32_t num_srcs, const Array1 **src) { diff --git a/k2/csrc/array_test.cu b/k2/csrc/array_test.cu index 4fd2b1cf3..502fe1095 100644 --- a/k2/csrc/array_test.cu +++ b/k2/csrc/array_test.cu @@ -276,6 +276,7 @@ void TestArray2() { auto cpu_array = array.To(cpu); auto cuda_array = array.To(GetCudaContext()); + auto cpu_acc = cpu_array.Accessor(); ASSERT_EQ(cpu_array.ElemStride0(), cpu_array.Dim1()); ASSERT_EQ(cuda_array.ElemStride0(), cuda_array.Dim1()); @@ -285,8 +286,8 @@ void TestArray2() { for (auto c = 0; c != kDim1; ++c) { // WARNING: it's inefficient to access elements of Array2 // with operator [][] - EXPECT_EQ(cpu_array[r][c], k); - EXPECT_EQ(cuda_array[r][c], k); + EXPECT_EQ(cpu_acc(r, c), k); + EXPECT_EQ(cuda_array.Row(r)[c], k); ++k; } @@ -316,14 +317,15 @@ void TestArray2() { auto cpu_array = array.To(cpu); auto cuda_array = array.To(GetCudaContext()); + auto cpu_acc = cpu_array.Accessor(); auto k = 0; for (auto r = 0; r != kDim0; ++r) for (auto c = 0; c != kDim1; ++c) { // WARNING: it's inefficient to access elements of Array2 // with operator [][] - EXPECT_EQ(cpu_array[r][c], k); - EXPECT_EQ(cuda_array[r][c], k); + EXPECT_EQ(cpu_acc(r,c), k); + EXPECT_EQ(cuda_array.Row(r)[c], k); ++k; } @@ -361,7 +363,7 @@ void TestArray2() { { // test operator[] for (int32_t i = 0; i < array.Dim0(); ++i) { - Array1 sub_array = array[i]; + Array1 sub_array = array.Row(i); const T *sub_array_data = sub_array.Data(); ASSERT_EQ(sub_array.Dim(), array.Dim1()); std::vector sub_array_cpu_data(sub_array.Dim()); diff --git a/k2/csrc/benchmark/benchmark.cu b/k2/csrc/benchmark/benchmark.cu index f051343da..b8b09b938 100644 --- a/k2/csrc/benchmark/benchmark.cu +++ b/k2/csrc/benchmark/benchmark.cu @@ -179,9 +179,10 @@ void PrintEnvironemntInfo() { os << kPrefix << "torch CUDA version: " << kTorchCudaVersion << "\n"; os << kPrefix << "NVTX enabled: " << kEnableNvtx << "\n"; os << kPrefix << "Debug disabled: " << internal::kDisableDebug << "\n"; - os << kPrefix - << "cuda device sync enabled: " << internal::EnableCudaDeviceSync() - << "\n"; + os << kPrefix << "cuda device sync enabled: " + << internal::EnableCudaDeviceSync() << "\n"; + os << kPrefix << "Checks disabled: " << internal::DisableChecks() << "\n"; + // print it to stderr so that it can be redirected std::cerr << os.str() << "\n"; diff --git a/k2/csrc/context.cu b/k2/csrc/context.cu index 6d8b69e70..a94697b0f 100644 --- a/k2/csrc/context.cu +++ b/k2/csrc/context.cu @@ -128,4 +128,50 @@ void GetBlockSizesForLambda2(int32_t m, int32_t n, dim3 *block_dim, } } + +void Semaphore::Signal(ContextPtr c) { + DeviceType device_type = c->GetDeviceType(); + if (device_type_ == kUnk) + device_type_ = device_type; + else + K2_CHECK_EQ(device_type, device_type_) + << "Semaphore must always be used with the same device type."; + if (device_type == kCuda) { + cudaEvent_t event; + cudaError_t e = cudaEventCreateWithFlags(&event, cudaEventDisableTiming); + K2_CHECK_CUDA_ERROR(e) << "Error creating event"; + // Note: this stream is subject to being overridden by With(stream..), see + // class With. + cudaStream_t stream = c->GetCudaStream(); + e = cudaEventRecord(event, stream); + K2_CHECK_CUDA_ERROR(e) << "Error recording event."; + std::lock_guard lock(events_mutex_); + events_.push_back(event); + } + semaphore_.release(); +} + +void Semaphore::Wait(ContextPtr c) { + DeviceType device_type = c->GetDeviceType(); + if (device_type_ == kUnk) + device_type_ = device_type; + else + K2_CHECK_EQ(device_type, device_type_) + << "Semaphore must always be used with the same device type."; + semaphore_.acquire(); + if (device_type == kCuda) { + cudaEvent_t event; + { + std::lock_guard lock(events_mutex_); + K2_CHECK(!events_.empty()); // would be code bug. + event = events_.front(); + events_.pop_front(); + } + int flags = 0; + cudaError_t e = cudaStreamWaitEvent(c->GetCudaStream(), event, + flags); + K2_CHECK_CUDA_ERROR(e) << "Error waiting on event."; + } +} + } // namespace k2 diff --git a/k2/csrc/context.h b/k2/csrc/context.h index f7768c55f..1aec099a9 100644 --- a/k2/csrc/context.h +++ b/k2/csrc/context.h @@ -16,14 +16,17 @@ #include #include +#include #include #include +#include #include #include #include #include "k2/csrc/log.h" #include "k2/csrc/nvtx.h" +#include "k2/csrc/semaphore.h" namespace k2 { @@ -363,7 +366,10 @@ inline DeviceType DeviceOf(const T &t) { // This is for use by ParallelRunner and Context. Users probably should not // interact with this directly. The idea is that the Context object will call -// this to possibly override its default thread. The +// this to possibly override its default thread. The user would +// create a new stream by calling ParallelRunner's NewStream() method, and +// do `With w(stream);` which calls Push(stream), and later Pop(stream) when it +// goes out of scope. class CudaStreamOverride { public: inline cudaStream_t OverrideStream(cudaStream_t stream) { @@ -397,6 +403,35 @@ class With { cudaStream_t stream_; }; + +/* + Our class Semaphore is a slight extension of std::counting_semaphore that also + takes care of stream synchronization. The projected use-case is when two + threads (possibly with different CUDA streams, if we are using CUDA) have a + producer-consumer relationship, such that one is waiting for the other. + The projected use is: + - Construct semaphore + - Producing thread (maybe repeatedly) calls semaphore.Signal(ctx); + - Consuming thread (maybe repeatedly) calls semaphore.Wait(ctx); + */ +class Semaphore { + public: + Semaphore(): device_type_(kUnk), semaphore_(0) { } + + void Signal(ContextPtr c); + + void Wait(ContextPtr c); + + private: + DeviceType device_type_; // makes sure it's always used with the same device + // type. + k2std::counting_semaphore semaphore_; + std::mutex events_mutex_; + std::deque events_; +}; + + + /* Class ParallelRunner allows you to invoke CUDA kernels in parallel. It works for CUDA and CPU, but for CPU it currently just executes things diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index 0bed15a67..7a5d2e005 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -161,8 +161,15 @@ void AddEpsilonSelfLoops(FsaOrVec &src, FsaOrVec *dest, Dim0() as b_fsas. Elements of it may be empty if the composition was empty, either intrinsically or due to failure of pruned search. - @param[out] arc_map_a Vector of - + @param[out] arc_map_a Will be set to a vector with Dim() equal to + the number of arcs in `out`, whose elements contain + the corresponding arc_idx01 in a_fsas. + @param[out] arc_map_b Will be set to a vector with Dim() equal to + the number of arcs in `out`, whose elements contain + the corresponding arc-index in b_fsas; this arc-index + is defined as the offset into b_fsas.scores, which is + well defined if the shape is known because we require + it to be contiguous. */ void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas, float search_beam, float output_beam, diff --git a/k2/csrc/fsa_utils.cu b/k2/csrc/fsa_utils.cu index 1078e4672..fdf3efa72 100644 --- a/k2/csrc/fsa_utils.cu +++ b/k2/csrc/fsa_utils.cu @@ -1789,4 +1789,31 @@ Ragged ComposeArcMaps(Ragged &step1_arc_map, return Index(step1_arc_map, step2_arc_map, true); } +void FixNumStates(FsaVec *fsas) { + K2_CHECK_EQ(fsas->NumAxes(), 3); + ContextPtr c = fsas->Context(); + int32_t num_fsas = fsas->Dim0(), + num_states = fsas->TotSize(1); + + Array1 changed(c, 1, 0); + Renumbering renumber_states(c, num_states); + renumber_states.Keep() = (char)1; // by default keep all states.. + + int32_t *changed_data = changed.Data(); + char *keep_data = renumber_states.Keep().Data(); + const int32_t *row_splits1_data = fsas->RowSplits(1).Data(); + K2_EVAL(c, num_fsas, lambda_set_must_remove, (int32_t i) -> void { + int32_t num_states = (row_splits1_data[i+1] - + row_splits1_data[i]); + if (num_states == 1) + keep_data[row_splits1_data[i]] = 0; + changed_data[0] = 1; + }); + if (changed[0] == 0) + return; // an optimization.. + fsas->shape = RemoveSomeEmptyLists(fsas->shape, 1, + renumber_states); +} + + } // namespace k2 diff --git a/k2/csrc/fsa_utils.h b/k2/csrc/fsa_utils.h index 63480b46e..515dbb8b1 100644 --- a/k2/csrc/fsa_utils.h +++ b/k2/csrc/fsa_utils.h @@ -447,6 +447,20 @@ FsaVec FsaVecFromArcIndexes(FsaVec &fsas, Ragged &best_arc_indexes); Ragged ComposeArcMaps(Ragged &step1_arc_map, Ragged &step2_arc_map); +/* + This function detects if there are any FSAs in an FsaVec that have exactly one + state (which is not allowed; the empty FSA may have either 0 or 2 states); and + it removes those states. These states cannot have any arcs leaving them; if + they do, it is an error and this function may crash or give undefined output. + + @param [in,out] fsas FsaVec to possibly modify; must have 3 axes. + + CAUTION: this is not used right now and I'm not sure if there are any + situations where it really should be used; think carefully before using it. + */ +void FixNumStates(FsaVec *fsas); + + } // namespace k2 #endif // K2_CSRC_FSA_UTILS_H_ diff --git a/k2/csrc/fsa_utils_test.cu b/k2/csrc/fsa_utils_test.cu index 4e8276f1b..0ac69cf48 100644 --- a/k2/csrc/fsa_utils_test.cu +++ b/k2/csrc/fsa_utils_test.cu @@ -1010,4 +1010,21 @@ TEST(FsaUtils, ComposeArcMapsTest) { } } + +TEST(FixNumStates, FixNumStates) { + FsaVec f("[ [ [] [] ] [ [] [] ] ]"), + g("[ [ [] ] [ [] [] ] ]"), + h("[ [ ] [ [] [] ] ]"); + + FsaVec f2(f), g2(g), h2(h); + + FixNumStates(&f2); + FixNumStates(&g2); + FixNumStates(&h2); + + EXPECT_EQ(Equal(f, f2), true); + EXPECT_EQ(Equal(h, g2), true); + EXPECT_EQ(Equal(h, h2), true); +} + } // namespace k2 diff --git a/k2/csrc/hash.h b/k2/csrc/hash.h index 80202924c..4695ec94f 100644 --- a/k2/csrc/hash.h +++ b/k2/csrc/hash.h @@ -73,11 +73,15 @@ class Hash32 { public: /* Constructor. Context can be for CPU or GPU. num_buckets must be a power of 2 with num_buckets >= 128 (an arbitrarily chosen cutoff) */ - Hash32(ContextPtr c, int32_t num_buckets): - data_(c, num_buckets, ~(uint64_t)0), buckets_num_bitsm1_(0) { + Hash32(ContextPtr c, int32_t num_buckets) { + std::ostringstream os; + os << K2_FUNC << ":num_buckets=" << num_buckets; + NVTX_RANGE(os.str().c_str()); + data_ = Array1(c, num_buckets, ~(uint64_t)0); K2_CHECK_GE(num_buckets, 128); int32_t n = 2; - for (; n < num_buckets; n *= 2, buckets_num_bitsm1_++) { } + for (buckets_num_bitsm1_ = 0; n < num_buckets; + n *= 2, buckets_num_bitsm1_++) { } K2_CHECK_EQ(num_buckets, 2 << buckets_num_bitsm1_) << " num_buckets must be a power of 2."; } diff --git a/k2/csrc/host/fsa_equivalent.cc b/k2/csrc/host/fsa_equivalent.cc index 83816b485..ae638c128 100644 --- a/k2/csrc/host/fsa_equivalent.cc +++ b/k2/csrc/host/fsa_equivalent.cc @@ -171,6 +171,10 @@ bool IsRandEquivalent(const Fsa &a, const Fsa &b, bool treat_epsilons_specially /*=true*/, std::size_t npath /*=100*/) { NVTX_RANGE(K2_FUNC); + if (!IsValid(a) || !IsValid(b)) { + K2_LOG(WARNING) << "One or more of the inputs is not valid."; + return false; + } FsaCreator valid_a_storage, valid_b_storage; ::Connect(a, &valid_a_storage); ::Connect(b, &valid_b_storage); @@ -178,6 +182,7 @@ bool IsRandEquivalent(const Fsa &a, const Fsa &b, ArcSort(&valid_b_storage.GetFsa()); const auto &valid_a = valid_a_storage.GetFsa(); const auto &valid_b = valid_b_storage.GetFsa(); + if (IsEmpty(valid_a) && IsEmpty(valid_b)) return true; if (IsEmpty(valid_a) || IsEmpty(valid_b)) return false; @@ -227,6 +232,10 @@ bool IsRandEquivalent(const Fsa &a, const Fsa &b, float delta /*=1e-6*/, bool top_sorted /*=true*/, std::size_t npath /*= 100*/) { NVTX_RANGE(K2_FUNC); + if (!IsValid(a) || !IsValid(b)) { + K2_LOG(WARNING) << "One or more of the inputs is not valid."; + return false; + } K2_CHECK_GT(beam, 0); // TODO(haowen): for now we only support top-sorted input Fsas K2_CHECK(top_sorted); diff --git a/k2/csrc/host/fsa_equivalent_test.cc b/k2/csrc/host/fsa_equivalent_test.cc index 8a9226595..eff265948 100644 --- a/k2/csrc/host/fsa_equivalent_test.cc +++ b/k2/csrc/host/fsa_equivalent_test.cc @@ -96,9 +96,9 @@ TEST(FsaEquivalent, IsRandEquivalent) { { // same fsas std::vector arcs_a = { - {0, 1, 1, 0}, {0, 2, 2, 0}, {1, 2, 3, 0}, {1, 3, 4, 0}, {2, 3, 5, 0}, + {0, 1, 1, 0}, {0, 2, 2, 0}, {1, 2, 3, 0}, {1, 3, 4, 0}, {2, 3, 5, 0}, {3, 4, -1, 0} }; - FsaCreator fsa_creator_a(arcs_a, 3); + FsaCreator fsa_creator_a(arcs_a, 4); const auto &a = fsa_creator_a.GetFsa(); bool status = IsRandEquivalent(a, a); EXPECT_TRUE(status); @@ -106,15 +106,15 @@ TEST(FsaEquivalent, IsRandEquivalent) { { std::vector arcs_a = { - {0, 1, 1, 0}, {0, 2, 2, 0}, {0, 3, 8, 0}, {1, 4, 4, 0}, {2, 4, 5, 0}, + {0, 1, 1, 0}, {0, 2, 2, 0}, {0, 3, 8, 0}, {1, 4, 4, 0}, {2, 4, 5, 0}, { 4, 5, -1, 0} }; - FsaCreator fsa_creator_a(arcs_a, 4); + FsaCreator fsa_creator_a(arcs_a, 5); const auto &a = fsa_creator_a.GetFsa(); std::vector arcs_b = { - {0, 2, 1, 0}, {0, 1, 2, 0}, {0, 3, 9, 0}, {1, 4, 5, 0}, {2, 4, 4, 0}, + {0, 2, 1, 0}, {0, 1, 2, 0}, {0, 3, 9, 0}, {1, 4, 5, 0}, {2, 4, 4, 0}, {4, 5, -1, 0} }; - FsaCreator fsa_creator_b(arcs_b, 4); + FsaCreator fsa_creator_b(arcs_b, 5); const auto &b = fsa_creator_b.GetFsa(); bool status = IsRandEquivalent(a, b); diff --git a/k2/csrc/host/properties.h b/k2/csrc/host/properties.h index 00f7152d4..fe30db481 100644 --- a/k2/csrc/host/properties.h +++ b/k2/csrc/host/properties.h @@ -123,6 +123,9 @@ inline bool IsTopSortedAndConnected(const Fsa &fsa) { /* Returns true if `fsa` is empty. (Note: if `fsa` is not empty, it would contain at least two states, the start state and the final state). + + Caution: this is not always very meaningful, as an FSA with no states is + conceptually equivalent to an FSA with two states but no arcs. */ inline bool IsEmpty(const Fsa &fsa) { return fsa.size1 == 0; } diff --git a/k2/csrc/host_shim_test.cu b/k2/csrc/host_shim_test.cu index e2f6bc25b..200634e86 100644 --- a/k2/csrc/host_shim_test.cu +++ b/k2/csrc/host_shim_test.cu @@ -30,6 +30,19 @@ TEST(HostShim, FsaToHostFsa) { // TODO(fangjun): check the content of host_fsa } + +TEST(HostShim, IsRandEquivalent) { + // check that empty FSAs with zero vs 2 states are equivalent. + FsaVec f("[ [ [] [] ] [] [] ]"), + g("[ [ [] [] ] [ [] [] ] [ [] [] ] ]"); + EXPECT_EQ(f.NumAxes(), 3); + EXPECT_EQ(g.NumAxes(), 3); + + EXPECT_EQ(IsRandEquivalent(f, g, true), true); +} + + + TEST(HostShim, FsaVecToHostFsa) { std::string s1 = R"( 0 1 1 1 1 2 2 2 diff --git a/k2/csrc/intersect_pruned.cu b/k2/csrc/intersect_pruned.cu index d49ff148f..6999ad200 100644 --- a/k2/csrc/intersect_pruned.cu +++ b/k2/csrc/intersect_pruned.cu @@ -10,6 +10,7 @@ */ #include +#include #include #include "k2/csrc/array_ops.h" @@ -58,11 +59,11 @@ struct ArcInfo { // for an arc that wasn't pruned away... // names for clarity. int32_t dest_a_fsas_state_idx01; // The destination-state as an index // into a_fsas_. - int32_t dest_info_state_idx01; // The destination-state as an index into - // the next FrameInfo's `arcs` or `states` - int32_t dest_info_state_idx1; // The destination-state as an index the - // next FrameInfo's `arcs` or `states`, - // this time omitting the FSA-index. + int32_t dest_info_state_idx1; // The destination-state as an idx1 into the + // next FrameInfo's `arcs` or `states`, + // omitting the FSA-index which can be worked + // out from the structure of this frame's + // ArcInfo. } u; float end_loglike; // loglike at the end of the arc just before // (conceptually) it joins the destination state. @@ -78,7 +79,9 @@ static std::ostream &operator<<(std::ostream &os, const StateInfo &s) { static std::ostream &operator<<(std::ostream &os, const ArcInfo &a) { os << "ArcInfo{" << a.a_fsas_arc_idx012 << "," << a.arc_loglike << "," - << a.u.dest_a_fsas_state_idx01 << "," << a.end_loglike << "}"; + << a.u.dest_a_fsas_state_idx01 << "," << a.end_loglike + << "[i=" << FloatToOrderedInt(a.end_loglike) << "]" + << "}"; return os; } */ @@ -136,6 +139,7 @@ class MultiGraphDenseIntersectPruned { dynamic_beams_(a_fsas.Context(), b_fsas.shape.Dim0(), search_beam) { NVTX_RANGE(K2_FUNC); c_ = GetContext(a_fsas.shape, b_fsas.shape); + T_ = b_fsas_.shape.MaxSize(1); K2_CHECK(b_fsas.scores.IsContiguous()); K2_CHECK_GT(search_beam, 0); K2_CHECK_GT(output_beam, 0); @@ -164,6 +168,38 @@ class MultiGraphDenseIntersectPruned { } int64_t num_keys = num_a_copies * (int64_t)a_fsas.TotSize(1); K2_CHECK(num_keys == (uint32_t)num_keys); + + { // set up do_pruning_after_ and prune_t_begin_end_. + + do_pruning_after_.resize(T_ + 1, (char)0); + + // each time we prune, prune 30 frames; but shift by 20 frames each + // time so there are 10 frames of overlap. + int32_t prune_num_frames = 30, + prune_shift = 20, + T = T_; + K2_CHECK_GT(prune_num_frames, prune_shift); + // The first begin_t is negative but will be rounded up to zero to get the + // start of the range. The motivation is: we don't want to wait until we + // have processed `prune_num_frames` frames to prune for the first time, + // because that first interval of not-pruning, being larger than normal, + // would dominate the maximum memory used by intersection. + for (int32_t begin_t = prune_shift - prune_num_frames; ; + begin_t += prune_shift) { + int32_t prune_begin = std::max(0, begin_t), + prune_end = begin_t + prune_num_frames; + bool last = false; + if (prune_end >= T) { + prune_end = T; + last = true; + } + K2_CHECK_LT(prune_begin, prune_end); + do_pruning_after_[prune_end - 1] = (char)1; + prune_t_begin_end_.push_back({prune_begin, prune_end}); + if (last) + break; + } + } } // The information we have for each frame of the pruned-intersection (really: @@ -200,50 +236,54 @@ class MultiGraphDenseIntersectPruned { neural-net output. */ NVTX_RANGE(K2_FUNC); - int32_t T = b_fsas_.shape.MaxSize(1), num_fsas = b_fsas_.shape.Dim0(); + int32_t num_fsas = b_fsas_.shape.Dim0(), T = T_; std::ostringstream os; os << "Intersect:T=" << T << ",num_fsas=" << num_fsas << ",TotSize(1)=" << b_fsas_.shape.TotSize(1); NVTX_RANGE(os.str().c_str()); - frames_.reserve(T + 1); + std::thread backward_thread(BackwardPassStatic, this); + + // we'll initially populate frames_[0.. T+1], but discard the one at T+1, + // which has no arcs or states, the ones we use are from 0 to T. + frames_.reserve(T + 2); frames_.push_back(InitialFrameInfo()); for (int32_t t = 0; t <= T; t++) { frames_.push_back(PropagateForward(t, frames_.back().get())); + if (do_pruning_after_[t]) + semaphore_.Signal(c_); // let a phase of backward-pass pruning commence. } // The FrameInfo for time T+1 will have no states. We did that // last PropagateForward so that the 'arcs' member of frames_[T] // is set up (it has no arcs but we need the shape). frames_.pop_back(); - { - NVTX_RANGE("InitOshapeUnpruned.."); - // each of these have 3 axes. - std::vector arcs_shapes(T + 1); - for (int32_t t = 0; t <= T; t++) - arcs_shapes[t] = &(frames_[t]->arcs.shape); + backward_thread.join(); + } - // oshape_unpruned_ is a 4-axis ragged tensor which is indexed: - // oshape_unpruned_[fsa_index][t][state_idx][arc_idx] - // This is BEFORE BACKWARD PRUNING... oshape_pruned_ will - // be after backward pruning - int32_t axis = 1; - oshape_unpruned_ = Stack(axis, T + 1, &(arcs_shapes[0])); - } - renumber_output_states_.Init(c_, oshape_unpruned_.TotSize(2)); - renumber_output_arcs_.Init(c_, oshape_unpruned_.TotSize(3)); - - for (int32_t t = T; t >= 0; t--) { - // this writes to elements of renumber_output_states_.Keep() and - // renumber_output_arcs_.Keep(). - PropagateBackward(t, frames_[t].get(), - (t == T ? NULL : frames_[t + 1].get())); + void BackwardPass() { + int32_t num_fsas = b_fsas_.shape.Dim0(), + num_work_items = max_active_ * num_fsas * T_; + ParallelRunner pr(c_); + // if num_work_items is big enough, it will actually create a new stream. + cudaStream_t stream = pr.NewStream(num_work_items); + With w(stream); // This overrides whatever stream c_ contains with `stream`, if it's not + + + NVTX_RANGE(K2_FUNC); + for (size_t i = 0; i < prune_t_begin_end_.size(); i++) { + semaphore_.Wait(c_); + int32_t prune_t_begin = prune_t_begin_end_[i].first, + prune_t_end = prune_t_begin_end_[i].second; + PruneTimeRange(prune_t_begin, prune_t_end); } - oshape_pruned_ = SubsampleRaggedShape( - oshape_unpruned_, renumber_output_states_, renumber_output_arcs_); + } + + static void BackwardPassStatic(MultiGraphDenseIntersectPruned *c) { + c->BackwardPass(); } // Return FrameInfo for 1st frame, with `states` set but `arcs` not set. @@ -290,140 +330,148 @@ class MultiGraphDenseIntersectPruned { void FormatOutput(FsaVec *ofsa, Array1 *arc_map_a, Array1 *arc_map_b) { NVTX_RANGE("FormatOutput"); + + int32_t T = T_; + + ContextPtr c_cpu = GetCpuContext(); - int32_t T = b_fsas_.shape.MaxSize(1); - - int32_t *oshapeu_row_ids3 = oshape_unpruned_.RowIds(3).Data(), - *oshapeu_row_ids2 = oshape_unpruned_.RowIds(2).Data(), - *oshapeu_row_ids1 = oshape_unpruned_.RowIds(1).Data(), - *oshapeu_row_splits3 = oshape_unpruned_.RowSplits(3).Data(), - *oshapeu_row_splits2 = oshape_unpruned_.RowSplits(2).Data(), - *oshapeu_row_splits1 = oshape_unpruned_.RowSplits(1).Data(); - - int32_t *oshapep_row_ids3 = oshape_pruned_.RowIds(3).Data(), - *oshapep_row_ids2 = oshape_pruned_.RowIds(2).Data(), - *oshapep_row_ids1 = oshape_pruned_.RowIds(1).Data(), - *oshapep_row_splits3 = oshape_pruned_.RowSplits(3).Data(), - *oshapep_row_splits2 = oshape_pruned_.RowSplits(2).Data(), - *oshapep_row_splits1 = oshape_pruned_.RowSplits(1).Data(); - - // the 0123 and 012 express what type of indexes they are, see comment at - // top of utils.h - int32_t *new2old_arc_map0123 = renumber_output_arcs_.New2Old().Data(), - *old2new_state_map012 = renumber_output_states_.Old2New().Data(); - - Array1 ai_data_ptrs(c_cpu, T + 1); + Array1 arcs_data_ptrs(c_cpu, T + 1); Array1 arcs_row_splits1_ptrs(c_cpu, T + 1); - Array1 arcs_row_splits2_ptrs(c_cpu, T + 1); - for (int32_t t = 0; t <= T; t++) { - ai_data_ptrs.Data()[t] = frames_[t]->arcs.values.Data(); - arcs_row_splits1_ptrs.Data()[t] = - frames_[t]->arcs.shape.RowSplits(1).Data(); - arcs_row_splits2_ptrs.Data()[t] = - frames_[t]->arcs.shape.RowSplits(2).Data(); + arcs_data_ptrs.Data()[t] = frames_[t]->arcs.values.Data(); + arcs_row_splits1_ptrs.Data()[t] = frames_[t]->arcs.RowSplits(1).Data(); } // transfer to GPU if we're using a GPU - ai_data_ptrs = ai_data_ptrs.To(c_); + arcs_data_ptrs = arcs_data_ptrs.To(c_); + ArcInfo **arcs_data_ptrs_data = arcs_data_ptrs.Data(); arcs_row_splits1_ptrs = arcs_row_splits1_ptrs.To(c_); - arcs_row_splits2_ptrs = arcs_row_splits2_ptrs.To(c_); - ArcInfo **ai_data_ptrs_data = ai_data_ptrs.Data(); - int32_t **arcs_row_splits1_ptrs_data = arcs_row_splits1_ptrs.Data(), - **arcs_row_splits2_ptrs_data = arcs_row_splits2_ptrs.Data(); - - int32_t tot_arcs_pruned = oshape_pruned_.TotSize(3); - *arc_map_a = Array1(c_, tot_arcs_pruned); - *arc_map_b = Array1(c_, tot_arcs_pruned); + int32_t **arcs_row_splits1_ptrs_data = arcs_row_splits1_ptrs.Data(); + const int32_t *b_fsas_row_splits1 = b_fsas_.shape.RowSplits(1).Data(); + const int32_t *a_fsas_row_splits1 = a_fsas_.RowSplits(1).Data(); + int32_t a_fsas_stride = a_fsas_stride_; // 0 or 1 depending if the decoding + // graph is shared. + int32_t num_fsas = b_fsas_.shape.Dim0(); + + RaggedShape final_arcs_shape; + { /* This block populates `final_arcs_shape`. It is the shape of a ragged + tensor of arcs that conceptually would live at frames_[T+1]->arcs. It + contains no actual arcs, but may contain some states, that represent + "missing" final-states. The problem we are trying to solve is that + there was a start-state for an FSA but no final-state because it did + not survive pruning, and this could lead to an output FSA that is + invalid or is misinterpreted (because we are interpreting a non-final + state as a final state). + */ + Array1 num_extra_states(c_, num_fsas + 1); + int32_t *num_extra_states_data = num_extra_states.Data(); + K2_EVAL(c_, num_fsas, lambda_set_num_extra_states, (int32_t i) -> void { + int32_t final_t = b_fsas_row_splits1[i+1] - b_fsas_row_splits1[i]; + int32_t *arcs_row_splits1_data = arcs_row_splits1_ptrs_data[final_t]; + int32_t num_states_final_t = arcs_row_splits1_data[i + 1] - + arcs_row_splits1_data[i]; + K2_CHECK_LE(num_states_final_t, 1); + + // has_start_state is 1 if there is a start-state; note, we don't prune + // the start-states, so they'll be present if they were present in a_fsas_. + int32_t has_start_state = (a_fsas_row_splits1[i * a_fsas_stride] < + a_fsas_row_splits1[i * a_fsas_stride + 1]); + + // num_extra_states_data[i] will be 1 if there was a start state but no final-state; + // else, 0. + num_extra_states_data[i] = has_start_state * (1 - num_states_final_t); + }); + ExclusiveSum(num_extra_states, &num_extra_states); + + RaggedShape top_shape = RaggedShape2(&num_extra_states, nullptr, -1), + bottom_shape = RegularRaggedShape(c_, top_shape.NumElements(), 0); + final_arcs_shape = ComposeRaggedShapes(top_shape, bottom_shape); + } + + + RaggedShape oshape; + // see documentation of Stack() in ragged_ops.h for explanation. + Array1 oshape_merge_map; + + { + NVTX_RANGE("InitOshape"); + // each of these have 3 axes. + std::vector arcs_shapes(T + 2); + for (int32_t t = 0; t <= T; t++) + arcs_shapes[t] = &(frames_[t]->arcs.shape); + arcs_shapes[T + 1] = &final_arcs_shape; + + // oshape is a 4-axis ragged tensor which is indexed: + // oshape[fsa_index][t][state_idx][arc_idx] + int32_t axis = 1; + oshape = Stack(axis, T + 2, arcs_shapes.data(), &oshape_merge_map); + } + + + int32_t *oshape_row_ids3 = oshape.RowIds(3).Data(), + *oshape_row_ids2 = oshape.RowIds(2).Data(), + *oshape_row_ids1 = oshape.RowIds(1).Data(), + *oshape_row_splits3 = oshape.RowSplits(3).Data(), + *oshape_row_splits2 = oshape.RowSplits(2).Data(), + *oshape_row_splits1 = oshape.RowSplits(1).Data(); + + + int32_t num_arcs = oshape.NumElements(); + *arc_map_a = Array1(c_, num_arcs); + *arc_map_b = Array1(c_, num_arcs); int32_t *arc_map_a_data = arc_map_a->Data(), *arc_map_b_data = arc_map_b->Data(); - Array1 arcs_out(c_, tot_arcs_pruned); + Array1 arcs_out(c_, num_arcs); Arc *arcs_out_data = arcs_out.Data(); const Arc *a_fsas_arcs = a_fsas_.values.Data(); int32_t b_fsas_num_cols = b_fsas_.scores.Dim1(); const int32_t *b_fsas_row_ids1 = b_fsas_.shape.RowIds(1).Data(); - const int32_t *b_fsas_row_splits1 = b_fsas_.shape.RowSplits(1).Data(); - K2_EVAL( - c_, tot_arcs_pruned, lambda_format_arc_data, - (int32_t pruned_idx0123)->void { - int32_t unpruned_idx0123 = new2old_arc_map0123[pruned_idx0123]; - int32_t unpruned_idx012 = oshapeu_row_ids3[unpruned_idx0123], - unpruned_idx01 = oshapeu_row_ids2[unpruned_idx012], - unpruned_idx01x = oshapeu_row_splits2[unpruned_idx01], - unpruned_idx01xx = oshapeu_row_splits3[unpruned_idx01x], - unpruned_idx23 = unpruned_idx0123 - unpruned_idx01xx, - unpruned_idx0 = oshapeu_row_ids1[unpruned_idx01], // fsa-id - unpruned_idx0x = oshapeu_row_splits1[unpruned_idx0], - // unpruned_idx0xx = oshapeu_row_splits2[unpruned_idx0x], - unpruned_idx1 = unpruned_idx01 - unpruned_idx0x, // t - unpruned_idx01_next_t = unpruned_idx01 + 1, - unpruned_idx01x_next_t = - oshapeu_row_splits2[unpruned_idx01_next_t]; - - int32_t t = unpruned_idx1; - int32_t *arcs_row_splits1_data = arcs_row_splits1_ptrs_data[t], - *arcs_row_splits2_data = arcs_row_splits2_ptrs_data[t], - arcs_idx0x = arcs_row_splits1_data[unpruned_idx0], - arcs_idx0xx = arcs_row_splits2_data[arcs_idx0x]; - // below: axes 2,3 of the unpruned layout coincide with axes 1,2 of - // 'arcs'; these are state and arc indexes (within this frame - // of this FSA). - int32_t arcs_idx012 = arcs_idx0xx + unpruned_idx23; - ArcInfo *ai_data = ai_data_ptrs_data[t]; - ArcInfo arc_info = ai_data[arcs_idx012]; - - // we call it ind2 because the state-index is axis 2 of oshape. - int32_t unpruned_dest_state_idx2 = arc_info.u.dest_info_state_idx1, - unpruned_dest_state_idx012 = - unpruned_idx01x_next_t + unpruned_dest_state_idx2, - pruned_dest_state_idx012 = - old2new_state_map012[unpruned_dest_state_idx012], - pruned_dest_state_idx01 = - oshapep_row_ids2[pruned_dest_state_idx012], - pruned_dest_state_idx0 = - oshapep_row_ids1[pruned_dest_state_idx01], - pruned_dest_state_idx0x = - oshapep_row_splits1[pruned_dest_state_idx0], - pruned_dest_state_idx0xx = - oshapep_row_splits2[pruned_dest_state_idx0x], - pruned_dest_state_idx12 = - pruned_dest_state_idx012 - pruned_dest_state_idx0xx; - - // note: the src-state and dest-state have the same ind0 which is the - // FSA-id. - int32_t pruned_src_state_idx012 = - old2new_state_map012[unpruned_idx012], - pruned_src_state_idx12 = - pruned_src_state_idx012 - pruned_dest_state_idx0xx; + const uint32_t *oshape_merge_map_data = oshape_merge_map.Data(); + K2_EVAL( + c_, num_arcs, lambda_format_arc_data, + (int32_t oarc_idx0123)->void { // by 'oarc' we mean arc with shape `oshape`. + int32_t oarc_idx012 = oshape_row_ids3[oarc_idx0123], + oarc_idx01 = oshape_row_ids2[oarc_idx012], + oarc_idx0 = oshape_row_ids1[oarc_idx01], + oarc_idx0x = oshape_row_splits1[oarc_idx0], + oarc_idx0xx = oshape_row_splits2[oarc_idx0x], + oarc_idx1 = oarc_idx01 - oarc_idx0x, + oarc_idx01x_next = oshape_row_splits2[oarc_idx01 + 1]; + + int32_t m = oshape_merge_map_data[oarc_idx0123], + t = m % (T + 2), // actually we won't get t == T or t == T + 1 + // here since those frames have no arcs. + arcs_idx012 = m / (T + 2); // arc_idx012 into FrameInfo::arcs on time t, + // index of the arc on that frame. + + K2_CHECK_EQ(t, oarc_idx1); + + const ArcInfo *arcs_data = arcs_data_ptrs_data[t]; + + ArcInfo arc_info = arcs_data[arcs_idx012]; Arc arc; - // The numbering for the dest-state in the output Arc is the numbering - // *within the FSA*, and we ignore the time index (1) because that - // index will be removed as the FSA format has no notion of time; - // that's why we use the indx12. + arc.src_state = oarc_idx012 - oarc_idx0xx; + // Note: the idx1 w.r.t. the frame's `arcs` is an idx2 w.r.t. `oshape`. + int32_t dest_state_idx012 = oarc_idx01x_next + + arc_info.u.dest_info_state_idx1; + arc.dest_state = dest_state_idx012 - oarc_idx0xx; + arc.label = a_fsas_arcs[arc_info.a_fsas_arc_idx012].label; - arc_map_a_data[pruned_idx0123] = arc_info.a_fsas_arc_idx012; + int32_t fsa_id = oarc_idx0, + b_fsas_idx0x = b_fsas_row_splits1[fsa_id], + b_fsas_idx01 = b_fsas_idx0x + t, + b_fsas_idx2 = (arc.label + 1), + b_fsas_arc_idx012 = b_fsas_idx01 * b_fsas_num_cols + b_fsas_idx2; - arc.src_state = pruned_src_state_idx12; - arc.dest_state = pruned_dest_state_idx12; - arc.label = a_fsas_arcs[arc_info.a_fsas_arc_idx012].label; - K2_CHECK_LE(static_cast(arc.label + 1), - static_cast(b_fsas_num_cols)) - << "label out of range"; - int32_t fsa_id = unpruned_idx0, - b_fsas_idx0x = b_fsas_row_splits1[fsa_id], - b_fsas_idx01 = b_fsas_idx0x + t, - b_fsas_idx2 = (arc.label + 1), - b_fsas_arc_idx012 = - b_fsas_idx01 * b_fsas_num_cols + b_fsas_idx2; arc.score = arc_info.arc_loglike; - arc_map_b_data[pruned_idx0123] = b_fsas_arc_idx012; - arcs_out_data[pruned_idx0123] = arc; + arc_map_a_data[oarc_idx0123] = arc_info.a_fsas_arc_idx012; + arc_map_b_data[oarc_idx0123] = b_fsas_arc_idx012; + arcs_out_data[oarc_idx0123] = arc; }); // Remove axis 1, which corresponds to time. - RaggedShape output_fsas_shape = RemoveAxis(oshape_pruned_, 1); - *ofsa = FsaVec(output_fsas_shape, arcs_out); + *ofsa = FsaVec(RemoveAxis(oshape, 1), arcs_out); } /* @@ -501,7 +549,7 @@ class MultiGraphDenseIntersectPruned { if (dynamic_beam > default_beam) dynamic_beam = default_beam; // Decrease the beam as long as we have more than // max_active active states. - dynamic_beam *= 0.9; + dynamic_beam *= 0.8; } dynamic_beams_data[i] = dynamic_beam; cutoffs_data[i] = best_loglike - dynamic_beam; @@ -757,6 +805,8 @@ class MultiGraphDenseIntersectPruned { StateInfo *kept_states_data = ans->states.values.Data(); { + int32_t *ans_states_row_splits1_data = ans->states.RowSplits(1).Data(); + NVTX_RANGE("LambdaSetStates"); K2_EVAL( c_, arc_info.NumElements(), lambda_set_arcs_and_states, @@ -775,11 +825,17 @@ class MultiGraphDenseIntersectPruned { state_idx01 = -1; // The destination state did not survive // pruning. - // state_idx01 is the index into ans->states, of the destination - // state. Note: multiple arcs may enter this state, which is why we - // had to set that in a separate kernel (lambda_modify_state_map). - info.u.dest_info_state_idx01 = state_idx01; - if (state_idx01 < 0) + int32_t state_idx1; + if (state_idx01 >= 0) { + int32_t state_idx0x = ans_states_row_splits1_data[fsa_id]; + state_idx1 = state_idx01 - state_idx0x; + } else { + state_idx1 = -1; // Meaning: invalid. + } + // state_idx1 is the idx1 into ans->states, of the destination + // state. + info.u.dest_info_state_idx1 = state_idx1; + if (state_idx1 < 0) return; // multiple threads may write the same value to the address written @@ -790,7 +846,7 @@ class MultiGraphDenseIntersectPruned { // Set the forward log-like of the dest state to the largest of any // of those of the incoming arcs. Note: we initialized this in // lambda_init_loglike above. - atomicMax(&(kept_states_data[state_idx01].forward_loglike), + AtomicMax(&(kept_states_data[state_idx01].forward_loglike), end_loglike_int); }); } @@ -802,7 +858,7 @@ class MultiGraphDenseIntersectPruned { (int32_t state_idx01)->void { int32_t a_fsas_state_idx01 = kept_states_data[state_idx01].a_fsas_state_idx01, - fsa_idx0 = next_states_row_ids1[state_idx01]; + fsa_idx0 = next_states_row_ids1[state_idx01]; int32_t state_map_idx = a_fsas_state_idx01 + fsa_idx0 * state_map_fsa_stride; state_map_acc.Delete(state_map_idx); @@ -811,44 +867,122 @@ class MultiGraphDenseIntersectPruned { return ans; } + /* - Does backward propagation of log-likes, which means setting the - backward_loglike field of the StateInfo variable. These backward log-likes - are normalized in such a way that you can add them with the forward - log-likes to produce the log-likelihood ratio vs the best path (this will be - non-positive). To do this, for the final state we have to set the backward - log-like to the negative of the forward log-like. + Sets backward_loglike fields of StateInfo to the negative of the forward + prob if (this is the final-state or !only_final_probs), else -infinity. + + This is used in computing the backward loglikes/scores for purposes of + pruning. This may be done after we're finished decoding/intersecting, + or while we are still decoding. + + Note: something similar to this (setting backward-prob == forward-prob) is + also done in PropagateBackward() when we detect final-states. That's needed + because not all sequences have the same length, so some may have reached + their final state earlier. (Note: we only get to the final-state of a_fsas_ + if we've reached the final frame of the input, because for non-final frames + we always have -infinity as the log-prob corresponding to the symbol -1.) + + While we are still decoding, a background process will do pruning + concurrently with the forward computation, for purposes of reducing memory + usage (and so that most of the pruning can be made concurrent with the + forward computation). In this case we want to avoid pruning away anything + that wouldn't have been pruned away if we were to have waited to the end; + and it turns out that setting the backward probs to the negative of the + forward probs (i.e. for all states, not just final states) accomplishes + this. The issue was mentioned in the "Exact Lattice Generation.." paper and + also in the code for Kaldi's lattice-faster-decoder; search for "As in [3], + to save memory..." + + @param [in] cur_frame Frame on which to set the backward probs + */ + void SetBackwardProbsFinal(FrameInfo *cur_frame) { + NVTX_RANGE("SetBackwardProbsFinal"); + Ragged &cur_states = cur_frame->states; // 2 axes: fsa,state + int32_t num_states = cur_states.values.Dim(); + if (num_states == 0) + return; + StateInfo *cur_states_data = cur_states.values.Data(); + const int32_t *a_fsas_row_ids1_data = a_fsas_.shape.RowIds(1).Data(), + *a_fsas_row_splits1_data = a_fsas_.shape.RowSplits(1).Data(), + *cur_states_row_ids1_data = cur_states.RowIds(1).Data(); + double minus_inf = -std::numeric_limits::infinity(); + + K2_EVAL(c_, num_states, lambda_set_backward_prob, (int32_t state_idx01) -> void { + StateInfo *info = cur_states_data + state_idx01; + double backward_loglike, + forward_loglike = OrderedIntToFloat(info->forward_loglike); + if (forward_loglike - forward_loglike == 0) { // not -infinity... + // canonically we'd set this to zero, but setting it to the forward + // loglike when this is the final-state (in a_fsas_) has the effect of + // making the (forward+backward) probs equivalent to the logprob minus + // the best-path log-prob, which is convenient for pruning. If this + // is not actually the last frame of this sequence, which can happen + // if this was called before the forward decoding process was + // finished, what we are doing is a form of pruning that is guaranteed + // not to prune anything out that would not have been pruned out if we + // had waited until the real end of the file to do the pruning. + backward_loglike = -forward_loglike; + } else { + backward_loglike = minus_inf; + } + info->backward_loglike = backward_loglike; + }); + } - @param [in] t The time-index (on which to look up log-likes), - t >= 0 - @param [in] cur_frame The FrameInfo for the frame on which we want to - set the forward log-like - @param [in] next_frame NULL if this is is the last frame of the - sequence; otherwise the next frame's FrameInfo; - arcs on `cur_frame` have transitions to states - on `next_frame`. The `backward_loglike` values - in `next_frame` are assumed to already be set. - */ - void PropagateBackward(int32_t t, FrameInfo *cur_frame, - FrameInfo *next_frame) { + /* + Does backward propagation of log-likes, which means setting the + backward_loglike field of the StateInfo variable (for cur_frame); + and works out which arcs and which states are to be pruned + on cur_frame; this information is output to Array1's which + are supplied by the caller. + + These backward log-likes are normalized in such a way that you can add them + with the forward log-likes to produce the log-likelihood ratio vs the best + path (this will be non-positive). (To do this, for the final state we have + to set the backward log-like to the negative of the forward log-like; see + SetBackwardProbsFinal()). + + This function also prunes arc-indexes on `cur_frame` and state-indexes + on `next_frame`. + + @param [in] t The time-index (on which to look up log-likes); + equals time index of `cur_frame`; t >= 0 + @param [in] cur_frame The FrameInfo for the frame on which we want to + set the forward log-like, and output pruning info + for arcs and states + @param [in] next_frame The next frame's FrameInfo, on which to look + up log-likes for the next frame; the + `backward_loglike` values of states on `next_frame` + are assumed to already be set, either by + SetBackwardProbsFinal() or a previous call to + PropagateBackward(). + @param [out] cur_frame_states_keep An array, created by the caller, + to which we'll write 1s for elements of cur_frame->states + which we need to keep, and 0s for others. + @param [out] cur_frame_arcs_keep An array, created by the caller, + to which we'll write 1s for elements of cur_frame->arcs + which we need to keep (because they survived pruning), + and 0s for others. + */ + void PropagateBackward(int32_t t, + FrameInfo *cur_frame, + FrameInfo *next_frame, + Array1 *cur_frame_states_keep, + Array1 *cur_frame_arcs_keep) { NVTX_RANGE("PropagateBackward"); int32_t num_states = cur_frame->states.NumElements(), num_arcs = cur_frame->arcs.NumElements(); - Ragged &cur_states = cur_frame->states; // 2 axes: fsa,state - StateInfo *cur_states_data = cur_states.values.Data(); + K2_CHECK_EQ(num_states, cur_frame_states_keep->Dim()); + K2_CHECK_EQ(num_arcs, cur_frame_arcs_keep->Dim()); - int32_t *a_fsas_row_ids1 = a_fsas_.shape.RowIds(1).Data(), - *a_fsas_row_splits1 = a_fsas_.shape.RowSplits(1).Data(); + int32_t *a_fsas_row_ids1_data = a_fsas_.shape.RowIds(1).Data(), + *a_fsas_row_splits1_data = a_fsas_.shape.RowSplits(1).Data(); float minus_inf = -std::numeric_limits::infinity(); - /* arc_backward_probs represents the backward-prob at the beginning of the - arc. Indexing is [state_idx01][arc_idx2], (where state_idx01 and - arc_idx2 are named w.r.t. frames_[t]->arcs. */ - RaggedShape sub_curr_frame_shape = RemoveAxis(cur_frame->arcs.shape, 0); - Array1 sub_curr_frame_values(c_, num_arcs); - Ragged arc_backward_prob(sub_curr_frame_shape, - sub_curr_frame_values); + Ragged arc_backward_prob(cur_frame->arcs.shape, + Array1(c_, cur_frame->arcs.NumElements())); float *arc_backward_prob_data = arc_backward_prob.values.Data(); ArcInfo *ai_data = cur_frame->arcs.values.Data(); @@ -858,70 +992,57 @@ class MultiGraphDenseIntersectPruned { *arcs_row_splits2 = cur_frame->arcs.shape.RowSplits(2).Data(); float output_beam = output_beam_; - int32_t *oshape_row_splits1 = oshape_unpruned_.RowSplits(1).Data(), - *oshape_row_splits2 = oshape_unpruned_.RowSplits(2).Data(), - *oshape_row_splits3 = oshape_unpruned_.RowSplits(3).Data(); - - // these have the "output" formatting where we number things with - // oshape_unpruned_, which is indexed [fsa][t][state][arc]. - char *keep_arcs_data = renumber_output_arcs_.Keep().Data(), - *keep_states_data = renumber_output_states_.Keep().Data(); - - if (next_frame != NULL) { - // compute arc backward probs, and set elements of 'keep_arcs' - - StateInfo *cur_states_data = cur_frame->states.values.Data(); - - // arc_row_ids maps from arc-idx to frame-state-idx, i.e. idx012 into - // `arcs` to idx01 into `arcs`. - - // next_states_row_splits1 maps from fsa_idx0 to state_idx01 - int32_t *next_states_row_splits1 = - next_frame->states.shape.RowSplits(1).Data(); - - StateInfo *next_states_data = next_frame->states.values.Data(); - K2_EVAL( - c_, arc_backward_prob.NumElements(), - lambda_set_arc_backward_prob_and_keep, (int32_t arcs_idx012)->void { - ArcInfo *arc = ai_data + arcs_idx012; - int32_t state_idx01 = arcs_rowids2[arcs_idx012], - fsa_idx0 = arcs_rowids1[state_idx01], - fsa_idx0x = arcs_row_splits1[fsa_idx0], - fsa_idx0xx = arcs_row_splits2[fsa_idx0x], - arcs_idx12 = arcs_idx012 - fsa_idx0xx; - - int32_t dest_state_idx01 = arc->u.dest_info_state_idx01; - char keep_this_arc = 0; - float backward_loglike = minus_inf; - if (dest_state_idx01 >= 0) { // Dest-state was not pruned out.. - int32_t next_state_idx0x = next_states_row_splits1[fsa_idx0], - dest_state_idx1 = dest_state_idx01 - next_state_idx0x; - arc->u.dest_info_state_idx1 = dest_state_idx1; - float arc_loglike = arc->arc_loglike; - float dest_state_backward_loglike = - next_states_data[dest_state_idx01].backward_loglike; - // 'backward_loglike' is the loglike at the beginning of the arc - backward_loglike = arc_loglike + dest_state_backward_loglike; - float src_state_forward_loglike = OrderedIntToFloat( - cur_states_data[arcs_rowids2[arcs_idx012]].forward_loglike); - keep_this_arc = (backward_loglike + src_state_forward_loglike >= - -output_beam); - } - int32_t oshape_arc_idx0x = oshape_row_splits1[fsa_idx0], - oshape_arc_idx01 = oshape_arc_idx0x + t, - oshape_arc_idx01x = oshape_row_splits2[oshape_arc_idx01], - oshape_arc_idx01xx = oshape_row_splits3[oshape_arc_idx01x], - oshape_arc_idx0123 = oshape_arc_idx01xx + arcs_idx12; - // note, for the previous line: indexes 1 and 2 of FrameInfo::arcs - // (==state,arc) become indexes 2 and 3 of oshape_unpruned_. - keep_arcs_data[oshape_arc_idx0123] = keep_this_arc; - arc_backward_prob_data[arcs_idx012] = backward_loglike; - }); - } else { - K2_DCHECK_EQ(arc_backward_prob.NumElements(), 0) - << "Caution: final frame has arcs; check that there were -infinities " - "in the right place on the last frame of the 'scores' matrix."; - } + // compute arc backward probs, and set elements of 'keep_cur_arcs_data' + int32_t next_num_states = next_frame->states.TotSize(1); + + char *keep_cur_arcs_data = cur_frame_arcs_keep->Data(), + *keep_cur_states_data = cur_frame_states_keep->Data(); + + const int32_t *next_states_row_splits1_data = + next_frame->states.RowSplits(1).Data(); + + StateInfo *next_states_data = next_frame->states.values.Data(); + StateInfo *cur_states_data = cur_frame->states.values.Data(); + + K2_EVAL(c_, num_arcs, lambda_set_arc_backward_prob_and_keep, + (int32_t arcs_idx012) -> void { + ArcInfo *arc = ai_data + arcs_idx012; + int32_t state_idx01 = arcs_rowids2[arcs_idx012], + seq_idx0 = arcs_rowids1[state_idx01], // 'seq' == fsa-idx in b + next_states_idx0x = next_states_row_splits1_data[seq_idx0]; + + // Note: if dest_state_idx1 == -1, dest_state_idx01 has a meaningless + // value below, but it's never referenced. + int32_t dest_state_idx1 = arc->u.dest_info_state_idx1, + dest_state_idx01 = next_states_idx0x + dest_state_idx1; + float backward_loglike = minus_inf; + char keep_this_arc = 0; + if (dest_state_idx1 == -1) { + // dest_state_idx1 == -1 means this arc was already pruned in + // the forward pass.. do nothing. + } else { + float arc_loglike = arc->arc_loglike; + float dest_state_backward_loglike = + next_states_data[dest_state_idx01].backward_loglike; + // 'backward_loglike' is the loglike at the beginning of the arc + backward_loglike = arc_loglike + dest_state_backward_loglike; + float src_state_forward_loglike = OrderedIntToFloat( + cur_states_data[arcs_rowids2[arcs_idx012]].forward_loglike); + + // should be <= 0.0, mathematically. + K2_CHECK_LT(backward_loglike, -src_state_forward_loglike + 2.0); + if (backward_loglike + src_state_forward_loglike >= -output_beam) { + keep_this_arc = 1; + } else { + backward_loglike = minus_inf; // Don't let arcs outside beam + // contribute to their start-states's + // backward prob (we'll use that to + // prune the start-states away.) + } + } + keep_cur_arcs_data[arcs_idx012] = keep_this_arc; + arc_backward_prob_data[arcs_idx012] = backward_loglike; + }); /* note, the elements of state_backward_prob that don't have arcs leaving them will be set to the supplied default. */ @@ -939,11 +1060,8 @@ class MultiGraphDenseIntersectPruned { (int32_t state_idx01)->void { StateInfo *info = cur_states_data + state_idx01; int32_t fsas_state_idx01 = info->a_fsas_state_idx01, - a_fsas_idx0 = a_fsas_row_ids1[fsas_state_idx01], - states_idx0 = cur_states_row_ids1[state_idx01], - fsas_state_idx0x_next = a_fsas_row_splits1[a_fsas_idx0 + 1]; - // Note: a_fsas_idx0 and states_idx0 will be the same if - // a_fsas_.Dim0() >= b_fsas_.Dim0(). + a_fsas_idx0 = a_fsas_row_ids1_data[fsas_state_idx01], + fsas_state_idx0x_next = a_fsas_row_splits1_data[a_fsas_idx0 + 1]; float forward_loglike = OrderedIntToFloat(info->forward_loglike), backward_loglike; // `is_final_state` means this is the final-state in a_fsas. this @@ -957,41 +1075,300 @@ class MultiGraphDenseIntersectPruned { } else { backward_loglike = state_backward_prob_data[state_idx01]; } - char keep_this_state = - (backward_loglike + forward_loglike >= -output_beam); - - // we can use the arcs row-splits because the structure of - // FrameInfo::states is the same as the top level structure of - // FrameInfo::arcs. - int32_t states_idx0x = arcs_row_splits1[states_idx0], - states_idx1 = state_idx01 - states_idx0x; - - int32_t oshape_idx0x = oshape_row_splits1[states_idx0], - oshape_idx01 = oshape_idx0x + t, - oshape_idx01x = oshape_row_splits2[oshape_idx01], - oshape_idx012 = oshape_idx01x + states_idx1; - // note: axis 1 of 'states' corresponds to axis 2 of 'oshape'; it's - // the state index. Also, - - keep_states_data[oshape_idx012] = keep_this_state; - if (!keep_this_state) { - // The reason we set the backward_loglike to -infinity here if it's - // outside the beam, is to prevent disconnected states from - // appearing after pruning due to numerical roundoff effects near - // the boundary at `-beam`. It would otherwise be correct and - // harmless to omit this if-block. - backward_loglike = minus_inf; - } info->backward_loglike = backward_loglike; + keep_cur_states_data[state_idx01] = (backward_loglike != minus_inf); }); } + /* + This function does backward propagation and pruning of arcs and states for a + specific time range. + @param [in] begin_t Lowest `t` value to call PropagateBackward() for + and to prune its arcs and states. Require t >= 0. + @param [in] end_t One-past-the-highest `t` value to call PropagateBackward() + and to prune its arcs and states. Require that + `frames_[t+1]` already be set up; this requires at least + end_t <= T. + Arcs on frames t >= end_t and states on frame t > end_t are ignored; the backward + probs on time end_t are set by SetBackwardProbsFinal(), see its documentation + to understand what this does if we haven't yet reached the end of one of the + sequences. + + After this function is done, the arcs for `frames_[t]` with begin_t <= t < end_t and + the states for `frames_[t]` with begin_t < t < end_t will have their numbering changed. + (We don't renumber the states on begin_t because that would require the dest-states + of the arcs on time `begin_t - 1` to be modified). TODO: check this... + */ + void PruneTimeRange(int32_t begin_t, + int32_t end_t) { + SetBackwardProbsFinal(frames_[end_t].get()); + ContextPtr cpu = GetCpuContext(); + int32_t num_fsas = b_fsas_.shape.Dim0(), + num_t = end_t - begin_t; + Array1 old_states_offsets(cpu, num_t + 1), + old_arcs_offsets(cpu, num_t + 1); + int32_t tot_states = 0, tot_arcs = 0; + { + int32_t *old_states_offsets_data = old_states_offsets.Data(), + *old_arcs_offsets_data = old_arcs_offsets.Data(); + for (int32_t i = 0; i <= num_t; i++) { + int32_t t = begin_t + i; + old_states_offsets_data[i] = tot_states; + old_arcs_offsets_data[i] = tot_arcs; + if (i < num_t) { + tot_states += frames_[t]->arcs.TotSize(1); + tot_arcs += frames_[t]->arcs.TotSize(2); + } + } + } + + + // contains respectively: row_splits1_ptrs, row_ids1_ptrs, + // row_splits1_ptrs, row_splits2_ptrs, + // old_arcs_ptrs (really type ArcInfo*), + // old_states_ptrs (really type StateInfo*). + Array1 old_all_ptrs(cpu, num_t * 6); + + Renumbering renumber_states(c_, tot_states), + renumber_arcs(c_, tot_arcs); + { + void **all_p = old_all_ptrs.Data(); + int32_t **old_row_splits1_ptrs_data = (int32_t**)all_p, + **old_row_ids1_ptrs_data = (int32_t**)all_p + num_t, + **old_row_splits2_ptrs_data = (int32_t**)all_p + 2 * num_t, + **old_row_ids2_ptrs_data = (int32_t**)all_p + 3 * num_t; + StateInfo **old_states_ptrs_data = (StateInfo**)all_p + 4 * num_t; + ArcInfo **old_arcs_ptrs_data = (ArcInfo**)all_p + 5 * num_t; + int32_t *old_states_offsets_data = old_states_offsets.Data(), + *old_arcs_offsets_data = old_arcs_offsets.Data(); + + for (int32_t t = end_t - 1; t >= begin_t; --t) { + int32_t i = t - begin_t; + Array1 this_states_keep = + renumber_states.Keep().Arange(old_states_offsets_data[i], + old_states_offsets_data[i + 1]), + this_arcs_keep = + renumber_arcs.Keep().Arange(old_arcs_offsets_data[i], + old_arcs_offsets_data[i + 1]); + FrameInfo *cur_frame = frames_[t].get(); + PropagateBackward(t, cur_frame, frames_[t+1].get(), + &this_states_keep, &this_arcs_keep); + + old_row_splits1_ptrs_data[i] = cur_frame->arcs.RowSplits(1).Data(); + old_row_ids1_ptrs_data[i] = cur_frame->arcs.RowIds(1).Data(); + old_row_splits2_ptrs_data[i] = cur_frame->arcs.RowSplits(2).Data(); + old_row_ids2_ptrs_data[i] = cur_frame->arcs.RowIds(2).Data(); + old_arcs_ptrs_data[i] = cur_frame->arcs.values.Data(); + old_states_ptrs_data[i] = cur_frame->states.values.Data(); + + // We can't discard any states on t == begin_t because: if it is not t == + // 0, it would be inconvenient to map the dest-states of arcs on t - 1; + // and if it is t == 0, this may remove the start-state, which would make + // it more complex to avoid invalid FSAs (e.g. with an end-state but no + // start-state, or in which we incorrectly interpret a non-start state as + // the start state). + if (i == 0) // t == begin_t + this_states_keep = (char)1; // set all elements of the array + // `states_keep` to 1. + } + } + + old_states_offsets = old_states_offsets.To(c_); + old_arcs_offsets = old_arcs_offsets.To(c_); + Array1 new_states_offsets = renumber_states.Old2New(true)[old_states_offsets], + new_arcs_offsets = renumber_arcs.Old2New(true)[old_arcs_offsets]; + int32_t new_num_states = renumber_states.NumNewElems(), + new_num_arcs = renumber_arcs.NumNewElems(); + // These arrays map to the (t - begin_t) corresponding to this state or arc + // in the new numbering, i.e. the frame index minus begin_t. + Array1 new_state_to_frame(c_, new_num_states), + new_arc_to_frame(c_, new_num_arcs); + RowSplitsToRowIds(new_states_offsets, &new_state_to_frame); + RowSplitsToRowIds(new_arcs_offsets, &new_arc_to_frame); + const int32_t *old_states_offsets_data = old_states_offsets.Data(), + *new_states_offsets_data = new_states_offsets.Data(), + *old_arcs_offsets_data = old_arcs_offsets.Data(), + *new_arcs_offsets_data = new_arcs_offsets.Data(), + *new_state_to_frame_data = new_state_to_frame.Data(), + *new_arc_to_frame_data = new_arc_to_frame.Data(), + *states_old2new_data = renumber_states.Old2New().Data(), + *states_new2old_data = renumber_states.New2Old().Data(), + *arcs_old2new_data = renumber_arcs.Old2New().Data(), + *arcs_new2old_data = renumber_arcs.New2Old().Data(); + + // Allocate the new row_splits and row_ids vectors for the shapes on the + // individual frames, and the new arc-info and state-info. + Array2 all_row_splits1(c_, num_t, num_fsas + 1); + auto all_row_splits1_acc = all_row_splits1.Accessor(); + Array1 all_row_ids1(c_, new_num_states); + // the "+ num_t" below is for the extra element of each row_splits array. + Array1 all_row_splits2(c_, new_num_states + num_t); + Array1 all_row_ids2(c_, new_num_arcs); + Array1 all_states(c_, new_num_states); + Array1 all_arcs(c_, new_num_arcs); + + int32_t *all_row_ids1_data = all_row_ids1.Data(), + *all_row_ids2_data = all_row_ids2.Data(), + *all_row_splits2_data = all_row_splits2.Data(); + StateInfo *all_states_data = all_states.Data(); + ArcInfo *all_arcs_data = all_arcs.Data(); + + old_all_ptrs = old_all_ptrs.To(c_); + void **all_p = old_all_ptrs.Data(); + + K2_EVAL2(c_, num_t, num_fsas + 1, + lambda_set_new_row_splits1, (int32_t t_offset, + int32_t seq_idx) -> void { + // note, t_offset is t - t_start. + int32_t *old_row_splits1 = (int32_t*) all_p[t_offset]; + int32_t old_idx0x = old_row_splits1[seq_idx]; + // "pos" means position in appended states vector + // old_start_pos means start for this `t`. + int32_t old_start_pos = old_states_offsets_data[t_offset], + old_pos = old_start_pos + old_idx0x, + new_start_pos = states_old2new_data[old_start_pos], + new_pos = states_old2new_data[old_pos], + new_idx0x = new_pos - new_start_pos; + all_row_splits1_acc(t_offset, seq_idx) = new_idx0x; + // TODO: set elem zero of row-splits? + + if (seq_idx == 0) { + // We assign the `seq_idx == 0` version of the kernel to set the initial + // zero in each row_splits vector. + all_row_splits2_data[new_pos + t_offset] = 0; + } + }); + + K2_EVAL(c_, new_num_states, lambda_per_state, (int32_t new_i) -> void { + // new_i is position in appended vector of all states. + int32_t t_offset = new_state_to_frame_data[new_i], + old_state_start_pos = old_states_offsets_data[t_offset], + new_arc_start_pos = new_arcs_offsets_data[t_offset], + old_arc_start_pos = old_arcs_offsets_data[t_offset], + old_i = states_new2old_data[new_i], + old_state_idx01 = old_i - old_state_start_pos; + + + // this old_states_data is from its FrameInfo::states. + const StateInfo *old_states_data = (StateInfo*)all_p[4 * num_t + t_offset]; + const int32_t *old_row_ids1_data = (int32_t*)all_p[1 * num_t + t_offset], + *old_row_splits2_data = (int32_t*)all_p[2 * num_t + t_offset]; + + // set the row-ids1 (these contain FSA-ids). + all_row_ids1_data[new_i] = old_row_ids1_data[old_state_idx01]; + + + { // set the row-splits2. + // We make each kernel responsible for the *next* row_splits entry, + // i.e. for its new_state_idx01 plus one. This solves the problem of no + // kernel being responsible for the last row-splits entry. We + // separately wrote the zeros for the 1st row-splits entry, in a + // previous kernel. + // + // It's safe to use old_state_idx01+1 instead of doing the same mapping + // from new_i+1 that we do from new_i to old_state_idx01, because + // we know this state was kept (because it has a new_i index.) + int32_t old_arc_idx01x_next = old_row_splits2_data[old_state_idx01+1], + old_arc_pos_next = old_arc_idx01x_next + old_arc_start_pos, + new_arc_pos_next = arcs_old2new_data[old_arc_pos_next], + new_arc_idx01x_next = new_arc_pos_next - new_arc_start_pos; + + // "+ t_offset" is to compensate for the extra element of each row_splits + // vector. The "+ 1" is about the "next", i.e. each kernel is responsible + // for the next row_splits element, and none is responsible for the initial zero; + // that is set in a previous kernel. + all_row_splits2_data[new_i + t_offset + 1] = new_arc_idx01x_next; + } + all_states_data[new_i] = old_states_data[old_state_idx01]; + }); + + K2_EVAL(c_, new_num_arcs, lambda_set_arcs, (int32_t new_i) -> void { + // new_i is position in appended vector of all arcs + int32_t t_offset = new_arc_to_frame_data[new_i], + new_state_start_pos = new_states_offsets_data[t_offset], + old_state_start_pos = old_states_offsets_data[t_offset], + next_old_state_start_pos = old_states_offsets_data[t_offset + 1], + old_arc_start_pos = old_arcs_offsets_data[t_offset], + old_i = arcs_new2old_data[new_i], + old_arc_idx012 = old_i - old_arc_start_pos; + + ArcInfo *old_info_data = (ArcInfo*)all_p[5 * num_t + t_offset]; + int32_t *old_row_ids2_data = (int32_t*)all_p[3 * num_t + t_offset], + *old_row_ids1_data = (int32_t*)all_p[1 * num_t + t_offset], + *next_old_row_splits1_data = (int32_t*)all_p[t_offset + 1]; + + int32_t old_src_state_idx01 = old_row_ids2_data[old_arc_idx012], + fsa_idx0 = old_row_ids1_data[old_src_state_idx01], + old_src_state_pos = old_src_state_idx01 + old_state_start_pos, + new_src_state_pos = states_old2new_data[old_src_state_pos], + new_src_state_idx01 = new_src_state_pos - new_state_start_pos; + + all_row_ids2_data[new_i] = new_src_state_idx01; + + ArcInfo info = old_info_data[old_arc_idx012]; + + if (t_offset + 1 == num_t) { + // Do nothing; this is the last frame of the batch of frames that we are + // pruning, so we don't need to renumber the destination-states of the + // arcs leaving it because the next frame's states have not been pruned + // (so the numbering stays the same). + } else { + // idx1 of the state in the next frame's `states` object. + int32_t dest_info_state_idx1 = info.u.dest_info_state_idx1; + + // the naming below is unusual; by "pos" we mean position in the old or + // new "all_states" or "all_arcs" vectors, which have all frames appended. + // (the new ones physically exist; the old ones don't, but they are the + // numberings used in renumber_states.Keep() and renumber_arcs.Keep().) + int32_t old_dest_state_idx0x = next_old_row_splits1_data[fsa_idx0], + old_dest_state_idx01 = old_dest_state_idx0x + dest_info_state_idx1, + old_dest_state_idx0x_pos = next_old_state_start_pos + old_dest_state_idx0x, + old_dest_state_idx01_pos = next_old_state_start_pos + old_dest_state_idx01, + new_dest_state_idx0x_pos = states_old2new_data[old_dest_state_idx0x_pos], + new_dest_state_idx01_pos = states_old2new_data[old_dest_state_idx01_pos], + new_dest_state_idx1 = new_dest_state_idx01_pos - new_dest_state_idx0x_pos; + info.u.dest_info_state_idx1 = new_dest_state_idx1; + } + all_arcs_data[new_i] = info; + }); + + // Now reconstruct the states and arcs for all the frames we pruned, from + // sub-parts of the arrays we just created. + new_states_offsets = new_states_offsets.To(cpu); + new_arcs_offsets = new_arcs_offsets.To(cpu); + new_states_offsets_data = new_states_offsets.Data(); + new_arcs_offsets_data = new_arcs_offsets.Data(); + for (int32_t i = 0; i < num_t; i++) { // i corresponds to "t_offset". + int32_t state_offset = new_states_offsets_data[i], + next_state_offset = new_states_offsets_data[i + 1], + arc_offset = new_arcs_offsets_data[i], + next_arc_offset = new_arcs_offsets_data[i + 1]; + + // next line: operator[] into Array2 gives Array1, one row. + Array1 row_splits1 = all_row_splits1.Row(i), + row_ids1 = all_row_ids1.Arange(state_offset, next_state_offset), + row_splits2 = all_row_splits2.Arange(state_offset + i, next_state_offset + (i+1)), + row_ids2 = all_row_ids2.Arange(arc_offset, next_arc_offset); + Array1 arcs = all_arcs.Arange(arc_offset, next_arc_offset); + + RaggedShape arcs_shape = RaggedShape3(&row_splits1, &row_ids1, -1, + &row_splits2, &row_ids2, -1); + int32_t t = begin_t + i; + frames_[t]->arcs = Ragged(arcs_shape, arcs); + Array1 states = all_states.Arange(state_offset, next_state_offset); + RaggedShape states_shape = GetLayer(arcs_shape, 0); + frames_[t]->states = Ragged(states_shape, states); + } + } + + ContextPtr c_; FsaVec &a_fsas_; // Note: a_fsas_ has 3 axes. int32_t a_fsas_stride_; // 1 if we use a different FSA per sequence // (a_fsas_.Dim0() > 1), 0 if the decoding graph is // shared (a_fsas_.Dim0() == 1). DenseFsaVec &b_fsas_; + int32_t T_; // == b_fsas_.MaxSize(1). float search_beam_; float output_beam_; int32_t min_active_; @@ -1019,22 +1396,26 @@ class MultiGraphDenseIntersectPruned { // from active states to the position in the // `states` array. Between frames, all values // have -1 in them. - std::vector> frames_; - // This is a rearranged version of the info in 'frames', computed at the end - // of the forward pass before pruning. It is indexed [fsa_id][t][state][arc]. - RaggedShape oshape_unpruned_; - - // these two Renumbering objects dictate how we renumber oshape_unpruned_, - // i.e. which states and arcs we delete. The data in their Keep() members, - // which are vectors of chars, are written to in PropagateBackward(). - Renumbering renumber_output_states_; - Renumbering renumber_output_arcs_; - - // This is as oshape_unpruned_, but after the backward-pass pruning. - // It is indexed [fsa_id][t][state][arc]. - RaggedShape oshape_pruned_; + // logically an array of bool, of size T_ + 1; for each 0 <= t <= T, after the + // forward pass finishes propagation with cur_frame_ == t, if + // do_pruning_after_[t] is false it will continue as normal; otherwise (if + // true), it will signal `semaphore_`. + std::vector do_pruning_after_; + + // For each t for which do_pruning_after_[t] is true, there will be a + // pair (begin_t, end_t) in prune_t_begin_end giving the + // arguments for which we will invoke PruneTimeRange() after the forward-pass + // for time t has completed. The size of this array equals the sum + // of nonzero elements of do_pruning_after_. + std::vector > prune_t_begin_end_; + + // Each time the forward-pass finishes forward processing for a t value for + // which do_pruning_after_[t] is true, it will signal this semaphore; the + // backward-pass thread (which does pruning) will wait on it as many times as + // do_pruning_after_[t] is set to true. + Semaphore semaphore_; }; void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas, diff --git a/k2/csrc/intersect_test.cu b/k2/csrc/intersect_test.cu index 0b793c01a..1479be59b 100644 --- a/k2/csrc/intersect_test.cu +++ b/k2/csrc/intersect_test.cu @@ -173,7 +173,7 @@ TEST(Intersect, RandomSingle) { for (int32_t n = 0; n < 10; n++) { int32_t i = RandInt(0, dfsavec.scores.Dim0() - 1); for (int32_t j = 0; j < dfsavec.scores.Dim1(); j++) { - dfsa_acc(i, j) += -2000.0; + dfsa_acc(i, j) += -100.0; } } } @@ -240,7 +240,7 @@ TEST(Intersect, RandomFsaVec) { for (int32_t n = 0; n < 10; n++) { int32_t i = RandInt(0, dfsavec.scores.Dim0() - 1); for (int32_t j = 0; j < dfsavec.scores.Dim1(); j++) { - dfsa_acc(i, j) += -2000.0; + dfsa_acc(i, j) += -100.0; } } } @@ -256,7 +256,10 @@ TEST(Intersect, RandomFsaVec) { float output_beam = 100000.0; // TODO(Dan) ... IntersectDense(fsavec, dfsavec, output_beam, &out_fsas, &arc_map_a, &arc_map_b); - K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_b = " << arc_map_b; + K2_LOG(INFO) << "out_fsas = " << out_fsas + << ", arc_map_a = " << arc_map_a + << ", arc_map_b = " << arc_map_b; + fsavec = fsavec.To(cpu); out_fsas = out_fsas.To(cpu); @@ -287,8 +290,20 @@ TEST(Intersect, RandomFsaVec) { ArcSort(&fsavec); // CAUTION if you later test the arc_maps: we arc-sort // here, so the input `fsa` is not the same as before. bool treat_epsilons_specially = false; - Intersect(fsavec, -1, fsas_b, -1, treat_epsilons_specially, - &out_fsas2, &arc_map_a2, &arc_map_b2); + + + { + Array1 arc_map_a2_temp, + arc_map_b2_temp; + FsaVec out_fsas2_temp; + Intersect(fsavec, -1, fsas_b, -1, treat_epsilons_specially, + &out_fsas2_temp, &arc_map_a2_temp, &arc_map_b2_temp); + Array1 connect_arc_map; + Connect(out_fsas2_temp, &out_fsas2, &connect_arc_map); + arc_map_a2 = arc_map_a2_temp[connect_arc_map]; + arc_map_b2 = arc_map_b2_temp[connect_arc_map]; + } + K2_LOG(INFO) << "out_fsas2 = " << out_fsas2 << ", arc_map_a2 = " << arc_map_a2 << ", arc_map_b2 = " << arc_map_b2; @@ -474,7 +489,10 @@ TEST(IntersectPruned, RandomSingle) { num_fsas = RandInt(2, 5); } - int32_t min_frames = 0, max_frames = 10, min_nsymbols = max_symbol + 1, + // set max_frames = 50 to be larger than the chunk sizes used for pruning + // in intersect_pruned.cu (see call to PruneTimeRange()). + int32_t min_frames = 0, max_frames = 50, + min_nsymbols = max_symbol + 1, max_nsymbols = max_symbol + 4; float scores_scale = 1.0; DenseFsaVec dfsavec = @@ -492,7 +510,7 @@ TEST(IntersectPruned, RandomSingle) { for (int32_t n = 0; n < 10; n++) { int32_t i = RandInt(0, dfsavec.scores.Dim0() - 1); for (int32_t j = 0; j < dfsavec.scores.Dim1(); j++) { - dfsa_acc(i, j) += -2000.0; + dfsa_acc(i, j) += -100.0; } } } @@ -556,7 +574,7 @@ TEST(IntersectPruned, RandomFsaVec) { for (int32_t n = 0; n < 10; n++) { int32_t i = RandInt(0, dfsavec.scores.Dim0() - 1); for (int32_t j = 0; j < dfsavec.scores.Dim1(); j++) { - dfsa_acc(i, j) += -2000.0; + dfsa_acc(i, j) += -100.0; } } } @@ -573,7 +591,9 @@ TEST(IntersectPruned, RandomFsaVec) { int32_t min_active = 0, max_active = 10; IntersectDensePruned(fsavec, dfsavec, search_beam, output_beam, min_active, max_active, &out_fsas, &arc_map_a, &arc_map_b); - K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_b = " << arc_map_b; + K2_LOG(INFO) << "out_fsas = " << out_fsas + << ", arc_map_a = " << arc_map_a + << ", arc_map_b = " << arc_map_b; out_fsas = out_fsas.To(cpu); fsavec = fsavec.To(cpu); @@ -604,8 +624,17 @@ TEST(IntersectPruned, RandomFsaVec) { ArcSort(&fsavec); // CAUTION if you later test the arc_maps: we arc-sort // here, so the input `fsa` is not the same as before. bool treat_epsilons_specially = false; - Intersect(fsavec, -1, fsas_b, -1, treat_epsilons_specially, - &out_fsas2, &arc_map_a2, &arc_map_b2); + { + Array1 arc_map_a2_temp, + arc_map_b2_temp; + FsaVec out_fsas2_temp; + Intersect(fsavec, -1, fsas_b, -1, treat_epsilons_specially, + &out_fsas2_temp, &arc_map_a2_temp, &arc_map_b2_temp); + Array1 connect_arc_map; + Connect(out_fsas2_temp, &out_fsas2, &connect_arc_map); + arc_map_a2 = arc_map_a2_temp[connect_arc_map]; + arc_map_b2 = arc_map_b2_temp[connect_arc_map]; + } K2_LOG(INFO) << "out_fsas2 = " << out_fsas2 << ", arc_map_a2 = " << arc_map_a2 << ", arc_map_b2 = " << arc_map_b2; diff --git a/k2/csrc/log.h b/k2/csrc/log.h index ffd7a7c46..b37934a2b 100644 --- a/k2/csrc/log.h +++ b/k2/csrc/log.h @@ -174,6 +174,17 @@ inline bool EnableCudaDeviceSync() { return enable_cuda_sync; } +inline bool DisableChecks() { + // Currently this just disables the checks called in the constructor of + // RaggedShape, which can otherwise dominate the time when in debug mode. + static std::once_flag init_flag; + static bool disable_checks = false; + std::call_once(init_flag, []() { + disable_checks = (std::getenv("K2_DISABLE_CHECKS") != nullptr); + }); + return disable_checks; +} + } // namespace internal } // namespace k2 @@ -246,7 +257,8 @@ inline bool EnableCudaDeviceSync() { K2_CHECK_CUDA_ERROR(e); \ } while (0) -// ============================================================ + +// ------------------------------------------------------------ // For debug check // ------------------------------------------------------------ diff --git a/k2/csrc/ragged.h b/k2/csrc/ragged.h index 5035531a1..00c3539f7 100644 --- a/k2/csrc/ragged.h +++ b/k2/csrc/ragged.h @@ -155,7 +155,9 @@ class RaggedShape { explicit RaggedShape(const std::vector &layers, bool check = !internal::kDisableDebug) : layers_(layers) { - if (check) Check(); + // the check can be disabled by settin the environment variable + // K2_DISABLE_CHECKS. + if (check && !internal::DisableChecks()) Check(); } explicit RaggedShape(const std::string &src) { @@ -165,6 +167,12 @@ class RaggedShape { K2_LOG(FATAL) << "Failed to construct RaggedShape from string: " << src; } + // Construct from context and string. This uses delegating constructors, (a + // c++11 feature), and an explicitly constructed RaggedShape + // "RaggedShape(src)" + RaggedShape(ContextPtr context, const std::string &src): + RaggedShape(RaggedShape(src).To(context)) { } + // A RaggedShape constructed this way will not be a valid RaggedShape. // The constructor is provided so you can immediately assign to it. RaggedShape() = default; diff --git a/k2/csrc/ragged_ops.cu b/k2/csrc/ragged_ops.cu index 128ddca5d..9af12fe3c 100644 --- a/k2/csrc/ragged_ops.cu +++ b/k2/csrc/ragged_ops.cu @@ -147,7 +147,8 @@ RaggedShape ComposeRaggedShapes(const RaggedShape &a, const RaggedShape &b) { std::size_t a_size = a_axes.size(), b_size = b_axes.size(); for (std::size_t i = 0; i < a_size; ++i) axes[i] = a_axes[i]; for (std::size_t i = 0; i < b_size; ++i) axes[i + a_size] = b_axes[i]; - return RaggedShape(axes); + bool validate = false; + return RaggedShape(axes, validate); } RaggedShape RaggedShape3(Array1 *row_splits1, @@ -1562,4 +1563,125 @@ RaggedShape GetLayer(const RaggedShape &src, int32_t layer) { } +void DecomposeRaggedShape(const RaggedShape &src, + int32_t axis, + RaggedShape *top, RaggedShape *bottom) { + K2_CHECK_GT(axis, 0); + K2_CHECK_LT(axis, src.NumAxes() - 1); + const std::vector &src_layers = src.Layers(); + std::vector top_layers(axis), + bottom_layers(src_layers.size() - axis); + int32_t src_size = src_layers.size(); + for (int32_t i = 0; i < axis; ++i) + top_layers[i] = src_layers[i]; + for (int32_t i = axis; i < src_size; ++i) + bottom_layers[i - axis] = src_layers[i]; + *top = RaggedShape(top_layers); + *bottom = RaggedShape(bottom_layers); +} + +RaggedShape RemoveEmptyLists(RaggedShape &src_shape, + int32_t axis, + Renumbering *renumbering_out) { + if (axis == 0) { + return RemoveEmptyListsAxis0(src_shape, renumbering_out); + } + RaggedShape top_shape, bottom_shape; + DecomposeRaggedShape(src_shape, axis, &top_shape, &bottom_shape); + + Renumbering r_temp; + if (!renumbering_out) + renumbering_out = &r_temp; + bottom_shape = RemoveEmptyListsAxis0(bottom_shape, renumbering_out); + top_shape = SubsampleRaggedShape(top_shape, *renumbering_out); + return ComposeRaggedShapes(top_shape, bottom_shape); +} + + +RaggedShape RemoveSomeEmptyLists(RaggedShape &src_shape, + int32_t axis, + Renumbering &renumbering) { + if (axis == 0) { + return RenumberAxis0Simple(src_shape, renumbering); + } + RaggedShape top_shape, bottom_shape; + DecomposeRaggedShape(src_shape, axis, &top_shape, &bottom_shape); + + bottom_shape = RenumberAxis0Simple(bottom_shape, renumbering); + top_shape = SubsampleRaggedShape(top_shape, renumbering); + return ComposeRaggedShapes(top_shape, bottom_shape); +} + + + +RaggedShape RemoveEmptyListsAxis0(RaggedShape &src_shape, + Renumbering *renumbering_out) { + Renumbering r_temp; + if (!renumbering_out) + renumbering_out = &r_temp; + + ContextPtr c = src_shape.Context(); + int32_t num_lists = src_shape.Dim0(); + *renumbering_out = Renumbering(c, num_lists); + int32_t *row_splits_data = src_shape.RowSplits(1).Data(); + char *keep_data = renumbering_out->Keep().Data(); + K2_EVAL(c, num_lists + 1, lambda_set_keep, (int32_t i) -> void { + keep_data[i] = (row_splits_data[i+1] != row_splits_data[i]); + }); + return RenumberAxis0Simple(src_shape, *renumbering_out); +} + +RaggedShape RenumberAxis0Simple(RaggedShape &src_shape, + Renumbering &renumbering) { + K2_CHECK_EQ(renumbering.NumOldElems(), src_shape.Dim0()); + ContextPtr c = src_shape.Context(); + src_shape.RowIds(1); // make sure RowIds(1) is populated. + std::vector layers = src_shape.Layers(); + int32_t num_layers = layers.size(); + int32_t new_num_lists = renumbering.NumNewElems(), + num_elems = src_shape.TotSize(1); // unchanged old vs. new. + Array1 new_row_splits(c, new_num_lists + 1), + new_row_ids = renumbering.Old2New()[src_shape.RowIds(1)]; + int32_t *new_row_splits_data = new_row_splits.Data(); + const int32_t *old_row_splits_data = src_shape.RowSplits(1).Data(), + *new2old_data = renumbering.New2Old().Data(); + // set `new_row_splits_data`. + +#ifndef NDEBUG + { + Array1 is_ok(c, 1, 1); + int32_t *is_ok_data = is_ok.Data(); + int32_t old_num_lists = src_shape.Dim0(); + const int32_t *old2new_data = renumbering.Old2New().Data(); + K2_EVAL(c, old_num_lists, lambda_check_preconditions, (int32_t i) -> void { + if (old2new_data[i+1] == old2new_data[i]) { // This list not kept + if (old_row_splits_data[i+1] != old_row_splits_data[i]) { + // this list was nonempty... + is_ok_data[0] = 0; + } + } + }); + K2_CHECK_NE(is_ok[0], 0) << "RenumberAxis0Simple(): preconditions not met; " + "renumbering removes nonempty lists."; + } +#endif + + K2_EVAL(c, new_num_lists + 1, lambda_set_new_row_splits, (int32_t new_i) -> void { + int32_t j; + if (new_i == new_num_lists) { + j = num_elems; + } else { + int32_t old_i = new2old_data[new_i]; + j = old_row_splits_data[old_i]; + } + new_row_splits_data[new_i] = j; + }); + layers[0].row_splits = new_row_splits; + layers[0].row_ids = new_row_ids; + // no need to set its cached_tot_size; that didn't change. + return RaggedShape(layers); +} + + + } // namespace k2 diff --git a/k2/csrc/ragged_ops.h b/k2/csrc/ragged_ops.h index 697714374..46898940c 100644 --- a/k2/csrc/ragged_ops.h +++ b/k2/csrc/ragged_ops.h @@ -314,6 +314,25 @@ Ragged RemoveAxis(Ragged &src, int32_t axis) { */ RaggedShape GetLayer(const RaggedShape &src, int32_t layer); + +/* + This is the inverse of ComposeRaggedShapes(); it splits up a RaggedShape + into two pieces such that `top->NumElements() == bottom->Dim0()`. + + @param [in] src Source RaggedShape + @param [in] axis Axis to split at; must satisfy + 0 < axis < src.NumLayers() - 1. Axis `axis` of + the input will correspond to the last axis of + `top` and axis 0 of `bottom`. + @param [out] top Top layers of the RaggedShape + @param [out] bottom Bottom layers of the RaggedShape; will satisfy + `top->NumElements() == bottom->Dim0()` and + `Equal(src, ComposeRaggedShapes(*top, *bottom))` + */ +void DecomposeRaggedShape(const RaggedShape &src, + int32_t axis, + RaggedShape *top, RaggedShape *bottom); + /* Returns a CPU array of shape (src[0]->NumAxes() + 1) by (num_srcs + 1), where each row is the exclusive-sum of the TotSize() of the respective sources, @@ -493,7 +512,9 @@ RaggedShape RandomRaggedShape(bool set_row_ids = false, Notice the other version of this function below. */ -RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering); +RaggedShape SubsampleRaggedShape(RaggedShape &src, + Renumbering &renumbering); + /* Return ragged shape with only a subset of the elements on the last @@ -508,6 +529,87 @@ RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering_before_last, Renumbering &renumbering_last); +/* + Removes empty lists on a particular axis (not last axis) of a RaggedShape, + returning the modified shape with those lists removed. + @param [in] src_shape RaggedShape that possibly has empty lists + to be removed + @param [in] axis Axis that is not the last axis of `src_shape`, + i.e. with `axis + 1 < src_shape.NumAxes()`. + @param [out] renumbering If not nullptr, a renumbering object that maps + between old and new indexes on axis `axis` (e.g. if + `axis == 0` would map between idx0's and idx0's; if + `axis == 1`, would map between idx01's and idx01's). + @return Returns modified shape with ans.NumAxes() == src_shape.NumAxes(). + ans.TotSize(axis) may differ from src_shape.TotSize(axis), + but other TotSize() values, and the numbering on other axes, + will remain the same. + */ +RaggedShape RemoveEmptyLists(RaggedShape &src_shape, + int32_t axis, + Renumbering *renumbering = nullptr); + +/* + Removes some subset of empty lists on a particular axis (not last axis) of + a RaggedShape, returning the modified shape with those lists removed. + + @param [in] src_shape RaggedShape that possibly has empty lists + to be removed + @param [in] axis Axis that is not the last axis of `src_shape`, + i.e. with `axis + 1 < src_shape.NumAxes()`. + @param [in] renumbering If not nullptr, a renumbering object that maps + between old and new indexes on axis `axis` (e.g. if + `axis == 0` would map between idx0's and idx0's; if + `axis == 1`, would map between idx01's and idx01's). + It is assumed that this renumbering preserves + all lists that are nonempty. + @return Returns modified shape with ans.NumAxes() == src_shape.NumAxes(). + ans.TotSize(axis) may differ from src_shape.TotSize(axis), + but other TotSize() values, and the numbering on other axes, + will remain the same. + */ +RaggedShape RemoveSomeEmptyLists(RaggedShape &src_shape, + int32_t axis, + Renumbering &renumbering); + + +/* + Removes empty lists on axis 0 of a RaggedShape, returning the modified shape + with those lists removed. Note: a list containing empty lists is not empty. + + @param [in] src_shape RaggedShape that possibly has empty lists on its + axis 0 + @param [out] renumbering If not nullptr, a renumbering object that maps + between old and new indexes on axis 0, i.e. between + old and new idx0's. + @return Returns modified shape with ans.NumAxes() == src_shape.NumAxes(). + ans.Dim0() may differ from src_shape.Dim0(), but for axis > 0, + we have `ans.TotSize(axis) == src.TotSize(axis)`. +*/ +RaggedShape RemoveEmptyListsAxis0(RaggedShape &src_shape, + Renumbering *renumbering = nullptr); + +/* + Removes some (but not necessarily all) empty lists on axis 0 of a RaggedShape, + returning the modified shape with those lists removed. Note: a list + containing empty lists is not empty. (this is what we mean by the "Simple" + part of the name, as it means we only have to deal with one layer). + + @param [in] src_shape RaggedShape that possibly has empty lists on its + axis 0 + @param [out] renumbering If not nullptr, a renumbering object that maps + between old and new indexes on axis 0, i.e. between + old and new idx0's. The removed lists must be empty. + + @return Returns modified shape with ans.NumAxes() == src_shape.NumAxes(). + ans.Dim0() may differ from src_shape.Dim0(), but for axis > 0, + we have `ans.TotSize(axis) == src.TotSize(axis)`. + */ +RaggedShape RenumberAxis0Simple(RaggedShape &src_shape, + Renumbering &renumbering); + + + /* Return ragged array with only a subset of the bottom-level elements kept. Require renumbering.NumOldElems() == src.NumElements(). Note: all diff --git a/k2/csrc/ragged_shape_test.cu b/k2/csrc/ragged_shape_test.cu index e1c0bca0b..cc4d2b586 100644 --- a/k2/csrc/ragged_shape_test.cu +++ b/k2/csrc/ragged_shape_test.cu @@ -254,6 +254,99 @@ TEST(RaggedShapeTest, RaggedShape) { } } } + + +TEST(RaggedShapeTest, DecomposeRaggedShape) { + ContextPtr cpu = GetCpuContext(); // will be used to copy data + for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) { + { + RaggedShape s(c, "[ [ x x ] [ x ] ]"), + t(c, "[ [ x ] [ x x ] [ x x x x ] ]"), + u = ComposeRaggedShapes(s, t); + + RaggedShape s2, t2; + DecomposeRaggedShape(u, 1, &s2, &t2); + EXPECT_EQ(Equal(s, s2), true); + EXPECT_EQ(Equal(t, t2), true); + } + } +} + + + +TEST(RaggedShapeTest, RemoveEmptyListsAxis0) { + ContextPtr cpu = GetCpuContext(); // will be used to copy data + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + { + RaggedShape s(c, "[ [ [ x ] [x ] ] [ ] [ [ x ] ] [ ] ]"), + t(c, "[ [ [ x ] [x ] ] [ [ x ] ] ]"); + + Renumbering r; + RaggedShape t2 = RemoveEmptyListsAxis0(s, &r); + EXPECT_EQ(Equal(t, t2), true); + } + + { + RaggedShape s(c, "[ [ x x ] [ ] [ x ] [ ] ]"), + t(c, "[ [ x x ] [ x ] ]"); + RaggedShape t2 = RemoveEmptyListsAxis0(s); + EXPECT_EQ(Equal(t, t2), true); + } + + { + RaggedShape s(c, "[ [ x x ] [ ] [ x ] [ ] ]"), + t(c, "[ [ x x ] [ x ] ]"); + Renumbering r; + RaggedShape t2 = RemoveEmptyLists(s, 0, &r); + EXPECT_EQ(Equal(t, t2), true); + } + + + { + RaggedShape s(c, "[ [ x x ] [ ] [ x ] [ ] ]"), + t(c, "[ [ x x ] [ ] [ x ] ]"); + + Array1 keep(c, std::vector({ (char)1, (char)1, (char)1, (char)0 })); + Renumbering r(c, 4); + Assign(keep, &r.Keep()); + RaggedShape t2 = RenumberAxis0Simple(s, r); + EXPECT_EQ(Equal(t, t2), true); + } + + { + RaggedShape s(c, "[ [ x x ] [ ] [ x ] [ ] ]"), + t(c, "[ [ x x ] [ ] [ x ] ]"); + + Array1 keep(c, std::vector({ (char)0, (char)1, (char)1, (char)0 })); + Renumbering r(c, 4); + Assign(keep, &r.Keep()); +#ifndef NDEBUG + ASSERT_DEATH(RenumberAxis0Simple(s, r), ""); +#endif + } + } +} + + +TEST(RaggedShapeTest, RemoveEmptyLists) { + ContextPtr cpu = GetCpuContext(); // will be used to copy data + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + { + RaggedShape s(c, "[ [ [ x ] [ ] ] [ ] [ [ x ] ] [ ] ]"), + t(c, "[ [ [ x ] ] [ ] [ [ x ] ] [ ] ]"); + + Renumbering r; + RaggedShape t2 = RemoveEmptyLists(s, 1, &r); + EXPECT_EQ(Equal(t, t2), true); + } + } +} + + + + + + TEST(RaggedShapeTest, RaggedShapeIterator) { // note RaggedShapeIndexIterator works only for CPU ContextPtr context = GetCpuContext(); diff --git a/k2/csrc/ragged_tensor_ops.h b/k2/csrc/ragged_tensor_ops.h new file mode 100644 index 000000000..dc7d26b8f --- /dev/null +++ b/k2/csrc/ragged_tensor_ops.h @@ -0,0 +1,40 @@ +/** + * @brief + * ragged_tensor_ops + * + * @copyright + * Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + */ + +#ifndef K2_CSRC_RAGGED_TENSOR_OPS_H_ +#define K2_CSRC_RAGGED_TENSOR_OPS_H_ + +#include +#include + +#include "k2/csrc/algorithms.h" +#include "k2/csrc/array.h" +#include "k2/csrc/log.h" +#include "k2/csrc/ragged.h" +#include "k2/csrc/tensor.h" +#include "k2/csrc/utils.h" + +namespace k2 { +// This file declares ops involving RaggedShape, Ragged, +// and Tensor. They are implemented in ragged_tensor_ops.cu +// (they don't need to be in the header as Tensor doesn't have type +// information, so these functions are not templated). + + + + + + + +} // namespace k2 + + +#endif // K2_CSRC_RAGGED_TENSOR_OPS_H_ diff --git a/k2/csrc/semaphore.h b/k2/csrc/semaphore.h new file mode 100644 index 000000000..ab831cea2 --- /dev/null +++ b/k2/csrc/semaphore.h @@ -0,0 +1,55 @@ +/** + * @brief + * semaphore + * + * @copyright + * Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + */ + +#ifndef K2_CSRC_SEMAPHORE_H_ +#define K2_CSRC_SEMAPHORE_H_ + +#include +#include +#include "k2/csrc/log.h" + +// caution: this contains class k2std::counting_semaphore, but not class k2::Semaphore +// which is in context.h. + +// `k2std` is our replacement for `std` until we require C++20 that has +// `counting_semaphore` as part of the standard library. +namespace k2std { + + +// This is intended to implement a subset of the functionality of C++20's +// counting_semaphore (at the time of writing, we compile with C++14.) +class counting_semaphore { +public: + explicit counting_semaphore(int count = 0): count_(count) { } + + void release() { // could also be 'signal' + std::lock_guard lock(mutex_); + ++count_; + cv_.notify_one(); + } + void acquire() { // could also be 'wait' + std::unique_lock lock(mutex_); + while(count_ == 0) + cv_.wait(lock); + --count_; + } + +private: + std::mutex mutex_; + std::condition_variable cv_; + int count_; +}; + + +} // namespace k2 + + +#endif // K2_CSRC_SEMAPHORE_H_ diff --git a/k2/csrc/tensor.cu b/k2/csrc/tensor.cu index ad28a2a65..600e2772d 100644 --- a/k2/csrc/tensor.cu +++ b/k2/csrc/tensor.cu @@ -56,25 +56,26 @@ Shape::Shape(const std::vector &dims, storage_size_ = ComputeStorageSize(); } -int32_t Shape::ComputeNumElement() const { +int64_t Shape::ComputeNumElement() const { NVTX_RANGE(K2_FUNC); if (num_axes_ == 0) return 0; - int32_t elements = 1; + int64_t elements = 1; for (int32_t i = 0; i < num_axes_; ++i) { elements *= dims_[i]; } return elements; } -int32_t Shape::ComputeStorageSize() const { +int64_t Shape::ComputeStorageSize() const { NVTX_RANGE(K2_FUNC); if (num_axes_ == 0) return 0; - int32_t size = 1; + int64_t size = 1; for (int32_t i = 0; i < num_axes_; ++i) { - size += (dims_[i] - 1) * strides_[i]; + size += (dims_[i] - 1) * (int64_t)strides_[i]; } + K2_CHECK_GE(size, 0); return size; } @@ -91,7 +92,7 @@ bool Shape::ComputeIsContiguous() const { } if (s == 0) return true; - int32_t z = 1; + int64_t z = 1; for (int32_t i = num_axes_ - 1; i >= 0; --i) { K2_CHECK_GE(strides_[i], z); if (dims_[i] != 1) { @@ -141,8 +142,8 @@ Tensor::Tensor(Dtype type, const Shape &shape, RegionPtr region, int32_t byte_offset) : impl_(std::make_shared()) { NVTX_RANGE(K2_FUNC); - int32_t storage_size = shape.StorageSize(); - int32_t element_size = TraitsOf(type).NumBytes(); + size_t storage_size = shape.StorageSize(); + size_t element_size = TraitsOf(type).NumBytes(); impl_->dtype = type; impl_->shape = shape; impl_->data = region; diff --git a/k2/csrc/tensor.h b/k2/csrc/tensor.h index 1840c12ae..6db8d75d2 100644 --- a/k2/csrc/tensor.h +++ b/k2/csrc/tensor.h @@ -51,7 +51,7 @@ class Shape { return std::vector(strides_, strides_ + num_axes_); } - int32_t StorageSize() const { return storage_size_; } + int64_t StorageSize() const { return storage_size_; } bool IsContiguous() const { return is_contiguous_; } @@ -79,8 +79,8 @@ class Shape { static const int32_t kMaxDim = 4; // Will increase this as needed int32_t num_axes_ = 0; // Must be >= 0 - int32_t num_element_ = 0; - int32_t storage_size_ = 0; + int64_t num_element_ = 0; + int64_t storage_size_ = 0; bool is_contiguous_ = true; // elements of dims_ and strides_ >= num_axes_ are currently not set; @@ -89,8 +89,10 @@ class Shape { int32_t strides_[kMaxDim]; // Strides in elements // compute the number of elements - int32_t ComputeNumElement() const; - int32_t ComputeStorageSize() const; + int64_t ComputeNumElement() const; + // compute the size of storage needed to hold this tensor, in elements. + // (different than ComputeNumElements(), because of strides). + int64_t ComputeStorageSize() const; bool ComputeIsContiguous() const; }; diff --git a/k2/csrc/utils.h b/k2/csrc/utils.h index 5dddf8bfa..e9a570300 100644 --- a/k2/csrc/utils.h +++ b/k2/csrc/utils.h @@ -383,6 +383,7 @@ __global__ void eval_lambda_redirect(int32_t num_jobs, TaskRedirect *redirect, lambda(task_id, num_threads_this_task, thread_idx_of_task); } + template __global__ void eval_lambda_redirect_large(int32_t num_jobs, TaskRedirect *redirect, @@ -597,13 +598,17 @@ __host__ __device__ __forceinline__ float OrderedIntToFloat(int32_t i) { } /* - Host version of Cuda's atomicMax function, marked __host__ (the default) for + host version of Cuda's atomicMax function, marked __host__ (the default) for clarity. So we can use this in lambdas that run on both host and device. */ -__host__ __forceinline__ int32_t atomicMax(int32_t *address, int32_t val) { +__host__ __device__ __forceinline__ int32_t AtomicMax(int32_t *address, int32_t val) { +#if defined(__CUDA_ARCH__) + return atomicMax(address, val); +#else int32_t old = *address; if (old < val) *address = val; return old; +#endif } // have to figure out if there's a better place to put this diff --git a/k2/python/host/tests/fsa_equivalent_test.py b/k2/python/host/tests/fsa_equivalent_test.py index 9d7692db1..e81dfe424 100644 --- a/k2/python/host/tests/fsa_equivalent_test.py +++ b/k2/python/host/tests/fsa_equivalent_test.py @@ -60,7 +60,7 @@ def test_bad_case_2(self): self.assertFalse(k2host.is_rand_equivalent(fsa_a, fsa_b, 100)) def test_good_case_1(self): - # both fsas will be empty after triming + # both fsas will be empty after trimming s_a = r''' 0 1 1 0 0 2 2 0 @@ -84,7 +84,8 @@ def test_good_case_2(self): 1 2 3 0 1 3 4 0 2 3 5 0 - 3 + 3 4 -1 0 + 4 ''' fsa_a = k2host.str_to_fsa(s_a) self.assertTrue(k2host.is_rand_equivalent(fsa_a, fsa_a)) @@ -96,7 +97,8 @@ def test_bad_case_2(self): 0 3 8 0 1 4 4 0 2 4 5 0 - 4 + 4 5 -1 0 + 5 ''' fsa_a = k2host.str_to_fsa(s_a) s_b = r''' @@ -105,7 +107,8 @@ def test_bad_case_2(self): 0 3 9 0 1 4 5 0 2 4 4 0 - 4 + 4 5 -1 0 + 5 ''' fsa_b = k2host.str_to_fsa(s_b) self.assertTrue(k2host.is_rand_equivalent(fsa_a, fsa_b)) @@ -195,7 +198,7 @@ def test_bad_case_2(self): 0 1 1 0 0 2 2 0 1 3 4 0 - 3 + 4 ''' fsa = k2host.str_to_fsa(s_a) rand_path = k2host.RandPath(fsa, False) @@ -217,7 +220,8 @@ def test_good_case_1(self): 2 4 5 0 3 4 7 0 4 5 9 0 - 5 + 5 6 -1 0 + 6 ''' fsa = k2host.str_to_fsa(s_a) rand_path = k2host.RandPath(fsa, False) @@ -233,7 +237,8 @@ def test_good_case_2(self): 0 1 1 0 1 2 3 0 2 3 4 0 - 3 + 3 4 -1 0 + 4 ''' fsa = k2host.str_to_fsa(s_a) rand_path = k2host.RandPath(fsa, False) @@ -245,10 +250,10 @@ def test_good_case_2(self): self.assertTrue(status) self.assertFalse(k2host.is_empty(path)) self.assertFalse(arc_map.empty()) - expected_arc_indexes = torch.IntTensor([0, 1, 2, 3, 3]) + expected_arc_indexes = torch.IntTensor([0, 1, 2, 3, 4, 4]) expected_arcs = torch.IntTensor([[0, 1, 1, 0], [1, 2, 3, 0], - [2, 3, 4, 0]]) - expected_arc_map = torch.IntTensor([0, 1, 2]) + [2, 3, 4, 0], [3, 4, -1, 0]]) + expected_arc_map = torch.IntTensor([0, 1, 2, 3]) self.assertTrue(torch.equal(path.indexes, expected_arc_indexes)) self.assertTrue(torch.equal(path.data, expected_arcs)) self.assertTrue(torch.equal(arc_map.data, expected_arc_map)) @@ -262,7 +267,8 @@ def test_eps_arc_1(self): 2 4 5 0 3 4 7 0 4 5 9 0 - 5 + 5 6 -1 0 + 6 ''' fsa = k2host.str_to_fsa(s_a) rand_path = k2host.RandPath(fsa, True) @@ -285,7 +291,8 @@ def test_eps_arc_2(self): 3 5 7 0 3 4 8 0 4 5 9 0 - 5 + 5 6 -1 0 + 6 ''' fsa = k2host.str_to_fsa(s_a) rand_path = k2host.RandPath(fsa, True) diff --git a/k2/python/k2/version.py b/k2/python/k2/version.py index 9b530856b..088b83095 100644 --- a/k2/python/k2/version.py +++ b/k2/python/k2/version.py @@ -38,14 +38,8 @@ def main(): torch_cuda_version = _k2.version.torch_cuda_version enable_nvtx = _k2.version.enable_nvtx disable_debug = _k2.version.disable_debug - sync_kernels = os.getenv('K2_SYNC_KERNELS', None) - - if sync_kernels is None: - sync_kernels = False - elif sync_kernels == '': - # It's enabled as long as it is defined, no matter - # what the value is - sync_kernels = True + sync_kernels = os.getenv('K2_SYNC_KERNELS', None) is not None + disable_checks = os.getenv('K2_DISABLE_CHECKS', None) is not None print(f''' k2 version: {version} @@ -65,6 +59,7 @@ def main(): NVTX enabled: {enable_nvtx} Disable debug: {disable_debug} Sync kernels : {sync_kernels} +Disable checks: {disable_checks} ''') diff --git a/k2/python/tests/intersect_dense_pruned_test.py b/k2/python/tests/intersect_dense_pruned_test.py index aa0abf05f..2b76f2c8c 100644 --- a/k2/python/tests/intersect_dense_pruned_test.py +++ b/k2/python/tests/intersect_dense_pruned_test.py @@ -167,5 +167,56 @@ def test_two_fsas(self): assert torch.allclose(expected_grad_log_prob, log_prob.grad) + def test_two_fsas_long_pruned(self): + # as test_two_fsas_long in intersect_dense_test.py, but with pruned intersection + s1 = ''' + 0 1 1 1.0 + 1 1 1 50.0 + 1 2 2 2.0 + 2 3 -1 3.0 + 3 + ''' + + s2 = ''' + 0 1 1 1.0 + 1 2 2 2.0 + 2 3 -1 3.0 + 3 + ''' + + devices = [torch.device('cpu')] + if torch.cuda.is_available(): + devices.append(torch.device('cuda', 0)) + for device in devices: + fsa1 = k2.Fsa.from_str(s1) + fsa2 = k2.Fsa.from_str(s2) + + fsa1.requires_grad_(True) + fsa2.requires_grad_(True) + + fsa_vec = k2.create_fsa_vec([fsa1, fsa2]) + log_prob = torch.rand((2, 100, 3), + dtype=torch.float32, + device=device, + requires_grad=True) + + supervision_segments = torch.tensor([[0, 0, 95], [1, 20, 50]], + dtype=torch.int32) + dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) + fsa_vec = fsa_vec.to(device) + out_fsa = k2.intersect_dense_pruned(fsa_vec, + dense_fsa_vec, + search_beam=100, output_beam=100, + min_active_states=1, + max_active_states=10) + assert out_fsa.shape == (2, None, + None), 'There should be two FSAs!' + + scores = k2.get_tot_scores(out_fsa, + log_semiring=False, + use_double_scores=False) + scores.sum().backward() + + if __name__ == '__main__': unittest.main() diff --git a/k2/python/tests/intersect_dense_test.py b/k2/python/tests/intersect_dense_test.py index e95e69c1f..8b4ab8cdd 100644 --- a/k2/python/tests/intersect_dense_test.py +++ b/k2/python/tests/intersect_dense_test.py @@ -187,12 +187,12 @@ def test_two_fsas_long(self): fsa2.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa1, fsa2]) - log_prob = torch.rand((2, 500, 3), + log_prob = torch.rand((2, 100, 3), dtype=torch.float32, device=device, requires_grad=True) - supervision_segments = torch.tensor([[0, 0, 490], [1, 0, 300]], + supervision_segments = torch.tensor([[0, 0, 95], [1, 20, 50]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) fsa_vec = fsa_vec.to(device) @@ -208,5 +208,6 @@ def test_two_fsas_long(self): scores.sum().backward() + if __name__ == '__main__': unittest.main()