flynn.gg

Christopher Flynn

Machine Learning
Systems Architect,
PhD Mathematician

Home
Projects
Open Source
Blog
Résumé

GitHub
LinkedIn

Blog


Machine Learning Model Serialization

2020-01-04 Feed

Dill Photo by Jay Jay on Unsplash

There are a few ways to put trained machine learning (ML) models into production. The most common method is to serialize the model using some particular format after training, and deserialize that model in the production environment. In Python, there are several language-specific serialization formats based on pickle. Alternatives include more agnostic exported formats, namely onnx or pmml. In Java, the H2O framework serializes using POJO or MOJO, which are Plain Old Java Object and Model ObJect Optimized structures, respectively.

Commonly, Python is the language for machine learning modeling. Different Python ML frameworks have different serialization recommendations. In particular: sklearn recommends using the joblib package, pytorch's load and save methods use python’s built-in pickle module, and keras supports exporting in hdf5 format. There is also an alternative serialization package dill which generalizes pickle at the cost of performance.

What is pickling?

Pickling is a way to write Python objects to a bytestream, which can then be written to disk as a file. One can take this file and load it back into a separate Python interpreter at a later date, recovering the objects from the previous session.

Pickled files are tightly coupled with the environment in which they were created. There is no guarantee that a pickled object in one version of Python will work in a previous or newer version. Similarly the imported objects that are used in pickled Python objects must also be available when deserializing. In general, the destination environment for a pickle file should have all of the dependencies and definitions required in the pickled objects.

What can be pickled?

According to the documentation, the following types can be pickled:

Note that functions (built-in and user-defined) are pickled by “fully qualified” name reference, not by value. This means that only the function name is pickled, along with the name of the module the function is defined in. Neither the function’s code, nor any of its function attributes are pickled. Thus the defining module must be importable in the unpickling environment, and the module must contain the named object, otherwise an exception will be raised.

The above implies that it is a necessary condition that classes and functions be importable in order to be unpicklable. This means that any custom function made in your script cannot be deserialized in another environment, unless it is part of an importable module in the unpickling environment.

pickle

The built in pickle module allows for the serialization of Python objects into a bytestream. This means you can save a Python object like a class instance to a file, send it to another environment or computer, and deserialize it back into a Python object to be interacted with again.

The pickle module can save and load class instances. However, when deserializing, the class definition must be importable from the same module path as in the original environment. If we try to pickle a pytorch model in a research environment, and then load it into a production environment that doesn’t have pytorch installed, the model will fail to deserialize.

joblib

The joblib package provides dump and load functions for serializing Python objects, with particular optimizations for large numpy arrays. It is intended to be a drop-in replacement for pickle and can be effective for sklearn models which store lots of data internally, such as random forest or cluster-based classifiers.

dill

The dill package extends the functionality of pickle by enabling the serialization of a much larger set of Python objects. According to its documentation:

dill can pickle the following standard types:

dill can also pickle more exotic standard types:

Part of this includes custom functions. In particular, dill can be extremely useful for sklearn models which use the flexible FunctionTransformer. This transformer allows custom transformations using a function argument passed on instantiation.

Custom functions serialized with dill may have problems being deserialized, mostly due to the use of imported packages in the function itself. This can be mitigated by invoking these imports before deserializing. Alternatively, the functions can be declared such that they import any package dependencies within the function if the function uses them directly. As an example the following function

import os

def get_env_var(s):
    return os.getenv(s)

should be written like this in the file in which it is being serialized:

def get_env_var(s):
    import os
    return os.getenv(s)

It means the code is a bit sloppier and unconventional, but if the os module has not already been imported in the environment, then the deserialization may throw a NameError if os has yet to be defined. This can also generate significant overhead since every time the function is called the import statement is also needlessly invoked.

In general, dills additional flexibility comes at the cost of performance. In other words, deserializing dill files will typically take a bit longer than pickle files. This can be problematic for loading models on-the-fly for time sensitive predictions, but the model can always be kept in memory after being deserialized when starting an application.

Example

Here is a simple example in which we create an sklearn Pipeline with a custom function implemented using the FunctionTransformer. The pipeline trains a RandomForestClassifier using the prepackaged iris dataset.

import pickle

import dill
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer

# Load the iris dataset
data = load_iris()
X = data["data"]
y = data["target"]

# Split into training/testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y)


# Create a function for use with the FunctionTransformer
# Transform function that relies on imported package
def scale(X_input):
    """Scale the input matrix."""
    import os
    scale_factor = os.getenv("SCALE_FACTOR", 2)
    return X_input * scale_factor


# Create a simple toy model that transforms the dataset
# and uses a random forest
model = Pipeline(
    [
        ("transform", FunctionTransformer(scale)),
        ("forest", RandomForestClassifier())
    ]
)


if __name__ == "__main__":
    # Train the model
    model.fit(X_train, y_train)

    # Serialize the model using dill.
    with open("model.dill", "wb") as f:
        dill.dump(model, f, protocol=pickle.HIGHEST_PROTOCOL)

Here we deserialize the model in a separate environment, using the same Python version and the same version of sklearn (and any other dependencies required by the model).

import inspect

import dill
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

data = load_iris()
X = data["data"]
y = data["target"]

X_train, X_test, y_train, y_test = train_test_split(X, y)


if __name__ == "__main__":
    # Load the model
    with open("model.dill", "rb") as f:
        model = dill.load(f)

    # Show the model.
    # Note that the custom function has been serialized/deserialized.
    print(model)
    # Pipeline(memory=None,
    #          steps=[('transform',
    #                  FunctionTransformer(accept_sparse=False, check_inverse=True,
    #                                      func=<function scale at 0x110081e18>,
    #                                      inv_kw_args=None, inverse_func=None,
    #                                      kw_args=None, validate=False)),
    #                 ('forest',
    #                  RandomForestClassifier(bootstrap=True, ccp_alpha=0.0,
    #                                         class_weight=None, criterion='gini',
    #                                         max_depth=None, max_features='auto',
    #                                         max_leaf_nodes=None, max_samples=None,
    #                                         min_impurity_decrease=0.0,
    #                                         min_impurity_split=None,
    #                                         min_samples_leaf=1, min_samples_split=2,
    #                                         min_weight_fraction_leaf=0.0,
    #                                         n_estimators=100, n_jobs=None,
    #                                         oob_score=False, random_state=None,
    #                                         verbose=0, warm_start=False))],
    #          verbose=False)

    # View the source of the embedded serialized function.
    # (thanks, dill)
    print(inspect.getsource(model.steps[0][1].func))
    # def scale(X_input):
    #     """Scale the input matrix."""
    #     import os
    #     scale_factor = os.getenv("SCALE_FACTOR", 2)
    #     return X_input * scale_factor

    # Make some predictions
    predictions = model.predict(X_test)

    print(predictions)
    # [2 1 1 0 2 1 0 0 2 0 2 1 1 0 1 0 1 0 2 0 0 0 2 2 1 0 0 1 0 2 2 2 2 2 2 2 2 1]

If we perform the serialization with pickle, you will find that we won’t be able to deserialize the model in a new environment. That’s because of the custom function we’ve defined called scale. Here is the error you would see:

Traceback (most recent call last):
  File "/Users/flynn/.asdf/installs/python/3.6.9/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/Users/flynn/.asdf/installs/python/3.6.9/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/Users/flynn/projects/mltest/environment.py", line 18, in <module>
    model = pickle.load(f)
AttributeError: Can't get attribute 'scale' on <module 'environment' from '/Users/flynn/projects/mltest/environment.py'>

Because of this limitation, dill might be more suitable for model serialization in many cases.

Guidelines

For Python ML models, particularly sklearn, here are some guidelines for serialization.

For modeling, it can also be helpful to follow some rules of thumb, namely:

In general, dill will provide the most flexibility in terms of getting the model serialized and should be considered the path of least resistance when it comes to serializing ML models for production.

Further reading

Model persistence

Pickle-based Serialization

Other formats

python

Back to the posts.