What Makes a "Good" Data Augmentation in Knowledge Distillation -- A Statistical Perspective

Huan Wang1,2,†  Suhas Lohit2,*  Mike Jones2  Yun Fu1

NeurIPS 2022

1Northeastern University, Boston, MA  2MERL, Cambridge, MA
Work done when Huan was an intern at MERL *Corresponding author: slohit@merl.com
We present a proven proposition to precisely answer "What makes a good data augmentation (DA) in knowledge distillation (KD)?": A good DA should reduce the covariance of the teacher-student cross-entropy. We present a practical metric that only needs the teacher to measure the goodness of a DA in KD: the stddev of teacher’s mean probability (shorted as T. stddev). Interestingly, T. stddev poses a strong positive correlation (note the p-values are far below 5%) with student’s test loss (S. test loss) on CIFAR100 and Tiny ImageNet (see the right figure above), despite knowing nothing about the student, implying the goodness of a DA in KD is probably student-invariant. Based on the theory, we further propose an entropy-based data picking algorithm that can further boost prior SOTA DA scheme (CutMix) in KD, resulting in a new strong DA method, CutMixPick. Finally, we show how the theory can be utilized in practice to harvest considerable performance gains simply by using a stronger DA with prolonged training epochs.

Abstract

Knowledge distillation (KD) is a general neural network training approach that uses a teacher to guide a student. Existing works mainly study KD from the network output side (e.g., trying to design a better KD loss function), while few have attempted to understand it from the input side. Especially, its interplay with data augmentation (DA) has not been well understood. In this paper, we ask: Why do some DA schemes (e.g., CutMix) inherently perform much better than others in KD? What makes a “good” DA in KD? Our investigation from a statistical perspective suggests that a good DA scheme should reduce the variance of the teacher’s mean probability, which will eventually lead to a lower generalization gap for the student. Besides the theoretical understanding, we also introduce a new entropy-based data-mixing DA scheme to enhance CutMix. Extensive empirical studies support our claims and demonstrate how we can harvest considerable performance gains simply by using a better DA scheme in knowledge distillation.

Correlation Results on CIFAR100 & Tiny ImageNet





Correlation Results on ImageNet100

ImageNet100 is a 100-class subset randomly drawn from the full ImageNet-1K dataset. The correlation turns weaker (compared to CIFAR100 and Tiny ImageNet) generally speaking, due to the fact that ImageNet100 is essentially more challenging than CIFAR100 and Tiny ImageNet. But still, suggested by the p-values, the correlation is fairly strong. This means the effectiveness of our metric can generalize to the standard 224x224 RGB images.

Correlation Results on ImageNet

The results on ImageNet are not very much aligned with our expectation, as the correlation between S. test loss and T. stddev below is not statistically significant. We don't know why for now (presumbly think it might be related to the number of classes of the dataset). We consider this as a limitation of this work and shall explore it in the future version.

Boosting KD via Stronger KD + More Epochs
(CIFAR100 and Tiny ImageNet)

As shown below, compared to the original KD results, we can harvest considerable performance gains simply by using a stronger DA (CutMix and our proposed CutMixPick) with more training epochs.

BibTeX

@inproceedings{wang2022what,
    author = {Huan Wang and Suhas Lohit and Mike Jones and Yun Fu},
    title = {What Makes a "Good" Data Augmentation in Knowledge Distillation -- A Statistical Perspective},
    booktitle = {NeurIPS},
    year = {2022},
}