Torch Inference Workflow
This workflow is used to perform inference on Torch models.
Constructor Arguments
model_source (ModelSource)
: The source of the model. For available model sources, see here.model_args (Optional[dict[str, Any]])
: The arguments for loading the model. This is different depending on the model source. For more information, see here.
Pytorch requires models to exist in the classpath. Sci-kit Learn (opens in a new tab) models are included in the classpath by default from the sk2torch (opens in a new tab) library. If you are using a model that was trained on a custom module, you will need to include that module in the classpath. Refer to our Torch Iris Classification Example (opens in a new tab) to see this in action.
Additional Installations
Since this workflow uses the torch library, you'll need to install infernet-ml[torch_inference]
. Alternatively, you can
install those packages directly. The optional dependencies "[torch_inference]"
are provided for your convenience.
To install via pip (opens in a new tab):
pip install infernet-ml[torch_inference]
Input Format
Input format is the following dictionary:
{
"dtype": str,
"values": list[Any]
}
dtype (str)
: The data type of the input. For example,"float32"
. Refer to below for supported data types.values (list[Any])
: The input values. The length of the list should match the input shape of the model.
The input is pre-processed as follows:
def do_preprocessing(self, input_data: dict[str, Any]) -> torch.Tensor:
# lookup dtype from str
dtype = DTYPES.get(input_data["dtype"], None)
values = input_data["values"]
return torch.tensor(values, dtype=dtype)
This is intentionally kept-generic. If you need to perform any specific pre-processing, you can do so by subclassing
this class and overriding the do_preprocessing
method.
Supported Data Types
DTYPES = {
"float": torch.float,
"double": torch.double,
"cfloat": torch.cfloat,
"cdouble": torch.cdouble,
"half": torch.half,
"bfloat16": torch.bfloat16,
"uint8": torch.uint8,
"int8": torch.int8,
"short": torch.short,
"int": torch.int,
"long": torch.long,
"bool": torch.bool,
}
Example
from infernet_ml.workflows.inference.torch_inference_workflow import TorchInferenceWorkflow
from infernet_ml.utils.model_loader import ModelSource
from path.to.your.model import IrisClassificationModel
workflow = TorchInferenceWorkflow(
model_source=ModelSource.HUGGINGFACE_HUB,
model_args={
"repo_id": "Ritual-Net/iris-classification",
"filename": "iris.torch",
}
)
workflow.setup()
results = workflow.inference({
"values": [[1.0380048, 0.5586108, 1.1037828, 1.712096]],
"dtype": "float",
})
print(f"results: {results}")