14 May 2023 34 min 14_5_2023_Data_Free_Knowledge_Distillation

Table of Contents

Introduction

In this post, I'm going to review the current state-of-the-art in data-free knowledge distillation. I'll start with the basics of knowledge distillation, go through the motivation for data-free knowledge distillation, and then introduce several works including state-of-the-art that I consider important to understand the subject.

Knowledge Distillation

Knowledge distillation (Hinton et al., 2015) is the technique of transferring knowledge from a model or ensemble of models (teacher) to a another model (student). As the field walks towards larger models, it is becoming increasingly important to find ways to transfer that knowledge to smaller models that can be deployed in hardware with limited resources or applications where energy consumption is a concern. Some of the most common applications of knowledge distillation are:

  • Transfer to a smaller version of the teacher model with fewer layers and fewer neurons per layer
  • Transfer to a smaller architecture
  • Transfer to a quantized version of the teacher model

Response-based KD

The most common form of knowledge distillation is the response-based knowledge distillation. In this form of knowledge distillation, the student model is trained to mimic the output of the teacher model. The equation shows a possible loss function in a multi-class classification problem using the cross-entropy loss:

LKD=DKL(Softmax(T(x))Softmax(S(x)))\mathcal{L}_{KD} = D_{KL}(\text{Softmax}(T(\mathbf{x})) || \text{Softmax}(\mathbf{S(x)}))

Feature-based KD

In feature-based knowledge distillation, the student model is trained to mimic the intermediate-layers features of the teacher model. The equation shows a possible loss function that minimizes the mean squared error between the features of the teacher and the student, where fTf_T and fSf_S are the features of the teacher and student respectively, and ΦT\Phi_T and ΦS\Phi_S are projections that map the features from the teacher and student to the same space:

LKD=MSE(ΦT(fT(x)),ΦS(fS(x)))\mathcal{L}_{KD} = MSE(\Phi_T(f_T(x)), \Phi_S(f_S(x)))

Relation-based KD

Relation-based knowledge distillation explores relations between different layers and data samples and optimizes the student model to mimic those relations. For example, Yim et al. (2017) uses the FST matrix between the features of two layers to measure the relation between them. The student model is then trained to minimize the difference between its FST matrix and the FST matrix of the teacher model.

Data-Free Knowledge Distillation

In many cases, the teacher model is trained on datasets that are very large or are not publicly available for copyright or privacy reasons. Which makes it difficult to distil knowledge from those models using the same data where they were trained on. Introduced in Lopes et al. (2017), data-free knowledge distillation is the process of distilling knowledge from a model without using any data. That original work used metadata from the original dataset to reconstruct a dataset that was similar to the original dataset. However, in many cases, the metadata is not available or is difficult to obtain. For that reason, the following works have been proposing new methods that try to create a dataset without any metadata.

In this section, I'll introduce several works that I consider important to understand the current state-of-the-art in data-free knowledge distillation.

Zero-Shot Knowledge Distillation in Deep Networks

ZSKD (Nayak et al., 2019) introduces a method to synthesize data impressions by optimising the output of the teacher model to match a sample from a Dirichlet distribution. Since the model is frozen the only parameter that can optimize the loss function is the data impression.

Give an output ss, after the softmax layer of the teacher model, which is a vector of probabilities of a certain input belonging to a certain class, the authors propose to model the probability of the output yy belonging to a class cc as a Dirichlet distribution, p(yc)=Dir(C,αc)p(y^c) = Dir(C, \alpha^c). Where C is the number of classes and αc\alpha^c is the concentration parameter of the Dirichlet distribution.

To find the concentration parameter αc\alpha^c they compute the cosine similarity matrix between all the C classes and use the c-th row of the matrix as the concentration parameter αc\alpha^c.

Then for each class, cc, are sampled NN outputs yy from the Dirichlet distribution and optimized one data impression for each yy using a cross-entropy loss.

Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion

DeepInversion (Yin et al., 2019) synthesizes data impressions that maximize the confidence of the model for a class yiy_i while regularizing using the statistics in the batch norm layers, and the l2 norm and variance of the outputs. By using the statistics of the batch norm layers, DeepInversion is able to generate data impressions that recover some of the statistics of the original dataset.

The loss function is the following:

L=LCE(T(x),yi)+λL2LL2+λVarLVar+λfeatureLfeature\mathcal{L} = \mathcal{L}_{CE}(T(x), y_i) + \lambda_{L2} \mathcal{L}_{L2} + \lambda_{Var} \mathcal{L}_{Var} + \lambda_{feature} \mathcal{L}_{feature}
LL2=T(x)22\mathcal{L}_{L2} = -\left\| T(x) \right \|_2^2
LVar=cσ(T(x)c)2\mathcal{L}_{Var} = \sum_c \sigma(T(x)_c)^2
Lfeature=lμl(x^)E[μl(x)]2+lσl(x^)E[σl(x)]2\mathcal{L}_{feature} = \sum_l \left\| \mu_l(\hat{x})-\mathbb{E}[\mu_l(x)] \right \|_2 + \sum_l \left\| \sigma_l(\hat{x})-\mathbb{E}[\sigma_l(x)] \right\|_2

Adaptive DeepInversion

To further improve the quality of the generated data impressions, (Yin et al., 2019) proposed Adaptive DeepInversion. Given that the student knowledge evolves during the training process it makes sense to generate data impressions where the teacher and student disagree. The Adaptive DeepInversion loss extends the DeepInversion loss by adding a term that maximizes the Jensen-Shannon divergence between the teacher and student models.

LAdaptativeDeepInversion=LDeepInversion+λAdversarial(1JS(T(x),S(x)))\mathcal{L}_{AdaptativeDeepInversion} = \mathcal{L}_{DeepInversion} + \lambda_{Adversarial} (1 - JS(T(x), S(x)))

Data-Free Learning of Student Networks

DAFL (Chen et al., 2019) is another method that uses a generative model to generate data impressions. They employ a cross-entropy loss to minimise the entropy of the teacher logits for the different classes, maximise L1 norm of the teacher logits, and maximise the information entropy of the average of the teacher logits to encourage a balanced representation of the different classes.

Data-Free Knowledge Distillation with Soft Targeted Transfer Set Synthesis

Wang (2021) introduces a method that models the output of the l-th layer of the teacher model as Multivariate normal distribution to generate data impressions. Considering a Teacher model TT, with LL layers, the output of the l-th layer is sls^l a vector of size KK. sls^l is modeled as a Multivariate normal distribution with mean μl\mu^l and covariance matrix Σl\Sigma^l. The mean considered is μl=0\mu^l=0. To obtain the covariance matrix Σl\Sigma^l they argue that the correlation between the outputs of the l-th layer is implictly encoded in the weights of the l-th layer. Therefore, they use the weights of the l-th layer to compute the correlation matrix, Rl=wiwjwiwjR^l=\frac{w_i^\top w_j}{\left\| w_i \right\| \cdot \left\| w_j \right\|}. Then, the covariance matrix is given by Σl=D×Rl×D\Sigma^l = D \times R^l \times D where DD is a learnable parameter.

To generate data impressions, they sample in total NN outputs ysoftly^l_{soft} from the different layers and NN noise inputs and optimize the data impression to minimise the KL divergence between the sample of the l-th layer ysoftly_{soft}^l and the current output of the l-th layer y^l\hat{y}^l, as shown in the following equation:

Lmv=KL(ysoftly^l)\mathcal{L}_{mv} = \text{KL}(y_{soft}^l || \hat{y}^l)

Additionally, as commonly used in synthetic data generation, they add a term to the loss function that encourages higher activations to the last convolutional layer. This term is given by:

Lact=1ni=1nsilastConv1\mathcal{L}_{act} = - \frac{1}{n} \sum_{i=1}^n \left\| s^{lastConv}_i \right\|_1

The final loss function to optimize the data impressions is given by:

L=Lmv+λactLact\mathcal{L} = \mathcal{L}_{mv} + \lambda_{act} \mathcal{L}_{act}

Zero-shot Knowledge Transfer via Adversarial Belief Matching

ZSKT (Micaelli & Storkey, 2019) trains a generator to generate the data impressions in an adversarial way, i.e. maximize the disagreement between the teacher and the student models.

Ladv=DKL(T(x)S(x))\mathcal{L}_{adv} = -D_\text{KL}(T(x) || S(x))

Additionally, when transferring knowledge from a teacher model to a student model, they use an additional loss term that exploits the fact that in some cases have a similar block structure to minimize the l2 distance between the normalized spatial attention maps of the hidden layers of the teacher and student models. A spatial attention map of a layer ll is denoted as Al\mathcal{A}_l and is computed as: Al=1Clc=1Clalc2\mathcal{A}_l = \frac{1}{C_l} \sum_{c=1}^{C_l} \mathcal{a}_{lc}^2, where ClC_l is the number of channels of the layer ll and alc\mathcal{a}_{lc} is the activation of the channel cc of the layer ll. The final distillation loss is given by:

Ldist=LCE(S(x),T(x))+βl=1LAlSAlS2AlTAlT2\mathcal{L}_{dist} = \mathcal{L}_{CE}(S(x), T(x)) + \beta \sum_{l=1}^L \left\| \frac{\mathcal{A}^S_l}{\left\| \mathcal{A}^S_l \right\|_2} - \frac{\mathcal{A}^T_l}{\left\| \mathcal{A}^T_l \right\|_2} \right\|

Data-free network quantization with adversarial knowledge distillation

DFQ (Choi et al., 2020) is a method to distil knowledge from a teacher model to a quantized student model. To avoid the mode collapse that tends to affect generator networks, which leads to a low diversity of the generated data impressions, they use ensembles of generators and students which act as discriminators. To optimize the generators they use the batch normalization loss, minimise the entropy of each of the teacher outputs, and maximize the entropy of the average of the teacher outputs in the batch. Moreover, to ensure that the generated data impressions are diverse, they employ KL divergence between the teacher and student outputs as a discriminator loss. The final loss function is given by:

min1iSmax1jGj=1G(1Si=1SLdiscrαLdfq)\min_{1\le i \le S}\max_{1 \le j \le G} \sum_{j=1}^G\left(\frac{1}{S}\sum_{i=1}^{S}\mathcal{L}_{discr}-\alpha \mathcal{L}_{dfq} \right)
Ldiscr=Eρ(z)[DKL(Tg(z),Sg(z))]\mathcal{L}_{discr}=\mathbb{E}_{\rho(z)}[D_{KL}(T \circ g(z), S\circ g(z) )]
Ldfq=l,c(DKLN(μl,c(x),σl,c2(x)),(μl,c,σl,c2))+Eρ(z)[H(Tg(z))]+H(Eρ(z)[T)g(z)]\mathcal{L}_{dfq} = \sum_{l,c}(D_{KL}^\mathcal{N}(\mu_{l,c}(x), \sigma_{l,c}^2(x)), (\mu_{l,c},\sigma^2_{l,c})) + \\ \mathbb{E}_{\rho(z)}[H(T \circ g(z))] + \\ H(\mathbb{E}_{\rho(z)}[T \circ) g(z)]

Contrastive Model Inversion for Data-Free Knowledge Distillation

CMI [@fag_2021_artificial] starts from the observation that previous methods for data-free knowledge distillation are producing sets of data impressions with a low level of diversity. To address this issue, they propose using a contrastive loss as a diversity indicator and therefore create a method that generates a more diverse set of data impressions. To measure the similarity between data impressions they add a new model h=fTh=f \circ T which projects the output of the teacher given an input xx to a new embedding space and then computes the cosine similarity between the embeddings of the data impressions.

sim(xi,xj)=h(xi)h(xj)h(xi)h(xj)\text{sim}(x_i, x_j) = \frac{h(x_i)^\top h(x_j)}{\left\| h(x_i) \right\| \cdot \left\| h(x_j) \right\|}

Similarly to Adaptative DeepInversion methods, they employ a unified inversion framework to generate data impressions, which consists of a class conditional loss (cover all possible classes), batch normalization loss and an adversarial generation loss (cover space where the teacher and student disagree). However, they optimize a generative model to generate data impressions instead of optimizing the data impressions directly. For the adversarial generation loss, they only consider it when the class predicted by the teacher and the student are the same.

The final loss is given by:

LCMI=LCE(T(x),yi)+λbnLbn+λadvLadv+λctrLctr\mathcal{L}_{CMI} = \mathcal{L}_{CE}(T(x), y_i) + \lambda_{bn} \mathcal{L}_{bn} + \lambda_{adv} \mathcal{L}_{adv} + \lambda_{ctr} \mathcal{L}_{ctr}
Lbn=l(DKLN(μl(x),σl2(x)),(μl,σl2))\mathcal{L}_{bn} = \sum_{l}(D_{KL}^\mathcal{N}(\mu_{l}(x), \sigma_{l}^2(x)), (\mu_{l},\sigma^2_{l}))
Ladv=DKL(T(x)S(x))1{argmaxT(x)=argmaxS(x)}\mathcal{L}_{adv} = -D_\text{KL}(T(x) || S(x)) \cdot \mathbb{1} \{\text{arg} \text{max} T(x) = \text{arg} \text{max} S(x)\}
Lctr=1Ni=1Nexp(sim(xi,x~i))j=1Nsim(xi,xj))\mathcal{L}_{ctr} = \frac{1}{N} \sum_{i=1}^N \frac{\text{exp}(\text{sim}(x_i, \tilde{x}_i))}{\sum_{j=1}^N \text{sim}(x_i, x_j))}

Up to 100× Faster Data-Free Knowledge Distillation

FastDFKD (Fang et al., 2022) is a method that reuses the features of the generator using meta updates, which allows them to generate data impressions faster, unlike previous methods that require learning to generate data impressions every batch. This results in a large speedup in the generation of data impressions which in some cases, according to the authors, can be up to 100x faster.

Conclusion

In this post, I've reviewed the current state-of-the-art in data-free knowledge distillation. While many methods have been proposed, I'm still not convinced that data-free knowledge distillation is a viable solution to distil knowledge from large models in real-world applications. Mainly because the datasets used to benchmark these models are very small and do not represent the complexity of real-world datasets, moreover the resolution of the images is very low. The comparison between the different methods is also not clear and it's difficult to understand which method is better than the others and the reproducibility of the results and availability of the code are also lacking in some cases.

Another problem with this field is the lack of motivation for the methods proposed. While all argue that these methods might be useful in scenarios where the data is not available, I'm not convinced that this is a problem most times. I would argue that if in some cases the data is not available it is still possible to find a similar dataset that can be used as a starting point for methods like these to enrich those datasets. Finally, we will see in the coming months the publication of new methods that will leverage the recent advances in diffusion models to generate data impressions.

References

Chen, H., Wang, Y., Xu, C., Yang, Z., Liu, C., Shi, B., Xu, C., Xu, C., & Tian, Q. (2019). Data-Free Learning of Student Networks. 2019 IEEE/CVF International Conference on Computer Vision (ICCV), 3513–3521.
Choi, Y., Choi, J. P., El-Khamy, M., & Lee, J. (2020). Data-Free Network Quantization With Adversarial Knowledge Distillation. 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), 3047–3057.
Fang, G., Mo, K., Wang, X., Song, J., Bei, S., Zhang, H., & Song, M. (2022). Up to 100x Faster Data-free Knowledge Distillation. AAAI Conference on Artificial Intelligence.
Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the Knowledge in a Neural Network.
Lopes, R. G., Fenu, S., & Starner, T. (2017). Data-Free Knowledge Distillation for Deep Neural Networks.
Micaelli, P., & Storkey, A. (2019). Zero-shot Knowledge Transfer via Adversarial Belief Matching. arXiv. https://doi.org/10.48550/ARXIV.1905.09768
Nayak, G. K., Mopuri, K. R., Shaj, V., Babu, R. V., & Chakraborty, A. (2019). Zero-Shot Knowledge Distillation in Deep Networks.
Wang, Z. (2021). Data-Free Knowledge Distillation with Soft Targeted Transfer Set Synthesis.
Yim, J., Joo, D., Bae, J., & Kim, J. (2017). A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning. 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 7130–7138. https://doi.org/10.1109/CVPR.2017.754
Yin, H., Molchanov, P., Li, Z., Alvarez, J. M., Mallya, A., Hoiem, D., Jha, N. K., & Kautz, J. (2019). Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion. arXiv. https://doi.org/10.48550/ARXIV.1912.08795