Skip to content

Commit

Permalink
Review updates
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 committed May 30, 2024
1 parent 4031348 commit 13c6cce
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 41 deletions.
67 changes: 26 additions & 41 deletions src/include/migraphx/op/reshape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/sat_ops.hpp>

#include <algorithm>

Expand Down Expand Up @@ -70,70 +71,54 @@ struct reshape
shape dyn_1arg_compute_shape(shape s0) const
{
auto input_dyn_dims = s0.dyn_dims();
bool has_negative_dim_attr = false;
const auto neg_dim_num =
std::distance(this->dims.begin(), std::find(this->dims.begin(), this->dims.end(), -1));
const bool has_negative_dim_attr = neg_dim_num < dims.size();
// construct output dynamic shape from dims attribute
std::vector<shape::dynamic_dimension> output_dyn_dims(dims.size());
for(int i = 0; i < dims.size(); ++i)
{
int64_t d = dims.at(i);
if(d == 0)
{
output_dyn_dims.at(i) = input_dyn_dims.at(i);
}
else if(d == -1)
{
has_negative_dim_attr = true;
output_dyn_dims.at(i) = {1, 1};
}
else
{
std::size_t u_dim = d;
output_dyn_dims.at(i) = {u_dim, u_dim};
}
}
std::transform(dims.begin(),
dims.end(),
input_dyn_dims.begin(),
output_dyn_dims.begin(),
[](auto dim, auto input_dyn_dim) -> shape::dynamic_dimension {
if(dim == 0)
{
return input_dyn_dim;
}
if(dim == -1)
{
return {1, 1};
}
std::size_t u_dim = dim;
return {u_dim, u_dim};
});

if(has_negative_dim_attr)
{
// comparing the -1 dimension against the other dimensions
auto neg_dim_num = std::distance(this->dims.begin(),
std::find(this->dims.begin(), this->dims.end(), -1));

// unsigned int wraparound check, false = no wraparound
auto uint_wraparound_check = [](std::size_t a, std::size_t b) {
std::size_t c = a * b;
return a != 0 and c / a != b;
};

// accumulate the minimum and maximum elements in the dimensions before the -1 dimension
std::size_t min_cur_elements = 1;
std::size_t max_cur_elements = 1;
std::size_t max_int = std::numeric_limits<std::size_t>::max();
for(const auto& dd : output_dyn_dims)
{
min_cur_elements = uint_wraparound_check(min_cur_elements, dd.min)
? max_int
: min_cur_elements * dd.min;
max_cur_elements = uint_wraparound_check(max_cur_elements, dd.max)
? max_int
: max_cur_elements * dd.max;
min_cur_elements = mul_sat(min_cur_elements, dd.min);
max_cur_elements = mul_sat(max_cur_elements, dd.max);
}
// accumulate the elements in the input dimensions
std::size_t min_input_elements = 1;
std::size_t max_input_elements = 1;
for(const auto& dd : input_dyn_dims)
{
min_input_elements = uint_wraparound_check(min_input_elements, dd.min)
? max_int
: min_input_elements * dd.min;
max_input_elements = uint_wraparound_check(max_input_elements, dd.max)
? max_int
: max_input_elements * dd.max;
min_input_elements = mul_sat(min_input_elements, dd.min);
max_input_elements = mul_sat(max_input_elements, dd.max);
}

// maximum dimensions should never accumulate to zero
assert(max_cur_elements != 0);

// hanle 0 dimension value (keep unknown lower bound)
std::size_t max_int = std::numeric_limits<std::size_t>::max();
// handle 0 dimension value (keep unknown lower bound)
std::size_t min_dim =
(min_cur_elements == 0) ? 0 : min_input_elements / min_cur_elements;
// handle maximum dimension value (keep unknown upper bound)
Expand Down
49 changes: 49 additions & 0 deletions src/include/migraphx/sat_ops.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SAT_OPS_HPP
#define MIGRAPHX_GUARD_RTGLIB_SAT_OPS_HPP

#include <type_traits>
#include <limits>

template <class T>
constexpr T mul_sat(T a, T b) noexcept
{
T c;
if(not __builtin_mul_overflow(a, b, &c))
{
return c;
}
if constexpr(std::is_unsigned<T>{})
{
return std::numeric_limits<T>::max();
}
else if(a < 0 != b < 0)
{
return std::numeric_limits<T>::min();
}
return std::numeric_limits<T>::max();
}

#endif

0 comments on commit 13c6cce

Please sign in to comment.