Skip to content

Commit

Permalink
Add input_tensor input type (davisking#2951)
Browse files Browse the repository at this point in the history
  • Loading branch information
kSkip authored May 12, 2024
1 parent fa0e3ff commit 51c7a35
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 0 deletions.
8 changes: 8 additions & 0 deletions dlib/cuda/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,14 @@ namespace dlib

// ----------------------------------------------------------------------------------------

inline void memcpy (
alias_tensor_instance&& dest,
const tensor& src
)
{
memcpy(static_cast<tensor&>(dest), src);
}

}

#endif // DLIB_DNn_TENSOR_H_
Expand Down
8 changes: 8 additions & 0 deletions dlib/cuda/tensor_abstract.h
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,14 @@ namespace dlib
);
};

inline void memcpy (
alias_tensor_instance&& dest,
const tensor& src
) { memcpy(static_cast<tensor&>(dest), src); }
/*!
A convenient overload for copying from src to dest when you have a temporary alias tensor.
!*/

class alias_tensor_const_instance
{
/*!
Expand Down
87 changes: 87 additions & 0 deletions dlib/dnn/input.h
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,93 @@ namespace dlib
float avg_blue;
};

// ----------------------------------------------------------------------------------------

class input_tensor
{
public:
typedef tensor input_type;

input_tensor() {}
input_tensor(const input_tensor&) {}

template<typename forward_iterator>
void to_tensor(
forward_iterator ibegin,
forward_iterator iend,
resizable_tensor& data
) const
{
DLIB_CASSERT(std::distance(ibegin, iend) > 0);
const auto k = ibegin->k();
const auto nr = ibegin->nr();
const auto nc = ibegin->nc();
// make sure all the input tensors have the same dimensions
for (auto i = ibegin; i != iend; ++i)
{
DLIB_CASSERT(i->k() == k && i->nr() == nr && i->nc() == nc,
"\t input_tensor::to_tensor()"
<< "\n\t All tensor objects given to to_tensor() must have the same dimensions."
<< "\n\t k: " << k
<< "\n\t nr: " << nr
<< "\n\t nc: " << nc
<< "\n\t i->k(): " << i->k()
<< "\n\t i->nr(): " << i->nr()
<< "\n\t i->nc(): " << i->nc()
);
}

const auto num_samples = count_samples(ibegin, iend);
// initialize data to the right size to contain the stuff in the iterator range.
data.set_size(num_samples, k, nr, nc);

const size_t stride = k * nr * nc;
size_t offset = 0;
for (auto i = ibegin; i != iend; ++i)
{
alias_tensor slice(i->num_samples(), k, nr, nc);
memcpy(slice(data, offset), *i);
offset += slice.num_samples() * stride;
}
}

friend void serialize(const input_tensor&, std::ostream& out)
{
serialize("input_tensor", out);
}

friend void deserialize(input_tensor&, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "input_tensor")
throw serialization_error("Unexpected version found while deserializing dlib::input_tensor.");
}

friend std::ostream& operator<<(std::ostream& out, const input_tensor&)
{
out << "input_tensor";
return out;
}

friend void to_xml(const input_tensor&, std::ostream& out)
{
out << "<input_tensor/>\n";
}

private:

template<typename forward_iterator>
long long count_samples(
forward_iterator ibegin,
forward_iterator iend
) const
{
return std::accumulate(ibegin, iend, 0,
[](long long a, const auto& b) { return a + b.num_samples(); });
}
};

// ----------------------------------------------------------------------------------------

}
Expand Down
51 changes: 51 additions & 0 deletions dlib/dnn/input_abstract.h
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,57 @@ namespace dlib

// ----------------------------------------------------------------------------------------

class input_tensor
{
/*!
WHAT THIS OBJECT REPRESENTS
This input layer works with dlib::tensor objects. It is very similar to
the dlib::input layer except that it allows for concatenating data that
already resides in GPU memory.
!*/

public:
typedef tensor input_type;

input_tensor(
);
/*!
ensures
- input_tensor objects are default constructable
!*/

input_tensor(
const input_tensor& item
);
/*!
ensures
- input_tensor objects are copy constructable
!*/

template <typename forward_iterator>
void to_tensor(
forward_iterator ibegin,
forward_iterator iend,
resizable_tensor& data
) const;
/*!
requires
- [ibegin, iend) is an iterator range over input_type objects.
- std::distance(ibegin,iend) > 0
- The input range should contain tensor objects that all have the same
dimensions.
ensures
- Copies the iterator range into #data. In particular, if the input tensors
have R rows, C columns, and K channels then we will have:
- #data.num_samples() == count_samples(ibegin,iend)
- #data.nr() == R
- #data.nc() == C
- #data.k() == K
This results in a tensor concatenation along the sample dimension.
!*/
};

// ----------------------------------------------------------------------------------------
}

#endif // DLIB_DNn_INPUT_ABSTRACT_H_
Expand Down
33 changes: 33 additions & 0 deletions dlib/test/dnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4276,6 +4276,38 @@ namespace
#endif
}

void test_input_tensor()
{
using namespace dlib::tt;
print_spinner();
tt::tensor_rand rnd;
std::vector<resizable_tensor> tensors(3);

for (auto& t : tensors) {
t.set_size(1, 3, 224, 224);
rnd.fill_gaussian(t);
}

resizable_tensor out;
input_tensor input_layer;

input_layer.to_tensor(tensors.begin(), tensors.end(), out);

DLIB_TEST(out.num_samples() == 3);
DLIB_TEST(out.k() == 3);
DLIB_TEST(out.nr() == 224);
DLIB_TEST(out.nc() == 224);
size_t stride = out.k() * out.nr() * out.nc();
size_t offset = 0;
int error = 0;

for (auto& t : tensors) {
error = memcmp(out.host() + offset, t.host(), sizeof(float) * t.size());
DLIB_TEST(error == 0);
offset += stride;
}
}

// ----------------------------------------------------------------------------------------

class dnn_tester : public tester
Expand Down Expand Up @@ -4386,6 +4418,7 @@ namespace
test_input_ouput_mappers();
test_fuse_layers();
test_reorg();
test_input_tensor();
}

void perform_test()
Expand Down

0 comments on commit 51c7a35

Please sign in to comment.