Skip to content

Commit

Permalink
add SetRandomGeneratorSeed
Browse files Browse the repository at this point in the history
  • Loading branch information
taku910 committed Oct 24, 2020
1 parent 63211c1 commit 5d79c12
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 5 deletions.
4 changes: 4 additions & 0 deletions python/src/sentencepiece/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,9 @@ def Load(self, model_file=None, model_proto=None):
# Register SentencePieceProcessor in _sentencepiece:
_sentencepiece.SentencePieceProcessor_swigregister(SentencePieceProcessor)


def SetRandomGeneratorSeed(seed):
return _sentencepiece.SetRandomGeneratorSeed(seed)
class SentencePieceTrainer(object):
thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")

Expand Down Expand Up @@ -516,6 +519,7 @@ def _batched_func(self, arg):

_add_snake_case(SentencePieceProcessor)
_add_snake_case(SentencePieceTrainer)
set_random_generator_seed = SetRandomGeneratorSeed



1 change: 1 addition & 0 deletions python/src/sentencepiece/sentencepiece.i
Original file line number Diff line number Diff line change
Expand Up @@ -740,4 +740,5 @@ for m in [

_add_snake_case(SentencePieceProcessor)
_add_snake_case(SentencePieceTrainer)
set_random_generator_seed = SetRandomGeneratorSeed
%}
95 changes: 95 additions & 0 deletions python/src/sentencepiece/sentencepiece_wrap.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -3301,6 +3301,70 @@ SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor_Decod
"piece id is out of range.");
return self->DecodeIdsAsSerializedProto(ids);
}

SWIGINTERN int
SWIG_AsVal_unsigned_SS_long (PyObject *obj, unsigned long *val)
{
#if PY_VERSION_HEX < 0x03000000
if (PyInt_Check(obj)) {
long v = PyInt_AsLong(obj);
if (v >= 0) {
if (val) *val = v;
return SWIG_OK;
} else {
return SWIG_OverflowError;
}
} else
#endif
if (PyLong_Check(obj)) {
unsigned long v = PyLong_AsUnsignedLong(obj);
if (!PyErr_Occurred()) {
if (val) *val = v;
return SWIG_OK;
} else {
PyErr_Clear();
return SWIG_OverflowError;
}
}
#ifdef SWIG_PYTHON_CAST_MODE
{
int dispatch = 0;
unsigned long v = PyLong_AsUnsignedLong(obj);
if (!PyErr_Occurred()) {
if (val) *val = v;
return SWIG_AddCast(SWIG_OK);
} else {
PyErr_Clear();
}
if (!dispatch) {
double d;
int res = SWIG_AddCast(SWIG_AsVal_double (obj,&d));
if (SWIG_IsOK(res) && SWIG_CanCastAsInteger(&d, 0, ULONG_MAX)) {
if (val) *val = (unsigned long)(d);
return res;
}
}
}
#endif
return SWIG_TypeError;
}


SWIGINTERN int
SWIG_AsVal_unsigned_SS_int (PyObject * obj, unsigned int *val)
{
unsigned long v;
int res = SWIG_AsVal_unsigned_SS_long (obj, &v);
if (SWIG_IsOK(res)) {
if ((v > UINT_MAX)) {
return SWIG_OverflowError;
} else {
if (val) *val = static_cast< unsigned int >(v);
}
}
return res;
}

SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromString(absl::string_view arg){
const auto _status = sentencepiece::SentencePieceTrainer::Train(arg);
if (!_status.ok()) throw _status;
Expand Down Expand Up @@ -4977,6 +5041,36 @@ SWIGINTERN PyObject *SentencePieceProcessor_swiginit(PyObject *SWIGUNUSEDPARM(se
return SWIG_Python_InitShadowInstance(args);
}

SWIGINTERN PyObject *_wrap_SetRandomGeneratorSeed(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
unsigned int arg1 ;
unsigned int val1 ;
int ecode1 = 0 ;
PyObject *swig_obj[1] ;

if (!args) SWIG_fail;
swig_obj[0] = args;
ecode1 = SWIG_AsVal_unsigned_SS_int(swig_obj[0], &val1);
if (!SWIG_IsOK(ecode1)) {
SWIG_exception_fail(SWIG_ArgError(ecode1), "in method '" "SetRandomGeneratorSeed" "', argument " "1"" of type '" "unsigned int""'");
}
arg1 = static_cast< unsigned int >(val1);
{
try {
sentencepiece::SetRandomGeneratorSeed(arg1);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
resultobj = SWIG_Py_Void();
return resultobj;
fail:
return NULL;
}


SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromString(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
absl::string_view arg1 ;
Expand Down Expand Up @@ -5307,6 +5401,7 @@ static PyMethodDef SwigMethods[] = {
{ "SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck", _wrap_SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck, METH_VARARGS, NULL},
{ "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL},
{ "SentencePieceProcessor_swiginit", SentencePieceProcessor_swiginit, METH_VARARGS, NULL},
{ "SetRandomGeneratorSeed", _wrap_SetRandomGeneratorSeed, METH_O, NULL},
{ "SentencePieceTrainer__TrainFromString", _wrap_SentencePieceTrainer__TrainFromString, METH_O, NULL},
{ "SentencePieceTrainer__TrainFromMap", _wrap_SentencePieceTrainer__TrainFromMap, METH_O, NULL},
{ "SentencePieceTrainer__TrainFromMap2", _wrap_SentencePieceTrainer__TrainFromMap2, METH_VARARGS, NULL},
Expand Down
3 changes: 3 additions & 0 deletions src/spm_encode_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ int main(int argc, char *argv[]) {
rest_args.push_back(absl::GetFlag(FLAGS_input));
}

if (absl::GetFlag(FLAGS_random_seed) != -1)
sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed));

if (rest_args.empty())
rest_args.push_back(""); // empty means that reading from stdin.

Expand Down
4 changes: 4 additions & 0 deletions src/spm_train_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ ABSL_FLAG(std::string, unk_surface, kDefaultTrainerSpec.unk_surface(),
ABSL_FLAG(bool, train_extremely_large_corpus,
kDefaultTrainerSpec.train_extremely_large_corpus(),
"Increase bit depth for unigram tokenization.");
ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator.");

int main(int argc, char *argv[]) {
sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true);
Expand All @@ -148,6 +149,9 @@ int main(int argc, char *argv[]) {
CHECK(!absl::GetFlag(FLAGS_input).empty());
CHECK(!absl::GetFlag(FLAGS_model_prefix).empty());

if (absl::GetFlag(FLAGS_random_seed) != -1)
sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed));

auto load_lines = [](absl::string_view filename) {
std::vector<std::string> lines;
auto input = sentencepiece::filesystem::NewReadableFile(filename);
Expand Down
10 changes: 6 additions & 4 deletions src/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ void SetRandomGeneratorSeed(unsigned int seed) {
if (seed != kDefaultSeed) g_seed = seed;
}

uint32 GetRandomGeneratorSeed() {
return g_seed == kDefaultSeed ? std::random_device{}() : g_seed;
}

namespace string_util {

// mblen sotres the number of bytes consumed after decoding.
Expand Down Expand Up @@ -153,8 +157,7 @@ class RandomGeneratorStorage {
std::mt19937 *Get() {
auto *result = static_cast<std::mt19937 *>(pthread_getspecific(key_));
if (result == nullptr) {
result = new std::mt19937(g_seed == kDefaultSeed ? std::random_device{}()
: g_seed);
result = new std::mt19937(GetRandomGeneratorSeed());
pthread_setspecific(key_, result);
}
return result;
Expand All @@ -172,8 +175,7 @@ std::mt19937 *GetRandomGenerator() {
}
#else
std::mt19937 *GetRandomGenerator() {
thread_local static std::mt19937 mt(
g_seed == kDefaultSeed ? std::random_device{}() : g_seed);
thread_local static std::mt19937 mt(GetRandomGeneratorSeed());
return &mt;
}
#endif
Expand Down
4 changes: 3 additions & 1 deletion src/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ std::ostream &operator<<(std::ostream &out, const std::vector<T> &v) {
return out;
}

uint32 GetRandomGeneratorSeed();

// String utilities
namespace string_util {

Expand Down Expand Up @@ -306,7 +308,7 @@ template <typename T>
class ReservoirSampler {
public:
explicit ReservoirSampler(std::vector<T> *sampled, size_t size)
: sampled_(sampled), size_(size), engine_(std::random_device{}()) {}
: sampled_(sampled), size_(size), engine_(GetRandomGeneratorSeed()) {}
explicit ReservoirSampler(std::vector<T> *sampled, size_t size, size_t seed)
: sampled_(sampled), size_(size), engine_(seed) {}
virtual ~ReservoirSampler() {}
Expand Down

0 comments on commit 5d79c12

Please sign in to comment.