Advanced Data Preparation Using Custom Transformers in Scikit-Learn

Author:Murphy  |  View: 21871  |  Time: 2025-03-23 18:21:33
Image by Daniel K Cheung on Unsplash

Scikit-Learn provides many useful tools for data preparation, but sometimes the pre-built options aren't enough.

In this article, I'll show you how to create advanced data preparation workflows using custom Transformers. If you've been using scikit-learn for a while and want to level-up your skills, learning about Transformers is an excellent way to advance beyond "beginner mode" and learn about some of the more advanced capabilities required in modern Data Science teams.

If the topic sounds a bit advanced, don't worry – this article is packed full of examples which will help you feel confident with both the code and the concepts.

I'll start with a brief overview of scikit-learn's Transformer class and then walk through two ways to build customised Transformers:

  1. Using a FunctionTransformer
  2. Writing a custom Transformer from scratch

Transformers: The best way to preprocess data in scikit-learn

The Transformer is one of the central building blocks of scikit-learn. It's so foundational, in fact, that chances are you've already been using one without even realising.

In scikit-learn, a Transformer is any object with the fit() and transform() methods. In plain English, that means a Transformer is a class (i.e. a reusable chunk of code) that takes your raw dataset as an input and returns a transformed version of that dataset.

Image by author

Importantly, scikit-learn Transformers are NOT the same as the "transformers" used in Large Language Models (LLMs) like BERT and GPT-4, or the models available through the HuggingFace transformers library. In the context of LLMs, a "transformer" (lower-case ‘t') is a deep learning model; a scikit-learn Transformer (upper-case ‘T') is a completely different (and much simpler) entity. You can think of it simply as a tool for preprocessing data in a typical ML workflow.

When you import scikit-learn, you get automatic access to a bunch of pre-built Transformers designed for common ML data preprocessing tasks like imputing missing values, rescaling features and one-hot encoding. Some of the most popular Transformers include:

  1. sklearn.impute.SimpleImputer – a Transformer that will replace missing values in your dataset
  2. sklearn.preprocessing.MinMaxScaler – a Transformer that can rescale the numerical features in your dataset
  3. sklearn.preprocessing.OneHotEncoder – a Transformer for one-hot encoding categorical features

Using a scikit-learn sklearn.pipeline.Pipeline, you can even chain together multiple Transformers to build multi-step data preparation workflows, in preparation for subsequent ML modelling:

Image by author

If you're not familiar with Pipelines or ColumnTransformers, they're a great way to simplify your ML code, and you read more about them in my previous article:

Simplify Your Data Preparation With These 4 Lesser-Known Scikit-Learn Classes

What's wrong with the pre-built scikit-learn Transformers?

Nothing at all!

If you're working with simple datasets and performing standard data preparation steps, chances are that scikit-learn's pre-built transformers will be perfectly adequate. There's no need to reinvent the wheel by writing custom ones from scratch.

But – and let's be honest – when are datasets ever really simple in real life?

(Spoiler: never.)

If you're working with real-world data or need to implement some juicy preprocessing method, chances are that scikit-learn's built-in Transformers won't always be adequate. Sooner or later, you're going to need to implement custom data transformations.

Luckily, scikit-learn provides a few ways to extend its basic Transformer functionalities and build more customised Transformers. To showcase how these work, I'll be using the canonical Titanic Survival Prediction dataset. Even on this supposedly "simple" dataset, you'll find that there's plenty of opportunity for getting creative with your data preparation. And, as I'll show, custom Transformers are the ideal tool for the task.

Data

First, let's load the dataset and split it into training and testing subsets:

Python">import pandas as pd
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

# Load data and split into training and testing sets
X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True)
X.drop(['boat', 'body', 'home.dest'], axis=1, inplace=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.2)
X_train.head()
Image by author. Titanic dataset available a CC0 public domain license.

Because this article focuses on how to build customised Transformers, I won't go into detail on the standard preprocessing steps which can be easily applied to this dataset using scikit-learn's in-built Transformers (e.g. one-hot encoding categorical variables like sex using a OneHotEncoder, or replacing missing values using a SimpleImputer).

Instead, I'll focus on how to incorporate more complex preprocessing steps which cannot be implemented using "off-the-shelf" Transformers.

One such step involves extracting each passenger's title (e.g. Mr, Mrs, Master) from the name field. Why might we want to do this? Well, if we know that each passenger's title contains an indication of their class/age/sex, and we assume that these factors influenced passengers' ability to get on lifeboats, it's reasonable to hypothesise that titles might be informative about survival chances. For example, a passenger with a "Master" title (indicating that they are a child) might be more likely to survive than a passenger with a "Mr" title (indicating that they are an adult).

The problem, of course, is that there's no in-built scikit-learn class which can do something as specific as extracting the title from the name field. To extract the titles, we need to build a custom Transformer.

1. FunctionTransformer

The quickest way to build a custom Transformer is through using the FunctionTransformer class, which allows you to create Transformers directly from normal Python functions.

To use a FunctionTransformer, you start by defining a function which takes an input dataset X, performs the desired transformation, and returns a transformed version of X. Then, wrap your function in a FunctionTransformer, and scikit-learn will create a customised Transformer which implements your function.

For example, here's a function that can extract a passenger's Title from the name field of our Titanic dataset:

from sklearn.preprocessing import FunctionTransformer

def extract_title(X):
    """Extract the title from each passenger's `name`."""
    X['title'] = X['name'].str.split(', ', expand=True)[1].str.split('.', expand=True)[0]
    return X

extract_title_transformer = FunctionTransformer(extract_title)
print(type(extract_title_transformer))
# 

As you can see, wrapping the function in a FunctionTransformer turned it into a scikit-learn Transformer, giving it the .fit() and .transform() methods.

We can then incorporate this Transformer into our data preparation pipeline alongside any additional preprocessing steps/transformers we want to include:

from sklearn.pipeline import Pipeline

preprocessor = Pipeline(steps=[
  ('extract_title', extract_title_transformer),
  # ... any other transformers we want to include, e.g. SimpleImputer or MinMaxScaler
])

X_train_transformed = preprocessor.fit_transform(X_train)
X_train_transformed
Image by author

If you want to define a more complex function/Transformer that takes additional arguments, you can pass these to the function by incorporating them into the kw_args argument of FunctionTransformer. For example, let's define another function which identifies whether each passenger is from an upper-class/professional background, based on their title:

def extract_title(X):
    """Extract the title from each passenger's `name`."""
    X['title'] = X['name'].str.split(', ', expand=True)[1].str.split('.', expand=True)[0]

def is_upper_class(X, upper_class_titles):
    """If the passenger's title is in the list of `upper_class_titles`, return 1, else 0."""
    X['upper_class'] = X['title'].apply(lambda x: 1 if x in upper_class_titles else 0)
    return X

preprocessor = Pipeline(steps=[
  ('extract_title', FunctionTransformer(extract_title)),
  ('is_upper_class', FunctionTransformer(is_upper_class,
                      kw_args={'upper_class_titles':['Dr', 'Col', 'Major', 'Lady', 'Rev', 'Sir', 'Capt']})),
    # ... any other transformers we want to include, e.g. SimpleImputer or MinMaxScaler
])

X_train_transformed = preprocessor.fit_transform(X_train)
X_train_transformed
Image by author

As you can see, using FunctionTransformer is a really simple way to incorporate these complex preprocessing steps into a Pipeline without fundamentally changing the structure of our code.

1.1 The limitations of FunctionTransformer: Stateful transformations

FunctionTransformer is a powerful and elegant solution, but it's only suitable when you want to apply stateless transformations (i.e. rule-based transformations which are not dependent on prior values computed from the training data). If you want to define a custom Transformer that can transform testing datasets based on the values observed in the training dataset, you can't use a FunctionTransformer, and you'll need to take a different approach.

If that sounds a bit confusing, take a minute to reconsider the function we just wrote to extract passenger's titles:

def extract_title(X):
    """Extract the title from each passenger's `name`."""
    X['title'] = X['name'].str.split(', ', expand=True)[1].str.split('.', expand=True)[0]

The function is stateless because it has no memory of the past; it does not use any pre-computed values during the operation. Each time we call this function, it will be applied from scratch as if it were being done for the very first time.

A stateful function, by contrast, retains information from previous operations and uses this when implementing the current operation. To illustrate this distinction, here are two functions that replace missing values in our dataset with a mean value:

# Stateless - no prior information is used in the transformation
def impute_mean_stateless(X):
    X['column1'] = X['column1'].fillna(X['column1'].mean())

# Stateful - prior information about the training set is used in the transformation
column1_mean_train = np.mean(X_train['column1'])
def impute_mean(X):
    X['column1'] = X['column1'].fillna(column1_mean_train)
    return X

The first function is a stateless function because no prior information is used in the transformation; the mean is only calculated using the dataset X which is passed to the function.

The second is a stateful function which uses column1_mean_train (i.e. the mean value of column1 from the training set X_train) to replace missing values in X.

The distinction between stateless and stateful transformation might seem a bit abstruse, but it's an incredibly important concept in ML tasks where we have separate training and testing datasets. Whenever we want to replace missing values, scale features or perform one-hot encoding on our testing datasets, we want these transformations to be based on the values observed in the training dataset. In other words, we want our Transformer to be fit to the training set. Using the example of imputing missing values with the mean, we would want the "mean" value to be the mean value of the training set.

The problem with using FunctionTransformer is that it can't be used to implement stateful transformations. Even though a Transformer created with FunctionTransformer technically has the .fit() method, calling it won't do anything, and so we can't really "fit" this Transformer to the training data. Why? Because the transformations in a FunctionTransformer-created Transformer are always dependent on the function's input value X. Our Transformer will always re-calculate the values using the dataset it is passed; it has no way of imputing/transforming with a pre-calculated value.

1.2 An example to illustrate the limitations of FunctionTransformer

To illustrate this, here's an example where I try to "fit" a FunctionTransformer-based Transformer to a training set and then transform the testing set using this supposedly "fitted" transformer. As you can see, the missing values in the testing set are not replaced with the mean value from the training set; they are recalculated based on the testing set. In other words, the Transformer was unable to apply a stateful transformation.

# Show the test set, pre-transformation
X_test.head(3)
Image by author, showing a missing value in the third row of the testing set in the column ‘Age'
print("X_train mean: ", X_train['age'].mean())
# X_train mean:  29.857414148681055

print("X_test mean: ", X_test['age'].mean())
# X_test mean:  29.97444952830189
def impute_mean(X):
    X['age'] = X['age'].fillna(X['age'].mean())
    return X

impute_mean_FT = FunctionTransformer(impute_mean) # Convert function to Transformer
prepro = impute_mean_FT.fit(X_train) # The Transformer is "fitted" to the train set
prepro.transform(X_test) # The fitted Transformer is used to transform the test set
Image by author. The missing value in the third row was replaced with the mean of the testing set, not the mean of the training set, illustrating the inability of FuntionTransformer to produce Transformers capable of stateful transformations.

If this all sounds a bit confusing, don't sweat it. The key takeaway message is: if you want to define a custom Transformer that can preprocess testing datasets based on the values observed in the training dataset, you can't use a FunctionTransformer, and you'll need to take a different approach.

2. Create a custom Transformer from scratch

One alternative approach is to define a new Transformer class which inherits from a class found in the sklearn.base module: TransformerMixin. This new class will then function as a Transformer, and is suitable for applying both stateless and stateful transformations.

Here's how we'd take our extract_title code snippet and turn it into a Transformer using this approach:

from sklearn.base import TransformerMixin

class ExtractTitle(TransformerMixin):
    def fit(self, X, y=None):
        return self
    def transform(self, X, y=None):
        X['title'] = X['name'].str.split(', ', expand=True)[1].str.split('.', expand=True)[0]
        return X

preprocessor = Pipeline(steps=[
    ('extract_title', ExtractTitle()),
])

X_train_transformed = preprocessor.fit_transform(X_train)
X_train_transformed.head()
Image by author

As you can see, we achieve the exact same transformation as we did when constructing our Transformer using FunctionTransformer.

2.1 Passing arguments to a custom Transformer

If you need to pass data to your custom Transformer, simply define an __init__() method before defining the fit() and transform() methods:

class IsUpperClass(TransformerMixin):
    def __init__(self, upper_class_titles):
        self.upper_class_titles = upper_class_titles        

    def fit(self, X, y=None):
        return self

    def transform(self, X, y=None):
        X['upper_class'] = X['title'].apply(lambda x: 1 if x in self.upper_class_titles else 0)
        return X

preprocessor = Pipeline(steps=[
    ('IsUpperClass', IsUpperClass(upper_class_titles=['Dr', 'Col', 'Major', 'Lady', 'Rev', 'Sir', 'Capt'])),
])

X_train_transformed = preprocessor.fit_transform(X_train)
X_train_transformed.head()
Image by author

And there you have it: two methods for building customised Transformers in scikit-learn. The ability to use these methods is an incredibly valuable one, and something which I regularly use in my day-to-day job as a Data Scientist. I hope you've found it useful.

If you liked this article and you'd like to get further tips and insights on working in Data Science, consider following me here on Medium or LinkedIn. If you'd like to get unlimited access to all of my stories (and the rest of Medium.com), you can sign up via my referral link for $5 per month. It adds no extra cost to you vs. signing up via the general signup page, and helps to support my writing as I get a small commission.

Thanks for reading!

Tags: Data Science Machine Learning Python Scikit Learn Tips And Tricks

Comment