There are a few ways to put trained machine learning (ML) models into production. The most common method is to serialize the model using some particular format after training, and deserialize that model in the production environment. In Python, there are several language-specific serialization formats based on
pickle. Alternatives include more agnostic exported formats, namely
pmml. In Java, the H2O framework serializes using
MOJO, which are Plain Old Java Object and Model ObJect Optimized structures, respectively.
Commonly, Python is the language for machine learning modeling. Different Python ML frameworks have different serialization recommendations. In particular:
sklearn recommends using the
pytorch's load and save methods use python’s built-in
pickle module, and
keras supports exporting in
hdf5 format. There is also an alternative serialization package
dill which generalizes
pickle at the cost of performance.
Pickling is a way to write Python objects to a bytestream, which can then be written to disk as a file. One can take this file and load it back into a separate Python interpreter at a later date, recovering the objects from the previous session.
Pickled files are tightly coupled with the environment in which they were created. There is no guarantee that a pickled object in one version of Python will work in a previous or newer version. Similarly the imported objects that are used in pickled Python objects must also be available when deserializing. In general, the destination environment for a pickle file should have all of the dependencies and definitions required in the pickled objects.
According to the documentation, the following types can be pickled:
__dict__or the result of calling
__getstate__()is picklable (see section Pickling Class Instances for details).
Note that functions (built-in and user-defined) are pickled by “fully qualified” name reference, not by value. This means that only the function name is pickled, along with the name of the module the function is defined in. Neither the function’s code, nor any of its function attributes are pickled. Thus the defining module must be importable in the unpickling environment, and the module must contain the named object, otherwise an exception will be raised.
The above implies that it is a necessary condition that classes and functions be importable in order to be unpicklable. This means that any custom function made in your script cannot be deserialized in another environment, unless it is part of an importable module in the unpickling environment.
The built in
pickle module allows for the serialization of Python objects into a bytestream. This means you can save a Python object like a class instance to a file, send it to another environment or computer, and deserialize it back into a Python object to be interacted with again.
pickle module can save and load class instances. However, when deserializing, the class definition must be importable from the same module path as in the original environment. If we try to pickle a
pytorch model in a research environment, and then load it into a production environment that doesn’t have
pytorch installed, the model will fail to deserialize.
joblib package provides
load functions for serializing Python objects, with particular optimizations for large
numpy arrays. It is intended to be a drop-in replacement for pickle and can be effective for
sklearn models which store lots of data internally, such as random forest or cluster-based classifiers.
dill package extends the functionality of
pickle by enabling the serialization of a much larger set of Python objects. According to its documentation:
dill can pickle the following standard types:
dill can also pickle more exotic standard types:
Part of this includes custom functions. In particular,
dill can be extremely useful for
sklearn models which use the flexible
FunctionTransformer. This transformer allows custom transformations using a function argument passed on instantiation.
Custom functions serialized with
dill may have problems being deserialized, mostly due to the use of imported packages in the function itself. This can be mitigated by invoking these imports before deserializing. Alternatively, the functions can be declared such that they import any package dependencies within the function if the function uses them directly. As an example the following function
import os def get_env_var(s): return os.getenv(s)
should be written like this in the file in which it is being serialized:
def get_env_var(s): import os return os.getenv(s)
It means the code is a bit sloppier and unconventional, but if the
os module has not already been imported in the environment, then the deserialization may throw a
os has yet to be defined. This can also generate significant overhead since every time the function is called the import statement is also needlessly invoked.
dills additional flexibility comes at the cost of performance. In other words, deserializing
dill files will typically take a bit longer than
pickle files. This can be problematic for loading models on-the-fly for time sensitive predictions, but the model can always be kept in memory after being deserialized when starting an application.
Here is a simple example in which we create an
sklearn Pipeline with a custom function implemented using the
FunctionTransformer. The pipeline trains a
RandomForestClassifier using the prepackaged iris dataset.
import pickle import dill from sklearn.datasets import load_iris from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split from sklearn.pipeline import Pipeline from sklearn.preprocessing import FunctionTransformer # Load the iris dataset data = load_iris() X = data["data"] y = data["target"] # Split into training/testing sets X_train, X_test, y_train, y_test = train_test_split(X, y) # Create a function for use with the FunctionTransformer # Transform function that relies on imported package def scale(X_input): """Scale the input matrix.""" import os scale_factor = os.getenv("SCALE_FACTOR", 2) return X_input * scale_factor # Create a simple toy model that transforms the dataset # and uses a random forest model = Pipeline( [ ("transform", FunctionTransformer(scale)), ("forest", RandomForestClassifier()) ] ) if __name__ == "__main__": # Train the model model.fit(X_train, y_train) # Serialize the model using dill. with open("model.dill", "wb") as f: dill.dump(model, f, protocol=pickle.HIGHEST_PROTOCOL)
Here we deserialize the model in a separate environment, using the same Python version and the same version of
sklearn (and any other dependencies required by the model).
import inspect import dill from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split data = load_iris() X = data["data"] y = data["target"] X_train, X_test, y_train, y_test = train_test_split(X, y) if __name__ == "__main__": # Load the model with open("model.dill", "rb") as f: model = dill.load(f) # Show the model. # Note that the custom function has been serialized/deserialized. print(model) # Pipeline(memory=None, # steps=[('transform', # FunctionTransformer(accept_sparse=False, check_inverse=True, # func=<function scale at 0x110081e18>, # inv_kw_args=None, inverse_func=None, # kw_args=None, validate=False)), # ('forest', # RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, # class_weight=None, criterion='gini', # max_depth=None, max_features='auto', # max_leaf_nodes=None, max_samples=None, # min_impurity_decrease=0.0, # min_impurity_split=None, # min_samples_leaf=1, min_samples_split=2, # min_weight_fraction_leaf=0.0, # n_estimators=100, n_jobs=None, # oob_score=False, random_state=None, # verbose=0, warm_start=False))], # verbose=False) # View the source of the embedded serialized function. # (thanks, dill) print(inspect.getsource(model.steps.func)) # def scale(X_input): # """Scale the input matrix.""" # import os # scale_factor = os.getenv("SCALE_FACTOR", 2) # return X_input * scale_factor # Make some predictions predictions = model.predict(X_test) print(predictions) # [2 1 1 0 2 1 0 0 2 0 2 1 1 0 1 0 1 0 2 0 0 0 2 2 1 0 0 1 0 2 2 2 2 2 2 2 2 1]
If we perform the serialization with
pickle, you will find that we won’t be able to deserialize the model in a new environment. That’s because of the custom function we’ve defined called
scale. Here is the error you would see:
Traceback (most recent call last): File "/Users/flynn/.asdf/installs/python/3.6.9/lib/python3.6/runpy.py", line 193, in _run_module_as_main "__main__", mod_spec) File "/Users/flynn/.asdf/installs/python/3.6.9/lib/python3.6/runpy.py", line 85, in _run_code exec(code, run_globals) File "/Users/flynn/projects/mltest/environment.py", line 18, in <module> model = pickle.load(f) AttributeError: Can't get attribute 'scale' on <module 'environment' from '/Users/flynn/projects/mltest/environment.py'>
Because of this limitation,
dill might be more suitable for model serialization in many cases.
For Python ML models, particularly
sklearn, here are some guidelines for serialization.
pickleto serialize objects with an importable hierarchy.
joblibfor objects which contain lots of data in
joblibwon’t work, or when you have custom functions that need to be serialized as part of the model.
For modeling, it can also be helpful to follow some rules of thumb, namely:
dill will provide the most flexibility in terms of getting the model serialized and should be considered the path of least resistance when it comes to serializing ML models for production.