개요

우선 공식 샘플을 이용한 것이 훨씬 간단하므로 그것을 살펴보고, 그 후에 직접 구현하는 것을 사용해서 내부가 어떻게 구성되었는지를 파악하자.

Build

NVIDIA 샘플 코드를 프로젝트에 추가하는 방법은 아래 링크 참조

TensorRT C++/ NVIDIA Sample 소스를 프로젝트에 추가 하기

Trt 엔진 파일을 build 하기 위해서는 builder, network, config, parser가 필요하다. 추가로 builder와 parser를 생성하기 위해 logger가 필요한데, 이는 Sample 코드의 것을 사용하면 된다. 아래와 같이 build 코드를 만들 수 있다.

#include <string>
#include <NvInfer.h>
#include <NvOnnxParser.h>  // onnx 파일의 parser. TensorRT는 이 외에 Caffe, UFF parser를 지원한다.
#include "common/common.h"  // Nvidia의 sample 코드

using namespace std;
using namespace nvinfer1;
using namespace nvonnxparser;
using namespace samplesCommon;

bool BuildEngine(
    const string& pathOnnxModelFile,
    const string& pathEngineFile,
    const size_t sizeWorkSpaceMax,
    const int sizeBatchMax,
    const int batchCount,
    const int channels,
    const int width,
    const int height
)
{
    // builder, network, config, parser를 만든다.
		// builder, network, config, parser 등은 unique_ptr을 사용할 수 없기 때문에, NVIDIA Sample에서 unique_ptr 처럼 사용할 수 있도록 만들어둔 SampleUniquePtr을 사용한다.
    SampleUniquePtr<IBuilder> builder = SampleUniquePtr<IBuilder>(createInferBuilder(sample::gLogger.getTRTLogger()));

    if (builder)
    {
				// flag 값은 공식문서에 나와 있는 것을 그대로 따른다.
        uint32_t flag = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
        SampleUniquePtr<INetworkDefinition> network = SampleUniquePtr<INetworkDefinition>(builder->createNetworkV2(flag));

        if (network)
        {
            SampleUniquePtr<IBuilderConfig> config = SampleUniquePtr<IBuilderConfig>(builder->createBuilderConfig());

            if (config)
            {
                SampleUniquePtr<IParser> parser = SampleUniquePtr<IParser>(createParser(*network, sample::gLogger.getTRTLogger()));

                if (parser)
                {
										// trt로 변환할 onnx 파일을 parse 한다. flag는 역시 공식 문서에 있는 것을 그대로 따른다
                    bool parsed = parser->parseFromFile(pathOnnxModelFile.c_str(), static_cast<int>(ILogger::Severity::kWARNING));

                    // error가 있는지 체크
                    for (int32_t i = 0; i < parser->getNbErrors(); ++i)
                    {
                        std::cout << parser->getError(i)->desc() << std::endl;
                    }

                    if (parsed)
                    {
												// profileStream을 만드는 것은 공식 문서에는 없고, Sample 코드에만 있다 - 없어도 build는 됨
                        auto profileStream = samplesCommon::makeCudaStream();
                        if (!profileStream)
                        {
                            return false;
                        }

                        config->setProfileStream(*profileStream);
                        config->setMaxWorkspaceSize(sizeWorkSpaceMax);
                        builder->setMaxBatchSize(sizeBatchMax);
                        network->getInput(0)->setDimensions(Dims4(batchCount, channels, width, height));

                        SampleUniquePtr<IHostMemory> engine = SampleUniquePtr<IHostMemory>(builder->buildSerializedNetwork(*network, *config));

                        if (engine)
                        {
														// 엔진이 만들어졌으면 지정된 경로에 바이너리 형태로 저장한다.
                            std::ofstream file(pathEngineFile, std::ios::out | std::ios::binary);
                            file.write((char*)(engine->data()), engine->size());
                            file.close();

		                        return true;
                        }
                    }
                }
            }
        }
    }

    return false;
}

만일 NVIDIA Sample 코드 없이 build를 구성하려면 위의 Logger와 SampleUniquePtr 부분만 직접 구현하면 된다. —makeCudaStream은 없어도 무방하니 생략

Logger는 공식문서에 나와 있는 대로 ILogger를 상속 받고 log 만 override 해서 만들면 된다. 아래 코드 참조.

#pragma once

#include <NvInfer.h>
#include <iostream>

using namespace nvinfer1;

// 클래스 이름은 자신이 원하는 것으로 정의
// logger는 전역 변수로 만들어서 여러 곳에서 공통적으로 사용할 수 있게 사용하면 된다.
class TrtLogger : public ILogger
{
		// log는 noexcept를 해야 에러나지 않는다. --공식 문서에는 noexcept가 안 써 있음
    void log(Severity severity, const char* msg) noexcept override
    {
        // suppress info-level messages
        if (severity <= Severity::kWARNING)
        {
            std::cout << msg << std::endl;
        }
    }
};

SampleUniquePtr은 Deleter는 아래와 같이 생겼으며, 이를 참조하여 자신이 원하는 이름으로 바꾸어서 정의한 후 사용하면 된다.

struct InferDeleter
{
    template <typename T>
    void operator()(T* obj) const
    {
        delete obj;
    }
};

template <typename T>
using SampleUniquePtr = std::unique_ptr<T, InferDeleter>;