diff --git a/include/views/view.h b/include/views/view.h index 434259766..241abb4e2 100644 --- a/include/views/view.h +++ b/include/views/view.h @@ -57,6 +57,10 @@ struct VectorView { index_t size_; increment_t strd_; // never size_t, because it could be negative + // Start of the vector. + // If stride is negative, start at the end of the vector and move backward. + container_t ptr_; + VectorView(view_container_t data, view_increment_t strd, view_index_t size); VectorView(VectorView opV, view_increment_t strd, view_index_t size); @@ -97,17 +101,13 @@ struct VectorView { template SYCL_BLAS_INLINE typename std::enable_if::type eval( index_t i) { - return (strd_ == 1) ? *(data_ + i) - : (strd_ > 0) ? *(data_ + i * strd_) - : *(data_ + (size_ * -strd_) + ((i + 1) * strd_)); + return (strd_ == 1) ? *(ptr_ + i) : *(ptr_ + i * strd_); } template SYCL_BLAS_INLINE typename std::enable_if::type eval( index_t i) const { - return (strd_ == 1) ? *(data_ + i) - : (strd_ > 0) ? *(data_ + i * strd_) - : *(data_ + (size_ * -strd_) + ((i + 1) * strd_)); + return (strd_ == 1) ? *(ptr_ + i) : *(ptr_ + i * strd_); } SYCL_BLAS_INLINE value_t &eval(cl::sycl::nd_item<1> ndItem) { @@ -121,13 +121,13 @@ struct VectorView { template SYCL_BLAS_INLINE typename std::enable_if::type eval( index_t indx) { - return *(data_ + indx); + return *(ptr_ + indx); } template SYCL_BLAS_INLINE typename std::enable_if::type eval( index_t indx) const noexcept { - return *(data_ + indx); + return *(ptr_ + indx); } }; diff --git a/src/views/view.hpp b/src/views/view.hpp index 9d24b4907..0804b5cb7 100644 --- a/src/views/view.hpp +++ b/src/views/view.hpp @@ -45,7 +45,7 @@ SYCL_BLAS_INLINE VectorView<_container_t, _IndexType, _IncrementType>::VectorView(_container_t data, _IncrementType strd, _IndexType size) - : data_(data), size_(size), strd_(strd) {} + : data_(data), size_(size), strd_(strd), ptr_(strd > 0 ? data_ : data_ + (size_ - 1) * (-strd_)) {} /*! @brief Creates a view from an existing view. @@ -55,7 +55,7 @@ SYCL_BLAS_INLINE VectorView<_container_t, _IndexType, _IncrementType>::VectorView( VectorView<_container_t, _IndexType, _IncrementType> opV, _IncrementType strd, _IndexType size) - : data_(opV.get_data()), size_(size), strd_(strd) {} + : data_(opV.get_data()), size_(size), strd_(strd), ptr_(strd > 0 ? data_ : data_ + (size_ - 1) * (-strd_)) {} /*! * @brief Returns a reference to the container