Skip to content

Commit

Permalink
Merge branch 'master' into maddft_newexchange
Browse files Browse the repository at this point in the history
merge test_dc
  • Loading branch information
“hborchert” committed Aug 7, 2024
2 parents 9334e01 + c829e8b commit c4a0b26
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 8 deletions.
156 changes: 149 additions & 7 deletions src/madness/world/test_dc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,18 @@
fax: 865-572-0680
*/

//#define MAD_ARCHIVE_DEBUG_ENABLE

#include <algorithm>

#include <madness/world/MADworld.h>
#include <madness/world/worlddc.h>
#include <madness/world/worldmutex.h>
#include <madness/world/atomicint.h>

#include <madness/world/vector_archive.h>
#include <madness/world/parallel_archive.h>

using namespace madness;
using namespace std;

Expand Down Expand Up @@ -240,16 +248,150 @@ void test_local(World& world) {

}

namespace madness {
namespace archive {
/// Write container to parallel archive
template <class keyT, class valueT>
struct ArchiveStoreImpl< ParallelOutputArchive<VectorOutputArchive>, WorldContainer<keyT,valueT> > {
static void store(const ParallelOutputArchive<VectorOutputArchive>& ar, const WorldContainer<keyT,valueT>& t) {
using localarchiveT = VectorOutputArchive;
const long magic = -5881828; // Sitar Indian restaurant in Knoxville (negative to indicate parallel!)
typedef WorldContainer<keyT,valueT> dcT;
using const_iterator = typename dcT::const_iterator;

const size_t default_size = 100*1024*1024;

World* world = ar.get_world();
world->gop.fence();

std::vector<unsigned char> v;
v.reserve(default_size);
size_t count = 0;

class op : public TaskInterface {
const size_t ntasks;
const size_t taskid;
const dcT& t;
std::vector<unsigned char>& vtotal;
size_t& total_count;
Mutex& mutex;

public:
op(size_t ntasks, size_t taskid, const dcT& t, std::vector<unsigned char>& vtotal, size_t& total_count, Mutex& mutex)
: ntasks(ntasks), taskid(taskid), t(t), vtotal(vtotal), total_count(total_count), mutex(mutex) {}
void run(World& world) {
std::vector<unsigned char> v;
v.reserve(std::max(size_t(1024*1024),vtotal.capacity()/ntasks));
VectorOutputArchive var(v);
const_iterator it=t.begin();
size_t count = 0;
size_t n = 0;
while (it!=t.end()) {
if ((n%ntasks) == taskid) {
var & *it;
++count;
}
++it;
n++;
}

if (count) {
mutex.lock();
vtotal.insert(vtotal.end(), v.begin(), v.end());
total_count += count;
mutex.unlock();
}
}
};

Mutex mutex;
size_t ntasks = std::max(size_t(1), ThreadPool::size());
for (size_t taskid=0; taskid<ntasks; taskid++)
world->taskq.add(new op(ntasks, taskid, t, v, count, mutex));
world->gop.fence();

// Gather all buffers to process 0
// first gather all of the sizes and counts to a vector in process 0
int size = v.size();
std::vector<int> sizes(world->size());
MPI_Gather(&size, 1, MPI_INT, sizes.data(), 1, MPI_INT, 0, world->mpi.comm().Get_mpi_comm());
world->gop.sum(count); // just need total number of elements

// build the cumulative sum of sizes
std::vector<int> offsets(world->size());
offsets[0] = 0;
for (int i=1; i<world->size(); ++i) offsets[i] = offsets[i-1] + sizes[i-1];
int total_size = offsets.back() + sizes.back();

// gather the vector of data v from each process to process 0
unsigned char* all_data=0;
if (world->rank() == 0) {
all_data = new unsigned char[total_size];
}
MPI_Gatherv(v.data(), v.size(), MPI_BYTE, all_data, sizes.data(), offsets.data(), MPI_BYTE, 0, world->mpi.comm().Get_mpi_comm());

if (world->rank() == 0) {
auto& localar = ar.local_archive();
localar & magic & 1; // 1 client
// localar & t;
ArchivePrePostImpl<localarchiveT,dcT>::preamble_store(localar);
localar & -magic & count;
localar.store(all_data, total_size);
ArchivePrePostImpl<localarchiveT,dcT>::postamble_store(localar);

delete[] all_data;
}
world->gop.fence();
}
};
}
}

void test_florian(World& world) {
WorldContainer<Key,Node> c(world);

Key key1(1);
Node node1(1);

if (world.rank() == 0) {
for (int i=0; i<100; ++i) {
c.replace(Key(i),Node(i));
}
}
world.gop.fence();

std::vector<unsigned char> v;
{
archive::VectorOutputArchive var(v);
archive::ParallelOutputArchive ar(world,var);
ar & c;
}

WorldContainer<Key,Node> c2(world);
{
archive::VectorInputArchive var2(v);
archive::ParallelInputArchive ar2(world,var2);
ar2 & c2;
}

for (int i=0; i<100; ++i) {
MADNESS_CHECK(c2.find(Key(i)).get()->second.get() == i);
}

world.gop.fence();
print("test_florian passed");
}

int main(int argc, char** argv) {
initialize(argc, argv);
World world(SafeMPI::COMM_WORLD);

try {
test0(world);
test1(world);
test1(world);
test1(world);
test_local(world);
World& world = initialize(argc, argv);
// test0(world);
// test1(world);
// test1(world);
// test1(world);
// test_local(world);
test_florian(world);
}
catch (const SafeMPI::Exception& e) {
error("caught an MPI exception");
Expand Down
2 changes: 1 addition & 1 deletion src/madness/world/worldgop.h
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ namespace madness {
auto buf0 = std::unique_ptr<T[]>(new T[nelem_per_maxmsg]);
auto buf1 = std::unique_ptr<T[]>(new T[nelem_per_maxmsg]);

auto reduce_impl = [&,this](T* buf, int nelem) {
auto reduce_impl = [&,this](T* buf, size_t nelem) {
MADNESS_ASSERT(nelem <= nelem_per_maxmsg);
SafeMPI::Request req0, req1;
Tag gsum_tag = world_.mpi.unique_tag();
Expand Down

0 comments on commit c4a0b26

Please sign in to comment.