InfoBatch: Dataset Pruning on the Fly

January 17, 2024 (1y ago)

Zangwei Zheng, zangwei@u.nus.edu
National University of Singapore

ICLR 2024 Oral
Other version: [arXiv] [Code] [中文]
Discuss on X with the author.

TL;DR

Training for many epochs wastes time on easy, well‑learned samples. InfoBatch speeds things up by dynamically pruning data and rescaling the loss to keep performance. It delivers 20–40% faster training on image classification, semantic segmentation, vision pretraining, diffusion models, and LLM instruction fine‑tuning—without losing accuracy.

overview

How does InfoBatch work?

We provide a plug‑and‑play PyTorch implementation for InfoBatch (under active development). With the three changes shown below, you can plug InfoBatch into your training code.

code

Here is a brief overview of the InfoBatch algorithm.

  • First, InfoBatch randomly drops a fraction (1ratio)(1 - ratio) of samples whose loss is below the average loss over the batch. The paper discusses more advanced strategies, but this simple rule already works very well.
  • Second, for the remaining below‑average‑loss samples, InfoBatch rescales their loss by (1ratio)1(1 - ratio)^{-1} to keep overall training unbiased.
  • Third, at the end of training, InfoBatch runs through all samples once to mitigate forgetting.

The hyperparameter deltadelta controls the fraction of epochs that perform on‑the‑fly pruning. A good starting point is ratio=0.5,delta=0.875ratio = 0.5, delta = 0.875.

In the code above: (1) the dataset is wrapped and the index order is managed, (2) the InfoBatch sampler is passed to the DataLoader constructor, and (3) the loss is rescaled and the sampler is updated with the loss between the forward and backward pass. For more mathematical discussion and ablations, see the paper. For parallel training, see the code.

Applications

The idea behind InfoBatch is simple but effective across many applications.

  • Image classification: 40% speedup with no accuracy drop, unlike prior methods.
  • MAE pretraining: 20% time saved for ViT and Swin, with no downstream accuracy loss.
  • Semantic segmentation: 40% time saved with no mIoU degradation.
  • Diffusion models: 27% time saved with comparable FID.
  • LLM instruction fine‑tuning: 20% time saved.