Skip to content

Inference plugin

Intro

A inference plugin allows directly interfacing with a inference backend, this could be a driver or hardware.

It's important to understand that CVEDIA-RT will do all the pre and post processing of the tensors, the inference plugin is responsible for:

  1. Setting up the device
  2. Identifiying and registering device capabilities and driver versions
  3. Enumerating how many backends of the same type are available
  4. Thread safety
  5. Loading models
  6. Forward pass
  7. Unloading models
  8. Gracefully shutting down

Key methods

All sources for the methods described here can be found at out github plugin repo.

We will be mainly talking about core (mnncore.cpp) methods.

1. getCapabilities

Here you can return any specifics about the device, this depends if you're bound to a driver, firmware, etc.

This is all optional, you can return a empty VAL().

Example:

pCValue MNNCore::getCapabilities() {
    cmap propmap;

    propmap["firmware"] = VAL(std::string("1.0.0"));
    return VAL(propmap);
}

2. getDeviceGuids

This function complements the declaration within inferencehandler where you will register a backend and a file format:

extern "C" EXPORT void registerHandler() {
    api::inference::registerSchemeHandler("mnn", &MNNInferenceHandler::create);
    api::inference::registerExtHandler(".mnn", &MNNInferenceHandler::create);
}

In the core method, you will define what backends are available, potentially scanning for available devices within the hardware.

For example, if there's multiple NVIDIA GPUs, the user can address individual cards with tensorrt.1://... where tensorrt comes from registerHandler and 1 from getDeviceGuids.

If the device in question doesn't support or care about individual addressing you can simply return auto, like in the MNN function below:

std::vector<std::pair<std::string, std::string>> MNNCore::getDeviceGuids() {
    std::vector<std::pair<std::string, std::string>> out;

    out.push_back(std::make_pair(string("auto"), "Runs on best available device"));

    return out;
}

3. setDevice

While getDeviceGuids list available devices, setDevice set the context to a specific device.

MNN example:

expected<void> MNNCore::setDevice(std::string const& device) {

    LOGD << "Setting device to " << device;

    if (!loadBackend()) {
        return unexpected(RTErrc::NoSuchDevice);
    }

    return {};
}

4. loadModel

This method should be implemented in two ways, first with a file path and second with weights as input.

A model in the compatible format for the backend is expected to be loaded into the device, this happens once after we setDevice in the very first runInference call.

CVEDIA-RT can load models directly from our distribution platform (model forge) or straight from disk, that's why there's two implementations of the same method.

expected<void> MNNCore::loadModel(string const& path) {

    auto weights = readFile(path);

    return loadModel(path, weights);
}

expected<void> MNNCore::loadModel(string const& path, std::vector<unsigned char> const& weights) {

    unique_lock<mutex> m(sessMux_);

    auto ptr = MNN::Interpreter::createFromBuffer(weights.data(), weights.size());

    if (ptr) {
        network_ = std::shared_ptr<MNN::Interpreter>(ptr);

        MNN::ScheduleConfig config;
        config.type = MNN_FORWARD_AUTO;

        MNN::BackendConfig backendConfig;
        backendConfig.precision = MNN::BackendConfig::Precision_Normal;
        backendConfig.memory = MNN::BackendConfig::Memory_Normal;
        backendConfig.power = MNN::BackendConfig::Power_Normal;

        config.backendConfig = &backendConfig;

        LOGD << "Creating network session";
        session_ = network_->createSession(config);
        if (!session_) {
            LOGE << "createSession returned nullptr";
            return unexpected(RTErrc::OperationFailed);
        }

        LOGD << "Input tensors";
        auto inputs = network_->getSessionInputAll(session_);

        if (inputs.size() != 1) {
            LOGE << "Found " << inputs.size() << " input tensors. Only one supported currently";
            return unexpected(RTErrc::UnsupportedModel);
        }

        for (auto const& input : inputs) {
            for (auto const& s : input.second->shape()) {
                inputShape_.push_back(s);
            }

            LOGD << "- " << input.first << " (" << shapeToString(inputShape_) << ")";

            deviceInputTensor_ = input.second;

            // Only support 1 input
            break;
        }

        LOGD << "Output tensors";
        auto outputs = network_->getSessionOutputAll(session_);
        for (auto const& output : outputs) {
            std::vector<int> op;

            op = output.second->shape();

            size_t total = 1;
            for (auto const& s : output.second->shape()) {
                total *= static_cast<size_t>(s);
            }

            std::stringstream outShapeStr;
            std::copy(op.begin(), op.end(), std::ostream_iterator<int>(outShapeStr, " "));

            deviceOutputTensors_.push_back(output.second);

            LOGD << "- " << output.first << " (" << shapeToString(op) << ")";

            outputShape_.push_back(op);
            outSize_.push_back(total);
        }

        hostInputTensor_ = new MNN::Tensor(deviceInputTensor_, deviceInputTensor_->getDimensionType());
        for (auto t : deviceOutputTensors_) {
            hostOutputTensors_.push_back(new MNN::Tensor(t, t->getDimensionType()));
        }

        modelLoaded_ = true;

        network_->releaseModel();

        return {};
    }
    else {
        modelLoaded_ = false;

        LOGE << "MNN failed to load model at " << path;
        return unexpected(RTErrc::LoadModelFailed);
    }
}

5. runInference

Here is where everything comes toguether, this method receives a tensor with the input already normalized (depending on the model configuration).

This data now needs to be send to the backend in a type the driver understands, so you might need to manipulate this object. CVEDIA-RT uses xtensor library, which allow easy and hardware accelerated vector transformation.

Note that, if your platform works with quantized data you may need to transform the float tensor into a suitable type. The same happens when the backend replies, if data needs to be dequantized / transformed , you will have to handle it here.

MNN example:

expected<vector<xt::xarray<float>>> MNNCore::runInference(std::vector<cvedia::rt::Tensor>& input) {

    unique_lock<mutex> m(sessMux_);

    vector<xt::xarray<float>> output;
    if (input.empty())
        return output;

    auto data = input[0].move<float>();
    memcpy(hostInputTensor_->host<float>(), data.data(), data.size() * sizeof(float));

    // Copy input data to MNN
    deviceInputTensor_->copyFromHostTensor(hostInputTensor_);

    network_->runSession(session_);

    for (size_t i = 0; i < outputShape_.size(); i++) {
        deviceOutputTensors_[i]->copyToHostTensor(hostOutputTensors_[i]);

        float* data = hostOutputTensors_[i]->host<float>();

        std::vector<size_t> sizetShape(outputShape_[i].begin(), outputShape_[i].end());
        auto xarr = xt::adapt(data, outSize_[i], xt::no_ownership(), sizetShape);
        output.push_back(xarr);
    }

    return output;
}

6. unloadBackend and destroy

When the backend stops being used, either because a new model is being loaded or there's no more references to it, CVEDIA-RT will destroy it, this will cause a unloadBackend call followed by a class destruction.

Your backend needs to handle this gracefully avoiding any memory and threading issues.

MNN example:

void MNNCore::unloadBackend() {

    unique_lock<mutex> m(sessMux_);

    backendLoaded_ = false;
}

MNNCore::~MNNCore()
{
    if (session_) {
        network_->releaseSession(session_);
    }

    unloadBackend();
}

Different edges

CVEDIA-RT can run in different edges and operating systems, we also compile against a few different toolchains.

Where and how you will run also depends on your backend and your driver support.

We currently support:

  • GCC 5.5+ (8.4 preferred)
  • x86_64
  • aarch64
  • armv6
  • armv7