Train, convert and predict with ONNX Runtime#

This example demonstrates an end to end scenario starting with the training of a machine learned model to its use in its converted from.

Train a logistic regression#

The first step consists in retrieving the iris datset.

from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)

Then we fit a model.

clr = LogisticRegression()
clr.fit(X_train, y_train)
LogisticRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


We compute the prediction on the test set and we show the confusion matrix.

from sklearn.metrics import confusion_matrix  # noqa: E402

pred = clr.predict(X_test)
print(confusion_matrix(y_test, pred))
[[19  0  0]
 [ 0 10  1]
 [ 0  0  8]]

Conversion to ONNX format#

We use module sklearn-onnx to convert the model into ONNX format.

from skl2onnx import convert_sklearn  # noqa: E402
from skl2onnx.common.data_types import FloatTensorType  # noqa: E402

initial_type = [("float_input", FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type)
with open("logreg_iris.onnx", "wb") as f:
    f.write(onx.SerializeToString())

We load the model with ONNX Runtime and look at its input and output.

import onnxruntime as rt  # noqa: E402

sess = rt.InferenceSession("logreg_iris.onnx", providers=rt.get_available_providers())

print(f"input name='{sess.get_inputs()[0].name}' and shape={sess.get_inputs()[0].shape}")
print(f"output name='{sess.get_outputs()[0].name}' and shape={sess.get_outputs()[0].shape}")
input name='float_input' and shape=[None, 4]
output name='output_label' and shape=[None]

We compute the predictions.

input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name

import numpy  # noqa: E402

pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
print(confusion_matrix(pred, pred_onx))
[[19  0  0]
 [ 0 10  0]
 [ 0  0  9]]

The prediction are perfectly identical.

Probabilities#

Probabilities are needed to compute other relevant metrics such as the ROC Curve. Let’s see how to get them first with scikit-learn.

prob_sklearn = clr.predict_proba(X_test)
print(prob_sklearn[:3])
[[0.00448644 0.8987307  0.09678287]
 [0.0040423  0.88517075 0.11078696]
 [0.08153688 0.90836333 0.01009979]]

And then with ONNX Runtime. The probabilies appear to be

prob_name = sess.get_outputs()[1].name
prob_rt = sess.run([prob_name], {input_name: X_test.astype(numpy.float32)})[0]

import pprint  # noqa: E402

pprint.pprint(prob_rt[0:3])
[{0: 0.004486436024308205, 1: 0.8987307548522949, 2: 0.09678284823894501},
 {0: 0.004042299464344978, 1: 0.8851708769798279, 2: 0.1107867956161499},
 {0: 0.08153686672449112, 1: 0.908363401889801, 2: 0.010099786333739758}]

Let’s benchmark.

from timeit import Timer  # noqa: E402


def speed(inst, number=5, repeat=10):
    timer = Timer(inst, globals=globals())
    raw = numpy.array(timer.repeat(repeat, number=number))
    ave = raw.sum() / len(raw) / number
    mi, ma = raw.min() / number, raw.max() / number
    print(f"Average {ave:1.3g} min={mi:1.3g} max={ma:1.3g}")
    return ave


print("Execution time for clr.predict")
speed("clr.predict(X_test)")

print("Execution time for ONNX Runtime")
speed("sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]")
Execution time for clr.predict
Average 8.39e-05 min=6.64e-05 max=0.000182
Execution time for ONNX Runtime
Average 1.81e-05 min=1.66e-05 max=2.58e-05

1.805053999987649e-05

Let’s benchmark a scenario similar to what a webservice experiences: the model has to do one prediction at a time as opposed to a batch of prediction.

def loop(X_test, fct, n=None):
    nrow = X_test.shape[0]
    if n is None:
        n = nrow
    for i in range(n):
        im = i % nrow
        fct(X_test[im : im + 1])


print("Execution time for clr.predict")
speed("loop(X_test, clr.predict, 50)")


def sess_predict(x):
    return sess.run([label_name], {input_name: x.astype(numpy.float32)})[0]


print("Execution time for sess_predict")
speed("loop(X_test, sess_predict, 50)")
Execution time for clr.predict
Average 0.00283 min=0.00234 max=0.00355
Execution time for sess_predict
Average 0.000316 min=0.000311 max=0.000344

0.0003159323399995628

Let’s do the same for the probabilities.

print("Execution time for predict_proba")
speed("loop(X_test, clr.predict_proba, 50)")


def sess_predict_proba(x):
    return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]


print("Execution time for sess_predict_proba")
speed("loop(X_test, sess_predict_proba, 50)")
Execution time for predict_proba
Average 0.00329 min=0.00325 max=0.00337
Execution time for sess_predict_proba
Average 0.000317 min=0.000312 max=0.000341

0.000316713559998334

This second comparison is better as ONNX Runtime, in this experience, computes the label and the probabilities in every case.

Benchmark with RandomForest#

We first train and save a model in ONNX format.

from sklearn.ensemble import RandomForestClassifier  # noqa: E402

rf = RandomForestClassifier(n_estimators=10)
rf.fit(X_train, y_train)

initial_type = [("float_input", FloatTensorType([1, 4]))]
onx = convert_sklearn(rf, initial_types=initial_type)
with open("rf_iris.onnx", "wb") as f:
    f.write(onx.SerializeToString())

We compare.

sess = rt.InferenceSession("rf_iris.onnx", providers=rt.get_available_providers())


def sess_predict_proba_rf(x):
    return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]


print("Execution time for predict_proba")
speed("loop(X_test, rf.predict_proba, 50)")

print("Execution time for sess_predict_proba")
speed("loop(X_test, sess_predict_proba_rf, 50)")
Execution time for predict_proba
Average 0.0221 min=0.0219 max=0.0236
Execution time for sess_predict_proba
Average 0.000312 min=0.000307 max=0.000341

0.00031169780000254836

Let’s see with different number of trees.

measures = []

for n_trees in range(5, 51, 5):
    print(n_trees)
    rf = RandomForestClassifier(n_estimators=n_trees)
    rf.fit(X_train, y_train)
    initial_type = [("float_input", FloatTensorType([1, 4]))]
    onx = convert_sklearn(rf, initial_types=initial_type)
    with open(f"rf_iris_{n_trees}.onnx", "wb") as f:
        f.write(onx.SerializeToString())
    sess = rt.InferenceSession(f"rf_iris_{n_trees}.onnx", providers=rt.get_available_providers())

    def sess_predict_proba_loop(x):
        return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]  # noqa: B023

    tsk = speed("loop(X_test, rf.predict_proba, 25)", number=5, repeat=4)
    trt = speed("loop(X_test, sess_predict_proba_loop, 25)", number=5, repeat=4)
    measures.append({"n_trees": n_trees, "sklearn": tsk, "rt": trt})

from pandas import DataFrame  # noqa: E402

df = DataFrame(measures)
ax = df.plot(x="n_trees", y="sklearn", label="scikit-learn", c="blue", logy=True)
df.plot(x="n_trees", y="rt", label="onnxruntime", ax=ax, c="green", logy=True)
ax.set_xlabel("Number of trees")
ax.set_ylabel("Prediction time (s)")
ax.set_title("Speed comparison between scikit-learn and ONNX Runtime\nFor a random forest on Iris dataset")
ax.legend()
Speed comparison between scikit-learn and ONNX Runtime For a random forest on Iris dataset
5
Average 0.00809 min=0.00768 max=0.00925
Average 0.000159 min=0.00015 max=0.000179
10
Average 0.0114 min=0.0109 max=0.0123
Average 0.000159 min=0.000152 max=0.000179
15
Average 0.0145 min=0.0141 max=0.0154
Average 0.000161 min=0.000153 max=0.00018
20
Average 0.0176 min=0.0172 max=0.0187
Average 0.000163 min=0.000154 max=0.000187
25
Average 0.0208 min=0.0204 max=0.0218
Average 0.000165 min=0.000155 max=0.000188
30
Average 0.0239 min=0.0235 max=0.025
Average 0.000167 min=0.000158 max=0.000185
35
Average 0.0269 min=0.0265 max=0.0279
Average 0.000168 min=0.00016 max=0.000186
40
Average 0.03 min=0.0297 max=0.0309
Average 0.000172 min=0.000163 max=0.000194
45
Average 0.0331 min=0.0328 max=0.0341
Average 0.000172 min=0.000163 max=0.000192
50
Average 0.0363 min=0.0359 max=0.0373
Average 0.000175 min=0.000167 max=0.000198

<matplotlib.legend.Legend object at 0x79bb80d86170>

Total running time of the script: (0 minutes 6.487 seconds)

Gallery generated by Sphinx-Gallery