Machine Learning
Systems Architect,
PhD Mathematician
There are a few ways to put trained machine learning (ML) models into production. The most common method is to serialize the model using some particular format after training, and deserialize that model in the production environment. In Python, there are several language-specific serialization formats based on pickle
. Alternatives include more agnostic exported formats, namely onnx
or pmml
. In Java, the H2O framework serializes using POJO
or MOJO
, which are Plain Old Java Object and Model ObJect Optimized structures, respectively.
Commonly, Python is the language for machine learning modeling. Different Python ML frameworks have different serialization recommendations. In particular: sklearn
recommends using the joblib
package, pytorch
's load and save methods use python’s built-in pickle
module, and keras
supports exporting in hdf5
format. There is also an alternative serialization package dill
which generalizes pickle
at the cost of performance.
Pickling is a way to write Python objects to a bytestream, which can then be written to disk as a file. One can take this file and load it back into a separate Python interpreter at a later date, recovering the objects from the previous session.
Pickled files are tightly coupled with the environment in which they were created. There is no guarantee that a pickled object in one version of Python will work in a previous or newer version. Similarly the imported objects that are used in pickled Python objects must also be available when deserializing. In general, the destination environment for a pickle file should have all of the dependencies and definitions required in the pickled objects.
According to the documentation, the following types can be pickled:
None
, True
, and False
def
, not lambda
)__dict__
or the result of calling __getstate__()
is picklable (see section Pickling Class Instances for details).Note that functions (built-in and user-defined) are pickled by “fully qualified” name reference, not by value. This means that only the function name is pickled, along with the name of the module the function is defined in. Neither the function’s code, nor any of its function attributes are pickled. Thus the defining module must be importable in the unpickling environment, and the module must contain the named object, otherwise an exception will be raised.
The above implies that it is a necessary condition that classes and functions be importable in order to be unpicklable. This means that any custom function made in your script cannot be deserialized in another environment, unless it is part of an importable module in the unpickling environment.
The built in pickle
module allows for the serialization of Python objects into a bytestream. This means you can save a Python object like a class instance to a file, send it to another environment or computer, and deserialize it back into a Python object to be interacted with again.
The pickle
module can save and load class instances. However, when deserializing, the class definition must be importable from the same module path as in the original environment. If we try to pickle a pytorch
model in a research environment, and then load it into a production environment that doesn’t have pytorch
installed, the model will fail to deserialize.
The joblib
package provides dump
and load
functions for serializing Python objects, with particular optimizations for large numpy
arrays. It is intended to be a drop-in replacement for pickle and can be effective for sklearn
models which store lots of data internally, such as random forest or cluster-based classifiers.
The dill
package extends the functionality of pickle
by enabling the serialization of a much larger set of Python objects. According to its documentation:
dill
can pickle the following standard types:
dill
can also pickle more exotic standard types:
Part of this includes custom functions. In particular, dill
can be extremely useful for sklearn
models which use the flexible FunctionTransformer
. This transformer allows custom transformations using a function argument passed on instantiation.
Custom functions serialized with dill
may have problems being deserialized, mostly due to the use of imported packages in the function itself. This can be mitigated by invoking these imports before deserializing. Alternatively, the functions can be declared such that they import any package dependencies within the function if the function uses them directly. As an example the following function
import os
def get_env_var(s):
return os.getenv(s)
should be written like this in the file in which it is being serialized:
def get_env_var(s):
import os
return os.getenv(s)
It means the code is a bit sloppier and unconventional, but if the os
module has not already been imported in the environment, then the deserialization may throw a NameError
if os
has yet to be defined. This can also generate significant overhead since every time the function is called the import statement is also needlessly invoked.
In general, dill
s additional flexibility comes at the cost of performance. In other words, deserializing dill
files will typically take a bit longer than pickle
files. This can be problematic for loading models on-the-fly for time sensitive predictions, but the model can always be kept in memory after being deserialized when starting an application.
Here is a simple example in which we create an sklearn
Pipeline with a custom function implemented using the FunctionTransformer
. The pipeline trains a RandomForestClassifier
using the prepackaged iris dataset.
import pickle
import dill
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
# Load the iris dataset
data = load_iris()
X = data["data"]
y = data["target"]
# Split into training/testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y)
# Create a function for use with the FunctionTransformer
# Transform function that relies on imported package
def scale(X_input):
"""Scale the input matrix."""
import os
scale_factor = os.getenv("SCALE_FACTOR", 2)
return X_input * scale_factor
# Create a simple toy model that transforms the dataset
# and uses a random forest
model = Pipeline(
[
("transform", FunctionTransformer(scale)),
("forest", RandomForestClassifier())
]
)
if __name__ == "__main__":
# Train the model
model.fit(X_train, y_train)
# Serialize the model using dill.
with open("model.dill", "wb") as f:
dill.dump(model, f, protocol=pickle.HIGHEST_PROTOCOL)
Here we deserialize the model in a separate environment, using the same Python version and the same version of sklearn
(and any other dependencies required by the model).
import inspect
import dill
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
data = load_iris()
X = data["data"]
y = data["target"]
X_train, X_test, y_train, y_test = train_test_split(X, y)
if __name__ == "__main__":
# Load the model
with open("model.dill", "rb") as f:
model = dill.load(f)
# Show the model.
# Note that the custom function has been serialized/deserialized.
print(model)
# Pipeline(memory=None,
# steps=[('transform',
# FunctionTransformer(accept_sparse=False, check_inverse=True,
# func=<function scale at 0x110081e18>,
# inv_kw_args=None, inverse_func=None,
# kw_args=None, validate=False)),
# ('forest',
# RandomForestClassifier(bootstrap=True, ccp_alpha=0.0,
# class_weight=None, criterion='gini',
# max_depth=None, max_features='auto',
# max_leaf_nodes=None, max_samples=None,
# min_impurity_decrease=0.0,
# min_impurity_split=None,
# min_samples_leaf=1, min_samples_split=2,
# min_weight_fraction_leaf=0.0,
# n_estimators=100, n_jobs=None,
# oob_score=False, random_state=None,
# verbose=0, warm_start=False))],
# verbose=False)
# View the source of the embedded serialized function.
# (thanks, dill)
print(inspect.getsource(model.steps[0][1].func))
# def scale(X_input):
# """Scale the input matrix."""
# import os
# scale_factor = os.getenv("SCALE_FACTOR", 2)
# return X_input * scale_factor
# Make some predictions
predictions = model.predict(X_test)
print(predictions)
# [2 1 1 0 2 1 0 0 2 0 2 1 1 0 1 0 1 0 2 0 0 0 2 2 1 0 0 1 0 2 2 2 2 2 2 2 2 1]
If we perform the serialization with pickle
, you will find that we won’t be able to deserialize the model in a new environment. That’s because of the custom function we’ve defined called scale
. Here is the error you would see:
Traceback (most recent call last):
File "/Users/flynn/.asdf/installs/python/3.6.9/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/Users/flynn/.asdf/installs/python/3.6.9/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/Users/flynn/projects/mltest/environment.py", line 18, in <module>
model = pickle.load(f)
AttributeError: Can't get attribute 'scale' on <module 'environment' from '/Users/flynn/projects/mltest/environment.py'>
Because of this limitation, dill
might be more suitable for model serialization in many cases.
For Python ML models, particularly sklearn
, here are some guidelines for serialization.
pickle
to serialize objects with an importable hierarchy.joblib
for objects which contain lots of data in numpy
arraysdill
when pickle
or joblib
won’t work, or when you have custom functions that need to be serialized as part of the model.For modeling, it can also be helpful to follow some rules of thumb, namely:
dill
).In general, dill
will provide the most flexibility in terms of getting the model serialized and should be considered the path of least resistance when it comes to serializing ML models for production.