Leverage Python Inheritance in ML projects

Author:Murphy  |  View: 23722  |  Time: 2025-03-22 19:36:44

Introduction

Many people approaching Machine Learning don't have a strong background in computer engineering, and when they need to work on a real product their code can be messy and difficult to manage. This is why I always strongly recommend learning to use coding best practices which will enable you to work smoothly within a team and level up the project you're working on. Today I want to talk about Python inheritance and show some simple examples of how to use it within the field of Machine Learning.

In software development and other information technology fields, technical debt (also known as design debt or code debt) is the implied cost of future reworking because a solution prioritizes expedience over long-term design.

If you are interested in learning more about design patterns you might be interested in some of my previous articles.

Python Inheritance

Inheritance it's not just a Python concept but a general concept in Object Oriented Programming. So in this tutorial, we have to deal with classes and objects which is a programming paradigm not very used in Python with respect to other languages like Java.

In OOP, we can define a general class representing something in the world, for example, a Person which we simply define by a name, surname and age in the following way.

Python">class Person:
    def __init__(self, name, surname, age):
        self.name = name
        self.surname = surname
        self.age = age

    def __str__(self):
        return f"Name: {self.name}, surname: {self.surname}, age: {self.age}"

    def grow(self):
        self.age +=1

In this class, we defined a simple constructor ( init). Then we defined the str method, which will take care of printing the object in the way we desire. Finally, we have the grow() method to make the person one year older.

Now we can instantiate an object and use this class.

person = Person("Marcello", "Politi", 28)
person.grow()
print(person)

# output wiil be
# Name: Marcello, surname: Politi, age: 29

Now what if we want to define a particular type of person, for example, a worker? Well, we can do the same thing as before, but we add another input variable to add its salary.

class Worker:
    def __init__(self, name, surname, age, salary):
        self.name = name
        self.surname = surname
        self.age = age
        self.salary = salary

    def __str__(self):
        return f"Name: {self.name}, surname: {self.surname}, age: {self.age}, salary: {self.salary}"

    def grow(self):
        self.age +=1

That's it. But is this the best way to implement this? You see that most of the Worker code is the same as the Person code, this is because a worker is a particular person, and then it shares many things in common with a person.

What we can do, is to tell Python that the worker should inherit everything from the Person, and then manually add all the things we need, that a general person doesn't have.

class Worker(Person):
    def __init__(self, name, surname, age, salary):
        super().__init__(name, surname, age)
        self.salary = salary

    def __str__(self):
        text = super().__str__()
        return text + f",salary: {self.salary}"

In the worker class, the constructor calls the constructor of the person class leveraging the super() keyword and then adds also the salary variable.

Same thing when defining the str method. We use the same text return from Person using the super keyword, and add the salary when printing the object.

Inheritance in Machine Learning

There are no rules on when to use inheritance in Machine Learning. I don't know what project you're working on, or what your code looks like. I just want to stress the fact that you should adopt an OOP paradigm in your codebase. But still, let's see some examples of how to use inheritance.

Define a BaseModel

Let's code a base machine learning model class that is defined by some standard variable. This class then will have a method to load the data, one to train, another to evaluate, and one to preprocess the data. However, each specific model will preprocess the data differently, so the subclasses that will inherit the base model shall rewrite the preprocessing method. Be alert, the BaseMLModel itself inherit the ABC class. This is a way to tell Python that this class is an abstract class, and shall not be used, but it's only a template to build subclasses.

The same is true for the _preprocess_traindata which is marked a @abstactmethod. This means that subclasses must reimplement this method.

Check this video to learn more about abstract classes and methods:

from abc import ABC, abstractmethod
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
import numpy as np

class BaseMLModel(ABC):
    def __init__(self, test_size=0.2, random_state=42):
        self.model = None  # This will be set in subclasses
        self.test_size = test_size
        self.random_state = random_state
        self.X_train = None
        self.X_test = None
        self.y_train = None
        self.y_test = None

    def load_data(self, X, y):
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
            X, y, test_size=self.test_size, random_state=self.random_state
        )

    @abstractmethod
    def preprocess_train_data(self):
        """Each model can define custom preprocessing for training data."""
        pass

    def train(self):
        self.X_train, self.y_train = self.preprocess_train_data()
        self.model.fit(self.X_train, self.y_train)

    def evaluate(self):
        predictions = self.model.predict(self.X_test)
        return accuracy_score(self.y_test, predictions)

Now let's see how we can inherit from this class. First, we can implement a LogisticRegressionModel. Which will have its own preprocessing algorithm.


class LogisticRegressionModel(BaseMLModel):
    def __init__(self, **kwargs):
        super().__init__()
        self.model = LogisticRegression(**kwargs)

    def preprocess_train_data(self):
        #Standardize features for Logistic Regression
        mean = self.X_train.mean(axis=0)
        std = self.X_train.std(axis=0)
        X_train_scaled = (self.X_train - mean) / std
        return X_train_scaled, self.y_train

Then we can define as many subclasses as we want. I define here one for a Random Forest.

class RandomForestModel(BaseMLModel):
    def __init__(self, n_important_features=2, **kwargs):
        super().__init__()
        self.model = RandomForestClassifier(**kwargs)
        self.n_important_features = n_important_features

    def preprocess_train_data(self):
        #Select top `n_important_features` features based on variance
        feature_variances = np.var(self.X_train, axis=0)
        top_features_indices = np.argsort(feature_variances)[-self.n_important_features:]
        X_train_selected = self.X_train[:, top_features_indices]
        return X_train_selected, self.y_train

Then we can use all of this in our main function:

if __name__ == "__main__":
    # Load dataset
    data = load_iris()
    X, y = data.data, data.target

    # Logistic Regression
    log_reg_model = LogisticRegressionModel(max_iter=200)
    log_reg_model.load_data(X, y)
    log_reg_model.train()
    print(f"Logistic Regression Accuracy: {log_reg_model.evaluate()}")

    # Random Forest
    rf_model = RandomForestModel(n_estimators=100, n_important_features=3)
    rf_model.load_data(X, y)
    rf_model.train()
    print(f"Random Forest Accuracy: {rf_model.evaluate()}")

Final Thoughts

One of the main benefits of Python's inheritance in ML projects is in the design of modular, maintainable, and scalable codebases. Inheritance helps avoid redundant code by writing common logic in a base class, such as BaseMLModel. Therefore reducing code duplication. Inheritance also makes it easy to encapsulate common behaviours in a base class, allowing subclasses to define particular details.

The main benefit in my opinion is that a well-organized, object-oriented codebase allows multiple developers within a team to work independently on separate parts. In our example, a lead engineer could define the base model, and then each developer could focus on a single algorithm and write the subclass.

Before diving into complex design patterns, focus on leveraging OOP best practices. Doing so will make you a better programmer compared to many others in the ML field.

Follow me on Medium if you like this article!

Tags: Data Science Machine Learning Programming Python python-inheritance

Comment