Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add special build for testing serialization via a serialization roundtrip in JIT compilation and fix serialization leaks #7763

Merged
merged 36 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
c0cf44d
add back JIT testing, enclosed in #ifdef blocks
TH3CHARLie Aug 15, 2023
d674662
fix typo
TH3CHARLie Aug 15, 2023
3a2de83
nits
TH3CHARLie Aug 16, 2023
de155f1
WITH_SERIALIZATION_JIT->WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING
steven-johnson Aug 23, 2023
01fc657
Merge branch 'main' into pr/7763
steven-johnson Aug 23, 2023
bad7fb0
Merge branch 'main' into pr/7763
steven-johnson Aug 23, 2023
7260cd9
fix self-reference leaks: now uses weak function ptr in reverse funct…
TH3CHARLie Aug 25, 2023
7559e3b
Merge branch 'main' into pr/7763
steven-johnson Aug 28, 2023
951c714
Merge branch 'xuanda/serialization-testing' of https://github.com/TH3…
steven-johnson Aug 28, 2023
ee10d4c
Move clang-tidy checks back to Linux
steven-johnson Aug 28, 2023
20cf471
bogus
steven-johnson Aug 28, 2023
b4fff91
Update Generator.cpp
steven-johnson Aug 29, 2023
8c5f3ac
Merge branch 'srj/tidy-revert' into pr/7763
steven-johnson Aug 29, 2023
9b812b8
Update Generator.cpp
steven-johnson Aug 29, 2023
6c4ef67
Merge branch 'srj/tidy-revert' into pr/7763
steven-johnson Aug 29, 2023
67de0ad
Merge branch 'main' into pr/7763
steven-johnson Aug 30, 2023
0748868
Merge branch 'main' into pr/7763
steven-johnson Aug 30, 2023
e1fb1c3
Merge branch 'main' into pr/7763
steven-johnson Sep 13, 2023
4bd7179
call copy_to_host before serializing buffers
TH3CHARLie Sep 14, 2023
0986319
throw an error if we serialize on-device buffer
TH3CHARLie Sep 14, 2023
f2977c7
Merge branch 'main' into pr/7763
steven-johnson Sep 15, 2023
48f479f
Skip specialize_to_gpu
steven-johnson Sep 15, 2023
50cfc54
Merge branch 'main' into pr/7763
steven-johnson Sep 18, 2023
955a6d7
Merge branch 'main' into pr/7763
steven-johnson Sep 18, 2023
8c98d59
Update Pipeline.cpp
steven-johnson Sep 18, 2023
feacc37
Merge branch 'main' into pr/7763
steven-johnson Sep 19, 2023
004b8ee
Skip two more tests
steven-johnson Sep 25, 2023
de8818d
Merge remote-tracking branch 'origin/main' into xuanda/serialization-…
TH3CHARLie Sep 28, 2023
d488510
use serialize to memory during jit testing
TH3CHARLie Sep 28, 2023
ca33de5
makefile update
TH3CHARLie Sep 29, 2023
1d273d3
makefile fix
TH3CHARLie Sep 29, 2023
0428ad1
skip the tutorial if flatc is not there
TH3CHARLie Oct 2, 2023
441bce4
fix
TH3CHARLie Oct 2, 2023
a538bc0
fix signature
TH3CHARLie Oct 5, 2023
ef7e91b
fix makefile
TH3CHARLie Oct 5, 2023
12583c0
trigger buildbot
TH3CHARLie Nov 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ LLVM_SHARED_LIBS = -Wl,-rpath=$(LLVM_LIBDIR) -L $(LLVM_LIBDIR) -lLLVM
LLVM_LIBS_FOR_SHARED_LIBHALIDE=$(if $(WITH_LLVM_INSIDE_SHARED_LIBHALIDE),$(LLVM_STATIC_LIBS),$(LLVM_SHARED_LIBS))

TUTORIAL_CXX_FLAGS ?= -std=c++17 -g -fno-omit-frame-pointer $(RTTI_CXX_FLAGS) -I $(ROOT_DIR)/tools $(SANITIZER_FLAGS) $(LLVM_CXX_FLAGS_LIBCPP)
ifneq (,$(shell which flatc))
TUTORIAL_CXX_FLAGS += -DWITH_SERIALIZATION -I $(BUILD_DIR) -I $(shell which flatc | sed 's/bin.flatc/include/')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this is necessary? These flags are the ones that would be set by users of Halide. It would be bad if they needed to set any of this to use serialization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should just remove this part, we already skipping tutorial 23 if no flatc found.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

endif
# The tutorials contain example code with warnings that we don't want
# to be flagged as errors, so the test flags are the tutorial flags
# plus our warning flags.
Expand Down Expand Up @@ -2083,6 +2086,12 @@ tutorial_%: $(BIN_DIR)/tutorial_% $(TMP_DIR)/images/rgb.png $(TMP_DIR)/images/gr
cd $(TMP_DIR) ; $(CURDIR)/$<
@-echo

# Skip the serialization tutorial, if we didn't build -DWITH_SERIALIZATION
ifeq (,$(shell which flatc))
tutorial_lesson_23_serialization:
@echo "Skipping tutorial lesson 23 (serialization not enabled) ..."
endif

test_mullapudi2016: $(MULLAPUDI2016_TESTS:$(ROOT_DIR)/test/autoschedulers/mullapudi2016/%.cpp=mullapudi2016_%)

mullapudi2016_%: $(BIN_DIR)/mullapudi2016_% $(BIN_MULLAPUDI2016)
Expand Down
7 changes: 7 additions & 0 deletions cmake/HalideTestHelpers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ function(add_halide_test TARGET)
CXX_VISIBILITY_PRESET hidden
VISIBILITY_INLINES_HIDDEN TRUE)


if (WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING)
if (WITH_SERIALIZATION)
target_compile_definitions(${TARGET} PRIVATE WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING)
endif ()
endif ()

# Add a meta-target for each group, to allow us to build by group easily
foreach (GROUP IN LISTS args_GROUPS)
set(META_TARGET build_${GROUP})
Expand Down
9 changes: 9 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,15 @@ if (WITH_SERIALIZATION)
target_compile_definitions(Halide PRIVATE WITH_SERIALIZATION)
endif ()

# Enable serialization testing by intercepting JIT compilation with a serialization roundtrip;
# This is used only for special builds made specifically for testing, and must be disabled by default.
option(WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING "Intercepting JIT compilation with a serialization roundtrip, for test only" OFF)
if (WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING)
if (WITH_SERIALIZATION)
target_compile_definitions(Halide PRIVATE WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING)
endif ()
endif ()

add_library(Halide::Halide ALIAS Halide)

target_link_libraries(Halide PRIVATE Halide::LLVM)
Expand Down
7 changes: 6 additions & 1 deletion src/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,12 @@ void Deserializer::build_reverse_function_mappings(const std::vector<Function> &
}
int count = 0;
for (const auto &f : functions) {
this->reverse_function_mappings[count++] = f.get_contents();
// The reverse function mappings are used in places where only weak references are needed.
FunctionPtr ptr;
ptr.strong = nullptr;
ptr.weak = f.get_contents().group();
ptr.idx = f.get_contents().idx;
this->reverse_function_mappings[count++] = ptr;
}
}

Expand Down
23 changes: 23 additions & 0 deletions src/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,24 @@ void Pipeline::compile_jit(const Target &target_arg) {
// Clear all cached info in case there is an error.
contents->invalidate_cache();

#ifdef WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING
std::map<std::string, Parameter> external_params;
std::vector<uint8_t> data;
serialize_pipeline(*this, data, external_params);
Pipeline deserialized_pipe = deserialize_pipeline(data, external_params);
std::vector<Function> outputs;
for (const Func &f : deserialized_pipe.outputs()) {
outputs.push_back(f.function());
}
// We save the original output functions and requirements,
// and restore them once all lowering is done,
// so that reschedule/reorder storage can be properly handled.
std::vector<Function> origin_outputs = contents->outputs;
std::vector<Internal::Stmt> origin_requirements = contents->requirements;
contents->outputs = outputs;
contents->requirements = deserialized_pipe.requirements();
#endif

// Infer an arguments vector
infer_arguments();

Expand All @@ -596,6 +614,11 @@ void Pipeline::compile_jit(const Target &target_arg) {
Module module = compile_to_module(args, generate_function_name(), target).resolve_submodules();
std::map<std::string, JITExtern> lowered_externs = contents->jit_externs;
contents->jit_cache = compile_jit_cache(module, std::move(args), contents->outputs, contents->jit_externs, target);
#ifdef WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING
// Restore the original outputs and requirements.
contents->outputs = origin_outputs;
contents->requirements = origin_requirements;
#endif
}

Callable Pipeline::compile_to_callable(const std::vector<Argument> &args_in, const Target &target_arg) {
Expand Down
10 changes: 7 additions & 3 deletions src/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class Serializer {

Offset<Serialize::ExternFuncArgument> serialize_extern_func_argument(FlatBufferBuilder &builder, const ExternFuncArgument &extern_func_argument);

Offset<Serialize::Buffer> serialize_buffer(FlatBufferBuilder &builder, const Buffer<> &buffer);
Offset<Serialize::Buffer> serialize_buffer(FlatBufferBuilder &builder, Buffer<> &buffer);
TH3CHARLie marked this conversation as resolved.
Show resolved Hide resolved

std::vector<Offset<Serialize::WrapperRef>> serialize_wrapper_refs(FlatBufferBuilder &builder, const std::map<std::string, FunctionPtr> &wrappers);

Expand Down Expand Up @@ -1352,10 +1352,14 @@ Offset<Serialize::ExternFuncArgument> Serializer::serialize_extern_func_argument
}
}

Offset<Serialize::Buffer> Serializer::serialize_buffer(FlatBufferBuilder &builder, const Buffer<> &buffer) {
Offset<Serialize::Buffer> Serializer::serialize_buffer(FlatBufferBuilder &builder, Buffer<> &buffer) {
if (!buffer.defined()) {
return Serialize::CreateBuffer(builder, false);
}
if (buffer.device_dirty()) {
user_error << "Cannot serialize on-device buffer: " << buffer.name() << "\n";
}
buffer.copy_to_host();
const auto name_serialized = serialize_string(builder, buffer.name());
const auto type_serialized = serialize_type(builder, buffer.type());
const int32_t dimensions = buffer.dimensions();
Expand Down Expand Up @@ -1447,7 +1451,7 @@ void Serializer::serialize(const Pipeline &pipeline, std::vector<uint8_t> &resul

std::vector<Offset<Serialize::Buffer>> buffers_serialized;
buffers_serialized.reserve(buffers_in_pipeline.size());
for (const auto &buffer : buffers_in_pipeline) {
for (auto &buffer : buffers_in_pipeline) {
buffers_serialized.push_back(serialize_buffer(builder, buffer.second));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
using namespace Halide;

int main(int argc, char **argv) {
#ifdef WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING
printf("[SKIP] Serialization won't preserve GPU buffers, skipping.\n");
return 0;
#endif

Target t(get_jit_target_from_environment());
if (!t.has_gpu_feature()) {
printf("[SKIP] No GPU target enabled.\n");
Expand Down
4 changes: 4 additions & 0 deletions test/correctness/leak_device_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ void halide_print(JITUserContext *user_context, const char *str) {
}

int main(int argc, char **argv) {
#ifdef WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING
printf("[SKIP] Serialization won't preserve GPU buffers, skipping.\n");
return 0;
#endif

Target target = get_jit_target_from_environment();

Expand Down
5 changes: 5 additions & 0 deletions test/correctness/specialize_to_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
using namespace Halide;

int main(int argc, char **argv) {
#ifdef WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING
printf("[SKIP] Serialization won't preserve GPU buffers, skipping.\n");
return 0;
#endif

if (!get_jit_target_from_environment().has_gpu_feature()) {
printf("[SKIP] No GPU target enabled.\n");
return 0;
Expand Down