Unsupervised data pruning: less data to learn better

Author:Murphy  |  View: 26813  |  Time: 2025-03-23 19:37:06

Foundation models | Scaling law | Large models | Data pruning

image by the author using DALL-E

Scaling law has been observed in different contexts (pictures, text, language, speech, and so on). Is increasing the number of parameters really the only recipe for a better model? And if not what you can actually do?

What is scaling law and why is it problematic

In recent years we have seen the number of parameters in models increase by leaps and bounds. All the biggest companies have been pushing to create more and more capable models. This has led to a reduced error in benchmark datasets and the emergence of unanticipated behavior. But what is scaling law?

In short, the scaling law states that the "test error often falls off as a power law with either the amount of training data, model size, or compute." Put another way, to improve the performance of a model one must increase one of these three factors: the number of examples during training, the number of parameters, or the duration of training.

Previous technical studies suggested that test loss could decrease as a power law of the training dataset. This was formally defined in 2017 when Hestness explored this idea in different machine-learning domains (machine translation, language modeling, image processing, and speech recognition).

The scaling law was then defined in an OpenAI article where they showed that by increasing the size of the model or dataset or the amount of computing used for training, performance improved.

scaling law as shown by OpenAI (source: here)

Formally they wrote:

Model performance depends most strongly on scale, which consists of three factors: the number of model parameters N (excluding embeddings), the size of the dataset D, and the amount of compute C used for training. Within reasonable limits, performance depends very weakly on other architectural hyperparameters such as depth vs. width. – source

In summary, they assert that performance has a power-law relationship with each of these three factors. In addition, if you increase the number of parameters N you must also increase the number of examples D in tandem, or a penalty (overfitting) emerges. There is a relationship between N and D whereby you have to increase the amount of data by 5 times if you increase the number of parameters by 8 times.

This concept was taken to extremes with GPT-3 and later models (Google's LaMDA reached more than 500 B parameters). These models have shown incredible capabilities, and some have speculated that massively increasing the number of parameters will give us general intelligence. And will it?

No. In a nutshell, neural networks are pattern-matching machines (or universal approximators). What neural networks do is find patterns that they saw in the training set. A neural network that is larger and has been trained with more data may contain more patterns in its parameters and may recognize more of them. So infinite data? Well, the data is not infinite.

Such power law scaling has motivated significant societal investments in data collection, compute, and associated energy consumption. However, power law scaling is extremely weak and unsustainable. – source

It can be seen in the OpenAI article that it takes an order of magnitude of data or parameters or compute to decrease the error by a mere 3–2 %. Also, to scale vision transformers it took two billion data points to get a few more accuracy points in ImageNet.

source (here)

To recap, so far we have witnessed a paradigm that states more is better. But is it the only strategy?

"can we achieve exponential scaling instead, with a good strategy for selecting training examples?"

Meanwhile, it can be said that many of the data are redundant. The model often sees many similar examples. Datasets are often acquired by randomly downloading thousands and thousands of examples from the Internet. Previous studies have shown that one can, for example, sort the examples in the training set in order of increasing difficulty (easy and redundant examples) down to the most difficult ones. One can reduce the number of examples in the dataset while maintaining performance (in other words, easy and redundant examples only consume training cycles without adding anything to learning).

More about these concepts, how efficiently scale vision transformers and the difference between model-centric and data-centric AI:

Why Do We Have Huge Language Models and Small Vision Transformers?

A critical analysis of your dataset

Unsupervised data pruning: can you remove useless training data points without knowing the labels?

Previous studies have left different questions open. Meanwhile, whether a power law of error with respect to data can also be defined (exponentially reducing the number of examples without sacrificing performance). Second, the strategies described above require that the collected examples be labeled anyway (time-consuming and expensive). Therefore, an optimal strategy should be unsupervised.

A paper was recently published that tries to answer these questions. The study was conducted by a collaboration between META, Stanford, and the University of Tubingen:

Beyond neural scaling laws: beating power law scaling via data pruning

The authors started from the fact that the scaling law is inefficient: The exponents of the power law are close to zero (suggesting inefficient use of resources). Moreover, an increase in the parameters or amount of data leads to a minimal reduction in error. What we would like is that one can prune a dataset without altering the performance of the model, even if the dataset is unlabeled (after all, labeling a dataset is one of the most expensive and time-consuming operations). How?

The authors explored this possibility in the case of the teacher-student setting. This is a training method in which you have a model that has been trained with a large number of examples (not restricted to a CNN even though it is one of the most widely used cases). You provide some data to the pre-trained, larger model (teacher) and train the smaller model using instead of original labels the teacher's output (the teacher provides class probabilities as output, also called soft labels).

in a nutshell, the authors used CIFAR-10 as a dataset, obtained probabilities from a teacher model, and trained a student model for a few epochs using the teacher's output probabilities as labels. They then calculated the margin between student and teacher outputs (which provide a measure of the learning difficulty of an example for a model).

image source: original article

The first interesting result is when the number of examples in the training set is large it is better to keep the difficult examples while pruning (and thus reduce the easy examples). When you have few examples in the training set it is better to prune the difficult examples. This may seem counterintuitive. The authors note that the easy examples provide coarse-grained information about the target function (basically the general patterns of the dataset). In contrast, the more difficult examples provide fine-tuned information that might be missed in a large dataset.

Intuitively, in the limited data regime, it is challenging to model outliers since the basics are not adequately captured; hence, it is more important to keep easy examples so that the model can get to moderate error. However, with a larger dataset, the easy examples can be learned without difficulty, making modeling outliers the fundamental challenge. source: here

In other words, when there is little data it is better for the model to learn general information about the patterns in the data, while when there is a lot of data the hard examples help the model to better understand the decision boundary between classes.

Also from an information-theoretic perspective, the authors suggest that data pruning increases the information obtained from each individual example as uninformative examples are filtered.

Data pruning improves transfer learning

image from Pixabay (source: here)

One of the reasons for our interest in scaling law is that we are interested in having foundation models. A foundation model is a wide model (transformer, vision transformer, and so on) that has been trained with a large amount of unlabeled data and then can be used for different downstream tasks.

Simply put, we train a very wide model on a large amount of data, and then through transfer learning, we adopt it for another task. Examples are BERT, GPT3 for text tasks, and ResNet for computer vision tasks (actually DALL-E and stable diffusion also have a pre-trained language model as a component).

Training a foundation model is extremely expensive, and so far we have tried to increase its parameters and the amount of data used for training. However, DeepMind's Chinchilla and subsequent studies show that improving the data used for training would perhaps be more useful. So the authors of this study wondered: could data pruning help with transfer learning?

The authors used a pre-trained vision transformer (ViT) and then fine-tuned it on a pruned subset of 10 % of CIFAR-10. This approach worked better than fine-tuning the ViT on the entire CIFAR-10. In addition, the authors pre-trained ResNet50 on different pruned subsets of ImageNet (a reduced version) and then conducted fine-tuning on CIFAR-10. The result shows that training a model on a pruned dataset has better performance than using the whole ImageNet.

image source: original article

Thus intriguingly pruning pre-training data on an upstream task can still maintain high performance on a different downstream task. Overall these results demonstrate the promise of data pruning in transfer learning for both the pre-training and fine-tuning phases.

SparseGPT: fewer parameters is better?

Scaling the approach on a large dataset

Previous pruning studies have been done on small datasets but it is important to know how these generalize to large datasets. For these reasons, the authors benchmarked the various previous approaches on ImageNet and how they impact model performance (they chose 8 different approaches).

The result shows that these metrics retain only a fraction of the hard examples and perform better than random pruning. But while they perform well in a small dataset, only a few of them match the performance obtained by training on the full dataset. Also, note the authors:

We found that all pruning metrics amplify class imbalance, which results in degraded performance. Fig. 5 shows many data pruning metrics do not scale well to ImageNet, while the few that do require substantial amounts of compute. Furthermore, all these metrics require labels, thereby limiting their ability to prune data for large-scale foundation models trained on massive unlabeled datasets. Thus there is a clear need for simple, scalable, self-supervised pruning metrics. source: here

image source: original article

The authors as a solution propose:

  • The first step is to use a pre-trained model called SWaV to extract a low-dimensional representation of each example in the dataset.
  • Using k-means clustering they group the representation of the examples.
  • After that, calculate the distance to the center of the cluster using cosine distance. If an example is closer to the center of the cluster (thus closer as a representation to the others as well) it is considered an easy example to rank, but if it is farther from the center it is a difficult example.
  • Lastly, one can decide to prune a percentage of easy or difficult examples according to the case.

we find our self-supervised prototype metric matches or exceeds the performance of the best supervised metric, memorization, until only 70–80% of the data is kept, despite the fact that our metric does not use labels and is much simpler and cheaper to compute than many previously proposed supervised metrics. source: here

The results matched the state-of-the-art technique memorization which requires labels and it is much slower to compute.

image source: original article

Conclusions

The authors show how data pruning can affect errors on par with the scaling law. In addition, they show how through unsupervised learning one can obtain a coreset (a subset of a dataset that allows one to train a model as performing equal to the full dataset). The approach is inexpensive, scalable, and does not need labels.

Looking forward, the authors suggest that this approach can still be improved and allow even more aggressive pruning. Which would be extremely useful for the training of large foundation models. They also suggest:

If highly pruned versions of these datasets can be used to train a large number of different models, one can conceive of such carefully chosen data subsets as foundation datasets in which the initial computational cost of data pruning can be amortized across efficiency gains in training many downstream models, just at the initial computational cost of training foundation models is amortized across the efficiency gains of fine-tuning across many downstream tasks. source: here

In conclusion, reducing the dataset before training saves time and cost (less labeling work). Not to mention that prospectively, if you reduce the overrepresented populations this would help combat and/or identify bias during training.

What do you guys think? Have you tried pruning your dataset?

If you have found this interesting:

You can look for my other articles, you can also subscribe to get notified when I publish articles, and you can also connect or reach me on LinkedIn.

Here is the link to my GitHub repository, where I am planning to collect code and many resources related to machine learning, Artificial Intelligence, and more.

GitHub – SalvatoreRa/tutorial: Tutorials on machine learning, artificial intelligence, data science…

or you may be interested in one of my recent articles:

Microsoft BioGPT: Towards the ChatGPT of life science?

Everything but everything you need to know about ChatGPT

Tags: Artificial Intelligence Data Science Machine Learning Mathematics Science

Comment