Skip to content

Commit

Permalink
Merge pull request #832 from RossBrunton/traitusm
Browse files Browse the repository at this point in the history
[NFC] USM test device support refactored
  • Loading branch information
bader authored Nov 17, 2023
2 parents bb20fb6 + bd3c9bb commit 0f7df58
Showing 1 changed file with 29 additions and 47 deletions.
76 changes: 29 additions & 47 deletions tests/usm/usm_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,35 @@ void check_values(const T *begin, const ReferenceIt &reference) {
}
}

/** @brief Trait for tests that have no additional device requirements
*
* Tests implementing this trait do not require any device support beyond
* the destination memory location.
*/
struct noAdditionalDeviceRequirements {
static bool supports_device(sycl_cts::util::logger&, const sycl::queue&) {
return true;
}
};

/** @brief Trait for tests that require support for a given USM alloc type
*
* Tests implementing this trait require device support for a specific
* type of allocation (e.g. as a source memory location).
*/
template <allocation allocationType>
struct requiresUsmAllocationSupport {
static bool supports_device(sycl_cts::util::logger& log,
const sycl::queue& queue) {
return check_device_support<allocationType>(log, queue);
}
};

/** @brief Provides generic copy test logic for memcpy() and copy() tests
* @tparam sourceAllocation Allocation type to use for data source
*/
template <typename T, size_t count, allocation sourceAllocation>
class copyGeneric {
class copyGeneric : public requiresUsmAllocationSupport<sourceAllocation> {
protected:
using storage_t = storage<T, count, sourceAllocation>;
const typename storage_t::type source;
Expand Down Expand Up @@ -251,13 +275,6 @@ class copy : public copyGeneric<T, count, sourceAllocation> {
*/
static constexpr bool has_non_usm_support() { return true; }

/** @brief This test only works on devices with `sourceAllocation` support
*/
static bool supports_device(sycl_cts::util::logger& log,
const sycl::queue& queue) {
return check_device_support<sourceAllocation>(log, queue);
}

template <allocation alloc>
static std::string description() {
return "copy from " + get_allocation_decription<sourceAllocation>() +
Expand All @@ -283,13 +300,6 @@ class memcpy : public copyGeneric<T, count, sourceAllocation> {
*/
static constexpr bool has_non_usm_support() { return true; }

/** @brief This test only works on devices with `sourceAllocation` support
*/
static bool supports_device(sycl_cts::util::logger& log,
const sycl::queue& queue) {
return check_device_support<sourceAllocation>(log, queue);
}

template <allocation alloc>
static std::string description() {
return "memcpy from " + get_allocation_decription<sourceAllocation>() +
Expand Down Expand Up @@ -326,7 +336,7 @@ using memcpy_from_shared = detail::memcpy<T, count, allocation::shared>;
/** @brief Provides test logic for the fill() member function tests
*/
template <typename T, size_t count>
class fill {
class fill : public detail::noAdditionalDeviceRequirements {
const T value;

public:
Expand All @@ -336,13 +346,6 @@ class fill {
*/
static constexpr bool has_non_usm_support() { return false; }

/** @brief This test doesn't have any device requirements
*/
static bool supports_device(sycl_cts::util::logger& log,
const sycl::queue& queue) {
return true;
}

template <allocation alloc>
static std::string description() {
return "fill using " + get_allocation_decription<alloc>();
Expand All @@ -364,7 +367,7 @@ class fill {
/** @brief Provides test logic for the memset() member function tests
*/
template <typename T, size_t count>
class memset {
class memset : public detail::noAdditionalDeviceRequirements {
const int value;

public:
Expand All @@ -376,13 +379,6 @@ class memset {
*/
static constexpr bool has_non_usm_support() { return false; }

/** @brief This test doesn't have any device requirements
*/
static bool supports_device(sycl_cts::util::logger& log,
const sycl::queue& queue) {
return true;
}

template <allocation alloc>
static std::string description() {
return "memset using " + get_allocation_decription<alloc>();
Expand All @@ -404,7 +400,7 @@ class memset {
/** @brief Provides test logic for the prefetch() member function tests
*/
template <typename T, size_t count>
class prefetch {
class prefetch : public detail::noAdditionalDeviceRequirements {
public:
static constexpr size_t size = count * sizeof(T);

Expand All @@ -414,13 +410,6 @@ class prefetch {
*/
static constexpr bool has_non_usm_support() { return false; }

/** @brief This test doesn't have any device requirements
*/
static bool supports_device(sycl_cts::util::logger& log,
const sycl::queue& queue) {
return true;
}

template <allocation alloc>
static std::string description() {
return "prefetch using " + get_allocation_decription<alloc>();
Expand All @@ -439,7 +428,7 @@ class prefetch {
/** @brief Provides test logic for the mem_advise() member function tests
*/
template <typename T, size_t count>
class mem_advise {
class mem_advise : public detail::noAdditionalDeviceRequirements {
public:
static constexpr size_t size = count * sizeof(T);

Expand All @@ -449,13 +438,6 @@ class mem_advise {
*/
static constexpr bool has_non_usm_support() { return false; }

/** @brief This test doesn't have any device requirements
*/
static bool supports_device(sycl_cts::util::logger& log,
const sycl::queue& queue) {
return true;
}

template <allocation alloc>
static std::string description() {
return "mem_advise using " + get_allocation_decription<alloc>();
Expand Down

0 comments on commit 0f7df58

Please sign in to comment.