Train, convert and predict with ONNX Runtime#

This example demonstrates an end to end scenario starting with the training of a machine learned model to its use in its converted from.

Train a logistic regression#

The first step consists in retrieving the iris dataset.

from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)

Then we fit a model.

clr = LogisticRegression()
clr.fit(X_train, y_train)
LogisticRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


We compute the prediction on the test set and we show the confusion matrix.

from sklearn.metrics import confusion_matrix  # noqa: E402

pred = clr.predict(X_test)
print(confusion_matrix(y_test, pred))
[[ 9  0  0]
 [ 0 14  1]
 [ 0  0 14]]

Conversion to ONNX format#

We use module sklearn-onnx to convert the model into ONNX format.

from skl2onnx import convert_sklearn  # noqa: E402
from skl2onnx.common.data_types import FloatTensorType  # noqa: E402

initial_type = [("float_input", FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type)
with open("logreg_iris.onnx", "wb") as f:
    f.write(onx.SerializeToString())

We load the model with ONNX Runtime and look at its input and output.

import onnxruntime as rt  # noqa: E402

sess = rt.InferenceSession("logreg_iris.onnx", providers=rt.get_available_providers())

print(f"input name='{sess.get_inputs()[0].name}' and shape={sess.get_inputs()[0].shape}")
print(f"output name='{sess.get_outputs()[0].name}' and shape={sess.get_outputs()[0].shape}")
input name='float_input' and shape=[None, 4]
output name='output_label' and shape=[None]

We compute the predictions.

input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name

import numpy  # noqa: E402

pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
print(confusion_matrix(pred, pred_onx))
[[ 9  0  0]
 [ 0 14  0]
 [ 0  0 15]]

The prediction are perfectly identical.

Probabilities#

Probabilities are needed to compute other relevant metrics such as the ROC Curve. Let’s see how to get them first with scikit-learn.

prob_sklearn = clr.predict_proba(X_test)
print(prob_sklearn[:3])
[[2.21560687e-02 9.34919279e-01 4.29246524e-02]
 [9.69516133e-01 3.04836967e-02 1.70527949e-07]
 [4.49723368e-04 1.96033068e-01 8.03517208e-01]]

And then with ONNX Runtime. The probabilities appear to be

prob_name = sess.get_outputs()[1].name
prob_rt = sess.run([prob_name], {input_name: X_test.astype(numpy.float32)})[0]

import pprint  # noqa: E402

pprint.pprint(prob_rt[0:3])
[{0: 0.022156063467264175, 1: 0.9349193572998047, 2: 0.042924605309963226},
 {0: 0.9695160984992981, 1: 0.030483705922961235, 2: 1.7052775547199417e-07},
 {0: 0.00044972379691898823, 1: 0.19603325426578522, 2: 0.8035170435905457}]

Let’s benchmark.

from timeit import Timer  # noqa: E402


def speed(inst, number=5, repeat=10):
    timer = Timer(inst, globals=globals())
    raw = numpy.array(timer.repeat(repeat, number=number))
    ave = raw.sum() / len(raw) / number
    mi, ma = raw.min() / number, raw.max() / number
    print(f"Average {ave:1.3g} min={mi:1.3g} max={ma:1.3g}")
    return ave


print("Execution time for clr.predict")
speed("clr.predict(X_test)")

print("Execution time for ONNX Runtime")
speed("sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]")
Execution time for clr.predict
Average 5.16e-05 min=4.48e-05 max=7.23e-05
Execution time for ONNX Runtime
Average 1.85e-05 min=1.72e-05 max=2.53e-05

1.854979999734496e-05

Let’s benchmark a scenario similar to what a webservice experiences: the model has to do one prediction at a time as opposed to a batch of prediction.

def loop(X_test, fct, n=None):
    nrow = X_test.shape[0]
    if n is None:
        n = nrow
    for i in range(n):
        im = i % nrow
        fct(X_test[im : im + 1])


print("Execution time for clr.predict")
speed("loop(X_test, clr.predict, 50)")


def sess_predict(x):
    return sess.run([label_name], {input_name: x.astype(numpy.float32)})[0]


print("Execution time for sess_predict")
speed("loop(X_test, sess_predict, 50)")
Execution time for clr.predict
Average 0.00189 min=0.0016 max=0.00231
Execution time for sess_predict
Average 0.000314 min=0.000307 max=0.000341

0.00031365842000923293

Let’s do the same for the probabilities.

print("Execution time for predict_proba")
speed("loop(X_test, clr.predict_proba, 50)")


def sess_predict_proba(x):
    return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]


print("Execution time for sess_predict_proba")
speed("loop(X_test, sess_predict_proba, 50)")
Execution time for predict_proba
Average 0.00218 min=0.00215 max=0.00224
Execution time for sess_predict_proba
Average 0.000313 min=0.000309 max=0.000339

0.0003127810799924191

This second comparison is better as ONNX Runtime, in this experience, computes the label and the probabilities in every case.

Benchmark with RandomForest#

We first train and save a model in ONNX format.

from sklearn.ensemble import RandomForestClassifier  # noqa: E402

rf = RandomForestClassifier(n_estimators=10)
rf.fit(X_train, y_train)

initial_type = [("float_input", FloatTensorType([1, 4]))]
onx = convert_sklearn(rf, initial_types=initial_type)
with open("rf_iris.onnx", "wb") as f:
    f.write(onx.SerializeToString())

We compare.

sess = rt.InferenceSession("rf_iris.onnx", providers=rt.get_available_providers())


def sess_predict_proba_rf(x):
    return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]


print("Execution time for predict_proba")
speed("loop(X_test, rf.predict_proba, 50)")

print("Execution time for sess_predict_proba")
speed("loop(X_test, sess_predict_proba_rf, 50)")
Execution time for predict_proba
Average 0.016 min=0.0158 max=0.0174
Execution time for sess_predict_proba
Average 0.000306 min=0.000302 max=0.000337

0.0003064713400090113

Let’s see with different number of trees.

measures = []

for n_trees in range(5, 51, 5):
    print(n_trees)
    rf = RandomForestClassifier(n_estimators=n_trees)
    rf.fit(X_train, y_train)
    initial_type = [("float_input", FloatTensorType([1, 4]))]
    onx = convert_sklearn(rf, initial_types=initial_type)
    with open(f"rf_iris_{n_trees}.onnx", "wb") as f:
        f.write(onx.SerializeToString())
    sess = rt.InferenceSession(f"rf_iris_{n_trees}.onnx", providers=rt.get_available_providers())

    def sess_predict_proba_loop(x):
        return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]  # noqa: B023

    tsk = speed("loop(X_test, rf.predict_proba, 25)", number=5, repeat=4)
    trt = speed("loop(X_test, sess_predict_proba_loop, 25)", number=5, repeat=4)
    measures.append({"n_trees": n_trees, "sklearn": tsk, "rt": trt})

from pandas import DataFrame  # noqa: E402

df = DataFrame(measures)
ax = df.plot(x="n_trees", y="sklearn", label="scikit-learn", c="blue", logy=True)
df.plot(x="n_trees", y="rt", label="onnxruntime", ax=ax, c="green", logy=True)
ax.set_xlabel("Number of trees")
ax.set_ylabel("Prediction time (s)")
ax.set_title("Speed comparison between scikit-learn and ONNX Runtime\nFor a random forest on Iris dataset")
ax.legend()
Speed comparison between scikit-learn and ONNX Runtime For a random forest on Iris dataset
5
Average 0.00577 min=0.00543 max=0.00679
Average 0.000159 min=0.00015 max=0.000181
10
Average 0.0084 min=0.00788 max=0.00921
Average 0.000161 min=0.000151 max=0.000181
15
Average 0.0109 min=0.0102 max=0.0129
Average 0.000161 min=0.000153 max=0.000183
20
Average 0.0131 min=0.0126 max=0.0146
Average 0.000163 min=0.000154 max=0.000187
25
Average 0.0155 min=0.0149 max=0.0167
Average 0.000165 min=0.000156 max=0.000189
30
Average 0.0178 min=0.0174 max=0.0187
Average 0.000166 min=0.000157 max=0.00019
35
Average 0.0202 min=0.0198 max=0.0215
Average 0.000168 min=0.000159 max=0.000194
40
Average 0.0224 min=0.022 max=0.0233
Average 0.000169 min=0.00016 max=0.000189
45
Average 0.0249 min=0.0245 max=0.0258
Average 0.000169 min=0.000161 max=0.000192
50
Average 0.0271 min=0.0267 max=0.0281
Average 0.000174 min=0.000166 max=0.000194

<matplotlib.legend.Legend object at 0x7ebddc295fc0>

Total running time of the script: (0 minutes 4.941 seconds)

Gallery generated by Sphinx-Gallery