Note
Go to the end to download the full example code.
Train, convert and predict with ONNX Runtime#
This example demonstrates an end to end scenario starting with the training of a machine learned model to its use in its converted from.
Train a logistic regression#
The first step consists in retrieving the iris dataset.
Then we fit a model.
We compute the prediction on the test set and we show the confusion matrix.
[[ 9 0 0]
[ 0 14 1]
[ 0 0 14]]
Conversion to ONNX format#
We use module sklearn-onnx to convert the model into ONNX format.
from skl2onnx import convert_sklearn # noqa: E402
from skl2onnx.common.data_types import FloatTensorType # noqa: E402
initial_type = [("float_input", FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type)
with open("logreg_iris.onnx", "wb") as f:
f.write(onx.SerializeToString())
We load the model with ONNX Runtime and look at its input and output.
import onnxruntime as rt # noqa: E402
sess = rt.InferenceSession("logreg_iris.onnx", providers=rt.get_available_providers())
print(f"input name='{sess.get_inputs()[0].name}' and shape={sess.get_inputs()[0].shape}")
print(f"output name='{sess.get_outputs()[0].name}' and shape={sess.get_outputs()[0].shape}")
input name='float_input' and shape=[None, 4]
output name='output_label' and shape=[None]
We compute the predictions.
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
import numpy # noqa: E402
pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
print(confusion_matrix(pred, pred_onx))
[[ 9 0 0]
[ 0 14 0]
[ 0 0 15]]
The prediction are perfectly identical.
Probabilities#
Probabilities are needed to compute other relevant metrics such as the ROC Curve. Let’s see how to get them first with scikit-learn.
prob_sklearn = clr.predict_proba(X_test)
print(prob_sklearn[:3])
[[2.21560687e-02 9.34919279e-01 4.29246524e-02]
[9.69516133e-01 3.04836967e-02 1.70527949e-07]
[4.49723368e-04 1.96033068e-01 8.03517208e-01]]
And then with ONNX Runtime. The probabilities appear to be
prob_name = sess.get_outputs()[1].name
prob_rt = sess.run([prob_name], {input_name: X_test.astype(numpy.float32)})[0]
import pprint # noqa: E402
pprint.pprint(prob_rt[0:3])
[{0: 0.022156063467264175, 1: 0.9349193572998047, 2: 0.042924605309963226},
{0: 0.9695160984992981, 1: 0.030483705922961235, 2: 1.7052775547199417e-07},
{0: 0.00044972379691898823, 1: 0.19603325426578522, 2: 0.8035170435905457}]
Let’s benchmark.
from timeit import Timer # noqa: E402
def speed(inst, number=5, repeat=10):
timer = Timer(inst, globals=globals())
raw = numpy.array(timer.repeat(repeat, number=number))
ave = raw.sum() / len(raw) / number
mi, ma = raw.min() / number, raw.max() / number
print(f"Average {ave:1.3g} min={mi:1.3g} max={ma:1.3g}")
return ave
print("Execution time for clr.predict")
speed("clr.predict(X_test)")
print("Execution time for ONNX Runtime")
speed("sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]")
Execution time for clr.predict
Average 5.16e-05 min=4.48e-05 max=7.23e-05
Execution time for ONNX Runtime
Average 1.85e-05 min=1.72e-05 max=2.53e-05
1.854979999734496e-05
Let’s benchmark a scenario similar to what a webservice experiences: the model has to do one prediction at a time as opposed to a batch of prediction.
def loop(X_test, fct, n=None):
nrow = X_test.shape[0]
if n is None:
n = nrow
for i in range(n):
im = i % nrow
fct(X_test[im : im + 1])
print("Execution time for clr.predict")
speed("loop(X_test, clr.predict, 50)")
def sess_predict(x):
return sess.run([label_name], {input_name: x.astype(numpy.float32)})[0]
print("Execution time for sess_predict")
speed("loop(X_test, sess_predict, 50)")
Execution time for clr.predict
Average 0.00189 min=0.0016 max=0.00231
Execution time for sess_predict
Average 0.000314 min=0.000307 max=0.000341
0.00031365842000923293
Let’s do the same for the probabilities.
print("Execution time for predict_proba")
speed("loop(X_test, clr.predict_proba, 50)")
def sess_predict_proba(x):
return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]
print("Execution time for sess_predict_proba")
speed("loop(X_test, sess_predict_proba, 50)")
Execution time for predict_proba
Average 0.00218 min=0.00215 max=0.00224
Execution time for sess_predict_proba
Average 0.000313 min=0.000309 max=0.000339
0.0003127810799924191
This second comparison is better as ONNX Runtime, in this experience, computes the label and the probabilities in every case.
Benchmark with RandomForest#
We first train and save a model in ONNX format.
from sklearn.ensemble import RandomForestClassifier # noqa: E402
rf = RandomForestClassifier(n_estimators=10)
rf.fit(X_train, y_train)
initial_type = [("float_input", FloatTensorType([1, 4]))]
onx = convert_sklearn(rf, initial_types=initial_type)
with open("rf_iris.onnx", "wb") as f:
f.write(onx.SerializeToString())
We compare.
sess = rt.InferenceSession("rf_iris.onnx", providers=rt.get_available_providers())
def sess_predict_proba_rf(x):
return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]
print("Execution time for predict_proba")
speed("loop(X_test, rf.predict_proba, 50)")
print("Execution time for sess_predict_proba")
speed("loop(X_test, sess_predict_proba_rf, 50)")
Execution time for predict_proba
Average 0.016 min=0.0158 max=0.0174
Execution time for sess_predict_proba
Average 0.000306 min=0.000302 max=0.000337
0.0003064713400090113
Let’s see with different number of trees.
measures = []
for n_trees in range(5, 51, 5):
print(n_trees)
rf = RandomForestClassifier(n_estimators=n_trees)
rf.fit(X_train, y_train)
initial_type = [("float_input", FloatTensorType([1, 4]))]
onx = convert_sklearn(rf, initial_types=initial_type)
with open(f"rf_iris_{n_trees}.onnx", "wb") as f:
f.write(onx.SerializeToString())
sess = rt.InferenceSession(f"rf_iris_{n_trees}.onnx", providers=rt.get_available_providers())
def sess_predict_proba_loop(x):
return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0] # noqa: B023
tsk = speed("loop(X_test, rf.predict_proba, 25)", number=5, repeat=4)
trt = speed("loop(X_test, sess_predict_proba_loop, 25)", number=5, repeat=4)
measures.append({"n_trees": n_trees, "sklearn": tsk, "rt": trt})
from pandas import DataFrame # noqa: E402
df = DataFrame(measures)
ax = df.plot(x="n_trees", y="sklearn", label="scikit-learn", c="blue", logy=True)
df.plot(x="n_trees", y="rt", label="onnxruntime", ax=ax, c="green", logy=True)
ax.set_xlabel("Number of trees")
ax.set_ylabel("Prediction time (s)")
ax.set_title("Speed comparison between scikit-learn and ONNX Runtime\nFor a random forest on Iris dataset")
ax.legend()

5
Average 0.00577 min=0.00543 max=0.00679
Average 0.000159 min=0.00015 max=0.000181
10
Average 0.0084 min=0.00788 max=0.00921
Average 0.000161 min=0.000151 max=0.000181
15
Average 0.0109 min=0.0102 max=0.0129
Average 0.000161 min=0.000153 max=0.000183
20
Average 0.0131 min=0.0126 max=0.0146
Average 0.000163 min=0.000154 max=0.000187
25
Average 0.0155 min=0.0149 max=0.0167
Average 0.000165 min=0.000156 max=0.000189
30
Average 0.0178 min=0.0174 max=0.0187
Average 0.000166 min=0.000157 max=0.00019
35
Average 0.0202 min=0.0198 max=0.0215
Average 0.000168 min=0.000159 max=0.000194
40
Average 0.0224 min=0.022 max=0.0233
Average 0.000169 min=0.00016 max=0.000189
45
Average 0.0249 min=0.0245 max=0.0258
Average 0.000169 min=0.000161 max=0.000192
50
Average 0.0271 min=0.0267 max=0.0281
Average 0.000174 min=0.000166 max=0.000194
<matplotlib.legend.Legend object at 0x7ebddc295fc0>
Total running time of the script: (0 minutes 4.941 seconds)