diff --git a/cmake/FindDPCPP.cmake b/cmake/FindDPCPP.cmake index 41d853d4e1..2954d259d6 100644 --- a/cmake/FindDPCPP.cmake +++ b/cmake/FindDPCPP.cmake @@ -38,7 +38,7 @@ find_library(DPCPP_LIB_DIR NAMES sycl sycl6 PATHS "${DPCPP_BIN_DIR}/../lib") add_library(DPCPP::DPCPP INTERFACE IMPORTED) -set(DPCPP_FLAGS "-fsycl;-mllvm;-enable-global-offset=false;") +set(DPCPP_FLAGS "-fsycl;") if(NOT "${DPCPP_SYCL_TARGET}" STREQUAL "") list(APPEND DPCPP_FLAGS "-fsycl-targets=${DPCPP_SYCL_TARGET};") endif() @@ -51,9 +51,11 @@ if(NOT "${DPCPP_SYCL_ARCH}" STREQUAL "") endif() endif() +set(DPCPP_COMPILE_FLAGS "${DPCPP_FLAGS};-mllvm;-enable-global-offset=false") + if(UNIX) set_target_properties(DPCPP::DPCPP PROPERTIES - INTERFACE_COMPILE_OPTIONS "${DPCPP_FLAGS}" + INTERFACE_COMPILE_OPTIONS "${DPCPP_COMPILE_FLAGS}" INTERFACE_LINK_OPTIONS "${DPCPP_FLAGS}" INTERFACE_LINK_LIBRARIES ${DPCPP_LIB_DIR} INTERFACE_INCLUDE_DIRECTORIES "${DPCPP_BIN_DIR}/../include/sycl;${DPCPP_BIN_DIR}/../include") @@ -61,7 +63,7 @@ if(UNIX) message(STATUS "Using DPCPP flags: ${DPCPP_FLAGS}") else() set_target_properties(DPCPP::DPCPP PROPERTIES - INTERFACE_COMPILE_OPTIONS "${DPCPP_FLAGS}" + INTERFACE_COMPILE_OPTIONS "${DPCPP_COMPILE_FLAGS}" INTERFACE_LINK_LIBRARIES ${DPCPP_LIB_DIR} INTERFACE_INCLUDE_DIRECTORIES "${DPCPP_BIN_DIR}/../include/sycl") endif() @@ -88,8 +90,8 @@ function(add_sycl_to_target) endfunction() function(add_sycl_include_directories_to_target NAME) - target_include_directories(${NAME} + target_include_directories(${NAME} SYSTEM PUBLIC ${DPCPP_BIN_DIR}/../include/sycl - PUBLIC ${DPCPP_BIN_DIR}/../include> + PUBLIC ${DPCPP_BIN_DIR}/../include ) endfunction() diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index 20767c5a05..7c34d49461 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -41,7 +41,7 @@ namespace cute #define SYCL_DEVICE_BUILTIN(x) SYCL_EXTERNAL extern "C" x #else #define SYCL_DEVICE_BUILTIN(x) \ - inline x { assert(false); } + inline x { CUTE_INVALID_CONTROL_PATH("Trying to use XE built-in on non-XE hardware"); } #endif enum class CacheControl { diff --git a/include/cute/arch/mma_xe.hpp b/include/cute/arch/mma_xe.hpp index 7c5ad7b74e..c4a72b05e2 100644 --- a/include/cute/arch/mma_xe.hpp +++ b/include/cute/arch/mma_xe.hpp @@ -37,7 +37,7 @@ #ifdef __SYCL_DEVICE_ONLY__ #define SYCL_DEVICE_OCL(x) SYCL_EXTERNAL x #else -#define SYCL_DEVICE_OCL(x) inline x { assert(false); } +#define SYCL_DEVICE_OCL(x) inline x { CUTE_INVALID_CONTROL_PATH("Trying to use XE built-in on non-XE hardware"); } #endif SYCL_DEVICE_OCL(cute::intel::float8 intel_sub_group_bf16_bf16_matrix_mad_k16(cute::intel::short8 a, cute::intel::int8 b, cute::intel::float8 acc)); diff --git a/include/cute/config.hpp b/include/cute/config.hpp index a2add46d0c..e4c7db5ca6 100644 --- a/include/cute/config.hpp +++ b/include/cute/config.hpp @@ -132,6 +132,8 @@ // Fail and print a message. Typically used for notification of a compiler misconfiguration. #if defined(__CUDA_ARCH__) # define CUTE_INVALID_CONTROL_PATH(x) assert(0 && x); printf(x); __brkpt() +#elif defined(__has_builtin) && __has_builtin(__builtin_unreachable) +# define CUTE_INVALID_CONTROL_PATH(x) assert(0 && x); printf(x); __builtin_unreachable() #else # define CUTE_INVALID_CONTROL_PATH(x) assert(0 && x); printf(x) #endif diff --git a/include/cute/util/sycl_vec.hpp b/include/cute/util/sycl_vec.hpp index 116158dc2f..725de8f1fd 100644 --- a/include/cute/util/sycl_vec.hpp +++ b/include/cute/util/sycl_vec.hpp @@ -38,9 +38,9 @@ namespace cute namespace intel { #ifdef __SYCL_DEVICE_ONLY__ -template using vector_t = typename sycl::vec::vector_t; +template using vector_t = T __attribute__((ext_vector_type(N))); #else -template using vector_t = sycl::vec; +template using vector_t = sycl::marray; #endif using float8 = vector_t; @@ -50,11 +50,11 @@ using int16 = vector_t; using uint8 = vector_t; using uint16 = vector_t; -typedef ushort __attribute__((ext_vector_type(8))) ushort8; -typedef ushort __attribute__((ext_vector_type(16))) ushort16; -typedef ushort __attribute__((ext_vector_type(32))) ushort32; -typedef ushort __attribute__((ext_vector_type(64))) ushort64; -typedef uint __attribute__((ext_vector_type(32))) uint32; +using ushort8 = vector_t; +using ushort16 = vector_t; +using ushort32 = vector_t; +using ushort64 = vector_t; +using uint32 = vector_t; using coord_t = vector_t; } // namespace intel end