Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thread local prngs #4331

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions src/crypto/SecretKey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,9 @@ namespace stellar
// makes all signature-verification in the program faster and
// has no effect on correctness.

static std::mutex gVerifySigCacheMutex;
static RandomEvictionCache<Hash, bool> gVerifySigCache(0xffff);
static uint64_t gVerifyCacheHit = 0;
static uint64_t gVerifyCacheMiss = 0;
static thread_local RandomEvictionCache<Hash, bool> gVerifySigCache(0xffff);
static thread_local uint64_t gVerifyCacheHit = 0;
static thread_local uint64_t gVerifyCacheMiss = 0;

static Hash
verifySigCacheKey(PublicKey const& key, Signature const& signature,
Expand Down Expand Up @@ -316,14 +315,12 @@ SecretKey::fromStrKeySeed(std::string const& strKeySeed)
void
PubKeyUtils::clearVerifySigCache()
{
std::lock_guard<std::mutex> guard(gVerifySigCacheMutex);
gVerifySigCache.clear();
}

void
PubKeyUtils::flushVerifySigCacheCounts(uint64_t& hits, uint64_t& misses)
{
std::lock_guard<std::mutex> guard(gVerifySigCacheMutex);
hits = gVerifyCacheHit;
misses = gVerifyCacheMiss;
gVerifyCacheHit = 0;
Expand Down Expand Up @@ -438,7 +435,6 @@ PubKeyUtils::verifySig(PublicKey const& key, Signature const& signature,
auto cacheKey = verifySigCacheKey(key, signature, bin);

{
std::lock_guard<std::mutex> guard(gVerifySigCacheMutex);
if (gVerifySigCache.exists(cacheKey))
{
++gVerifyCacheHit;
Expand All @@ -453,7 +449,6 @@ PubKeyUtils::verifySig(PublicKey const& key, Signature const& signature,
bool ok =
(crypto_sign_verify_detached(signature.data(), bin.data(), bin.size(),
key.ed25519().data()) == 0);
std::lock_guard<std::mutex> guard(gVerifySigCacheMutex);
++gVerifyCacheMiss;
gVerifySigCache.put(cacheKey, ok);
return ok;
Expand Down
10 changes: 4 additions & 6 deletions src/herder/HerderImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1788,15 +1788,13 @@ HerderImpl::checkAndMaybeReanalyzeQuorumMap()
mLastQuorumMapIntersectionState.mCheckingQuorumMapHash = curr;
auto& cfg = mApp.getConfig();
releaseAssert(threadIsMain());
auto seed = gRandomEngine();
auto qic = QuorumIntersectionChecker::create(
qmap, cfg, mLastQuorumMapIntersectionState.mInterruptFlag, seed);
qmap, cfg, mLastQuorumMapIntersectionState.mInterruptFlag);
auto ledger = trackingConsensusLedgerIndex();
auto nNodes = qmap.size();
auto& hState = mLastQuorumMapIntersectionState;
auto& app = mApp;
auto worker = [curr, ledger, nNodes, qic, qmap, cfg, seed, &app,
&hState] {
auto worker = [curr, ledger, nNodes, qic, qmap, cfg, &app, &hState] {
try
{
ZoneScoped;
Expand All @@ -1809,8 +1807,8 @@ HerderImpl::checkAndMaybeReanalyzeQuorumMap()
// intersecting; if not intersecting we should finish ASAP
// and raise an alarm.
critical = QuorumIntersectionChecker::
getIntersectionCriticalGroups(
qmap, cfg, hState.mInterruptFlag, seed);
getIntersectionCriticalGroups(qmap, cfg,
hState.mInterruptFlag);
}
app.postOnMainThread(
[ok, curr, ledger, nNodes, split, critical, &hState] {
Expand Down
27 changes: 12 additions & 15 deletions src/herder/QuorumIntersectionChecker.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,21 @@ class QuorumIntersectionChecker
static std::shared_ptr<QuorumIntersectionChecker>
create(QuorumTracker::QuorumMap const& qmap,
std::optional<stellar::Config> const& cfg,
std::atomic<bool>& interruptFlag,
stellar_default_random_engine::result_type seed, bool quiet = false);
std::atomic<bool>& interruptFlag, bool quiet = false);

static std::shared_ptr<QuorumIntersectionChecker>
create(QuorumSetMap const& qmap, std::optional<stellar::Config> const& cfg,
std::atomic<bool>& interruptFlag,
stellar_default_random_engine::result_type seed, bool quiet = false);

static std::set<std::set<NodeID>> getIntersectionCriticalGroups(
QuorumTracker::QuorumMap const& qmap,
std::optional<stellar::Config> const& cfg,
std::atomic<bool>& interruptFlag,
stellar_default_random_engine::result_type seed);

static std::set<std::set<NodeID>> getIntersectionCriticalGroups(
QuorumSetMap const& qmap, std::optional<stellar::Config> const& cfg,
std::atomic<bool>& interruptFlag,
stellar_default_random_engine::result_type seed);
std::atomic<bool>& interruptFlag, bool quiet = false);

static std::set<std::set<NodeID>>
getIntersectionCriticalGroups(QuorumTracker::QuorumMap const& qmap,
std::optional<stellar::Config> const& cfg,
std::atomic<bool>& interruptFlag);

static std::set<std::set<NodeID>>
getIntersectionCriticalGroups(QuorumSetMap const& qmap,
std::optional<stellar::Config> const& cfg,
std::atomic<bool>& interruptFlag);

virtual ~QuorumIntersectionChecker(){};
virtual bool networkEnjoysQuorumIntersection() const = 0;
Expand Down
43 changes: 18 additions & 25 deletions src/herder/QuorumIntersectionCheckerImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ QBitSet::getSuccessors(BitSet const& nodes, QGraph const& inner)

// Slightly tweaked variant of Lachowski's next-node function.
size_t
MinQuorumEnumerator::pickSplitNode(
stellar::stellar_default_random_engine& randEngine) const
MinQuorumEnumerator::pickSplitNode() const
{
std::vector<size_t>& inDegrees = mQic.mInDegrees;
inDegrees.assign(mQic.mGraph.size(), 0);
Expand All @@ -84,7 +83,7 @@ MinQuorumEnumerator::pickSplitNode(
// currDegree same as existing max: replace it
// only probabilistically.
maxCount++;
if (rand_uniform<size_t>(0, maxCount, randEngine) == 0)
if (rand_uniform<size_t>(0, maxCount, gRandomEngine) == 0)
{
// Not switching max element with max degree.
continue;
Expand Down Expand Up @@ -237,7 +236,7 @@ MinQuorumEnumerator::anyMinQuorumHasDisjointQuorum()
}

// Phase two: recurse into subproblems.
size_t split = pickSplitNode(mQic.mRand);
size_t split = pickSplitNode();
if (mQic.mLogTrace)
{
CLOG_TRACE(SCP, "recursing into subproblems, split={}", split);
Expand Down Expand Up @@ -269,14 +268,13 @@ MinQuorumEnumerator::anyMinQuorumHasDisjointQuorum()
QuorumIntersectionCheckerImpl::QuorumIntersectionCheckerImpl(
QuorumIntersectionChecker::QuorumSetMap const& qmap,
std::optional<Config> const& cfg, std::atomic<bool>& interruptFlag,
stellar_default_random_engine::result_type seed, bool quiet)
bool quiet)
: mCfg(cfg)
, mLogTrace(Logging::logTrace("SCP"))
, mQuiet(quiet)
, mTSC()
, mInterruptFlag(interruptFlag)
, mCachedQuorums(MAX_CACHED_QUORUMS_SIZE)
, mRand(seed)
{
buildGraph(qmap);
// Awkwardly, the graph size is zero when we initialize mTSC. Update it
Expand Down Expand Up @@ -817,40 +815,35 @@ toQuorumIntersectionMap(QuorumTracker::QuorumMap const& qmap)
namespace stellar
{
std::shared_ptr<QuorumIntersectionChecker>
QuorumIntersectionChecker::create(
QuorumTracker::QuorumMap const& qmap, std::optional<Config> const& cfg,
std::atomic<bool>& interruptFlag,
stellar_default_random_engine::result_type seed, bool quiet)
QuorumIntersectionChecker::create(QuorumTracker::QuorumMap const& qmap,
std::optional<Config> const& cfg,
std::atomic<bool>& interruptFlag, bool quiet)
{
return create(toQuorumIntersectionMap(qmap), cfg, interruptFlag, seed,
quiet);
return create(toQuorumIntersectionMap(qmap), cfg, interruptFlag, quiet);
}

std::shared_ptr<QuorumIntersectionChecker>
QuorumIntersectionChecker::create(
QuorumSetMap const& qmap, std::optional<Config> const& cfg,
std::atomic<bool>& interruptFlag,
stellar_default_random_engine::result_type seed, bool quiet)
QuorumIntersectionChecker::create(QuorumSetMap const& qmap,
std::optional<Config> const& cfg,
std::atomic<bool>& interruptFlag, bool quiet)
{
return std::make_shared<QuorumIntersectionCheckerImpl>(
qmap, cfg, interruptFlag, seed, quiet);
qmap, cfg, interruptFlag, quiet);
}

std::set<std::set<NodeID>>
QuorumIntersectionChecker::getIntersectionCriticalGroups(
QuorumTracker::QuorumMap const& qmap, std::optional<Config> const& cfg,
std::atomic<bool>& interruptFlag,
stellar_default_random_engine::result_type seed)
std::atomic<bool>& interruptFlag)
{
return getIntersectionCriticalGroups(toQuorumIntersectionMap(qmap), cfg,
interruptFlag, seed);
interruptFlag);
}

std::set<std::set<NodeID>>
QuorumIntersectionChecker::getIntersectionCriticalGroups(
QuorumSetMap const& qmap, std::optional<Config> const& cfg,
std::atomic<bool>& interruptFlag,
stellar_default_random_engine::result_type seed)
std::atomic<bool>& interruptFlag)
{
// We're going to search for "intersection-critical" groups, by considering
// each SCPQuorumSet S that (a) has no innerSets of its own and (b) occurs
Expand Down Expand Up @@ -936,9 +929,9 @@ QuorumIntersectionChecker::getIntersectionCriticalGroups(
}

// Check to see if this modified config is vulnerable to splitting.
auto checker = QuorumIntersectionChecker::create(test_qmap, cfg,
interruptFlag, seed,
/*quiet=*/true);
auto checker =
QuorumIntersectionChecker::create(test_qmap, cfg, interruptFlag,
/*quiet=*/true);
if (checker->networkEnjoysQuorumIntersection())
{
CLOG_DEBUG(SCP,
Expand Down
7 changes: 2 additions & 5 deletions src/herder/QuorumIntersectionCheckerImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,7 @@ class MinQuorumEnumerator
QuorumIntersectionCheckerImpl const& mQic;

// Select the next node in mRemaining to split recursive cases between.
size_t
pickSplitNode(stellar::stellar_default_random_engine& randEngine) const;
size_t pickSplitNode() const;

// Size limit for mCommitted beyond which we should stop scanning.
size_t maxCommit() const;
Expand Down Expand Up @@ -534,9 +533,7 @@ class QuorumIntersectionCheckerImpl : public stellar::QuorumIntersectionChecker
QuorumIntersectionCheckerImpl(
stellar::QuorumIntersectionChecker::QuorumSetMap const& qmap,
std::optional<stellar::Config> const& cfg,
std::atomic<bool>& interruptFlag,
stellar::stellar_default_random_engine::result_type seed,
bool quiet = false);
std::atomic<bool>& interruptFlag, bool quiet = false);
bool networkEnjoysQuorumIntersection() const override;

std::pair<std::vector<stellar::NodeID>, std::vector<stellar::NodeID>>
Expand Down
16 changes: 7 additions & 9 deletions src/herder/test/QuorumIntersectionTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -882,8 +882,7 @@ TEST_CASE("quorum intersection interruption", "[herder][quorumintersection]")
Config cfg(getTestConfig());
cfg = configureShortNames(cfg, orgs);
std::atomic<bool> interruptFlag{false};
auto qic = QuorumIntersectionChecker::create(qm, cfg, interruptFlag,
gRandomEngine());
auto qic = QuorumIntersectionChecker::create(qm, cfg, interruptFlag);
std::thread canceller([&interruptFlag]() {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
interruptFlag = true;
Expand All @@ -897,9 +896,9 @@ TEST_CASE("quorum intersection interruption", "[herder][quorumintersection]")
std::this_thread::sleep_for(std::chrono::milliseconds(100));
interruptFlag = true;
});
REQUIRE_THROWS_AS(qic->getIntersectionCriticalGroups(qm, cfg, interruptFlag,
gRandomEngine()),
QuorumIntersectionChecker::InterruptedException);
REQUIRE_THROWS_AS(
qic->getIntersectionCriticalGroups(qm, cfg, interruptFlag),
QuorumIntersectionChecker::InterruptedException);
canceller2.join();
}

Expand Down Expand Up @@ -982,12 +981,11 @@ TEST_CASE("quorum intersection criticality",
cfg = configureShortNames(cfg, orgs);
debugQmap(cfg, qm);
std::atomic<bool> flag{false};
auto qic =
QuorumIntersectionChecker::create(qm, cfg, flag, gRandomEngine());
auto qic = QuorumIntersectionChecker::create(qm, cfg, flag);
REQUIRE(qic->networkEnjoysQuorumIntersection());

auto groups = QuorumIntersectionChecker::getIntersectionCriticalGroups(
qm, cfg, flag, gRandomEngine());
auto groups =
QuorumIntersectionChecker::getIntersectionCriticalGroups(qm, cfg, flag);
REQUIRE(groups.size() == 1);
REQUIRE(groups == std::set<std::set<PublicKey>>{{orgs[3][0]}});
}
Expand Down
44 changes: 35 additions & 9 deletions src/util/Math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
#include <autocheck/generator.hpp>
#include <catch.hpp>
#include <cmath>
#include <mutex>
#include <numeric>
#include <set>

namespace stellar
{

stellar_default_random_engine gRandomEngine;
thread_local stellar_default_random_engine
gRandomEngine(getLastGlobalStateSeed());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, it looks like it's very likely non-main threads will all get seeded with the same initial seed (unless reinitializeAllGlobalStateWithSeedInternal is called in between the threads spinning up). Is this an issue? I can imagine two potential issues:

  1. Attacker observes a PRNG value on thread 0 and now has advance knowledge of the next PRNG value on thread 1. (seems highly unlikely)
  2. Some work is being done. Part 1 of the work occurs on thread 0 and calls randomGenerate for some value. Part 2 of the work occurs on thread 1 and calls randomGenerate again, but receives the same value as before. Could this break an assumption that the work receives two different PRNG values?

Or am I like super overthinking this since it's PRNG and not RNG anyway...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not too worried about this for production: because we don't spin up threads on demand, fairly quickly threads will diverge (under the assumption that we generate a fairly good number of random numbers).

For tests, it's a different story: tests tend to spin up some "app" that will spin up threads, making it that now threads will be seeded the same way and will basically behave the exact same way in the context of that test. This has the potential to hide a large class of bugs.

A way to mitigate all those issues may just be to make seeds between threads different (but deterministic): when spinning up a new thread, generate a new seed from the main thread and use that seed to initialize the thread.

I guess this may be a new thing: until recently prngs were not really impacted by the scheduling of threads, now, I am not sure what the story is (and having TLS may be the way to get rid of the dependency with the OS scheduler).

std::uniform_real_distribution<double> uniformFractionDistribution(0.0, 1.0);

double
Expand Down Expand Up @@ -169,23 +171,52 @@ k_means(std::vector<double> const& points, uint32_t k)
return centroids;
}

static std::mutex gGlobalSeedMutex;
static unsigned int lastGlobalSeed{0};

unsigned int
getLastGlobalStateSeed()
{
std::lock_guard<std::mutex> guard(gGlobalSeedMutex);
return lastGlobalSeed;
}

static void
reinitializeAllGlobalStateWithSeedInternal(unsigned int seed)
{
lastGlobalSeed = seed;
{
std::lock_guard<std::mutex> guard(gGlobalSeedMutex);
lastGlobalSeed = seed;
}
PubKeyUtils::clearVerifySigCache();
srand(seed);

// gRandomEngine is a thread_local, and is initialized / initially seeded
// with the result of calling getLastGlobalStateSeed(). This means that the
// main thread's gRandomEngine will be initialized with the initial value of
// lastGlobalSeed: 0. So we _reseed_ it here. But other non-main threads
// will initialize their thread_local gRandomEngine instances as the threads
// are launched, which should happen _after_ we've set lastGlobalSeed to
// some nontrivial value (either test-provided, user-provded or from the
// system PRNG), so they will pick up that nontrivial value and we don't
// need to "reseed" them explicitly the same way.
assert(threadIsMain());
gRandomEngine.seed(seed);

randHash::initialize();
}

void
initializeAllGlobalState()
{
releaseAssert(lastGlobalSeed == 0);
auto const seed = static_cast<unsigned int>(
std::chrono::system_clock::now().time_since_epoch().count());
// libstdc++ and libc++ accept and use "/dev/urandom" as the token
// identifying the system _nonblocking_ random number generator: we're not
// after strong cryptographic randomness here, just nonblocking best-effort.
//
// MSVC ignores this parameter and calls into the RtlGenRandom /
// CryptGenRandom complex (depending on Windows version) to get the seed.
auto const seed = std::random_device("/dev/urandom")();
reinitializeAllGlobalStateWithSeedInternal(seed);
// shortHash needs to be initialized with a strong random seed
shortHash::initialize();
Expand All @@ -203,10 +234,5 @@ reinitializeAllGlobalStateWithSeed(unsigned int seed)
autocheck::rng().seed(seed);
}

unsigned int
getLastGlobalStateSeed()
{
return lastGlobalSeed;
}
#endif
}
5 changes: 3 additions & 2 deletions src/util/Math.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ bool rand_flip();

typedef std::minstd_rand stellar_default_random_engine;

extern stellar_default_random_engine gRandomEngine;
extern thread_local stellar_default_random_engine gRandomEngine;

template <typename T>
T
Expand Down Expand Up @@ -76,7 +76,8 @@ void initializeAllGlobalState();
// shouldn't be resetting globals like this mid-run -- especially not things
// like hash function keys.
void reinitializeAllGlobalStateWithSeed(unsigned int seed);
unsigned int getLastGlobalStateSeed();
#endif

unsigned int getLastGlobalStateSeed();

}
Loading