Machine Learning
Systems Architect,
PhD Mathematician
I’ve released a new instrumentation package for working with scikit-learn models. sklearn_instrumentation allows for the instrumentation of the sklearn package and any scikit-learn compatible packages with decorators for performance evaluation of methods of sklearn.base.BaseEstimator-derived classes.
The package works by decorating the most common methods of sklearn estimators and transformers that inherit from BaseEstimator. The decorators are straightforward and look like the following simple example.
from functools import wraps
def my_instrument(func, **dkwargs):
    """Wrap an estimator method with instrumentation.
    :param func: The method to be instrumented.
    :param dkwargs: Decorator kwargs, which can be passed to the
        decorator at decoration time. For estimator instrumentation
        this allows different parametrizations for each ml model.
    """
    @wraps(func)
    def wrapper(*args, **kwargs):
        """Wrapping function.
        :param args: The args passed to methods, typically
            just ``X`` and/or ``y``
        :param kwargs: The kwargs passed to methods, usually
            weights or other params
        """
        # Code goes here before execution of the estimator method
        retval = func(*args, **kwargs)
        # Code goes here after execution of the estimator method
        return retval
    return wrapper
The code above accepts a func, which is the estimator method on which we apply instrumentation. The dkwargs keyword arguments are the decorator kwargs, which can modify the decorator at instrumentation time.
The internal function wrapper, wraps the estimator method with logic that can be applied around the execution of the method.
Decorators can also be written as classes, for the cases where statefulness is required across instrumentations of multiple methods or classes. The pattern is similar, and the decorator is applied through the class’ __call__ method, like below.
from functools import wraps
from sklearn_instrumentation.instruments.base import BaseInstrument
class MyInstrument(BaseInstrument)
    def __init__(self, *args, **kwargs):
        # handle any statefulness here
        pass
    def __call__(self, func, **dkwargs):
        """Wrap an estimator method with instrumentation.
        :param func: The method to be instrumented.
        :param dkwargs: Decorator kwargs, which can be passed to the
            decorator at decoration time. For estimator instrumentation
            this allows different parametrizations for each ml model.
        """
        @wraps(func)
        def wrapper(*args, **kwargs):
            """Wrapping function.
            :param args: The args passed to methods, typically
                just ``X`` and/or ``y``
            :param kwargs: The kwargs passed to methods, usually
                weights or other params
            """
            # Code goes here before execution of the estimator method
            retval = func(*args, **kwargs)
            # Code goes here after execution of the estimator method
            return retval
        return wrapper
Instrumentation is applied to the most common methods in sklearn estimators, namely fit, predict, predict_proba, and transform. The package also allows the user to specify their own set of methods to decorate.
Instrumentation works in one of three different implementations. To apply instrumentation, first create an instrumentor. Here, we create one with a logging instrument that logs elapsed time of function calls.
from sklearn_instrumentation import SklearnInstrumentor
from sklearn_instrumentation.instruments.logging import TimeElapsedLogger
instrument = TimeElapsedLogger()
instrumentor = SklearnInstrumentor(instrument=instrument)
From here we have one of three ways to apply instrumentation.
Package instrumentation works by crawling the package module hierarchy for classes inheriting from BaseEstimator. Any such class will be instrumented by decorating the common methods with the instrument. Multiple packages can be passed, so other scikit-learn compatible packages can also be configured.
instrumentor.instrument_packages(["sklearn", "xgboost", "lightgbm"])
Estimator instrumentation works by crawling the instance attribute hierarchy of a machine learning metaestimator, e.g. a Pipeline. The instrumentor will inspect the attribute hierarchy of the estimator recursively, instrumenting any of the configured instance methods found.
instrumentor.instrument_estimator(classification_pipeline)  # a pipeline instance
Class instrumentation works also by by crawling the instance attribute hierarchy of a machine learning metaestimator. However, the difference is that it instruments the respective classes found rather than the instances themselves. This has the same effect of package instrumentation except that it’s faster, consumes less memory, and only instruments the classes found in the model. It also allows instrumentation to persist through fitting on metaestimators that utilize cloning.
instrumentor.instrument_estimator_classes(classification_pipeline)  # a pipeline instance
After instrumentation, we’ll have logging output of time elapsed for each estimators’ method calls. Here is a full example of a model using the TimeElapsedLogger instrument.
import logging
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import FeatureUnion
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn_instrumentation import SklearnInstrumentor
from sklearn_instrumentation.instruments.logging import TimeElapsedLogger
logging.basicConfig(level=logging.INFO)
# Create an instrumentor and instrument sklearn
instrumentor = SklearnInstrumentor(instrument=TimeElapsedLogger())
instrumentor.instrument_packages(["sklearn"])
# Create a toy model for classification
ss = StandardScaler()
pca = PCA(n_components=3)
rf = RandomForestClassifier()
classification_model = Pipeline(
    steps=[
        (
            "fu",
            FeatureUnion(
                transformer_list=[
                    ("ss", ss),
                    ("pca", pca),
                ]
            ),
        ),
        ("rf", rf),
    ]
)
X, y = load_iris(return_X_y=True)
# Observe logging
classification_model.fit(X, y)
# INFO:sklearn_instrumentation.instruments.logging:Pipeline.fit starting.
# INFO:sklearn_instrumentation.instruments.logging:Pipeline._fit starting.
# INFO:sklearn_instrumentation.instruments.logging:StandardScaler.fit starting.
# INFO:sklearn_instrumentation.instruments.logging:StandardScaler.fit elapsed time: 0.0006406307220458984 seconds
# INFO:sklearn_instrumentation.instruments.logging:StandardScaler.transform starting.
# INFO:sklearn_instrumentation.instruments.logging:StandardScaler.transform elapsed time: 0.0001430511474609375 seconds
# INFO:sklearn_instrumentation.instruments.logging:PCA._fit starting.
# INFO:sklearn_instrumentation.instruments.logging:PCA._fit elapsed time: 0.0006711483001708984 seconds
# INFO:sklearn_instrumentation.instruments.logging:Pipeline._fit elapsed time: 0.0026731491088867188 seconds
# INFO:sklearn_instrumentation.instruments.logging:BaseForest.fit starting.
# INFO:sklearn_instrumentation.instruments.logging:BaseForest.fit elapsed time: 0.1768970489501953 seconds
# INFO:sklearn_instrumentation.instruments.logging:Pipeline.fit elapsed time: 0.17983102798461914 seconds
# Observe logging
classification_model.predict(X)
# INFO:sklearn_instrumentation.instruments.logging:Pipeline.predict starting.
# INFO:sklearn_instrumentation.instruments.logging:FeatureUnion.transform starting.
# INFO:sklearn_instrumentation.instruments.logging:StandardScaler.transform starting.
# INFO:sklearn_instrumentation.instruments.logging:StandardScaler.transform elapsed time: 0.00024509429931640625 seconds
# INFO:sklearn_instrumentation.instruments.logging:_BasePCA.transform starting.
# INFO:sklearn_instrumentation.instruments.logging:_BasePCA.transform elapsed time: 0.0002181529998779297 seconds
# INFO:sklearn_instrumentation.instruments.logging:FeatureUnion.transform elapsed time: 0.0012080669403076172 seconds
# INFO:sklearn_instrumentation.instruments.logging:ForestClassifier.predict starting.
# INFO:sklearn_instrumentation.instruments.logging:ForestClassifier.predict_proba starting.
# INFO:sklearn_instrumentation.instruments.logging:ForestClassifier.predict_proba elapsed time: 0.013531208038330078 seconds
# INFO:sklearn_instrumentation.instruments.logging:ForestClassifier.predict elapsed time: 0.013692140579223633 seconds
# INFO:sklearn_instrumentation.instruments.logging:Pipeline.predict elapsed time: 0.015219926834106445 seconds
The package comes with several logging instruments. Optionally, extra dependencies can be installed which allow for performance profiling using the memory-profiler and pyinstrument packages. Additionally, instrumentation is provided for production system monitoring using instruments that support observability metrics (Statsd, Prometheus), as well as distributed tracing implementations.