Skip to content

Commit

Permalink
Merge pull request #841 from KhronosGroup/steffen/bump_and_builtins
Browse files Browse the repository at this point in the history
Check new builtin guarntees
  • Loading branch information
bader authored Dec 4, 2023
2 parents 4993dc6 + 9c05219 commit 3b98e3e
Show file tree
Hide file tree
Showing 7 changed files with 1,100 additions and 313 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/cts_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
matrix:
include:
- sycl-impl: dpcpp
version: 6bce7f64f51a4370052bffa3fa257ca16d8aad9e
version: 1dbee22f9c8a3a825deb871bab76937e04fa26fc
- sycl-impl: hipsycl
version: 3d8b1cd
steps:
Expand Down Expand Up @@ -114,7 +114,7 @@ jobs:
matrix:
include:
- sycl-impl: dpcpp
version: 6bce7f64f51a4370052bffa3fa257ca16d8aad9e
version: 1dbee22f9c8a3a825deb871bab76937e04fa26fc
- sycl-impl: hipsycl
version: 3d8b1cd
env:
Expand Down
5 changes: 5 additions & 0 deletions tests/math_builtin_api/math_builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,4 +374,9 @@ void test_function_multi_ptr_local(funT fun, argT arg) {
delete[] kernelResult;
}

template <typename T>
struct ImplicitlyConvertibleType {
operator T() const { return {}; }
};

#endif // CL_SYCL_CTS_MATH_BUILTIN_API_MATH_BUILTIN_H
1,307 changes: 998 additions & 309 deletions tests/math_builtin_api/modules/sycl_functions.py

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions tests/math_builtin_api/modules/sycl_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,11 +1471,29 @@ def create_types():
type_dic["sgentype"] = t_sgen_type


t_vgenfloatf_type = argtype("vgenfloatf", "NULL", "NULL", 0, ["sycl::vec<float, 2>","sycl::vec<float, 3>","sycl::vec<float, 4>","sycl::vec<float, 8>","sycl::vec<float, 16>"])
type_dic["vgenfloatf"] = t_vgenfloatf_type

t_vgenfloatd_type = argtype("vgenfloatd", "NULL", "NULL", 0, ["sycl::vec<double, 2>","sycl::vec<double, 3>","sycl::vec<double, 4>","sycl::vec<double, 8>","sycl::vec<double, 16>"])
type_dic["vgenfloatd"] = t_vgenfloatd_type

t_vgenfloath_type = argtype("vgenfloath", "NULL", "NULL", 0, ["sycl::vec<sycl::half, 2>","sycl::vec<sycl::half, 3>","sycl::vec<sycl::half, 4>","sycl::vec<sycl::half, 8>","sycl::vec<sycl::half, 16>"])
type_dic["vgenfloath"] = t_vgenfloath_type

t_vgenfloat_type = argtype("vgenfloat", "NULL", "NULL", 0, ["sycl::vec<float, 2>","sycl::vec<float, 3>","sycl::vec<float, 4>","sycl::vec<float, 8>","sycl::vec<float, 16>",
"sycl::vec<double, 2>","sycl::vec<double, 3>","sycl::vec<double, 4>","sycl::vec<double, 8>","sycl::vec<double, 16>",
"sycl::vec<sycl::half, 2>","sycl::vec<sycl::half, 3>","sycl::vec<sycl::half, 4>","sycl::vec<sycl::half, 8>","sycl::vec<sycl::half, 16>"])
type_dic["vgenfloat"] = t_vgenfloat_type

t_mgenfloatf_type = argtype("mgenfloatf", "NULL", "NULL", 0, ["sycl::marray<float, 2>","sycl::marray<float, 3>","sycl::marray<float, 4>","sycl::marray<float, 5>","sycl::marray<float, 17>"])
type_dic["mgenfloatf"] = t_mgenfloatf_type

t_mgenfloatd_type = argtype("mgenfloatd", "NULL", "NULL", 0, ["sycl::marray<double, 2>","sycl::marray<double, 3>","sycl::marray<double, 4>","sycl::marray<double, 5>","sycl::marray<double, 17>",])
type_dic["mgenfloatd"] = t_mgenfloatd_type

t_mgenfloath_type = argtype("mgenfloath", "NULL", "NULL", 0, ["sycl::marray<sycl::half, 2>","sycl::marray<sycl::half, 3>","sycl::marray<sycl::half, 4>","sycl::marray<sycl::half, 5>","sycl::marray<sycl::half, 17>"])
type_dic["mgenfloath"] = t_mgenfloath_type

t_mgenfloat_type = argtype("mgenfloat", "NULL", "NULL", 0, ["sycl::marray<float, 2>","sycl::marray<float, 3>","sycl::marray<float, 4>","sycl::marray<float, 5>","sycl::marray<float, 17>",
"sycl::marray<double, 2>","sycl::marray<double, 3>","sycl::marray<double, 4>","sycl::marray<double, 5>","sycl::marray<double, 17>",
"sycl::marray<sycl::half, 2>","sycl::marray<sycl::half, 3>","sycl::marray<sycl::half, 4>","sycl::marray<sycl::half, 5>","sycl::marray<sycl::half, 17>"])
Expand Down
67 changes: 65 additions & 2 deletions tests/math_builtin_api/modules/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,71 @@ def generate_arguments(sig, memory, decorated):
arg_index += 1
return (arg_names, arg_src)

template_args_template = Template("""
static_assert(std::is_same_v<decltype(${namespace}::${func_name}<${template_arg_types}>(${arg_names})), decltype(${namespace}::${func_name}(${arg_names}))>,
"Error: ${namespace}::${func_name}(${arg_types}) definition does not use the required template arguments.");
""")
implicit_convertible_args_template = Template("""
static_assert(std::is_same_v<decltype(${namespace}::${func_name}(${convertible_args})), decltype(${namespace}::${func_name}(${arg_names}))>,
"Error: ${namespace}::${func_name}(${arg_types}) cannot properly convert arguments.");
""")
def generate_additional_static_checks(sig, arg_names):
asc = ""
arg_types = [a.name for a in sig.arg_types]
if len(sig.template_arg_map) > 0:
# Check template arguments on functions with it.
template_arg_types = ["decltype(" + arg_names[i] + ")" for i in sig.template_arg_map]
asc += template_args_template.substitute(
namespace=sig.namespace,
func_name=sig.name,
arg_names=", ".join(arg_names),
ret_type=sig.ret_type.name,
template_arg_types=", ".join(template_arg_types),
arg_types=", ".join(arg_types))
else:
# Otherwise check for implcitly convertible arguments.
convertible_args = ["std::declval<ImplicitlyConvertibleType<" + a + ">>()" for a in arg_types]
asc += implicit_convertible_args_template.substitute(
namespace=sig.namespace,
func_name=sig.name,
arg_names=", ".join(arg_names),
ret_type=sig.ret_type.name,
convertible_args=", ".join(convertible_args),
arg_types=", ".join(arg_types))

# Detect all vec arguments.
vec_args = [re.search("^sycl::vec<.+,\s*(\d+)>$", at) for at in arg_types]
# Except the ones that are pointers.
for i in sig.pntr_indx:
if i > 0:
vec_args[i-1] = None

# If there are any vector arguments in the builtin we check that they also
# accept swizzles.
if not all(va is None for va in vec_args):
convertible_args = []
for (vec_arg, arg_name) in zip(vec_args, arg_names):
if vec_arg:
(num_elems,) = vec_arg.groups()
indices = [str(i) for i in range(int(num_elems))]
convertible_args.append(arg_name + ".swizzle<" + (",".join(indices)) + ">()")
else:
convertible_args.append(arg_name)
asc += implicit_convertible_args_template.substitute(
namespace=sig.namespace,
func_name=sig.name,
arg_names=", ".join(arg_names),
ret_type=sig.ret_type.name,
convertible_args=", ".join(convertible_args),
arg_types=", ".join(arg_types))
return asc


function_call_template = Template("""
${arg_src}
static_assert(std::is_same_v<decltype(${namespace}::${func_name}(${arg_names})), ${ret_type}>,
"Error: Wrong return type of ${namespace}::${func_name}(${arg_types}), not ${ret_type}");
${additional_static_checks}
return ${namespace}::${func_name}(${arg_names});
""")
def generate_function_call(sig, arg_names, arg_src):
Expand All @@ -174,7 +235,8 @@ def generate_function_call(sig, arg_names, arg_src):
func_name=sig.name,
arg_names=", ".join(arg_names),
ret_type=sig.ret_type.name,
arg_types=", ".join([a.name for a in sig.arg_types]))
arg_types=", ".join([a.name for a in sig.arg_types]),
additional_static_checks=generate_additional_static_checks(sig, arg_names))
return fc

function_private_call_template = Template("""
Expand Down Expand Up @@ -389,7 +451,8 @@ def expand_signature(types, signature):
new_sig = sycl_functions.funsig(signature.namespace, matched_typelists[signature.ret_type][i],
signature.name, [matched_typelists[signature.arg_types[j]][i]
for j in range(len(signature.arg_types))],
signature.accuracy, signature.comment, signature.pntr_indx[:])
signature.accuracy, signature.comment, signature.pntr_indx[:],
signature.mutations[:], signature.template_arg_map[:])
exp_sig.append(new_sig)

return exp_sig
Expand Down
1 change: 1 addition & 0 deletions util/math_reference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ double fract(double a, double *b) {
float nan(unsigned int a) { return std::nanf(std::to_string(a).c_str()); }
double nan(unsigned long a) { return std::nan(std::to_string(a).c_str()); }
double nan(unsigned long long a) { return std::nan(std::to_string(a).c_str()); }
sycl::half nan(unsigned short a) { return nan(unsigned(a)); }

sycl::half modf(sycl::half a, sycl::half *b) {
float resPtr;
Expand Down
11 changes: 11 additions & 0 deletions util/math_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -1281,10 +1281,16 @@ sycl::marray<T, N> modf(sycl::marray<T, N> a, sycl::marray<T, N> *b) {
}
#endif

sycl::half nan(unsigned short a);
float nan(unsigned int a);
double nan(unsigned long a);
double nan(unsigned long long a);
template <int N>
sycl::vec<sycl::half, N> nan(sycl::vec<unsigned short, N> a) {
return sycl_cts::math::run_func_on_vector<sycl::half, unsigned short, N>(
[](unsigned short x) { return nan(x); }, a);
}
template <int N>
sycl::vec<float, N> nan(sycl::vec<unsigned int, N> a) {
return sycl_cts::math::run_func_on_vector<float, unsigned int, N>(
[](unsigned int x) { return nan(x); }, a);
Expand All @@ -1300,6 +1306,11 @@ nan(sycl::vec<T, N> a) {
// FIXME: hipSYCL does not support marray
#ifndef SYCL_CTS_COMPILING_WITH_HIPSYCL
template <size_t N>
sycl::marray<sycl::half, N> nan(sycl::marray<unsigned short, N> a) {
return sycl_cts::math::run_func_on_marray<sycl::half, unsigned short, N>(
[](unsigned short x) { return nan(x); }, a);
}
template <size_t N>
sycl::marray<float, N> nan(sycl::marray<unsigned int, N> a) {
return sycl_cts::math::run_func_on_marray<float, unsigned int, N>(
[](unsigned int x) { return nan(x); }, a);
Expand Down

0 comments on commit 3b98e3e

Please sign in to comment.