Train, convert and predict with ONNX Runtime#

This example demonstrates an end to end scenario starting with the training of a machine learned model to its use in its converted from.

Train a logistic regression#

The first step consists in retrieving the iris datset.

from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)

Then we fit a model.

clr = LogisticRegression()
clr.fit(X_train, y_train)
LogisticRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


We compute the prediction on the test set and we show the confusion matrix.

from sklearn.metrics import confusion_matrix  # noqa: E402

pred = clr.predict(X_test)
print(confusion_matrix(y_test, pred))
[[11  0  0]
 [ 0 10  0]
 [ 0  0 17]]

Conversion to ONNX format#

We use module sklearn-onnx to convert the model into ONNX format.

from skl2onnx import convert_sklearn  # noqa: E402
from skl2onnx.common.data_types import FloatTensorType  # noqa: E402

initial_type = [("float_input", FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type)
with open("logreg_iris.onnx", "wb") as f:
    f.write(onx.SerializeToString())

We load the model with ONNX Runtime and look at its input and output.

import onnxruntime as rt  # noqa: E402

sess = rt.InferenceSession("logreg_iris.onnx", providers=rt.get_available_providers())

print(f"input name='{sess.get_inputs()[0].name}' and shape={sess.get_inputs()[0].shape}")
print(f"output name='{sess.get_outputs()[0].name}' and shape={sess.get_outputs()[0].shape}")
input name='float_input' and shape=[None, 4]
output name='output_label' and shape=[None]

We compute the predictions.

input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name

import numpy  # noqa: E402

pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
print(confusion_matrix(pred, pred_onx))
[[11  0  0]
 [ 0 10  0]
 [ 0  0 17]]

The prediction are perfectly identical.

Probabilities#

Probabilities are needed to compute other relevant metrics such as the ROC Curve. Let’s see how to get them first with scikit-learn.

prob_sklearn = clr.predict_proba(X_test)
print(prob_sklearn[:3])
[[2.44417883e-06 4.32263509e-02 9.56771205e-01]
 [3.48213852e-03 7.70855560e-01 2.25662302e-01]
 [9.84002140e-01 1.59978050e-02 5.54569926e-08]]

And then with ONNX Runtime. The probabilies appear to be

prob_name = sess.get_outputs()[1].name
prob_rt = sess.run([prob_name], {input_name: X_test.astype(numpy.float32)})[0]

import pprint  # noqa: E402

pprint.pprint(prob_rt[0:3])
[{0: 2.444179017402348e-06, 1: 0.04322638735175133, 2: 0.956771194934845},
 {0: 0.003482144558802247, 1: 0.7708556056022644, 2: 0.22566227614879608},
 {0: 0.9840022325515747, 1: 0.015997808426618576, 2: 5.5456922609664616e-08}]

Let’s benchmark.

from timeit import Timer  # noqa: E402


def speed(inst, number=5, repeat=10):
    timer = Timer(inst, globals=globals())
    raw = numpy.array(timer.repeat(repeat, number=number))
    ave = raw.sum() / len(raw) / number
    mi, ma = raw.min() / number, raw.max() / number
    print(f"Average {ave:1.3g} min={mi:1.3g} max={ma:1.3g}")
    return ave


print("Execution time for clr.predict")
speed("clr.predict(X_test)")

print("Execution time for ONNX Runtime")
speed("sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]")
Execution time for clr.predict
Average 8.73e-05 min=8.25e-05 max=0.000106
Execution time for ONNX Runtime
Average 2.73e-05 min=2.59e-05 max=3.57e-05

2.7336219998232993e-05

Let’s benchmark a scenario similar to what a webservice experiences: the model has to do one prediction at a time as opposed to a batch of prediction.

def loop(X_test, fct, n=None):
    nrow = X_test.shape[0]
    if n is None:
        n = nrow
    for i in range(n):
        im = i % nrow
        fct(X_test[im : im + 1])


print("Execution time for clr.predict")
speed("loop(X_test, clr.predict, 50)")


def sess_predict(x):
    return sess.run([label_name], {input_name: x.astype(numpy.float32)})[0]


print("Execution time for sess_predict")
speed("loop(X_test, sess_predict, 50)")
Execution time for clr.predict
Average 0.00317 min=0.003 max=0.00406
Execution time for sess_predict
Average 0.00046 min=0.000453 max=0.000485

0.0004604251400007797

Let’s do the same for the probabilities.

print("Execution time for predict_proba")
speed("loop(X_test, clr.predict_proba, 50)")


def sess_predict_proba(x):
    return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]


print("Execution time for sess_predict_proba")
speed("loop(X_test, sess_predict_proba, 50)")
Execution time for predict_proba
Average 0.00398 min=0.00394 max=0.00405
Execution time for sess_predict_proba
Average 0.000471 min=0.000462 max=0.000503

0.0004714528200008772

This second comparison is better as ONNX Runtime, in this experience, computes the label and the probabilities in every case.

Benchmark with RandomForest#

We first train and save a model in ONNX format.

from sklearn.ensemble import RandomForestClassifier  # noqa: E402

rf = RandomForestClassifier(n_estimators=10)
rf.fit(X_train, y_train)

initial_type = [("float_input", FloatTensorType([1, 4]))]
onx = convert_sklearn(rf, initial_types=initial_type)
with open("rf_iris.onnx", "wb") as f:
    f.write(onx.SerializeToString())

We compare.

sess = rt.InferenceSession("rf_iris.onnx", providers=rt.get_available_providers())


def sess_predict_proba_rf(x):
    return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]


print("Execution time for predict_proba")
speed("loop(X_test, rf.predict_proba, 50)")

print("Execution time for sess_predict_proba")
speed("loop(X_test, sess_predict_proba_rf, 50)")
Execution time for predict_proba
Average 0.022 min=0.0213 max=0.0248
Execution time for sess_predict_proba
Average 0.000458 min=0.000448 max=0.000494

0.00045759126000007196

Let’s see with different number of trees.

measures = []

for n_trees in range(5, 51, 5):
    print(n_trees)
    rf = RandomForestClassifier(n_estimators=n_trees)
    rf.fit(X_train, y_train)
    initial_type = [("float_input", FloatTensorType([1, 4]))]
    onx = convert_sklearn(rf, initial_types=initial_type)
    with open("rf_iris_%d.onnx" % n_trees, "wb") as f:
        f.write(onx.SerializeToString())
    sess = rt.InferenceSession("rf_iris_%d.onnx" % n_trees, providers=rt.get_available_providers())

    def sess_predict_proba_loop(x):
        return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]  # noqa: B023

    tsk = speed("loop(X_test, rf.predict_proba, 25)", number=5, repeat=4)
    trt = speed("loop(X_test, sess_predict_proba_loop, 25)", number=5, repeat=4)
    measures.append({"n_trees": n_trees, "sklearn": tsk, "rt": trt})

from pandas import DataFrame  # noqa: E402

df = DataFrame(measures)
ax = df.plot(x="n_trees", y="sklearn", label="scikit-learn", c="blue", logy=True)
df.plot(x="n_trees", y="rt", label="onnxruntime", ax=ax, c="green", logy=True)
ax.set_xlabel("Number of trees")
ax.set_ylabel("Prediction time (s)")
ax.set_title("Speed comparison between scikit-learn and ONNX Runtime\nFor a random forest on Iris dataset")
ax.legend()
Speed comparison between scikit-learn and ONNX Runtime For a random forest on Iris dataset
5
Average 0.00815 min=0.00778 max=0.00912
Average 0.000236 min=0.000225 max=0.000252
10
Average 0.0111 min=0.0108 max=0.0121
Average 0.000241 min=0.000228 max=0.000269
15
Average 0.0146 min=0.0142 max=0.0155
Average 0.000242 min=0.000228 max=0.000266
20
Average 0.018 min=0.0172 max=0.0194
Average 0.000247 min=0.000236 max=0.000275
25
Average 0.0204 min=0.0199 max=0.0217
Average 0.000247 min=0.000235 max=0.000274
30
Average 0.0237 min=0.0231 max=0.0253
Average 0.000251 min=0.000237 max=0.000279
35
Average 0.0266 min=0.0261 max=0.0277
Average 0.000256 min=0.000243 max=0.000283
40
Average 0.0293 min=0.0284 max=0.0313
Average 0.000256 min=0.000245 max=0.000279
45
Average 0.032 min=0.0314 max=0.0333
Average 0.00027 min=0.000251 max=0.000293
50
Average 0.0357 min=0.0352 max=0.0367
Average 0.000266 min=0.000255 max=0.000289

<matplotlib.legend.Legend object at 0x7efb644beb30>

Total running time of the script: (0 minutes 6.669 seconds)

Gallery generated by Sphinx-Gallery