Skip to content

Commit

Permalink
Merge pull request #99 from comparch-security/simplify-set-lock
Browse files Browse the repository at this point in the history
simplify cache set lock
  • Loading branch information
wsong83 authored Jun 27, 2024
2 parents d473a32 + 6a76722 commit 3c8bb13
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 226 deletions.
3 changes: 3 additions & 0 deletions cache/cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ class CacheBase : public CacheMonitorSupport
virtual void meta_return_buffer(CMMetadataBase *buf) = 0; // return a copy buffer, used to detect conflicts in copy buffer
__always_inline void lock_line(uint32_t ai, uint32_t s, uint32_t w) { access(ai, s, w)->lock(); }
__always_inline void unlock_line(uint32_t ai, uint32_t s, uint32_t w) { access(ai, s, w)->unlock(); }
__always_inline void set_mt_state(uint32_t ai, uint32_t s, uint16_t prio) { arrays[ai]->set_mt_state(s, prio); }
__always_inline void check_mt_state(uint32_t ai, uint32_t s, uint16_t prio) { arrays[ai]->check_mt_state(s, prio); }
__always_inline void reset_mt_state(uint32_t ai, uint32_t s, uint16_t prio) { arrays[ai]->reset_mt_state(s, prio); }

virtual std::tuple<int, int, int> size() const = 0; // return the size parameters of the cache
uint32_t get_id() const { return id; }
Expand Down
130 changes: 5 additions & 125 deletions cache/cache_multi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,85 +2,6 @@
#define CM_CACHE_CACHE_MULTI_HPP

#include "cache/cache.hpp"
#include <mutex>
#include <condition_variable>
#include <tuple>

// Multi-thread support for Cache Array
class CacheArrayMultiThreadSupport
{
public:
virtual std::vector<uint32_t> *get_status() = 0;
virtual std::mutex* get_mutex(uint32_t s) = 0; // get set mutex
virtual std::mutex* get_cacheline_mutex(uint32_t s, uint32_t w) = 0; // get cacheline mutex
virtual std::condition_variable* get_cv(uint32_t s) = 0; // get set cv
};

// Multi-thread Cache Array
// IW: index width, NW: number of ways, MT: metadata type, DT: data type (void if not in use)
template<int IW, int NW, typename MT, typename DT>
requires C_DERIVE<MT, CMMetadataCommon>
&& C_DERIVE_OR_VOID<DT, CMDataBase>
class CacheArrayMultiThread : public CacheArrayNorm<IW, NW, MT, DT, true>,
public CacheArrayMultiThreadSupport
{

typedef CacheArrayNorm<IW, NW, MT, DT, true> CacheAT;
protected:
std::vector<uint32_t> status; // record every set status
std::vector<std::mutex *> status_mtxs; // mutex for status
std::vector<std::mutex *> mutexs; // mutex array for meta
std::vector<std::condition_variable *> cvs; // cv array, used in conjunction with mutexes

public:
using CacheAT::nset;
using CacheAT::way_num;
CacheArrayMultiThread(unsigned int extra_way = 0, std::string name = "") : CacheAT(extra_way, name){
size_t meta_num = nset * way_num;
status.resize(nset);
for(uint32_t i = 0; i < nset; i++) status[i] = 0;

mutexs.resize(meta_num);
for(auto &t:mutexs) t = new std::mutex();

status_mtxs.resize(nset);
for(auto &s : status_mtxs) s = new std::mutex();

cvs.resize(nset);
for(auto &c : cvs) c = new std::condition_variable();
}

virtual ~CacheArrayMultiThread(){
for(auto t: mutexs) delete t;
for(auto s : status_mtxs) delete s;
for(auto c : cvs) delete c;
}

virtual std::mutex* get_cacheline_mutex(uint32_t s, uint32_t w) { return mutexs[s*(way_num) + w]; }

virtual std::vector<uint32_t> *get_status(){ return &status; }
virtual std::mutex* get_mutex(uint32_t s) { return status_mtxs[s]; }
virtual std::condition_variable* get_cv(uint32_t s) { return cvs[s]; }
};


// Multi-thread support for CacheBase
class CacheBaseMultiThreadSupport
{
public:
virtual std::vector<uint32_t> *get_status(uint32_t ai) = 0;
virtual std::mutex* get_mutex(uint32_t ai, uint32_t s) = 0;
virtual std::condition_variable* get_cv(uint32_t ai, uint32_t s) = 0;
virtual std::mutex* get_cacheline_mutex(uint32_t ai, uint32_t s, uint32_t w) = 0;

// get set's status, mutex and cv in one function call
virtual std::tuple<std::vector<uint32_t> *, std::mutex*, std::condition_variable*>
get_set_control(uint32_t ai, uint32_t s) = 0;

virtual bool hit(uint64_t addr, uint32_t *ai, uint32_t *s, uint32_t *w,
uint16_t priority, bool need_replace = false) = 0;
};


// Multi-thread Skewed Cache
// IW: index width, NW: number of ways, P: number of partitions
Expand All @@ -91,11 +12,10 @@ class CacheBaseMultiThreadSupport
template<int IW, int NW, int P, typename MT, typename DT, typename IDX, typename RPC, typename DLY, bool EnMon, bool EF = true>
requires C_DERIVE<MT, CMMetadataBase> && C_DERIVE_OR_VOID<DT, CMDataBase> && C_DERIVE<IDX, IndexFuncBase> &&
C_DERIVE_OR_VOID<DLY, DelayBase>
class CacheSkewedMultiThread : public CacheSkewed<IW, NW, P, MT, DT, IDX, RPC, DLY, EnMon, EF, true>,
public CacheBaseMultiThreadSupport
class CacheSkewedMultiThread : public CacheSkewed<IW, NW, P, MT, DT, IDX, RPC, DLY, EnMon, EF, true>
{
typedef CacheSkewed<IW, NW, P, MT, DT, IDX, RPC, DLY, EnMon, EF, true> CacheT;
typedef CacheArrayMultiThread<IW, NW, MT, DT> CacheAT;
typedef CacheArrayNorm<IW, NW, MT, DT, true> CacheAT;

protected:
using CacheT::arrays;
Expand All @@ -104,34 +24,7 @@ class CacheSkewedMultiThread : public CacheSkewed<IW, NW, P, MT, DT, IDX, RPC, D
using CacheT::replace;
public:
CacheSkewedMultiThread(std::string name = "", unsigned int extra_par = 0, unsigned int extra_way = 0)
: CacheT(name, extra_par, extra_way)
{
for(int i=0; i<P; i++) {
delete arrays[i];
arrays[i] = new CacheAT(extra_way);
}
}

virtual std::vector<uint32_t> *get_status(uint32_t ai){
return (static_cast<CacheAT*>(arrays[ai]))->get_status();
}
virtual std::mutex* get_mutex(uint32_t ai, uint32_t s){
return (static_cast<CacheAT*>(arrays[ai]))->get_mutex(s);
}
virtual std::condition_variable* get_cv(uint32_t ai, uint32_t s) {
return (static_cast<CacheAT*>(arrays[ai]))->get_cv(s);
}

virtual std::tuple<std::vector<uint32_t> *, std::mutex*, std::condition_variable*>
get_set_control(uint32_t ai, uint32_t s)
{
return std::make_tuple(get_status(ai), get_mutex(ai, s), get_cv(ai, s));
}

virtual std::mutex* get_cacheline_mutex(uint32_t ai, uint32_t s, uint32_t w){
return (static_cast<CacheAT*>(arrays[ai]))->get_cacheline_mutex(s, w);
}

: CacheT(name, extra_par, extra_way) {}

virtual bool hit(uint64_t addr, uint32_t *ai, uint32_t *s, uint32_t *w,
uint16_t priority, bool need_replace = false)
Expand All @@ -144,14 +37,7 @@ class CacheSkewedMultiThread : public CacheSkewed<IW, NW, P, MT, DT, IDX, RPC, D
bool hit = false;
for(*ai = 0; *ai < P; (*ai)++){
*s = indexer.index(addr, *ai);
uint32_t idx = *s;
auto [status, mtx, cv] = get_set_control(*ai, *s);
std::unique_lock lk(*mtx, std::defer_lock);
lk.lock();
/** Wait until the high priority thread ends (lower the priority of the set) */
cv->wait(lk, [idx, status, priority] { return ((*status)[idx] < priority);} );
(*status)[*s] |= priority;
lk.unlock();
this->set_mt_state(*ai, *s, priority);

for(*w = 0; *w < NW; (*w)++){
if(access(*ai, *s, *w)->match(addr)) { hit = true; break;}
Expand All @@ -164,13 +50,7 @@ class CacheSkewedMultiThread : public CacheSkewed<IW, NW, P, MT, DT, IDX, RPC, D
/** if don't replace, then *ai=P, else if replace occurs, then 0<=(*ai)< P */
for(uint32_t i = 0; i < P; i++){
if(i != *ai){
uint32_t s = indexer.index(addr, i);
auto [status, mtx, cv] = get_set_control(i, s);
std::unique_lock lk(*mtx, std::defer_lock);
lk.lock();
(*status)[s] &= ~(priority);
lk.unlock();
cv->notify_all();
this->reset_mt_state(i, indexer.index(addr, i), priority);
}
}
return hit;
Expand Down
Loading

0 comments on commit 3c8bb13

Please sign in to comment.