Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup invoke #29

Merged
merged 8 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nautilus/include/nautilus/common/traceing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ value_ref traceCast(value_ref state, Type resultType);

std::array<uint8_t, 256>& getVarRefMap();

value_ref traceCall(const std::string& functionName, void* fptn, Type resultType, std::vector<tracing::value_ref> arguments);
value_ref traceCall(void* fptn, const std::type_info& ti, Type resultType, const std::vector<tracing::value_ref>& arguments);

std::ostream& operator<<(std::ostream& os, const Op& operation);

Expand Down
169 changes: 71 additions & 98 deletions nautilus/include/nautilus/function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,68 +2,18 @@

#include "nautilus/val.hpp"
#include "nautilus/val_ptr.hpp"
#include <cxxabi.h>
#include <dlfcn.h>
#include <functional>
#include <type_traits>
#include <utility>

namespace nautilus {

template <typename R, typename... FunctionArguments>
std::string getFunctionName(R (*fnptr)(FunctionArguments...)) {
Dl_info info;
dladdr(reinterpret_cast<void*>(fnptr), &info);
if (info.dli_sname != nullptr) {
return info.dli_sname;
}
return "xxx";
}

template <typename R, typename... FunctionArguments>
class CallableNautilusFunction {
public:
CallableNautilusFunction(R (*fnptr)(FunctionArguments...)) : fnptr(fnptr) {};

template <typename Arg>
auto transform(Arg argument) {
return make_value(argument);
}

template <typename... FunctionArgumentsRaw>
auto operator()(FunctionArgumentsRaw... args) {
// we are in the nautilus context
// keep current tracing context and continue tracing
// TODO implement polymorphic inline cache
return fnptr((args)...);
}

template <is_integral... FunctionArgumentsRaw>
auto operator()(FunctionArgumentsRaw... args) {
// function is called from an external context.
auto result = fnptr(transform((args))...);
return details::getRawValue(result);
}

template <is_integral... FunctionArgumentsRaw>
auto invoke(FunctionArgumentsRaw... args) {
return (*this)(args...);
}

private:
R (*fnptr)(FunctionArguments...);
};

template <is_traceable_value Arg>
tracing::value_ref getRefs(Arg& argument) {
return argument.state;
}

template <typename... ValueArguments>
auto getArgumentReferences(std::string_view, void*, const ValueArguments&... arguments) {
std::vector<tracing::value_ref> functionArgumentReferences = {};
functionArgumentReferences.reserve(sizeof...(ValueArguments));
auto getArgumentReferences(const ValueArguments&... arguments) {
std::vector<tracing::value_ref> functionArgumentReferences;
if constexpr (sizeof...(ValueArguments) > 0) {
for (const tracing::value_ref& p : {getRefs(arguments)...}) {
functionArgumentReferences.reserve(sizeof...(ValueArguments));
for (const tracing::value_ref& p : {details::getState(arguments)...}) {
functionArgumentReferences.emplace_back(p);
}
}
Expand All @@ -73,86 +23,109 @@ auto getArgumentReferences(std::string_view, void*, const ValueArguments&... arg
template <typename R, typename... FunctionArguments>
class CallableRuntimeFunction {
public:
explicit CallableRuntimeFunction(R (*fnptr)(FunctionArguments...), std::string& functionName)
: fnptr(fnptr), functionName(functionName) {};

template <typename Arg>
auto transform(Arg argument) {
return details::getRawValue(argument);
explicit CallableRuntimeFunction(R (*fnptr)(FunctionArguments...)) : fnptr(fnptr) {
}

template <typename... FunctionArgumentsRaw>
requires(!std::is_void_v<R>)
auto operator()(FunctionArgumentsRaw... args) {
// function is called from an external context.
auto operator()(FunctionArgumentsRaw&&... args) {
#ifdef ENABLE_TRACING
if (tracing::inTracer()) {
auto ptr = (void*) fnptr;
auto functionArgumentReferences = getArgumentReferences(functionName, (void*) fnptr, args...);
auto resultRef = tracing::traceCall(functionName, ptr, tracing::to_type<R>(), functionArgumentReferences);
const std::type_info &ti = typeid(fnptr);
auto functionArgumentReferences = getArgumentReferences(std::forward<FunctionArgumentsRaw>(args)...);
auto resultRef = tracing::traceCall(reinterpret_cast<void*>(fnptr), ti, tracing::to_type<R>(), functionArgumentReferences);
return val<R>(resultRef);
}
#endif
return val<R>(fnptr(transform((args))...));
return val<R>(fnptr(details::getRawValue(std::forward<FunctionArgumentsRaw>(args))...));
}

template <typename... FunctionArgumentsRaw>
requires std::is_void_v<R>
void operator()(FunctionArgumentsRaw... args) {
// function is called from an external context.
void operator()(FunctionArgumentsRaw&&... args) {
#ifdef ENABLE_TRACING
if (tracing::inTracer()) {
auto ptr = (void*) fnptr;
auto functionArgumentReferences = getArgumentReferences("functionName", (void*) fnptr, args...);
tracing::traceCall(functionName, ptr, Type::v, functionArgumentReferences);
return;
const std::type_info &ti = typeid(fnptr);
auto functionArgumentReferences = getArgumentReferences(std::forward<FunctionArgumentsRaw>(args)...);
tracing::traceCall(reinterpret_cast<void*>(fnptr), ti, Type::v, functionArgumentReferences);
return;
}
#endif
fnptr(transform((args))...);
fnptr(details::getRawValue(std::forward<FunctionArgumentsRaw>(args))...);
}

template <is_integral... FunctionArgumentsRaw>
auto invoke(FunctionArgumentsRaw... args) {
return (*this)(args...);
auto invoke(FunctionArgumentsRaw&&... args) {
return (*this)(std::forward<FunctionArgumentsRaw>(args)...);
}

private:
R (*fnptr)(FunctionArguments...);
std::string functionName;
};

template <typename R, typename... FunctionArguments, typename... ValueArguments>
auto invoke(R (*fnptr)(FunctionArguments...), ValueArguments... args) {
[[maybe_unused]] auto name = getFunctionName(fnptr);
return CallableRuntimeFunction<R, FunctionArguments...>(fnptr, name)(args...);
auto invoke(R (*fnptr)(FunctionArguments...), ValueArguments&&... args) {
return CallableRuntimeFunction<R, FunctionArguments...>(fnptr)(std::forward<ValueArguments>(args)...);
}

template <typename R, typename... FunctionArguments, typename... ValueArguments>
auto invoke(std::function<R(FunctionArguments...)> func, ValueArguments... args) {
typedef R (*DecisionFn)(FunctionArguments...);
DecisionFn fnptr = func.template target<R(FunctionArguments...)>();
[[maybe_unused]] auto name = getFunctionName(fnptr);
return CallableRuntimeFunction<R, FunctionArguments...>(fnptr, name)(args...);
auto invoke(std::function<R(FunctionArguments...)> func, ValueArguments&&... args) {
auto fnptr = func.template target<R(FunctionArguments...)>();
return CallableRuntimeFunction<R, FunctionArguments...>(fnptr)(std::forward<ValueArguments>(args)...);
}

template <is_fundamental... FunctionArguments, typename... ValueArguments>
void invoke(void (*fnptr)(FunctionArguments...), ValueArguments... args) {
[[maybe_unused]] auto name = getFunctionName(fnptr);
auto func = CallableRuntimeFunction<void, FunctionArguments...>(fnptr, name);
func(args...);
void invoke(void (*fnptr)(FunctionArguments...), ValueArguments&&... args) {
auto func = CallableRuntimeFunction<void, FunctionArguments...>(fnptr);
func(std::forward<ValueArguments>(args)...);
}

template <class>
constexpr bool is_reference_wrapper_v = false;
template <class U>
constexpr bool is_reference_wrapper_v<std::reference_wrapper<U>> = true;
template <typename R, typename... FunctionArguments>
auto function(R (*fnptr)(FunctionArguments...)) {
return CallableRuntimeFunction<R, FunctionArguments...>(fnptr);
}

template <class T>
using remove_cvref_t = std::remove_cv_t<std::remove_reference_t<T>>;
class MemberFuncWrapper {};

template <typename R, typename... FunctionArguments>
auto Function(R (*fnptr)(FunctionArguments...)) {
return CallableNautilusFunction<R, FunctionArguments...>(fnptr);
template <typename T, typename Rp, typename Tp>
class MemberFuncWrapperImpl : public MemberFuncWrapper {
public:
MemberFuncWrapperImpl(T func)
: func(func), callableRuntimeFunction(function(+[](MemberFuncWrapper* ptr, Tp* clazzPtr) -> auto {
auto p = static_cast<MemberFuncWrapperImpl<T, Rp, Tp>*>(ptr);
// return p->func(clazzPtr);
Rp (Tp::*func)() = p->func;
return (*clazzPtr.*func)();
// return std::invoke(p->func, clazzPtr);
})) {};

template <typename... FunctionArgumentsRaw>
auto operator()(FunctionArgumentsRaw... args) {
auto state = val<MemberFuncWrapper*>(this);
return callableRuntimeFunction(state, args...);
}
T func;
CallableRuntimeFunction<Rp, MemberFuncWrapper*, Tp*> callableRuntimeFunction;
};

template <typename T>
struct member_function_traits;

template <typename C, typename R, typename... Args>
struct member_function_traits<R (C::*)(Args...)> {
using class_type = C;
using return_type = R;
using arg_types = std::tuple<Args...>;
};

template <typename T>
auto& memberFunc(T func) {
using traits = member_function_traits<T>;
using ClassType = typename traits::class_type;
using ReturnType = typename traits::return_type;
//using ArgTypes = typename traits::arg_types;
auto ptr = new MemberFuncWrapperImpl<T, ReturnType, ClassType>(func);
return *ptr;
}

} // namespace nautilus
8 changes: 4 additions & 4 deletions nautilus/include/nautilus/val.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace nautilus {

namespace details {
template <typename LHS>
LHS getRawValue(val<LHS>& val);
LHS getRawValue(const val<LHS>& val);

#define COMMON_RETURN_TYPE val<typename std::common_type<typename LHS::basic_type, typename RHS::basic_type>::type>

Expand Down Expand Up @@ -164,10 +164,9 @@ class val<ValueType> {
const tracing::TypedValueRefHolder state;
#endif
private:
friend ValueType details::getRawValue<ValueType>(val<ValueType>& left);
friend ValueType details::getRawValue<ValueType>(const val<ValueType>& left);
ValueType value;


template <is_arithmetic LHS, is_arithmetic RHS>
friend COMMON_RETURN_TYPE mul(val<LHS>& left, val<RHS>& right);

Expand Down Expand Up @@ -405,10 +404,11 @@ val<LHS> neg(val<LHS>& val) {
}

template <typename LHS>
LHS inline getRawValue(val<LHS>& val) {
LHS inline getRawValue(const val<LHS>& val) {
return val.value;
}


} // namespace details

#define DEFINE_BINARY_OPERATOR(OP, FUNC) \
Expand Down
2 changes: 1 addition & 1 deletion nautilus/include/nautilus/val_enum.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class val<T> {
#endif

private:
friend T details::getRawValue<T>(val<T>& left);
friend T details::getRawValue<T>(const val<T>& left);
const T value;
};

Expand Down
3 changes: 1 addition & 2 deletions nautilus/include/nautilus/val_ptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#pragma once

#include <utility>
#include "nautilus/function.hpp"
#include "nautilus/val.hpp"

namespace nautilus {
Expand Down Expand Up @@ -94,7 +93,7 @@ class base_ptr_val {
#endif

#ifdef ENABLE_TRACING
tracing::TypedValueRefHolder state;
const tracing::TypedValueRefHolder state;
#endif
ValuePtrType value;
};
Expand Down
3 changes: 2 additions & 1 deletion nautilus/src/nautilus/api/std/cstring.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

#include <cstring>
#include <nautilus/function.hpp>
#include <nautilus/std/cstring.h>
namespace nautilus {
val<void*> memcpy(val<void*> dest, val<const void*> src, val<size_t> count) {
Expand Down Expand Up @@ -112,4 +113,4 @@ val<size_t> strlen(val<const char*> s) {
return invoke<>(
+[](const char* s) { return std::strlen(s); }, s);
}
} // namespace nautilus
} // namespace nautilus
2 changes: 1 addition & 1 deletion nautilus/src/nautilus/compiler/JITCompiler.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#include "nautilus/JITCompiler.hpp"
#include "nautilus/Executable.hpp"
#include "nautilus/compiler/backends/CompilationBackend.hpp"
#include "nautilus/config.hpp"
#include "nautilus/exceptions/RuntimeException.hpp"
#include <utility>

#ifdef ENABLE_COMPILER

#include "nautilus/compiler/backends/CompilationBackend.hpp"
#include "nautilus/tracing/TraceContext.hpp"
#include "nautilus/tracing/phases/SSACreationPhase.hpp"
#include "nautilus/tracing/phases/TraceToIRConversionPhase.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ namespace nautilus::compiler::ir {
ProxyCallOperation::ProxyCallOperation(OperationIdentifier identifier, const std::vector<Operation*>& inputArguments, Type resultType) : Operation(Operation::OperationType::ProxyCallOp, identifier, resultType, inputArguments) {
}

ProxyCallOperation::ProxyCallOperation(std::string functionSymbol, void* functionPtr, OperationIdentifier identifier, std::vector<Operation*> inputArguments, Type resultType)
: Operation(Operation::OperationType::ProxyCallOp, identifier, resultType, std::move(inputArguments)), mangedFunctionSymbol(std::move(functionSymbol)), functionPtr(functionPtr) {
ProxyCallOperation::ProxyCallOperation(const std::string& functionSymbol, const std::string& functionName, void* functionPtr, OperationIdentifier identifier, std::vector<Operation*> inputArguments, Type resultType)
: Operation(Operation::OperationType::ProxyCallOp, identifier, resultType, std::move(inputArguments)), mangedFunctionSymbol(functionSymbol), functionName(functionName), functionPtr(functionPtr) {
}

const std::vector<Operation*>& ProxyCallOperation::getInputArguments() const {
Expand All @@ -23,8 +23,7 @@ std::string ProxyCallOperation::toString() {
if (!identifier.toString().empty()) {
baseString = identifier.toString() + " = ";
}
baseString = baseString + "(";
// baseString = baseString + getFunctionSymbol() + "(";
baseString = baseString + getFunctionName() + "(";
if (!inputs.empty()) {
// baseString += inputArguments[0].lock()->getIdentifier().toString();
// for (int i = 1; i < (int) inputArguments.size(); ++i) {
Expand All @@ -41,7 +40,11 @@ std::string ProxyCallOperation::toString() {
return baseString + ")";
}

std::string ProxyCallOperation::getFunctionSymbol() {
const std::string& ProxyCallOperation::getFunctionName() {
return functionName;
}

const std::string& ProxyCallOperation::getFunctionSymbol() {
return mangedFunctionSymbol;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,24 @@ class ProxyCallOperation : public Operation {
public:
ProxyCallOperation(OperationIdentifier identifier, const std::vector<Operation*>& inputArguments, Type resultType);

ProxyCallOperation(std::string mangedFunctionSymbol, void* functionPtr, OperationIdentifier identifier,
std::vector<Operation*> inputArguments, Type resultType);
ProxyCallOperation(const std::string& functionSymbol, const std::string& functionName, void* functionPtr, OperationIdentifier identifier, std::vector<Operation*> inputArguments, Type resultType);

~ProxyCallOperation() override = default;

const std::vector<Operation*>& getInputArguments() const;

void setInputArguments(std::vector<Operation*>& newInputArguments);

std::string getFunctionSymbol();
const std::string& getFunctionSymbol();
const std::string& getFunctionName();

std::string toString() override;

void* getFunctionPtr();

private:
std::string mangedFunctionSymbol;
const std::string mangedFunctionSymbol;
const std::string functionName;
void* functionPtr;
};
} // namespace nautilus::compiler::ir
Loading