Next Article in Journal
A CP-ABE Scheme Based on Lattice LWE and Its Security Analysis
Previous Article in Journal
Analysis of the Influence and Propagation Law of Urban Rail Transit Disruptions: A Case Study of Beijing Rail Transit
 
 
Font Type:
Arial Georgia Verdana
Font Size:
Aa Aa Aa
Line Spacing:
Column Width:
Background:
Article

Discriminator-Enhanced Knowledge-Distillation Networks

1
Chengdu Institute of Computer Applications, Chinese Academy of Sciences, Chengdu 610041, China
2
School of Computer Science and Technology, University of Chinese Academy of Sciences, Beijing 100049, China
3
School of Electrical and Electronic Engineering, Nanyang Technological University, 50 Nanyang Avenue, Singapore 639798, Singapore
4
Key Laboratory of Advanced Manufacturing Technology, Ministry of Education, Guizhou University, Guiyang 550025, China
*
Author to whom correspondence should be addressed.
These authors contributed equally to this work.
Appl. Sci. 2023, 13(14), 8041; https://doi.org/10.3390/app13148041
Submission received: 7 June 2023 / Revised: 29 June 2023 / Accepted: 8 July 2023 / Published: 10 July 2023
(This article belongs to the Section Computing and Artificial Intelligence)

Abstract

:
Query auto-completion (QAC) serves as a critical functionality in contemporary textual search systems by generating real-time query completion suggestions based on a user’s input prefix. Despite the prevalent use of language models (LMs) in QAC candidate generation, LM-based approaches frequently suffer from overcorrection issues during pair-wise loss training and efficiency deficiencies. To address these challenges, this paper presents a novel framework—discriminator-enhanced knowledge distillation (Dis-KD)—for the QAC task. This framework combines three core components: a large-scale pre-trained teacher model, a lightweight student model, and a discriminator for adversarial learning. Specifically, the discriminator aids in discerning generative-level differences between the teacher and the student models. An additional discriminator score loss is amalgamated with the traditional knowledge-distillation loss, resulting in enhanced performance of the student model. Contrary to the stepwise evaluation of each generated word, our approach assesses the entire generation sequence. This method alleviates the prevalent overcorrection issue in the generation process. Consequently, our proposed framework boasts improvements in model accuracy and a reduction in parameter size. Empirical results highlight the superiority of Dis-KD over established baseline methods, with the student model surpassing the teacher model in QAC tasks for sub-word languages.

1. Introduction

Over the past decade, deep learning has provided a more powerful approach for constructing high-performance large language models (LLMs), driving advancements in the fields of natural language processing (NLP) and natural language generation (NLG), surpassing the anticipated boundaries. The current paradigm of deep learning research typically involves utilizing large model architectures and training them on massive amounts of data to achieve superior results in specific tasks. Query auto-completion is a foundational task in natural language processing. QAC systems aim to provide relevant and accurate suggestions to enhance the user experience when entering queries in the search box, thereby minimizing the time and workload of queries. QAC is an important function of contemporary search engines, providing completion suggestions when users enter a query in the search box, that is, providing the top k options with the highest score that match query Q in the candidate string set S. QAC belongs to the automatic completion task in text generation. Although large-scale neural networks have shown empirical success in text generation, there are still two main drawbacks. However, the deployability and sustainability of these technologies for real-world applications have not received commensurate attention [1]. The computational performance issues and high cost associated with BERT [2] have posed challenges to supporting high-queries-per-second (QPS) intelligent systems [3]. As an illustration, the operational costs of running ChatGPT by OpenAI are estimated to be approximately CNY 3 million per month.This significant investment reflects the extensive computational resources and infrastructure required to support the complex architecture and training requirements of the model. Second, most models often suffer from an over-correction problem, where deviations from the ground truth sequence result in immediate correction by the cross-entropy loss. In detail, models usually compute loss based on a strict pairwise matching between the predicted words and the ground truth. Once the model generates a word deviating from the ground truth sequence, the cross-entropy loss will correct the error immediately and draw the remaining generation back to the ground-truth sequence. However, a sequence usually has multiple reasonable representations, and it cannot be said that the model makes a mistake, even if it generates a word different from the ground truth word. A reasonable choice is to judge based on the entire sequence. An example is illustrated in Appendix A.1.
With the aim of developing high-precision and low-computational resource approaches for large LLMs in the inference stage, we propose a novel discriminator-enhanced knowledge-distillation framework. Our framework aims to optimize the trade-off between model accuracy and computational efficiency and leverages the discriminator network to provide additional information during the distillation process. By incorporating the discriminator network, our approach aims to minimize the loss of information during the knowledge-distillation process, thus achieving high precision while keeping computational costs low.
Therefore, this paper focuses on the auto-completion task in neural-based text generation. Specifically, QAC systems provide completion suggestions to a user when they type a query in the search box. This is an essential feature for search engines and involves providing top-k-scored completions that match query Q in a collection S of scored strings.
Large neural networks or other distillation learning methods have shown empirical success in text generation. However, most of them are optimized using strict pairwise matching at the word level, which has two major drawbacks.
First, such methods often suffer from an over-correction problem, where deviations from the ground truth sequence result in the immediate correction by the cross-entropy loss [4]. In detail, models usually compute loss based on strict pairwise matching between the predicted words and the ground truth. Once the model generates a word deviating from the ground truth sequence, the cross-entropy loss will correct the error immediately and draw the remaining generation back to the ground truth sequence. However, a sequence usually has multiple reasonable representations, and it cannot be said that the model makes a mistake, even if it generates a word different from the ground truth word. A reasonable choice is to judge based on the entire sequence. An example is illustrated in Appendix A.1. Second, popular generation models require larger models with deeper neural networks and greater computational resources to achieve better performance [2,5,6], limiting its application in low-computation-resource situations.
This study endeavors to conceptualize a novel methodology aimed at augmenting the efficiency and speed of language models by fusing a comprehensive sentence evaluation module, predicated on the discriminator concept inherent in adversarial neural networks. The core objectives of the present research are encapsulated below:
(1)
The study seeks to pioneer a shift in the analysis of sentence quality by focusing on a holistic sentence-level evaluation, diverging from conventional pairwise methodologies. This is intended to rectify prevalent over-correction issues observed within traditional models.
(2)
A primary objective lies in the formulation of a novel knowledge-distillation framework, aiming to engineer a more compact, efficient, and enhanced student network. This is envisaged to be achieved via a unique distillation scaffold that incorporates an additional discriminator.
(3)
Capitalizing on the discriminator signal is another key target. To accomplish this, we intend to incorporate the policy gradient from reinforcement learning [7], thereby overcoming constraints linked with the utilization of discrete signals during the back-propagation process within natural language processing tasks.

2. Related Work

In recent years, neural networks have achieved remarkable success in generating natural language text. Among various text generation tasks, question-answering completion is a common one that is employed by web search engines to suggest the most frequently used completions to users, a process called most popular completion (MPC) [8]. In the QAC task, both the accuracy and the inference speed of the model are critical factors that need to be considered, yet few prior studies have addressed these two factors simultaneously. In the following sections, we will review the methods for enhancing accuracy and the compression techniques in the QAC task.
Traditional methods for QAC accomplish the task in two steps. The first step involves extracting candidates to increase recall. MPC [8] expands each query to a high-dimensional feature vector and uses cosine similarity to find the top N candidates from historical queries with the same prefix. Other approaches generate query suggestions from additional corpora to offer more accurate recommendations [9,10]. The second step involves re-ranking the candidates with additional features to increase precision [11]. Several kinds of additional information have been exploited for this task, including context or session information [8,12], personalization [8,13], time/popularity sensitivity [8,14,15], user behaviors [16,17,18,19], and click-through logs [20]. To overcome traditional methods’ insufficient feature extraction ability, neural network methods have been adopted in recent years. Neural language models can seamlessly incorporate additional features, such as user ID embeddings to model personalization [21,22]. Time-aware [11] and spelling-error-aware models [23] have also been developed under this framework to increase the accuracy rate. Deep neural networks, such as ELMo [24] and BERT, learn high-quality, deep contextualized word representations using deep neural models that can be directly integrated into various tasks for performance boosting. However, these models contain millions of parameters, making their application in practice difficult when computational resources are limited.
Knowledge distillation [25] is a promising framework for reducing the size of models. Several methods have been developed to leverage linguistic knowledge to make the distillation process more informative. Cui et al. proposed a method, where the student layer can learn from multiple consecutive teacher layers to obtain more comprehensive supervision information [26]. A soft-attention mechanism is used to determine the impact of different layers on the learning process. Similarly, Jiao et al. developed three types of loss functions that encourage the effective transfer of linguistic knowledge from teacher layers to the student model [27]. The losses are applied to different parts of the network, including the embedding layer, the transformer layer, and the prediction layer. The ultimate goal of these knowledge-distillation methods is to enable the student model to accurately mimic the characteristics of the teacher model, such as its hidden representations [28], output probabilities [25], or even the generated sentences themselves [29].
In contrast to the previous approaches, this paper proposes a novel approach to simultaneously address both critical and challenging problems by introducing a discriminator into the knowledge-distillation framework. Notably, our framework offers an easy extension to the existing distillation methods mentioned above.
In this section, we present the details of Dis-KD. The overall architecture of our model is illustrated in Figure 1, which primarily comprises two components:
Our model comprises two main components: (1) a conventional knowledge-distillation framework based on GPT-2, which is known for its superior ability to model the relationships among words and its expressive power, and (2) a simple LSTM network in the discriminator that evaluates the quality of the generated sequence. To train discrete probabilistic models, a novelty loss is designed in this component, which is simplified from Monte Carlo tree search (MCTS) since such models cannot be described as a differentiable operation. The new loss provides two benefits: (1) it allows the generator to fully utilize the knowledge of the entire generated sequence’s assessment information to improve itself, and (2) the discriminator’s results can guide distillation. In contrast to learning based solely on words, this method encourages the student model to learn information about the entire sequence generated by the teacher network. The overall architecture of our model is presented in Figure 1.

2.1. Details of the Model

In this section, we introduce a standard knowledge-distillation method for compressing GPT-2 and present our proposed Dis-KD in detail.
Problem Definition The original large teacher network is represented by a function f ( x ; θ ) , where the x is the input to the network, and θ denotes the model parameters. The goal of knowledge distillation is to learn a new set of parameters θ such that the student network can achieve similar performance to the teacher, with a much lower computational cost. The traditional strategy is to force the student model to imitate outputs from the teacher model on the training dataset with a KL divergence. We add an additional discriminator signal to enhance this process.
In our experimental configuration, we employ a teacher network denoted as f ( x ; θ t ) , a deep language model, such as GPT-2. On the other hand, the student network g ( x ; θ s ) is a lighter model comprising fewer layers. We utilize the QAC dataset, which contains a set of sequences denoted as { w 1 , w 2 , , w i } . Here, each w i corresponds to a word, subword, or character in the sequence. The input instance for GPT-2 is constructed as x = { w 1 , w 2 , , w i 1 } , while the corresponding ground truth for the language modeling task is obtained as y = { w 2 , w 3 , . . . , w i } .
The input sequence is fed into the GPT-2 model, which generates a contextualized embedding denoted as h i = G P T 2 ( w i x i 1 ) R d , where d represents the dimension of the contextualized embedding. Furthermore, the model computes the output probability of the last words based on the preceding sequence using a softmax layer. W is a weight matrix to be learned:
P ( w i | x i 1 ) = S o f t m a x ( W ( h 1 , , h i 1 ) )
To implement knowledge distillation, we begin by training a teacher network. For instance, in the case of training a 6-layer GPT-2 model as the teacher network, the parameters of the teacher model are represented by the superscript t in the following equation. The objective during training is to maximize the likelihood of the equation given below:
L t ( y ) = i log P ( w i w i k , , w i 1 ; θ t )
Upon training the teacher network, we employ it to guide the student network. However, while using the teacher model for inference to guide the student model, we need to modify Equation (2) as follows:
P t ( w i | x i 1 ) = s o f t m a x ( W ( h 1 , , h i 1 ) T )
Here, P t ( · | · ) represents the probability output generated by the teacher model. This probability distribution is considered the soft label of the student model and is fixed during the knowledge-distillation process. Additionally, we use temperature parameter T during knowledge distillation, which controls the extent to which the student model relies on the teacher’s soft predictions. A higher value of temperature leads to a more diverse probability distribution over classes as suggested in previous studies [25].
To evaluate the difference between the probability distribution P t ( w i | x i 1 ) generated by the teacher model and the corresponding distribution P s ( w i | x i 1 ) produced by the student model, we employ cross-entropy loss. This approach encourages the student model to imitate the behavior of the teacher model.
Let θ s and θ t represent the parameters of the student and teacher models, respectively. Moreover, P s ( · | · ) and P t ( · | · ) denote the corresponding probability outputs from the student and teacher models, respectively. Additionally, N represents the length of the sequence, and W is the vocabulary size of the model.
Therefore, we define the distance between the predictions of the teacher and student models as follows:
L C E = n [ N ] w [ W ] [ P t ( w i | x i 1 ; θ t ) ] · l o g P s ( w i | x i 1 ; θ s )
To enhance the effectiveness of knowledge distillation, we aim to develop a discriminator that evaluates the entire sentence rather than individual words. To achieve this goal, we introduce an encoder, such as a bidirectional LSTM network, which reads the sequence of contextualized vector representations h i generated by the language model (either the student or the teacher GPT-2).
The objective of the discriminator is to differentiate between h i generated by the student and teacher models. We aim to train the student model to generate representations that can successfully fool the discriminator. We define the signal received from the discriminator as s:
s i = s i g m o i d ( L S T M ( h 1 , , h i ) )
Taking inspiration from Sutton et al. [30], we simplified the reward computation process in SeqGAN [7] by employing a Monte Carlo search with a roll-out policy to calculate the reward at each time step. In our proposed framework, we use the discriminator not only to evaluate the entire sequence but also to assess sub-sequences generated after each word. For instance, in the sequence “Say hello world”, we evaluate “Say”, “Say hello”, and “Say hello world”. Furthermore, we consider evaluating the potential generation probability of each word. The reward for the entire sequence is then defined as
S c o r e = i = 1 N 1 i P t ( w i | x i 1 ; θ ) ( 1 s i )
The training of the proposed framework involves two steps of backpropagation.
In the first step, the discriminator is optimized as a binary classifier. Once the discriminator is trained, Equation (6) is used to calculate the score of the student’s output and the teacher’s output. A smaller difference between the scores indicates that the output of the student model more closely resembles that of the teacher model. To optimize the generation of the entire sequence, the loss function is defined as follows:
L S c o r e = ( y l o g ( y l a b l e + ( 1 y ) l o g ( 1 y l a b e l ) ) )
The discriminator’s output is designed to yield a value of 0 if the input is determined to be the output of the teacher network, and 0 if it is deemed the output of the student network. To address the simple dichotomous task, we employ Equation (7) as a straightforward cross-entropy loss measure to optimize the discriminator.
Next, we go back to training the discriminator, and this cycle iterates several times. Even though the policy is simplified, the 2-layer distilled model outperforms the 6-layer teacher model. The data flow is described in Appendix A.2. The additional loss used for the student model is the language model loss, which is the same as the loss used in GPT:
L L M = 1 N i = 1 N l o g p ( w i | x i 1 )
Combined with the traditional knowledge-distillation loss, the final objective function can be formulated as follows:
L D i s K D = ( 1 α ) L C E S + α L L M + β L S c o r e
where α and β are hyper-parameters that balance the importance of different objective functions; α is 0.6 and β is 0.35 in our model.

2.2. Train Process

The learning procedure of the model comprises three stages, which are repeated multiple times: learning, distinguishing, and cheating. The first stage is divided into two parts: initially learning in a traditional knowledge-distillation manner, and subsequently learning with the additional discriminator loss. In the second stage, the discriminator distinguishes the generated sequence from the teacher or student and then evaluates it. In the third stage, the updated student model attempts to deceive the discriminator and improve its performance by Equation (7). This approach can be considered a hybrid of learning and cheating, as the generator learns from the discriminator and then seeks to outwit it.

3. Experiments

This section elucidates the experimental application of our proposed Dis-KD approach on the query auto-completion (QAC) task and provides comparative analyses with direct QAC baselines and knowledge-distillation baselines. Detailed insights into the baselines, experimental results, and ablation studies are sequentially furnished in the subsequent subsections.

3.1. Dataset and Evaluation Metric

Our experiments were conducted on the AOL dataset [31], following the protocol delineated in [32]. In particular, the dataset comprises 17,521,031 training queries, 1,521,971 validation queries, and 1,317,632 test queries. Notably, the test dataset includes 670,810 seen queries and 646,822 unseen queries, with the latter constituting almost half of the test data.
In our context, ’seen’ data refer to test data that have appeared in the training dataset, while ’unseen’ data signify test data that have never been featured in the training output. It should be noted that this arrangement does not induce tag leakage issues in QAC, owing to its distinct task definition.
In our experiment, if the query ‘ask.com’ is included in the test set, its presence or absence in the training dataset classifies it as either seen or unseen. Our primary task is to utilize the input subword to generate a complete query. For instance, during the model testing process, if a user enters the subword ‘ask’, the expected completion is ‘ask.com’, which may be generated from various candidate words, such as ‘ask jeeves’ or ‘ask.com’. Several queries starting with ‘ask’ can be found in the training data. We expect the model to rank candidate words in accordance with our anticipated order. To gauge the concordance of such expectations, we employ the mean reciprocal rank (MRR) metric, which is discussed in further detail later.
The AOL dataset comprises five distinct variables for examination: AnonID, Query, QueryTime, ItemRank, and ClickURL. These variables can inform diverse analytical models. Table 1 illustrates a representative data sample. Nonetheless, for this study, we limit our focus solely to the query column as input to our analytical model. For instance, ’ask.com’ is the data used for model training.
In our study, the BPE (byte pair encoding) tokenizer was utilized to preprocess the language model input as proposed by Sennrich et al. [33]. Specifically, the input text “ask.com” was tokenized into three subwords, namely “ask”, “.”, and “com”. Our model requires the input text to satisfy a minimum subword criterion. For example, in the case of “ask.com”, the input text used to test our model is “ask”, rather than “as” or “a”. This criterion was consistently applied to the input data in the test set. We used the model to generate ten candidate words to calculate the mean reciprocal rank (MRR) [34]. We show the output of the model in Table 5 and compare the output of other models.
The dataset used in our experiments was accumulated over three months, with the final week used for testing and the preceding week used for validation. During data preprocessing, we removed non-ASCII characters and transformed uppercase characters into lowercase. We implemented a technique to replace multiple consecutive spaces with a single space, and excised any leading or trailing spaces. Additionally, we identified and consolidated duplicates submitted by the same user for an identical query that was adjacently located. To ensure the rigor of our analysis, we also filtered out queries shorter than three characters or longer than forty. These preprocessing steps were instrumental in enhancing the quality of our dataset, thereby bolstering the validity of our subsequent analysis. Detailed information is in Table 2:
In the evaluation of the QAC system, we consider two key metrics: accuracy and time. Accuracy reflects the degree to which the model’s suggestions match the user’s intended queries, and better performance in this regard leads to a more satisfactory user experience. To quantify the auto-completions, we use two methods: mean reciprocal rank (MRR) [34] and partial-matching MRR (PMRR) [35]. The completion speed is measured in QPS. For all of the aforementioned metrics, higher scores indicate better performance.
The MRR is a widely used metric in evaluating the effectiveness of QAC systems. To determine the MRR of a given QAC system m, it is computed using a test dataset Q t e s t as follows:
M R R ( m ) = 1 Q t e s t q Q t e s t R R ( q , m ( p ) )
Here, p denotes a prefix of a query q , and m ( p ) represents the ranked list of candidate completions of p from the QAC system m. The reciprocal rank of q is calculated using the function RR, which returns the reciprocal rank of q if it appears in m ( p ) and, otherwise, 0, and our model generates ten candidates. In our test dataset, the Q t e s t is 670, 810 and 646, 822, respectively, corresponding to seen and unseen test datasets.
For instance, consider a question-answering task, where the model is required to return the correct answer based on a given question. In this task, the model generates multiple candidate answers and ranks them. Suppose the query is “Who is the President of the United States?” and the correct answer is “Biden.” If the model ranks “Biden” in the first position, then the MRR score is 1, as the rank of this answer is the reciprocal of 1. If the model ranks “Biden” in the second position, then the MRR score is 0.5, as the rank of this answer is the reciprocal of 0.5. If the model does not rank the correct answer within the top two positions, then the MRR score is 0. By computing the average MRR score across all queries, the performance of the model can be evaluated.

3.2. Implementation Details

In our experiments, the QAC system generates a total of N = 10 completion candidates using beam search with width B = 30 . The models are trained for 30 epochs using the AdamW optimizer with a learning rate of 5 × 10 5 and a batch size of 1024. The teacher model is a GPT-2 with six multi-head attention (MHA) layers, while the student model has two MHA layers. The model parameters follow the settings in Table 3. All evaluations were performed on an Intel Xeon Gold 5122 processor equipped with an NVIDIA Quadro GV100. The best value in each column is highlighted in bold. To better demonstrate the effectiveness of Dis-KD, we compare it against several strong QAC models.
Our model architecture comprises three components. The student network consists of two layers of GPT-2 with an embedding of 768 alongside a linear neural network with an embedding of 768; the difference in the teacher network is that the teacher network consists of six layers of GPT-2. In addition, we employ a bidirectional LSTM network, a linear layer, a dropout layer, another linear layer, and an activation function as the discriminator component, with an embedding of 256.

3.3. Baselines and Results

We present a comprehensive comparison of our proposed Dis-KD approach with several baseline models in two aspects: QAC systems and knowledge distillation. We consider the following strong QAC models for comparison: LWG [36] and MCG [37], which are both generation models ranked by frequency; CLSM [38] and LSTM, which are commonly used for optimizing the ranking and generated by MCG; and neural query language model (NQLM) [35], an end-to-end approach for both generation and ranking based on LSTM. Table 4 presents the performance of different QAC systems, where our proposed Dis-KD approach outperforms all the baselines. Also, we introduce topic models as query generators, pairwise coupled topic model (PCTM) [39] based on latent Dirichlet allocation (LDA) [40] and a topic-based LSTM model [41].
In Table 4, we present the results of our evaluation of QAC systems based on the word level. We first compare the performance of a 2-layer GPT-2 model with other types of models and find that, not surprisingly, it performs worse than the vanilla LSTM model. We hypothesize that deeper neural network language models could learn better word representations and yield better performance. To further investigate, we repeat the experiments using a 6-layer GPT-2 model. We observe that the 6-layer GPT-2 model outperforms all other models in the baseline. To demonstrate the effectiveness of our proposed framework, we add the discriminator loss and introduce a 2-layer GPT-2 model into the distillation framework. The results show that, after learning from the teacher model in our framework, the 2-layer GPT-2 student model’s performance not only far outperforms the original model but also outperforms the 6-layer teacher model; we show all the results in Figure 2.
In the context of text completion or prediction tasks, the PCTM demonstrates exceptional performance on seen data. However, one of the limitations of PCTM, and topic models in general, is their inability to be effectively transferred to scenarios involving unknown query completion. This means that when encountering out-of-vocabulary words, such as “Dis-KD” in a dataset not seen during training, PCTM is unable to generate accurate completions. Consequently, while PCTM may be effective in certain applications, its reliance on known datasets constrains its applicability in scenarios where unseen data may be encountered.
In the area of knowledge distillation, various approaches have been proposed to train smaller models with comparable performance to larger models. One such approach is Distil-BERT [42], which employs a triple loss that utilizes supervision from a bigger language model to distill a smaller one. This framework has been shown to achieve performance similar to that of the larger model. Another approach is Patient-KD [43], which utilizes the output from the last layer of the teacher network for distillation and enables the student model to patiently learn from multiple intermediate layers of the teacher model for incremental knowledge extraction. In addition, we add an example in Table 5, respectively showing the recommended candidate words under different methods.
Table 5. Results of candidates by a query “resident”, whose label is “resident evil”.
Table 5. Results of candidates by a query “resident”, whose label is “resident evil”.
Dis-KDLWG+FrequencyMCG+LSTMLSTMMCG
resident evilresidentialresident.comresidentialresidential
resident innresidentsresident bushresidentsresidential zone
residentialresident evilresident evilresident evilresident.com
resident evilresident filmsresident filmsresident alienresident card
resident evilresident alienresident movieresident movieresident film
As few works have directly applied knowledge distillation to GPT-2, we re-implement the above framework and transfer the BERT-based knowledge distillation to a GPT-2-based framework with the same student modules. Table 6 summarizes results for the knowledge-distillation baselines.
Our experiments demonstrate the efficacy of distillation as a means of transferring knowledge from a highly regularized, ensemble model or a large-scale model to a smaller, distilled model. Compared to DisGPT and Patient networks, our proposed approach presents a highly innovative discriminator framework that effectively mitigates the issue of over-correction. Notably, by eliminating the discriminator component, our network architecture remains identical to that of DisGPT, thus underscoring the substantial gains attributable to our novel discriminator concept in enhancing model performance.
We contrast the inference speed of our framework with that of the baselines in Table 7. Our framework takes only a third of the run time of the original framework on the same device during inference.

4. Conclusions

In this study, we introduce the discriminator-enhanced knowledge-distillation (Dis-KD) approach, a novel method for transferring knowledge from a teacher model to a student model, with the aim of improving model performance. This is achieved by utilizing evaluation results from a discriminator.
Our experimental results on the QAC task demonstrate that Dis-KD significantly outperforms the baseline methods, with the distilled two-layer student model even surpassing the six-layer teacher model, and our method is easy to optimize and can be combined with other methods to consistently improve performance. And our work has several main contributions:
(1)
We propose Dis-KD, a novel discriminator-enhanced knowledge-distillation framework, which can both enhance model accuracy and reduce parameter size.
(2)
To make the discriminator’s signal suitable for generation tasks, we introduce an easy-to-implement discriminator loss, as direct training on the signal from the discriminator is not differentiable.
(3)
Our approach involves intelligently leveraging the loss of the discriminator as an evaluation signal for the entire sentence. By adopting this method, we can effectively overcome the over-correction problem, thereby significantly reducing the model size nearly threefold and improving inference by the same factor. Furthermore, it exceeds the performance of the original GPT-2 model in terms of the mean reciprocal rank (MRR).
The innovative distillation framework integrates a discriminator—conceptualized as a long short-term memory (LSTM) network—from a generative adversarial network (GAN). The discriminator signal, represented as a scalar output from the LSTM, serves two crucial functions. First, it assesses the complete output, thereby mitigating the risk of overcorrection. Second, it calculates the loss between the student and teacher networks, essentially functioning as an auxiliary loss to assist in the training of the student network. Consequently, the student network’s output should ideally closely mimic that of the teacher network, as both should elicit the same signal from the discriminator.
The application of this method not only compresses the model size but also empowers the student network to achieve superior performance using fewer parameters than the teacher network. The proposed approach presents a promising strategy for enhancing the training of deep neural networks, with potential applicability across diverse domains.

Limitations and Future Work

The primary limitation of Dis-KD, as observed in this study, is the relatively slow reasoning speed during the execution of query auto-completion. This could impact its ability to provide real-time suggestions to users, especially in environments characterized by high volume or time sensitivity. Therefore, efforts focused on enhancing the computational efficiency of the Dis-KD reasoning mechanisms would likely improve its practical utility and broaden its appeal across various domains. In specific large-scale internet applications, such as Alibaba and LinkedIn, deep learning methodologies have not yet been extensively implemented in actual production. It is noteworthy that the proposed method demonstrates a slower inference speed than the most commonly generated (MCG) approach, with an approximate speed difference of 145 times.
Our investigation paves the way for further research in several domains. Firstly, while our results are encouraging, we have not delved into the possibility of extending our methodology to include information beyond the immediate query. Incorporating elements such as users’ behavioral history could yield additional insights into user preferences and interests, extending beyond the context available just prior to a search [44].
The distillation framework presented herein offers potential applications across a broad spectrum of scenarios, beyond the realm of query auto-completion. It could be effectively employed in areas such as translation, entity relationship extraction, and dialogue systems, among others. We anticipate that the proposed framework will serve as a launchpad for further exploration of its applicability in diverse contexts.
Future research could focus on investigating how this framework can be deployed to enhance the objectives of various applications, with the goal of widening its impact and facilitating its integration into practical use cases. It will be intriguing to observe the ways in which this framework contributes to advancing the frontier of machine learning and artificial intelligence.

Author Contributions

Z.L.: Software, Methodology, Writing—original draft. Z.C.: Methodology, Writing—review and editing, Validation. P.L.: Investigation, Writing—review and editing, Validation. Y.Z.: Writing—review and editing. S.L.: Writing—review and editing, Supervision, Project administration. All authors have read and agreed to the published version of the manuscript.

Funding

This work was supported by the AI industrial technology innovation platform of Sichuan Province, grant number 2020ZHCG0002.

Institutional Review Board Statement

Not applicable.

Informed Consent Statement

Not applicable.

Data Availability Statement

Conflicts of Interest

The authors declare no conflict of interest.

Appendix A

Appendix A.1

Table A1. The example of over-right.
Table A1. The example of over-right.
reference:We should comply with the rule.
cand1:We should abide with the rule.
cand2:We should abide by the law.
cand3:We should abide by the rule.

Appendix A.2

Figure A1. The framework of Dis-KD. It consists of three major modules: a language module, a discriminator module, and a fusion loss module.
Figure A1. The framework of Dis-KD. It consists of three major modules: a language module, a discriminator module, and a fusion loss module.
Applsci 13 08041 g0a1

References

  1. Singh, P.; De Clercq, O.; Lefever, E. Distilling Monolingual Models from Large Multilingual Transformers. Electronics 2023, 12, 1022. [Google Scholar] [CrossRef]
  2. Devlin, J.; Chang, M.W.; Lee, K.; Toutanova, K. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv 2018, arXiv:1810.04805. [Google Scholar]
  3. Guo, S.; Wang, Q. Application of Knowledge Distillation Based on Transfer Learning of ERNIE Model in Intelligent Dialogue Intention Recognition. Sensors 2022, 22, 1270. [Google Scholar] [CrossRef] [PubMed]
  4. Zhang, W.; Feng, Y.; Liu, Q. Bridging the gap between training and inference for neural machine translation. In Proceedings of the Twenty-Ninth International Conference on International Joint Conferences on Artificial Intelligence, Yokohama, Japan, 7–15 January 2021; pp. 4790–4794. [Google Scholar]
  5. Vaswani, A.; Shazeer, N.; Parmar, N.; Uszkoreit, J.; Jones, L.; Gomez, A.N.; Kaiser, Ł.; Polosukhin, I. Attention is all you need. In Proceedings of the Advances in Neural Information Processing Systems, Long Beach, CA, USA, 4–9 December 2017; pp. 5998–6008. [Google Scholar]
  6. Maimaiti, M.; Liu, Y.; Luan, H.; Sun, M. Enriching the transfer learning with pre-trained lexicon embedding for low-resource neural machine translation. Tsinghua Sci. Technol. 2021, 27, 150–163. [Google Scholar] [CrossRef]
  7. Yu, L.; Zhang, W.; Wang, J.; Yu, Y. Seqgan: Sequence generative adversarial nets with policy gradient. In Proceedings of the AAAI Conference on Artificial Intelligence, San Francisco, CA, USA, 4–9 February 2017; Volume 31. [Google Scholar]
  8. Bar-Yossef, Z.; Kraus, N. Context-sensitive query auto-completion. In Proceedings of the 20th International Conference on World Wide Web, Hyderabad, India, 28 March–1 April 2011; pp. 107–116. [Google Scholar]
  9. Bhatia, S.; Majumdar, D.; Mitra, P. Query suggestions in the absence of query logs. In Proceedings of the 34th International ACM SIGIR Conference on Research and Development in Information Retrieval, Beijing China, 24–28 July 2011; pp. 795–804. [Google Scholar]
  10. Maxwell, D.; Bailey, P.; Hawking, D. Large-scale generative query autocompletion. In Proceedings of the 22nd Australasian Document Computing Symposium, Brisbane, QLD, Australia, 7–8 December 2017; pp. 1–8. [Google Scholar]
  11. Cai, F.; Liang, S.; de Rijke, M. Prefix-adaptive and time-sensitive personalized query auto completion. IEEE Trans. Knowl. Data Eng. 2016, 28, 2452–2466. [Google Scholar] [CrossRef]
  12. Jiang, J.Y.; Ke, Y.Y.; Chien, P.Y.; Cheng, P.J. Learning user reformulation behavior for query auto-completion. In Proceedings of the 37th International ACM SIGIR Conference on Research & Development in Information Retrieval, Gold Coast, QLD, Australia, 6–11 July 2014; pp. 445–454. [Google Scholar]
  13. Shokouhi, M. Learning to personalize query auto-completion. In Proceedings of the 36th international ACM SIGIR conference on Research and development in information retrieval, Dublin, Ireland, 28 July–1 August 2013; pp. 103–112. [Google Scholar]
  14. Shokouhi, M.; Radinsky, K. Time-sensitive query auto-completion. In Proceedings of the 35th International ACM SIGIR Conference on Research and Development in Information Retrieval, Portland, ON, USA, 12–16 August 2012; pp. 601–610. [Google Scholar]
  15. Whiting, S.; Jose, J.M. Recent and robust query auto-completion. In Proceedings of the 23rd International Conference on World Wide Web, Seoul, Republic of Korea, 7–11 April 2014; pp. 971–982. [Google Scholar]
  16. Hofmann, K.; Mitra, B.; Radlinski, F.; Shokouhi, M. An eye-tracking study of user interactions with query auto completion. In Proceedings of the 23rd ACM International Conference on Conference on Information and Knowledge Management, Shanghai, China, 3–7 November 2014; pp. 549–558. [Google Scholar]
  17. Li, Y.; Dong, A.; Wang, H.; Deng, H.; Chang, Y.; Zhai, C. A two-dimensional click model for query auto-completion. In Proceedings of the 37th International ACM SIGIR conference on Research & Development in Information Retrieval, Gold Coast, QLD, Australia, 6–11 July 2014; pp. 455–464. [Google Scholar]
  18. Mitra, B.; Shokouhi, M.; Radlinski, F.; Hofmann, K. On user interactions with query auto-completion. In Proceedings of the 37th International ACM SIGIR conference on Research & Development in Information Retrieval, Gold Coast, QLD, Australia, 6–11 July 2014; pp. 1055–1058. [Google Scholar]
  19. Zhang, A.; Goyal, A.; Kong, W.; Deng, H.; Dong, A.; Chang, Y.; Gunter, C.A.; Han, J. Adaqac: Adaptive query auto-completion via implicit negative feedback. In Proceedings of the 38th International ACM SIGIR conference on Research and Development in Information Retrieval, Santiago, Chile, 9–13 August 2015; pp. 143–152. [Google Scholar]
  20. Li, L.; Deng, H.; Dong, A.; Chang, Y.; Baeza-Yates, R.; Zha, H. Exploring query auto-completion and click logs for contextual-aware web search and query suggestion. In Proceedings of the 26th International Conference on World Wide Web, Perth, Australia, 3–7 April 2017; pp. 539–548. [Google Scholar]
  21. Jiang, D.; Chen, W.; Cai, F.; Chen, H. Neural attentive personalization model for query auto-completion. In Proceedings of the 2018 IEEE 3rd Advanced Information Technology, Electronic and Automation Control Conference (IAEAC), Chongqing, China, 12–14 October 2018; pp. 725–730. [Google Scholar]
  22. Jaech, A.; Ostendorf, M. Personalized language model for query auto-completion. arXiv 2018, arXiv:1804.09661. [Google Scholar]
  23. Wang, P.W.; Zhang, H.; Mohan, V.; Dhillon, I.S.; Kolter, J.Z. Realtime query completion via deep language models. In Proceedings of the eCOM@ SIGIR, Ann Arbor, MI, USA, 12 July 2018. [Google Scholar]
  24. Gardner, M.; Grus, J.; Neumann, M.; Tafjord, O.; Dasigi, P.; Liu, N.F.; Peters, M.; Schmitz, M.; Zettlemoyer, L. AllenNLP: A Deep Semantic Natural Language Processing Platform. arXiv 2018, arXiv:1803.07640. [Google Scholar]
  25. Hinton, G.; Vinyals, O.; Dean, J. Distilling the knowledge in a neural network. arXiv 2015, arXiv:1503.02531. [Google Scholar]
  26. Cui, B.; Li, Y.; Zhang, Z. Joint structured pruning and dense knowledge distillation for efficient transformer model compression. Neurocomputing 2021, 458, 56–69. [Google Scholar] [CrossRef]
  27. Jiao, X.; Yin, Y.; Shang, L.; Jiang, X.; Chen, X.; Li, L.; Wang, F.; Liu, Q. TinyBERT: Distilling BERT for Natural Language Understanding. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Findings, Online, 16–20 November 2020; pp. 4163–4174. [Google Scholar]
  28. Romero, A.; Ballas, N.; Kahou, S.E.; Chassang, A.; Gatta, C.; Bengio, Y. Fitnets: Hints for thin deep nets. arXiv 2014, arXiv:1412.6550. [Google Scholar]
  29. Kim, Y.; Rush, A.M. Sequence-level knowledge distillation. arXiv 2016, arXiv:1606.07947. [Google Scholar]
  30. Sutton, R.S.; McAllester, D.A.; Singh, S.P.; Mansour, Y. Policy gradient methods for reinforcement learning with function approximation. In Proceedings of the Advances in Neural Information Processing Systems, Denver, CO, USA; 2000; pp. 1057–1063. [Google Scholar]
  31. Pass, G.; Chowdhury, A.; Torgeson, C. A picture of search. In Proceedings of the 1st International Conference on Scalable Information Systems, Hong Kong, 30 May–1 June 2006; p. 1-es. [Google Scholar]
  32. Kim, G. Subword language model for query auto-completion. arXiv 2019, arXiv:1909.00599. [Google Scholar]
  33. Sennrich, R.; Haddow, B.; Birch, A. Neural machine translation of rare words with subword units. arXiv 2015, arXiv:1508.07909. [Google Scholar]
  34. Carbonell, J.; Goldstein, J. The use of MMR, diversity-based reranking for reordering documents and producing summaries. In Proceedings of the 21st Annual International ACM SIGIR Conference on Research and Development in Information Retrieval, Melbourne, Australia, 24–28 August 1998; pp. 335–336. [Google Scholar]
  35. Park, D.H.; Chiba, R. A neural language model for query auto-completion. In Proceedings of the 40th International ACM SIGIR Conference on Research and Development in Information Retrieval, Tokyo, Japan, 7–11 August 2017; pp. 1189–1192. [Google Scholar]
  36. Mitra, B.; Craswell, N. Query auto-completion for rare prefixes. In Proceedings of the 24th ACM International on Conference on Information and Knowledge Management, Melbourne, Australia, 19–23 October 2015; pp. 1755–1758. [Google Scholar]
  37. Wang, S.; Guo, W.; Gao, H.; Long, B. Efficient Neural Query Auto Completion. In Proceedings of the 29th ACM International Conference on Information & Knowledge Management, Online, 19–23 October 2020; pp. 2797–2804. [Google Scholar]
  38. Shen, Y.; He, X.; Gao, J.; Deng, L.; Mesnil, G. Learning semantic representations using convolutional neural networks for web search. In Proceedings of the 23rd International Conference on World Wide Web, Seoul, Republic of Korea, 7–11 April 2014; pp. 373–374. [Google Scholar]
  39. Konishi, T.; Ohwa, T.; Fujita, S.; Ikeda, K.; Hayashi, K. Extracting search query patterns via the pairwise coupled topic model. In Proceedings of the Ninth ACM International Conference on Web Search and Data Mining, San Francisco, CA, USA, 22–25 February 2016; pp. 655–664. [Google Scholar]
  40. Blei, D.M.; Ng, A.Y.; Jordan, M.I. Latent dirichlet allocation. J. Mach. Learn. Res. 2003, 3, 993–1022. [Google Scholar]
  41. Abri, R.; Abri, S.; Cetin, S. Providing A Topic-Based LSTM Model to Re-Rank Search Results. In Proceedings of the 2022 7th International Conference on Machine Learning Technologies (ICMLT), Rome, Italy, 11–13 March 2022; pp. 249–254. [Google Scholar]
  42. Sanh, V.; Debut, L.; Chaumond, J.; Wolf, T. DistilBERT, a distilled version of BERT: Smaller, faster, cheaper and lighter. arXiv 2019, arXiv:1910.01108. [Google Scholar]
  43. Sun, S.; Cheng, Y.; Gan, Z.; Liu, J. Patient knowledge distillation for bert model compression. arXiv 2019, arXiv:1908.09355. [Google Scholar]
  44. Sordoni, A.; Bengio, Y.; Vahabi, H.; Lioma, C.; Grue Simonsen, J.; Nie, J.Y. A hierarchical recurrent encoder-decoder for generative context-aware query suggestion. In Proceedings of the 24th ACM International on Conference on Information and Knowledge Management, Melbourne, Australia, 19–23 October 2015; pp. 553–562. [Google Scholar]
Figure 1. The framework of Dis-KD. The data are initially inputted into both the teacher and student models. During the first iteration, the discriminator is trained to discern which model is the teacher and which is the student. In the second iteration, three losses are used to train the student model. The KL loss function is employed to compute the error between the teacher and student model. The cross-entropy loss is employed to compute the error between the label and the student’s model’s prediction. The discriminator’s loss is employed to give comments from the perspective of the whole sentence.
Figure 1. The framework of Dis-KD. The data are initially inputted into both the teacher and student models. During the first iteration, the discriminator is trained to discern which model is the teacher and which is the student. In the second iteration, three losses are used to train the student model. The KL loss function is employed to compute the error between the teacher and student model. The cross-entropy loss is employed to compute the error between the label and the student’s model’s prediction. The discriminator’s loss is employed to give comments from the perspective of the whole sentence.
Applsci 13 08041 g001
Figure 2. Results of baselines and Dis-KD.
Figure 2. Results of baselines and Dis-KD.
Applsci 13 08041 g002
Table 1. Dataset examples.
Table 1. Dataset examples.
Anon IDQueryQuery TimeItem RankClick URL
142rentdirect.com1 March 2006 07:17:12
217ask.com31 March 2006 14:31:101www.ask.com
993myspace.co1 March 2006 12:13:36
Table 2. Dataset summary.
Table 2. Dataset summary.
DatasetNumberSeenUnseen
Train17,521,031
Validation1,521,971
Test1,317,632670,810646,822
Table 3. Model parameter setting.
Table 3. Model parameter setting.
ParametersValues
Number of candidates10
Beam search size30
Learning epoch30
OptimizerAdamW
Batch size1024
Learning rate5 × 10−5
Heads in GPT-212
Student Networks layers2
Teacher Networks layers6
Embedding in GPT-2768
Embedding in LSTM256
Layers of LSTM2
Dropout0.1
Table 4. Results of completion generation.
Table 4. Results of completion generation.
GenerationRankingMRR
SeenUnseen
Dis-KDDis-KD0 .57310.1763
LWGFrequency0.44650.2241
MCGFrequency0.44690.2610
MCGCLSM0.42240.2628
MCGLSTM0.42930.2669
PCTM 0.614
LSTM-Topic 0.564
NQLMNQLM0.53950.1580
GPT(2)-2GPT-2(2)0.47020.1660
GPT(6)-2GPT-2(6)0.56960.1659
Table 6. Results of completion generation are presented in sub-word level. The best results are highlighted in bold.
Table 6. Results of completion generation are presented in sub-word level. The best results are highlighted in bold.
ModelMRRPMRR
SeenUnseenSeenUnseen
Dis-KD0.57310.17630.63390.3458
DisGPT0.51350.14320.59140.3129
Patient0.52780.15980.60130.3212
Table 7. Results of query per millisecond (ms/query).
Table 7. Results of query per millisecond (ms/query).
MethodsLatency
Dis-KD26.1 ms
GPT-2(2)26.0 ms
GPT-2(6)74.3 ms
LSTM34.1 ms
MCG0.18 ms
Disclaimer/Publisher’s Note: The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

Share and Cite

MDPI and ACS Style

Li, Z.; Cao, Z.; Li, P.; Zhong, Y.; Li, S. Discriminator-Enhanced Knowledge-Distillation Networks. Appl. Sci. 2023, 13, 8041. https://doi.org/10.3390/app13148041

AMA Style

Li Z, Cao Z, Li P, Zhong Y, Li S. Discriminator-Enhanced Knowledge-Distillation Networks. Applied Sciences. 2023; 13(14):8041. https://doi.org/10.3390/app13148041

Chicago/Turabian Style

Li, Zhenping, Zhen Cao, Pengfei Li, Yong Zhong, and Shaobo Li. 2023. "Discriminator-Enhanced Knowledge-Distillation Networks" Applied Sciences 13, no. 14: 8041. https://doi.org/10.3390/app13148041

Note that from the first issue of 2016, this journal uses article numbers instead of page numbers. See further details here.

Article Metrics

Back to TopTop