From 2ace87eb0866dcf6ed5d5a0a439f3d66668773da Mon Sep 17 00:00:00 2001 From: "ausk(jinlj)" Date: Tue, 25 Oct 2022 20:32:48 +0800 Subject: [PATCH 1/2] Fix yolop tensorrt subproject: not rely on Zed; support opencv built without cuda; generate engine file and test it --- toolkits/deploy/CMakeLists.txt | 40 +- toolkits/deploy/gen_wts.py | 2 +- toolkits/deploy/logging.h | 490 +----------------- toolkits/deploy/utils.h | 12 +- toolkits/deploy/yololayer.h | 4 +- .../deploy/{infer_files.cpp => yolov5.cpp} | 2 +- toolkits/deploy/yolov5.hpp | 12 +- 7 files changed, 50 insertions(+), 512 deletions(-) rename toolkits/deploy/{infer_files.cpp => yolov5.cpp} (97%) diff --git a/toolkits/deploy/CMakeLists.txt b/toolkits/deploy/CMakeLists.txt index ec9af2a9a..df42e4ea7 100644 --- a/toolkits/deploy/CMakeLists.txt +++ b/toolkits/deploy/CMakeLists.txt @@ -8,12 +8,18 @@ option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) set(CMAKE_CXX_STANDARD 11) set(CMAKE_BUILD_TYPE Release) - -find_package(ZED 3 REQUIRED) -find_package(CUDA ${ZED_CUDA_VERSION} EXACT REQUIRED) +find_package(ZED 3) +if(ZED_FOUND) + find_package(CUDA ${ZED_CUDA_VERSION} EXACT REQUIRED) +else(ZED_FOUND) + find_package(CUDA REQUIRED) +endif(ZED_FOUND) include_directories(${PROJECT_SOURCE_DIR}/include) +find_package(OpenCV REQUIRED) +include_directories(${OpenCV_INCLUDE_DIRS}) + # cuda include_directories(/usr/local/cuda-10.2/include) link_directories(/usr/local/cuda-10.2/lib64) @@ -26,20 +32,24 @@ link_directories(/usr/local/zed/lib) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Ofast -Wfatal-errors -D_MWAITXINTRIN_H_INCLUDED") -set(ZED_LIBS ${ZED_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUDART_LIBRARY}) - -coda_add_library(myplugins SHARED ${PROJECT_SOURCE_DIR}/yololayer.cu) +# to generate plugins +cuda_add_library(myplugins SHARED ${PROJECT_SOURCE_DIR}/yololayer.cu) target_link_libraries(myplugins nvinfer cudart) -find_package(OpenCV REQUIRED) -include_directories(${OpenCV_INCLUDE_DIRS}) +# to generate trt and test image dir +add_executable(yolov5 ${PROJECT_SOURCE_DIR}/yolov5.cpp) +target_link_libraries(yolov5 nvinfer cudart myplugins ${OpenCV_LIBS}) +add_definitions(-O3 -pthread) -add_executable(yolop ${PROJECT_SOURCE_DIR}/main.cpp) -target_link_libraries(yolop nvinfer) -target_link_libraries(yolop ${ZED_LIBS}) -target_link_libraries(yolop cudart) -target_link_libraries(yolop myplugins) -target_link_libraries(yolop ${OpenCV_LIBS}) -add_definitions(-O3 -pthread) +# to test with zed camera +if(ZED_FOUND) + set(ZED_LIBS ${ZED_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUDART_LIBRARY}) + add_executable(yolop ${PROJECT_SOURCE_DIR}/main.cpp) + target_link_libraries(yolop nvinfer) + target_link_libraries(yolop ${ZED_LIBS}) + target_link_libraries(yolop cudart) + target_link_libraries(yolop myplugins) + target_link_libraries(yolop ${OpenCV_LIBS}) +endif(ZED_FOUND) diff --git a/toolkits/deploy/gen_wts.py b/toolkits/deploy/gen_wts.py index 13c2db5fa..796ce64f3 100644 --- a/toolkits/deploy/gen_wts.py +++ b/toolkits/deploy/gen_wts.py @@ -11,7 +11,7 @@ device = torch.device('cpu') # Load model model = get_net(cfg) -checkpoint = torch.load('weights/End-to-end.pth', map_location=device) +checkpoint = torch.load(BASE_DIR + '/weights/End-to-end.pth', map_location=device) model.load_state_dict(checkpoint['state_dict']) # load to FP32 model.float() diff --git a/toolkits/deploy/logging.h b/toolkits/deploy/logging.h index 602b69fb5..7a4707b5b 100644 --- a/toolkits/deploy/logging.h +++ b/toolkits/deploy/logging.h @@ -1,22 +1,5 @@ -/* - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TENSORRT_LOGGING_H -#define TENSORRT_LOGGING_H - +// create by ausk(jinlj) 2022/10/25 +#pragma once #include "NvInferRuntimeCommon.h" #include #include @@ -28,476 +11,13 @@ using Severity = nvinfer1::ILogger::Severity; -class LogStreamConsumerBuffer : public std::stringbuf -{ -public: - LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) - : mOutput(stream) - , mPrefix(prefix) - , mShouldLog(shouldLog) - { - } - - LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) - : mOutput(other.mOutput) - { - } - - ~LogStreamConsumerBuffer() - { - // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence - // std::streambuf::pptr() gives a pointer to the current position of the output sequence - // if the pointer to the beginning is not equal to the pointer to the current position, - // call putOutput() to log the output to the stream - if (pbase() != pptr()) - { - putOutput(); - } - } - - // synchronizes the stream buffer and returns 0 on success - // synchronizing the stream buffer consists of inserting the buffer contents into the stream, - // resetting the buffer and flushing the stream - virtual int sync() - { - putOutput(); - return 0; - } - - void putOutput() - { - if (mShouldLog) - { - // prepend timestamp - std::time_t timestamp = std::time(nullptr); - tm* tm_local = std::localtime(×tamp); - std::cout << "["; - std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; - std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; - std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; - std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; - std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; - std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; - // std::stringbuf::str() gets the string contents of the buffer - // insert the buffer contents pre-appended by the appropriate prefix into the stream - mOutput << mPrefix << str(); - // set the buffer to empty - str(""); - // flush the stream - mOutput.flush(); - } - } - - void setShouldLog(bool shouldLog) - { - mShouldLog = shouldLog; - } - -private: - std::ostream& mOutput; - std::string mPrefix; - bool mShouldLog; -}; - -//! -//! \class LogStreamConsumerBase -//! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer -//! -class LogStreamConsumerBase -{ -public: - LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) - : mBuffer(stream, prefix, shouldLog) - { - } - -protected: - LogStreamConsumerBuffer mBuffer; -}; - -//! -//! \class LogStreamConsumer -//! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. -//! Order of base classes is LogStreamConsumerBase and then std::ostream. -//! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field -//! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. -//! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. -//! Please do not change the order of the parent classes. -//! -class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream -{ -public: - //! \brief Creates a LogStreamConsumer which logs messages with level severity. - //! Reportable severity determines if the messages are severe enough to be logged. - LogStreamConsumer(Severity reportableSeverity, Severity severity) - : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) - , std::ostream(&mBuffer) // links the stream buffer with the stream - , mShouldLog(severity <= reportableSeverity) - , mSeverity(severity) - { - } - - LogStreamConsumer(LogStreamConsumer&& other) - : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) - , std::ostream(&mBuffer) // links the stream buffer with the stream - , mShouldLog(other.mShouldLog) - , mSeverity(other.mSeverity) - { - } - - void setReportableSeverity(Severity reportableSeverity) - { - mShouldLog = mSeverity <= reportableSeverity; - mBuffer.setShouldLog(mShouldLog); - } - -private: - static std::ostream& severityOstream(Severity severity) - { - return severity >= Severity::kINFO ? std::cout : std::cerr; - } - - static std::string severityPrefix(Severity severity) - { - switch (severity) - { - case Severity::kINTERNAL_ERROR: return "[F] "; - case Severity::kERROR: return "[E] "; - case Severity::kWARNING: return "[W] "; - case Severity::kINFO: return "[I] "; - case Severity::kVERBOSE: return "[V] "; - default: assert(0); return ""; - } - } - - bool mShouldLog; - Severity mSeverity; -}; - -//! \class Logger -//! -//! \brief Class which manages logging of TensorRT tools and samples -//! -//! \details This class provides a common interface for TensorRT tools and samples to log information to the console, -//! and supports logging two types of messages: -//! -//! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) -//! - Test pass/fail messages -//! -//! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is -//! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. -//! -//! In the future, this class could be extended to support dumping test results to a file in some standard format -//! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). -//! -//! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger -//! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT -//! library and messages coming from the sample. -//! -//! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the -//! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger -//! object. - class Logger : public nvinfer1::ILogger { public: - Logger(Severity severity = Severity::kWARNING) - : mReportableSeverity(severity) - { - } - - //! - //! \enum TestResult - //! \brief Represents the state of a given test - //! - enum class TestResult - { - kRUNNING, //!< The test is running - kPASSED, //!< The test passed - kFAILED, //!< The test failed - kWAIVED //!< The test was waived - }; - - //! - //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger - //! \return The nvinfer1::ILogger associated with this Logger - //! - //! TODO Once all samples are updated to use this method to register the logger with TensorRT, - //! we can eliminate the inheritance of Logger from ILogger - //! - nvinfer1::ILogger& getTRTLogger() - { - return *this; - } - - //! - //! \brief Implementation of the nvinfer1::ILogger::log() virtual method - //! - //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the - //! inheritance from nvinfer1::ILogger - //! - void log(Severity severity, const char* msg) override - { - LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; - } - - //! - //! \brief Method for controlling the verbosity of logging output - //! - //! \param severity The logger will only emit messages that have severity of this level or higher. - //! - void setReportableSeverity(Severity severity) + void log(Severity severity, const char* msg) noexcept override { - mReportableSeverity = severity; - } - - //! - //! \brief Opaque handle that holds logging information for a particular test - //! - //! This object is an opaque handle to information used by the Logger to print test results. - //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used - //! with Logger::reportTest{Start,End}(). - //! - class TestAtom - { - public: - TestAtom(TestAtom&&) = default; - - private: - friend class Logger; - - TestAtom(bool started, const std::string& name, const std::string& cmdline) - : mStarted(started) - , mName(name) - , mCmdline(cmdline) - { + if (severity < Severity::kINFO) { + std::cout << msg << std::endl; } - - bool mStarted; - std::string mName; - std::string mCmdline; - }; - - //! - //! \brief Define a test for logging - //! - //! \param[in] name The name of the test. This should be a string starting with - //! "TensorRT" and containing dot-separated strings containing - //! the characters [A-Za-z0-9_]. - //! For example, "TensorRT.sample_googlenet" - //! \param[in] cmdline The command line used to reproduce the test - // - //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). - //! - static TestAtom defineTest(const std::string& name, const std::string& cmdline) - { - return TestAtom(false, name, cmdline); - } - - //! - //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments - //! as input - //! - //! \param[in] name The name of the test - //! \param[in] argc The number of command-line arguments - //! \param[in] argv The array of command-line arguments (given as C strings) - //! - //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). - static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) - { - auto cmdline = genCmdlineString(argc, argv); - return defineTest(name, cmdline); - } - - //! - //! \brief Report that a test has started. - //! - //! \pre reportTestStart() has not been called yet for the given testAtom - //! - //! \param[in] testAtom The handle to the test that has started - //! - static void reportTestStart(TestAtom& testAtom) - { - reportTestResult(testAtom, TestResult::kRUNNING); - assert(!testAtom.mStarted); - testAtom.mStarted = true; - } - - //! - //! \brief Report that a test has ended. - //! - //! \pre reportTestStart() has been called for the given testAtom - //! - //! \param[in] testAtom The handle to the test that has ended - //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, - //! TestResult::kFAILED, TestResult::kWAIVED - //! - static void reportTestEnd(const TestAtom& testAtom, TestResult result) - { - assert(result != TestResult::kRUNNING); - assert(testAtom.mStarted); - reportTestResult(testAtom, result); } - - static int reportPass(const TestAtom& testAtom) - { - reportTestEnd(testAtom, TestResult::kPASSED); - return EXIT_SUCCESS; - } - - static int reportFail(const TestAtom& testAtom) - { - reportTestEnd(testAtom, TestResult::kFAILED); - return EXIT_FAILURE; - } - - static int reportWaive(const TestAtom& testAtom) - { - reportTestEnd(testAtom, TestResult::kWAIVED); - return EXIT_SUCCESS; - } - - static int reportTest(const TestAtom& testAtom, bool pass) - { - return pass ? reportPass(testAtom) : reportFail(testAtom); - } - - Severity getReportableSeverity() const - { - return mReportableSeverity; - } - -private: - //! - //! \brief returns an appropriate string for prefixing a log message with the given severity - //! - static const char* severityPrefix(Severity severity) - { - switch (severity) - { - case Severity::kINTERNAL_ERROR: return "[F] "; - case Severity::kERROR: return "[E] "; - case Severity::kWARNING: return "[W] "; - case Severity::kINFO: return "[I] "; - case Severity::kVERBOSE: return "[V] "; - default: assert(0); return ""; - } - } - - //! - //! \brief returns an appropriate string for prefixing a test result message with the given result - //! - static const char* testResultString(TestResult result) - { - switch (result) - { - case TestResult::kRUNNING: return "RUNNING"; - case TestResult::kPASSED: return "PASSED"; - case TestResult::kFAILED: return "FAILED"; - case TestResult::kWAIVED: return "WAIVED"; - default: assert(0); return ""; - } - } - - //! - //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity - //! - static std::ostream& severityOstream(Severity severity) - { - return severity >= Severity::kINFO ? std::cout : std::cerr; - } - - //! - //! \brief method that implements logging test results - //! - static void reportTestResult(const TestAtom& testAtom, TestResult result) - { - severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " - << testAtom.mCmdline << std::endl; - } - - //! - //! \brief generate a command line string from the given (argc, argv) values - //! - static std::string genCmdlineString(int argc, char const* const* argv) - { - std::stringstream ss; - for (int i = 0; i < argc; i++) - { - if (i > 0) - ss << " "; - ss << argv[i]; - } - return ss.str(); - } - - Severity mReportableSeverity; }; - -namespace -{ - -//! -//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE -//! -//! Example usage: -//! -//! LOG_VERBOSE(logger) << "hello world" << std::endl; -//! -inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) -{ - return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); -} - -//! -//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO -//! -//! Example usage: -//! -//! LOG_INFO(logger) << "hello world" << std::endl; -//! -inline LogStreamConsumer LOG_INFO(const Logger& logger) -{ - return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); -} - -//! -//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING -//! -//! Example usage: -//! -//! LOG_WARN(logger) << "hello world" << std::endl; -//! -inline LogStreamConsumer LOG_WARN(const Logger& logger) -{ - return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); -} - -//! -//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR -//! -//! Example usage: -//! -//! LOG_ERROR(logger) << "hello world" << std::endl; -//! -inline LogStreamConsumer LOG_ERROR(const Logger& logger) -{ - return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); -} - -//! -//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR -// ("fatal" severity) -//! -//! Example usage: -//! -//! LOG_FATAL(logger) << "hello world" << std::endl; -//! -inline LogStreamConsumer LOG_FATAL(const Logger& logger) -{ - return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); -} - -} // anonymous namespace - -#endif // TENSORRT_LOGGING_H diff --git a/toolkits/deploy/utils.h b/toolkits/deploy/utils.h index 89808ad95..74b3acf3c 100644 --- a/toolkits/deploy/utils.h +++ b/toolkits/deploy/utils.h @@ -3,10 +3,14 @@ #include #include + +#ifdef HAVE_CUDA #include #include #include #include +#endif + #include #include "common.hpp" @@ -33,7 +37,7 @@ static inline cv::Mat preprocess_img(cv::Mat& img, int input_w, int input_h) { re.copyTo(out(cv::Rect(x, y, re.cols, re.rows))); cv::Mat tensor; out.convertTo(tensor, CV_32FC3, 1.f / 255.f); - + cv::subtract(tensor, cv::Scalar(0.485, 0.456, 0.406), tensor, cv::noArray(), -1); cv::divide(tensor, cv::Scalar(0.229, 0.224, 0.225), tensor, 1, -1); // std::cout << cv::format(out, cv::Formatter::FMT_NUMPY)<< std::endl; @@ -43,6 +47,7 @@ static inline cv::Mat preprocess_img(cv::Mat& img, int input_w, int input_h) { return tensor; } +#ifdef HAVE_CUDA void preprocess_img_gpu(cv::cuda::GpuMat& img, float* gpu_input, int input_w, int input_h) { int w, h, x, y; float r_w = input_w / (img.cols*1.0); @@ -77,6 +82,7 @@ void preprocess_img_gpu(cv::cuda::GpuMat& img, float* gpu_input, int input_w, in } cv::cuda::split(tensor, chw); } +#endif static inline int read_files_in_dir(const char *p_dir_name, std::vector &file_names) { DIR *p_dir = opendir(p_dir_name); @@ -111,6 +117,7 @@ void PrintMat(cv::Mat &A) std::cout << std::endl; } +#ifdef HAVE_CUDA void visualization(cv::cuda::GpuMat& cvt_img, cv::Mat& seg_res, cv::Mat& lane_res, std::vector& res, char& key) { static const std::vector segColor{cv::Vec3b(0, 0, 0), cv::Vec3b(0, 255, 0), cv::Vec3b(255, 0, 0)}; @@ -143,7 +150,7 @@ void visualization(cv::cuda::GpuMat& cvt_img, cv::Mat& seg_res, cv::Mat& lane_re cv::rectangle(cvt_img_cpu, r, cv::Scalar(0x27, 0xC1, 0x36), 2); cv::putText(cvt_img_cpu, std::to_string((int)res[j].class_id), cv::Point(r.x, r.y - 1), cv::FONT_HERSHEY_PLAIN, 1.2, cv::Scalar(0xFF, 0xFF, 0xFF), 2); } - + #ifdef SHOW_IMG cv::imshow("img", cvt_img_cpu); key = cv::waitKey(1); @@ -151,5 +158,6 @@ void visualization(cv::cuda::GpuMat& cvt_img, cv::Mat& seg_res, cv::Mat& lane_re cv::imwrite("../zed_result.jpg", cvt_img_cpu); #endif } +#endif #endif // TRTX_YOLOV5_UTILS_H_ \ No newline at end of file diff --git a/toolkits/deploy/yololayer.h b/toolkits/deploy/yololayer.h index 6406a6af7..784f3551c 100644 --- a/toolkits/deploy/yololayer.h +++ b/toolkits/deploy/yololayer.h @@ -16,7 +16,7 @@ namespace Yolo float anchors[CHECK_COUNT * 2]; }; static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000; - static constexpr int CLASS_NUM = 13; + static constexpr int CLASS_NUM = 1; static constexpr int INPUT_H = 384; static constexpr int INPUT_W = 640; static constexpr int IMG_H = 360; @@ -140,4 +140,4 @@ namespace nvinfer1 REGISTER_TENSORRT_PLUGIN(YoloPluginCreator); }; -#endif +#endif diff --git a/toolkits/deploy/infer_files.cpp b/toolkits/deploy/yolov5.cpp similarity index 97% rename from toolkits/deploy/infer_files.cpp rename to toolkits/deploy/yolov5.cpp index f7f8281b3..d336d6942 100644 --- a/toolkits/deploy/infer_files.cpp +++ b/toolkits/deploy/yolov5.cpp @@ -139,7 +139,7 @@ int main(int argc, char** argv) { auto& res = batch_res[b]; nms(res, &prob[b * OUTPUT_SIZE], CONF_THRESH, NMS_THRESH); } - + // show results for (int b = 0; b < fcount; ++b) { auto& res = batch_res[b]; diff --git a/toolkits/deploy/yolov5.hpp b/toolkits/deploy/yolov5.hpp index d1ef7a9ee..f4f51ab8b 100644 --- a/toolkits/deploy/yolov5.hpp +++ b/toolkits/deploy/yolov5.hpp @@ -5,7 +5,7 @@ #include "cuda_utils.h" #include "logging.h" #include "utils.h" -#include "calibrator.h" +//#include "calibrator.h" #define USE_FP16 // set USE_INT8 or USE_FP16 or USE_FP32 #define DEVICE 0 // GPU id @@ -101,7 +101,7 @@ ICudaEngine* build_engine(unsigned int maxBatchSize, IBuilder* builder, IBuilder IDeconvolutionLayer* deconv26 = network->addDeconvolutionNd(*conv25->getOutput(0), 128, DimsHW{ 2, 2 }, deconvwts26, emptywts); deconv26->setStrideNd(DimsHW{ 2, 2 }); deconv26->setNbGroups(128); - + auto bottleneck_csp27 = bottleneckCSP(network, weightMap, *deconv26->getOutput(0), 128, 64, 1, false, 1, 0.5, "model.27"); auto conv28 = convBlock(network, weightMap, *bottleneck_csp27->getOutput(0), 32, 3, 1, 1, "model.28"); // upsample 29 @@ -119,9 +119,9 @@ ICudaEngine* build_engine(unsigned int maxBatchSize, IBuilder* builder, IBuilder deconv32->setStrideNd(DimsHW{ 2, 2 }); deconv32->setNbGroups(8); - auto conv33 = convBlock(network, weightMap, *deconv32->getOutput(0), 3, 3, 1, 1, "model.33"); + auto conv33 = convBlock(network, weightMap, *deconv32->getOutput(0), 2, 3, 1, 1, "model.33"); // segmentation output - ISliceLayer *slicelayer = network->addSlice(*conv33->getOutput(0), Dims3{ 0, (Yolo::INPUT_H - Yolo::IMG_H) / 2, 0 }, Dims3{ 3, Yolo::IMG_H, Yolo::IMG_W }, Dims3{ 1, 1, 1 }); + ISliceLayer *slicelayer = network->addSlice(*conv33->getOutput(0), Dims3{ 0, (Yolo::INPUT_H - Yolo::IMG_H) / 2, 0 }, Dims3{ 2, Yolo::IMG_H, Yolo::IMG_W }, Dims3{ 1, 1, 1 }); auto segout = network->addTopK(*slicelayer->getOutput(0), TopKOperation::kMAX, 1, 1); segout->getOutput(1)->setName(OUTPUT_SEG_NAME); @@ -135,7 +135,7 @@ ICudaEngine* build_engine(unsigned int maxBatchSize, IBuilder* builder, IBuilder auto bottleneck_csp36 = bottleneckCSP(network, weightMap, *deconv35->getOutput(0), 128, 64, 1, false, 1, 0.5, "model.36"); auto conv37 = convBlock(network, weightMap, *bottleneck_csp36->getOutput(0), 32, 3, 1, 1, "model.37"); - + // upsample38 Weights deconvwts38{ DataType::kFLOAT, deval, 32 * 2 * 2 }; IDeconvolutionLayer* deconv38 = network->addDeconvolutionNd(*conv37->getOutput(0), 32, DimsHW{ 2, 2 }, deconvwts38, emptywts); @@ -156,7 +156,7 @@ ICudaEngine* build_engine(unsigned int maxBatchSize, IBuilder* builder, IBuilder ISliceLayer *laneSlice = network->addSlice(*conv42->getOutput(0), Dims3{ 0, (Yolo::INPUT_H - Yolo::IMG_H) / 2, 0 }, Dims3{ 2, Yolo::IMG_H, Yolo::IMG_W }, Dims3{ 1, 1, 1 }); auto laneout = network->addTopK(*laneSlice->getOutput(0), TopKOperation::kMAX, 1, 1); laneout->getOutput(1)->setName(OUTPUT_LANE_NAME); - + // // std::cout << std::to_string(slicelayer->getOutput(0)->getDimensions().d[0]) << std::endl; // // ISliceLayer *tmp1 = network->addSlice(*slicelayer->getOutput(0), Dims3{ 0, 0, 0 }, Dims3{ 1, (Yolo::INPUT_H - 2 * Yolo::PAD_H), Yolo::INPUT_W }, Dims3{ 1, 1, 1 }); // // ISliceLayer *tmp2 = network->addSlice(*slicelayer->getOutput(0), Dims3{ 1, 0, 0 }, Dims3{ 1, (Yolo::INPUT_H - 2 * Yolo::PAD_H), Yolo::INPUT_W }, Dims3{ 1, 1, 1 }); From b1d7d9384f8953e8b75f41596e3fa76f241ae0ba Mon Sep 17 00:00:00 2001 From: "ausk(jinlj)" Date: Tue, 25 Oct 2022 20:33:42 +0800 Subject: [PATCH 2/2] Add yolop tensorrt README.md --- toolkits/deploy/README.md | 63 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 toolkits/deploy/README.md diff --git a/toolkits/deploy/README.md b/toolkits/deploy/README.md new file mode 100644 index 000000000..20bd07d58 --- /dev/null +++ b/toolkits/deploy/README.md @@ -0,0 +1,63 @@ +YoloP TensorRT Usage(简要说明) +===== + + +## 1. 准备构建环境 | Prepare building environments + +Make sure you have install `c++`(support c++11)、 `cmake`、`opencv`(4.x)、`cuda`(10.x)、`nvinfer`(7.x). + +Now, zedcam is not necessary! + +And we can also use opencv that built without cuda. + +## 2. 编译 | build + +Go to `YOLOP/toolkits/deploy`. + +``` +mkdir build +cd build + +cmake .. +make +``` + +Now you can get `yolov5` and `libmyplugins.so`. +If you have Zed installed, you can get `yolop`. + + +## 3. 生成和测试 trt | Generate and test trt + +Go to build dir (`YOLOP/toolkits/deploy/build`). + +### 3.1 gen wts +``` +python3 ../gen_wts.py +``` + +### 3.2 gen trt +``` +./yolov5 -s yolop.wts yolop.trt s +``` + +### 3.3 test trt +``` +mkdir ../results +./yolov5 -d yolop.trt ../../../inference/images/ +``` + +It will output like as follow if successful! (`Jetson Xavier NX - Jetpack 4.4`) +``` +1601ms +26ms +26ms +26ms +26ms +28ms +``` + +![](build/results/_3c0e7240-96e390d2.jpg) + + + +