Neural Prototype Trees

Author:Murphy  |  View: 26136  |  Time: 2025-03-23 18:27:01

Machine Learning and Artificial Intelligence are now being used in a tremendous number of fields, but with this increased use comes increased risks and ethical tests that models need to pass.

Let's take a motivating example with the recent news of a Tesla crashing into a tree while in self-driving mode. According to authorities, the driver said the vehicle pulled to the right, went off the road, and hit a tree once she put it into self-driving mode. Currently, these claims are under investigation, but imagine how difficult it would be to identify the reasoning behind the car's sudden erratic decisions. Did it misclassify? What did it see that confused it? With traditional black-box models, investigations into the model internals are difficult and costly.

So what's the alternative? Is there an interpretable way to do image classification? Yes, through prototype learning [2] and neural prototype trees [1]! With these architectures, the model takes a very intuitive approach to prediction: identifying pieces that look familiar. Does the bird have a long beak? Does it have a red throat? It must be a hummingbird!

In this article, I aim to provide information on how these models work and discuss some of the benefits and downsides of using these models. The two main papers I'll be frequently referencing are below and I highly encourage anyone interested to give these a read:

This Looks Like That: Deep Learning for Interpretable Image Recognition

CVPR 2021 Open Access Repository

This is the second article I've written about variants of neural decision trees. If you haven't read the first article, I highly encourage you to skim through it as many of the concepts outlined here are built off of the standard neural tree.

Neural Networks as Decision Trees

What are Prototypes?

The idea of prototypes in image recognition was first introduced by Chen & Li et. al (2019) in their paper "This Looks Like That: Deep Learning for Interpretable Image Recognition" [2] and is a latent representation of some training image patch which are associated with a given class. As the name suggests, the model works like a human by dissecting the input images and finding prototypical part that provide evidence for the image being a certain class. The network simply calculates the euclidean distance and inverts it to create a similarity score. These scores are then passted through a fully connected layer to create the final class probabilities:

Once the model is trained, users can simply match up the prototypes learned with patches from the training set to create a very interpretable explanation of any prediction:

#

One problem with simply using a bag-of-prototypes model (such as ProtoPNet from the original prototype paper [2]) is that the prototype matching is all done at the same time, but human image recognition relies on a sequence of steps. If something doesn't have paws or ears but has a tail, it probably isn't a cat, therefore the network shouldn't be assigning points for it being a cat. This is where the Neural Protoype Tree [2] comes in. Instead of a bag-of-prototypes, Nauta et. al [1] opt to instead design their model as a neural decision tree. This decision tree provides this sequential decision making as well as providing global explainability instead of just local interpretability.

Review of Soft Decision Trees

The Neural Prototype tree is a soft decision tree, not a hard tree. While hard decision trees enforce deterministic branching (you either go left or right), soft decision trees instead use probabilistic branching (you have a p chance of going left and 1-p probability to go right). Additionally, while a hard decision tree outputs a single value, soft decision trees instead output a probability distribution for all possible classes where the probability of a class is the product of the probabilities that we travel through to reach the leaves and classification decisions are just the class with the highest probability.

Decision Making in a Prototype Tree

Like the standard neural tree, each leaf node contains a probability distribution across the classes. The branching decisions are made by calculating the distance of the image patches to a given prototype in the node. The score for each image is then the minimum distance found between a patch in the image and the prototype which is then translated into a probability. Simply put: go right if you find a prototype in the image, go left if you can't find that prototype!

This mechanism obviously allows for incredibly interpretable models that can be visualized using the same mechanisms we visualized the normal prototype model.

Learning Leaf Distributions

In a normal decision tree, the label of a leaf is learned by looking at the samples that end up in that leaf, but in a soft tree, the distributions in the leafs are part of the global learning problem. However, the authors noticed that clumping together the learning of the leafs with learning the prototypes results in inferior classification results. To rectify this, they leveraged a derivative-free strategy to get an update scheme for the leaf probabilities:

This update scheme was intertwined with mini-batch gradient to learn the prototypes and the convolution parameters to create an efficient learning procedure.

Pruning

To aid in interpretability, the authors also introduced a pruning mechanism. If a leaf node contains an effectively uniform distribution, it doesn't have much discriminating power therefore it is better to prune it since smaller trees are easier to read and interpret. Mathematically, the authors define a threshold t and remove all leaves where the highest class probability is less than _t (max(cl) ≤ t). If all leaves in a subtree are removed, that sub-tree and its associated prototypes can be removed, allowing for a more compact tree. Usually, t = 1/K+ epsilon where K is the number of classes and epsilon is a very small number representing a tolerance.

Performance

The authors benchmarked their methods using the CARS and CUBS dataset against other interpretable image-recognition methods (Such as ones with attention-based interpretability). They found that they were able to get close to SOTA accuracy with a relatively small ensemble of trees with small heights (9 and 11).

Conclusion

Interpretable deep-learning image classifiers offer a number of advantages over black-box models. They can help to build trust, improve debugging, and explain predictions. Additionally, they can be used to explore the data and learn more about the relationships between different features.

Overall, Neural Prototype Trees are a promising new approach to image recognition in a trustworthy manner. A doctor is more likely to trust a cancer-detecting model if he can check the characteristics of the image the model is looking at. These prototype trees can even be augmented with measures like attention to increase the accuracy further!

Resources and References

  1. Github for the Neural Prototype Tree: https://github.com/M-Nauta/ProtoTree
  2. If you're interested in Explainable Machine Learning and AI, consider giving me a follow: https://medium.com/@upadhyan.

References

[1] M. Nauta, R.v. Bree, C. Seifert. Neural Prototype Trees for Interpretable Fine-Grained Image Recognition (2021). IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2021

[2] C. Chen, O. Li, C. Tao. A.J. Barnett, J. Su, C. Rudin. This Looks Like That: Deep Learning for Interpretable Image Recognition (2019). 33rd Conference on Neural Information Processing Systems.

Tags: AI Deep Learning Explainable Ai Image Classification Neural Networks

Comment