ML Workflows
Inference Workflows
ONNXInferenceWorkflow

ONNX Inference Workflow

This workflow is used to perform inference on ONNX models. It makes use of the onnx runtime (opens in a new tab) to perform inference on the model.

Constructor Arguments

  • model_source (ModelSource): The source of the model. For available model sources, see here.
  • model_args (Optional[dict[str, Any]]): The arguments for loading the model. This is different depending on the model source. For more information, see here.

Additional Installations

Since this workflow uses the onnx runtime, you'll need to install infernet-ml[onnx_inference]. Alternatively, you can install those packages directly. The optional dependencies "[onnx_inference]" are provided for your convenience.

To install via pip (opens in a new tab):

pip install infernet-ml[onnx_inference]
ℹ️

The optional dependencies for this workflow require that cmake is installed on your system. You can install cmake on MacOS by running brew install cmake. On Ubuntu & Windows, consult the documentation (opens in a new tab) for more information.

Input Format

Input format to this workflow follows the ONNX runtime input_feed (opens in a new tab) format. This is a dictionary where the keys are the input names and the values are data types that are convertible to numpy.

They are pre-processed as follows:

    def do_preprocessing(self, input_data: dict[Any, Any]) -> Any:
        return {k: torch.Tensor(input_data[k]).numpy() for k in input_data}

This is intentionally kept-generic. If you need to perform any specific pre-processing, you can do so by subclassing this class and overriding the do_preprocessing method.

Example

from infernet_ml.workflows.inference.onnx_inference_workflow import ONNXInferenceWorkflow
from infernet_ml.utils.model_loader import ModelSource
 
workflow = ONNXInferenceWorkflow(
    model_source=ModelSource.HUGGINGFACE_HUB,
    model_args={
        "repo_id": "Ritual-Net/iris-classification",
        "filename": "iris.onnx",
    }
)
workflow.setup()
results = workflow.inference({
    "input": [[1.0380048, 0.5586108, 1.1037828, 1.712096]]
})
print(f"results: {results}")