Skip to content

Commit

Permalink
Merge pull request #31 from tenstorrent/30-ioctl-bugs
Browse files Browse the repository at this point in the history
Fix ioctl input & output sizes
  • Loading branch information
alewycky-tenstorrent authored Sep 6, 2024
2 parents 3ae2a24 + ce03953 commit 715a5d7
Show file tree
Hide file tree
Showing 8 changed files with 383 additions and 12 deletions.
10 changes: 5 additions & 5 deletions chardev.c
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ static long ioctl_get_device_info(struct chardev_private *priv,
const struct pci_dev *pdev = priv->device->pdev;
u32 bytes_to_copy;

struct tenstorrent_get_device_info_out in;
struct tenstorrent_get_device_info_in in;
struct tenstorrent_get_device_info_out out;
memset(&in, 0, sizeof(in));
memset(&out, 0, sizeof(out));
Expand All @@ -132,7 +132,7 @@ static long ioctl_get_device_info(struct chardev_private *priv,

bytes_to_copy = min(in.output_size_bytes, (u32)sizeof(out));

if (copy_to_user(&arg->out, &out, sizeof(out)) != 0)
if (copy_to_user(&arg->out, &out, bytes_to_copy) != 0)
return -EFAULT;

return 0;
Expand All @@ -143,7 +143,7 @@ static long ioctl_get_driver_info(struct chardev_private *priv,
{
u32 bytes_to_copy;

struct tenstorrent_get_driver_info_out in;
struct tenstorrent_get_driver_info_in in;
struct tenstorrent_get_driver_info_out out;
memset(&in, 0, sizeof(in));
memset(&out, 0, sizeof(out));
Expand All @@ -159,7 +159,7 @@ static long ioctl_get_driver_info(struct chardev_private *priv,

bytes_to_copy = min(in.output_size_bytes, (u32)sizeof(out));

if (copy_to_user(&arg->out, &out, sizeof(out)) != 0)
if (copy_to_user(&arg->out, &out, bytes_to_copy) != 0)
return -EFAULT;

return 0;
Expand Down Expand Up @@ -200,7 +200,7 @@ static long ioctl_reset_device(struct chardev_private *priv,

bytes_to_copy = min(in.output_size_bytes, (u32)sizeof(out));

if (copy_to_user(&arg->out, &out, sizeof(out)) != 0)
if (copy_to_user(&arg->out, &out, bytes_to_copy) != 0)
return -EFAULT;

return 0;
Expand Down
4 changes: 3 additions & 1 deletion test/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ all::

PROG := ttkmd_test

TEST_SOURCES := get_device_info.cpp query_mappings.cpp dma_buf.cpp pin_pages.cpp config_space.cpp lock.cpp hwmon.cpp map_peer_bar.cpp
TEST_SOURCES := get_driver_info.cpp get_device_info.cpp query_mappings.cpp \
dma_buf.cpp pin_pages.cpp config_space.cpp lock.cpp hwmon.cpp map_peer_bar.cpp \
ioctl_overrun.cpp ioctl_zeroing.cpp

CORE_SOURCES := enumeration.cpp util.cpp devfd.cpp main.cpp test_failure.cpp
SOURCES := $(CORE_SOURCES) $(TEST_SOURCES)
Expand Down
37 changes: 37 additions & 0 deletions test/get_driver_info.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
// SPDX-License-Identifier: GPL-2.0-only

#include <string>

#include <sys/ioctl.h>

#include "ioctl.h"

#include "util.h"
#include "test_failure.h"
#include "enumeration.h"
#include "devfd.h"

void TestGetDriverInfo(const EnumeratedDevice &dev)
{
DevFd dev_fd(dev.path);

tenstorrent_get_driver_info get_driver_info{};
get_driver_info.in.output_size_bytes = sizeof(get_driver_info.out);

if (ioctl(dev_fd.get(), TENSTORRENT_IOCTL_GET_DRIVER_INFO, &get_driver_info) != 0)
THROW_TEST_FAILURE("TENSTORRENT_IOCTL_GET_DRIVER_INFO failed on " + dev.path);

std::size_t min_get_driver_info_out
= offsetof(tenstorrent_get_driver_info_out, driver_version)
+ sizeof(get_driver_info.out.driver_version);

if (get_driver_info.out.output_size_bytes < min_get_driver_info_out)
THROW_TEST_FAILURE("GET_DRIVER_INFO output is too small.");

if (get_driver_info.out.output_size_bytes > sizeof(get_driver_info.out))
THROW_TEST_FAILURE("GET_DRIVER_INFO output is too large. (Test may be out of date.)");

if (get_driver_info.out.driver_version != TENSTORRENT_DRIVER_VERSION)
THROW_TEST_FAILURE("GET_DRIVER_INFO reports an unexpected driver version.");
}
212 changes: 212 additions & 0 deletions test/ioctl_overrun.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// Try to catch ioctls that read or write the wrong amount of data.
//
// When an ioctl input has output_size_bytes, we align the input to the end of the page
// and set output_size_bytes = 0. This should result in no output being written and no error.
// This catches read and write overruns.
//
// When an ioctl input doesn't have output_size_bytes, we align the entire structure to the
// end of the page. This catches write overruns.
// If hardware had support for PROT_WRITE without PROT_READ we could also check for read overruns.

#include <memory>
#include <string>

#include <cerrno>
#include <cstddef>
#include <cstdint>
#include <cstdlib>

#include <sys/ioctl.h>
#include <sys/mman.h>

#include "devfd.h"
#include "enumeration.h"
#include "ioctl.h"
#include "test_failure.h"
#include "util.h"

namespace
{

// Allocate data aligned to the end of a page, guaranteeing that the next page is unmapped.
template <class T>
class EndOfPage
{
public:
EndOfPage(const T& init = {});
~EndOfPage();

EndOfPage(const EndOfPage<T>&) = delete;
void operator = (const EndOfPage<T>&) = delete;

T *get();

private:
void *mapping = nullptr;
T *value = nullptr;

static std::size_t mapping_size();
};

template <class T>
std::size_t EndOfPage<T>::mapping_size()
{
return round_up(sizeof(T), page_size()) + page_size();
}

template <class T>
EndOfPage<T>::EndOfPage(const T& init)
{
auto size = mapping_size();

mapping = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
if (mapping == MAP_FAILED)
throw_system_error("end-of-page mapping allocation failed");

void *final_page = reinterpret_cast<void*>(reinterpret_cast<std::uintptr_t>(mapping) + size - page_size());

if (mprotect(final_page, page_size(), PROT_NONE) != 0)
throw_system_error("failed to disable access to overrun detection page");

void *p = reinterpret_cast<void*>(reinterpret_cast<std::uintptr_t>(final_page) - sizeof(T));
value = new (p) T(init);
}

template <class T>
EndOfPage<T>::~EndOfPage()
{
value->~T();
munmap(mapping, mapping_size());
}

template <class T>
T *EndOfPage<T>::get()
{
return value;
}

// The assumption is that the ioctl_data is aligned to the end of the page and no EFAULT should occur.
#define CHECK_IOCTL_OVERRUN(fd, ioctl_name, ioctl_data) CheckIoctlOverrun(fd, ioctl_name, #ioctl_name, ioctl_data)
#define CHECK_IOCTL_OVERRUN_ERROR(fd, ioctl_name, ioctl_data, expected_error) CheckIoctlOverrun(fd, ioctl_name, #ioctl_name, ioctl_data, expected_error)

template <class IoctlData>
void CheckIoctlOverrun(int fd, unsigned long ioctl_code, const char *ioctl_name, const IoctlData& ioctl_data, int expected_error = 0)
{
EndOfPage<IoctlData> aligned_ioctl_data(ioctl_data);

int result = ioctl(fd, ioctl_code, aligned_ioctl_data.get());

if (result != 0)
{
if (errno == EFAULT)
THROW_TEST_FAILURE(std::string(ioctl_name) + " failed overrun check.");
else if (errno != expected_error)
THROW_TEST_FAILURE(std::string(ioctl_name) + " overrun check failed other than EFAULT.");
}
}

void TestGetDeviceInfoOverrun(int fd)
{
tenstorrent_get_device_info_in in{};
in.output_size_bytes = 0;

CHECK_IOCTL_OVERRUN(fd, TENSTORRENT_IOCTL_GET_DEVICE_INFO, in);
}

void TestQueryMappingsOverrun(int fd)
{
tenstorrent_query_mappings_in in{};
in.output_mapping_count = 0;

CHECK_IOCTL_OVERRUN(fd, TENSTORRENT_IOCTL_QUERY_MAPPINGS, in);
}

void TestAllocateDmaBufOverrun(int fd)
{
tenstorrent_allocate_dma_buf alloc_buf{};

alloc_buf.in.requested_size = page_size();
alloc_buf.in.buf_index = 0;

CHECK_IOCTL_OVERRUN(fd, TENSTORRENT_IOCTL_ALLOCATE_DMA_BUF, alloc_buf);
}

void TestFreeDmaBufOverrun(int fd)
{
tenstorrent_free_dma_buf free_buf{};

CHECK_IOCTL_OVERRUN_ERROR(fd, TENSTORRENT_IOCTL_FREE_DMA_BUF, free_buf, EINVAL);
}

void TestGetDriverInfoOverrun(int fd)
{
tenstorrent_get_driver_info_in in{};
in.output_size_bytes = 0;

CHECK_IOCTL_OVERRUN(fd, TENSTORRENT_IOCTL_GET_DRIVER_INFO, in);
}

void TestResetDeviceOverrun(int fd)
{
tenstorrent_reset_device_in in{};
in.output_size_bytes = 0;
in.flags = TENSTORRENT_RESET_DEVICE_RESTORE_STATE;

CHECK_IOCTL_OVERRUN(fd, TENSTORRENT_IOCTL_RESET_DEVICE, in);
}

void TestPinPagesOverrun(int fd)
{
std::unique_ptr<void, Freer> page(std::aligned_alloc(page_size(), page_size()));

tenstorrent_pin_pages_in in{};
in.output_size_bytes = 0;
in.virtual_address = reinterpret_cast<std::uintptr_t>(page.get());
in.size = page_size();

CHECK_IOCTL_OVERRUN(fd, TENSTORRENT_IOCTL_PIN_PAGES, in);
}

void TestLockCtlOverrun(int fd)
{
tenstorrent_lock_ctl_in in{};
in.output_size_bytes = 0;
in.flags = TENSTORRENT_LOCK_CTL_TEST;
in.index = 0;

CHECK_IOCTL_OVERRUN(fd, TENSTORRENT_IOCTL_LOCK_CTL, in);
}

void TestMapPeerBarOverrun(int fd)
{
// TENSTORRENT_IOCTL_MAP_PEER_BAR requires 2 devices and doesn't have output_size_bytes
// so we can only test that it rejects the input without EFAULT.

tenstorrent_map_peer_bar_in in{};

in.peer_fd = fd;
in.peer_bar_index = 0;
in.peer_bar_offset = 0;
in.peer_bar_length = page_size();
in.flags = 0;

CHECK_IOCTL_OVERRUN_ERROR(fd, TENSTORRENT_IOCTL_MAP_PEER_BAR, in, EINVAL);
}

}

void TestIoctlOverrun(const EnumeratedDevice &dev)
{
DevFd dev_fd(dev.path);

TestGetDeviceInfoOverrun(dev_fd.get());
// TENSTORRENT_IOCTL_GET_HARVESTING simply fails.
TestQueryMappingsOverrun(dev_fd.get());
TestAllocateDmaBufOverrun(dev_fd.get());
TestFreeDmaBufOverrun(dev_fd.get());
TestGetDriverInfoOverrun(dev_fd.get());
TestResetDeviceOverrun(dev_fd.get());
TestPinPagesOverrun(dev_fd.get());
TestLockCtlOverrun(dev_fd.get());
TestMapPeerBarOverrun(dev_fd.get());
}
Loading

0 comments on commit 715a5d7

Please sign in to comment.