Skip to content

Commit

Permalink
Merge remote-tracking branch 'dan/master' into fix-ci-2
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Dec 8, 2021
2 parents 75b82b9 + 2cb3eea commit f1005ee
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 35 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ message(STATUS "Enabled languages: ${languages}")

project(k2 ${languages})

set(K2_VERSION "1.10")
set(K2_VERSION "1.11")

# ----------------- Supported build types for K2 project -----------------
set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel)
Expand Down
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include LICENSE

6 changes: 3 additions & 3 deletions k2/csrc/fsa_algo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ FsaVec CtcGraphs(const Ragged<int32_t> &symbols, bool modified /*= false*/,
case 2: // the arc pointing to the next symbol state
arc.label = next_symbol;
aux_labels_value = sym_state_idx01 + 1 == sym_final_state ?
0 : next_symbol;
-1 : next_symbol;
arc.dest_state = state_idx1 + 2;
break;
default:
Expand All @@ -720,8 +720,8 @@ FsaVec CtcGraphs(const Ragged<int32_t> &symbols, bool modified /*= false*/,
K2_CHECK_LT(arc_idx2, 2);
arc.label = arc_idx2 == 0 ? 0 : current_symbol;
arc.dest_state = arc_idx2 == 0 ? state_idx1 : state_idx1 + 1;
aux_labels_value = (arc_idx2 == 0 || final_state) ?
0 : current_symbol;
aux_labels_value = arc_idx2 == 0 ? 0 : current_symbol;
if (final_state && arc_idx2 != 0) aux_labels_value = -1;
}
arcs_data[arc_idx012] = arc;
if (aux_labels) aux_labels_data[arc_idx012] = aux_labels_value;
Expand Down
8 changes: 4 additions & 4 deletions k2/csrc/fsa_algo_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1293,8 +1293,8 @@ TEST(FsaAlgo, TestCtcGraph) {
" [ 4 4 0 0 4 5 3 0 ] [ 5 6 0 0 5 5 3 0 5 7 -1 0 ] "
" [ 6 6 0 0 6 7 -1 0 ] [ ] ] ]");
Array1<int32_t> aux_labels_ref(c, "[ 0 1 0 0 2 0 2 0 0 0 2 0 0 3 "
" 0 3 0 0 0 0 0 0 1 0 0 2 0 2 "
" 0 0 3 0 3 0 0 0 0 0 ]");
" 0 3 0 0 -1 0 -1 0 1 0 0 2 0 2 "
" 0 0 3 0 3 0 0 -1 0 -1 ]");
K2_CHECK(Equal(graph, graph_ref));
K2_CHECK(Equal(aux_labels, aux_labels_ref));
}
Expand All @@ -1315,8 +1315,8 @@ TEST(FsaAlgo, TestCtcGraphSimplified) {
" [ 4 4 0 0 4 5 3 0 ] [ 5 6 0 0 5 5 3 0 5 7 -1 0 ] "
" [ 6 6 0 0 6 7 -1 0 ] [ ] ] ]");
Array1<int32_t> aux_labels_ref(c, "[ 0 1 0 0 2 0 2 0 0 2 0 2 0 "
" 0 3 0 3 0 0 0 0 0 0 1 0 0 2 "
" 0 2 0 0 3 0 3 0 0 0 0 0 ]");
" 0 3 0 3 0 0 -1 0 -1 0 1 0 0 2 "
" 0 2 0 0 3 0 3 0 0 -1 0 -1 ]");
K2_CHECK(Equal(graph, graph_ref));
K2_CHECK(Equal(aux_labels, aux_labels_ref));
}
Expand Down
63 changes: 44 additions & 19 deletions k2/csrc/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,43 @@ unsigned long long int __forceinline__ __host__ __device__ AtomicCAS(
How class Hash works:
- It can function as a map from key=uint32_t to value=uint32_t, or from
key=uint64_t to value=uint64_t where you choose NUM_KEY_BITS and
`key` must have only up to NUM_KEY_BITS set and `value` must have
only up to (64-NUM_KEY_BITS) set. You decide NUM_KEY_BITS when
you call Hash::Accessor<NUM_KEY_BITS>()
- You can store any (key,value) pair except the pair where all the bits of
key=uint64_t to value=uint64_t, but you cannot use all 64 bits in the
key and value because we compress both of them into a single 64-bit
integer. There are several different modes of using this hash,
depending which accessor objects you use. The modes are:
- Use Accessor<NUM_KEY_BITS> with num_key_bits known at compile time;
the number of values bits will be 64 - NUM_KEY_BITS.
- Use GenericAccessor, which is like Accessor but the number of
key bits is not known at compile time; and they both must still
sum to 64.
- Use PackedAccessor, which allows you to have the number of key
plus value bits greater than 64; the rest of the bits are
implicit in groups of buckets (the number of buckets must
be >= 32 * 1 << (num_key_bits + num_value_bits - 64).
- You must decide the number of key and value bits, and the number of
buckets, when you create the hash, but you can resize it (manually)
and when you resize it you can change the number of key and value bits.
Some constraints:
- You can store any (key,value) pair allowed by the number of key and value
bits, except the pair where all the bits of
both and key and value are set [that is used to mean "nothing here"]
- The number of buckets is a power of 2 provided by the user to the constructor;
currently no resizing is supported.
- The number of buckets must always be a power of 2.
- When deleting values from the hash you must delete them all at
once (necessary because there is no concept of a "tombstone".
Some notes on usage:
You use it by: constructing it, obtaining its Accessor with GetAccessor()
with appropriate template args depending on your chosen accessor type; and
inside kernels (or host code), calling functions Insert(), Find() or Delete()
of the Accessor object. Resizing is not automatic; it is the user's
responsibility to make sure the hash does not get too full (which could cause
assertion failures in kernels, and will be very slow).
Some implementation notes:
- When accessing hash[key], we use bucket_index == key % num_buckets,
bucket_inc = 1 | (((key * 2) / num_buckets) ^ key).
- If the bucket at `bucket_index` is occupied, we look in locations
Expand All @@ -72,15 +101,7 @@ unsigned long long int __forceinline__ __host__ __device__ AtomicCAS(
being odd ensures we eventually try all locations (of course for
reasonable hash occupancy levels, we shouldn't ever have to try
more than two or three).
- When deleting values from the hash you must delete them all at
once (necessary because there is no concept of a "tombstone".
You use it by: constructing it, obtaining its Accessor with
GetAccessor<NUM_KEY_BITS>(), and inside kernels (or host code), calling
functions Insert(), Find() or Delete() of the Accessor object. There is no
resizing; sizing it correctly is the caller's responsibility and if the hash
gets full the code will just loop forever (of course it will get extremely
slow before it reaches that point).
*/
class Hash {
public:
Expand All @@ -94,10 +115,14 @@ class Hash {
@param [in] num_key_bits Number of bits in the key of the hash;
must satisfy 0 < num_key_bits < 64, and keys used must
be less than (1<<num_key_bits)-1.
@param [in] num_value_bits Number of bits in the value of the hash.
If not specified it defaults to 64 - num_key_bits; in future
we'll allow more bits than that, by making some bits of
the key implicit in the bucket index.
@param [in] num_value_bits Number of bits in the value of the hash;
if not specified, will be set to 64 - num_key_bits. There
are constraints on the num_value_bits, it interacts with
which accessor you use. For Accessor<> or GenericAccessor,
we require that num_key_bits + num_value_bits == 64.
For PackedAccessor we allow that num_key_bits + num_value_bits > 64,
but with the constraint that
(num_buckets >> (64 - num_key_bits - num_value_bits)) >= 32
*/
Hash(ContextPtr c,
int32_t num_buckets,
Expand Down
14 changes: 7 additions & 7 deletions k2/python/tests/ctc_graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ def test(self):
'0 0 0 0 0', '0 1 1 1 0', '1 2 0 0 0', '1 1 1 0 0',
'1 3 2 2 0', '2 2 0 0 0', '2 3 2 2 0', '3 4 0 0 0',
'3 3 2 0 0', '4 4 0 0 0', '4 5 2 2 0', '5 6 0 0 0',
'5 5 2 0 0', '5 7 -1 0 0', '6 6 0 0 0', '6 7 -1 0 0', '7'
'5 5 2 0 0', '5 7 -1 -1 0', '6 6 0 0 0', '6 7 -1 -1 0', '7'
])
expected_str1 = '\n'.join([
'0 0 0 0 0', '0 1 1 1 0', '1 2 0 0 0', '1 1 1 0 0',
'1 3 2 2 0', '2 2 0 0 0', '2 3 2 2 0', '3 4 0 0 0',
'3 3 2 0 0', '3 5 3 3 0', '4 4 0 0 0', '4 5 3 3 0',
'5 6 0 0 0', '5 5 3 0 0', '5 7 -1 0 0', '6 6 0 0 0',
'6 7 -1 0 0', '7'
'5 6 0 0 0', '5 5 3 0 0', '5 7 -1 -1 0', '6 6 0 0 0',
'6 7 -1 -1 0', '7'
])
actual_str_ragged0 = k2.to_str_simple(fsa_vec_ragged[0].to('cpu'))
actual_str_ragged1 = k2.to_str_simple(fsa_vec_ragged[1].to('cpu'))
Expand All @@ -81,15 +81,15 @@ def test_simplified(self):
'0 0 0 0 0', '0 1 1 1 0', '1 2 0 0 0', '1 1 1 0 0',
'1 3 2 2 0', '2 2 0 0 0', '2 3 2 2 0', '3 4 0 0 0',
'3 3 2 0 0', '3 5 2 2 0', '4 4 0 0 0', '4 5 2 2 0',
'5 6 0 0 0', '5 5 2 0 0', '5 7 -1 0 0', '6 6 0 0 0',
'6 7 -1 0 0', '7'
'5 6 0 0 0', '5 5 2 0 0', '5 7 -1 -1 0', '6 6 0 0 0',
'6 7 -1 -1 0', '7'
])
expected_str1 = '\n'.join([
'0 0 0 0 0', '0 1 1 1 0', '1 2 0 0 0', '1 1 1 0 0',
'1 3 2 2 0', '2 2 0 0 0', '2 3 2 2 0', '3 4 0 0 0',
'3 3 2 0 0', '3 5 3 3 0', '4 4 0 0 0', '4 5 3 3 0',
'5 6 0 0 0', '5 5 3 0 0', '5 7 -1 0 0', '6 6 0 0 0',
'6 7 -1 0 0', '7'
'5 6 0 0 0', '5 5 3 0 0', '5 7 -1 -1 0', '6 6 0 0 0',
'6 7 -1 -1 0', '7'
])
actual_str_ragged0 = k2.to_str_simple(fsa_vec_ragged[0].to('cpu'))
actual_str_ragged1 = k2.to_str_simple(fsa_vec_ragged[1].to('cpu'))
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def get_short_description():
packages=['k2', 'k2.ragged', 'k2.sparse', 'k2.version'],
install_requires=install_requires,
extras_require={'dev': dev_requirements},
data_files=[('', ['LICENSE'])],
ext_modules=[cmake_extension('_k2')],
cmdclass={
'build_ext': BuildExtension,
Expand Down

0 comments on commit f1005ee

Please sign in to comment.