Stochastic Weight Averaging Revisited

Averaging neural network weights sampled by a backbone stochastic gradient descent (SGD) is a simple yet effective approach to assist the backbone SGD in finding better optima, in terms of generalization. From a statistical perspective, weight averaging (WA) contributes to variance reduction. Recently, a well-established stochastic weight averaging (SWA) method is proposed, which is featured by the application of a cyclical or high constant (CHC) learning rate schedule (LRS) in generating weight samples for WA. Then a new insight on WA appears, which states that WA helps to discover wider optima and then leads to better generalization. We conduct extensive experimental studies for SWA, involving a dozen modern DNN model structures and a dozen benchmark open-source image, graph, and text datasets. We disentangle contributions of the WA operation and the CHC LRS for SWA, showing that the WA operation in SWA still contributes to variance reduction but does not always lead to wide optima. The experimental results indicate that there are global scale geometric structures in the DNN loss landscape. We then present an algorithm termed periodic SWA (PSWA) which makes use of a series of WA operations to discover the global geometric structures. PSWA outperforms its backbone SGD remarkably, providing experimental evidences for the existence of global geometric structures. Codes for reproducing the experimental results are available at https://github.com/ZJLAB-AMMI/PSWA.


Introduction
Stochastic gradient descent (SGD) equipped with a decaying learning rate schedule (LRS) is the de facto approach to train modern deep neural networks (DNNs).Averaging neural network (NN) weights sampled by a backbone SGD is shown to be a simple yet effective approach to assist the backbone SGD in finding better optima, in terms of generalization.The idea of weight averaging (WA), also referred to as iterate averaging or tail-averaging [1], goes back to [ Compute current learning rate α according to the LRS.

8:
end if 9: end for 10: return w SWA 3].A WA procedure averages the final few iterates of SGD.From a statistical perspective, it has been proved that, the WA operation contributes to decreasing the variance in the final iterate of its backbone SGD, resulting in a stabilizing effect in terms of regularization properties and prediction guarantees [4].We term this view as variance reduction in what follows.
Recently, a stochastic WA (SWA) method has been proposed and received a lot of attentions.It is extremely easy to implement yet could improve SGD to achieve better generalization without the significant computational overhead [5; 6; 7].SWA starts after a converged SGD (namely the backbone SGD), which runs preceding it and outputs a local optimum w SGD of the loss function f (w), where w denotes the NN weights.SWA rewarms its backbone SGD, starting at w SGD .The rewarmed SGD employs a cyclical or high constant (CHC) LRS.The application of the CHC LRS is the major feature that discriminates SWA from other WA methods.Novel local optima are sampled along the trajectory of this rewarmed SGD process.Then a WA operation is used, which outputs the mean of such optima, denoted by w SWA , as the final output of SWA.A pseudo-code to implement SWA is shown in Algorithm 1, which outputs a running average of the sampled weights per c iterations.
A common insight to explain SWA's success is that the local optima discovered by its rewarmed backbone SGD are located at the boundary of a high-quality basin region in the DNN weight parameter space.Doing WA over such local optima then results in a wider optimum, which is closer to the center of the basin region [5], and a wider optimum leads to better generalization [5; 8].
We now have two seemingly independent views on the role of WA, one is statistical, namely the variance reduction perspective, and the other is geometric, namely the wider optimum perspective.Then, what is the relationship between these two views?and, how do they reconcile?
After a detailed inspection of SWA [5], we find that its behavior results from a combined effect of several possible intertwined factors, namely the convergence rate of the SGD that runs preceding SWA, the CRC LRS, the WA operation, and finally the application of the momentum technique and weight decaying.The common geometric view can not explain the specific role of each factor.For example, it can not answer the following questions: The above concerns motivate us to revisit SWA.As SWA is a fundamental, generic, architecture-agnostic technique for training DNNs, any new findings, insight from this re-inspection could bring a broad potential impact on deep learning.The major contributions of this paper can be summarized as follows, 1. we disentangle contributions of the WA operation, the CHC LRS, the application of momentum and weight decaying, and the rate at which the preceding SGD converges, to the behavior of SWA.
2. we find that the actual function of the WA operation in SWA is variance reduction, in the same spirit as tailaveraging [1].
3. we find cases in which SWA fails to discover better optima than its backbone SGD.
4. we find experimental evidence for the existence of global geometric structures in the DNN loss landscape; we show that such global structures can be exploited by the WA operation.
5. we propose a novel algorithm design termed periodic SWA (PSWA) inspired by the above experimental finding, and demonstrate that it is preferable to SGD when the training budget is so limited that it can not support an SGD to converge.

Related Work
Iterates Averaging The basic idea of iterate averaging, also referred to as tail-averaging in [1], goes back to [2; 3].The tail-averaging method averages the final few iterates of SGD.
In this way, it decreases the variance in the final iterate of SGD and brings a stabilizing effect in terms of regularization properties and prediction guarantees [4].A generalization error bounds for tail-averaging in the context of least square regression with the stochastic approximation is derived in [1].We show in this paper that the WA operation is a type of tailaveraging, which gives the stabilizing effect and the variance reduction function to SWA.Cyclical Learning Rates The benefits of employing cyclical learning rates (CLRs) in SGD have been demonstrated in [9; 10].Following that, such CLR strategy has been widely used in developing advanced DNN optimizers, such as fast geometric ensembling (FGE) [11], snapshot ensembles [12], super-convergence training [13], or exploring the loss landscape of DNNs [14].In this paper, we characterize that the CLR strategy also plays a major role in SWA's success.
Convergence Theory In [15], Zhu et al. present a convergence theory for training DNNs, based on two assumptions: the input data points are distinct and the DNN architecture is over-parameterized.This theory tells that, at least for fullyconnected neural networks (NNs), convolutional NNs (CNN), and residual NNs (ResNet), SGD with a random weight initialization can attain 100% accuracy in classification tasks with the number of iterations scaling polynomial in the number of training samples and the number of NN layers.Cheridito et al. demonstrate that, for ReLU networks that have a much larger depth than their width, SGD fails to converge if the number of restarted SGD trajectories does not increase to infinity fast enough [16].Here we investigate specific roles of the CHC LRS and the WA operation in promoting SGD's convergence.While our result is empirical, it may stimulate more theoretical research on DNN convergence.
Loss Landscape Study & Sharpness-aware Minimization Another commonly used way to investigate the convergence problem is through loss landscape analysis.The Hessian spectrum analysis has shown to be an effective approach to inspect smoothness, curvature, and sharpness of NN loss landscapes [17; 18].Yao et al. developed an open-source scalable framework for fast computation of Hessian information in DNNs [19].It has been common wisdom that, at least for some cases, NNs generalize better when they converge to a wider local optimum, and vice versa [8].However, the correlation between the local sharpness of the loss landscape and the global property like generalization performance may be only correlative, other than causative [20].
The empirical finding of the relationship between local sharpness and global generalization motivated the design of practical approaches to improving the generalization property of SGD.For example, the sharpness-aware minimization (SAM) method seeks NN weights that lie in a wider loss basin by modifying the optimization objective function to be sharpness-aware [21].SWA can be seen as a type of widenessaware solver for DNN optimization.It is reported that SWA finds wider minima than SAM [6].Our work characterizes the root cause that leads to SWA's success and provides more empirical evidence for deeply understanding the loss landscape of DNNs.

Main Results
Our goal is to inspect the real cause that leads to SWA's behavior.Toward this goal, we experiment with different DNN architectures on different datasets.We present the main results indexed with questions of our interest.All details about the experimental settings are described in Section 6.1.

Does SWA always find wider optima than SGD?
The results reported in the SWA paper [5] show that SGD generally converges to a boundary of a wide basin region and SWA helps to find an optimum exactly located in that wide basin region.All experiments conducted there use image datasets, such as CIFAR-{10, 100} [22], and ImageNet ILSVRC-2012 [23; 24].We wonder whether SWA always finds wider optima than SGD.We conduct experiments on graph and text datasets.
Results show that the answer is no.Specifically, on a graph dataset MUTAG, we use SWA to train a graph isomorphism network (GIN) for graph classification.The baseline optimizer selected is Adam, which is an advanced SGD method that performs better for graph data based tasks.We find that if we run Adam with 300 epochs, then we get a test accuracy (TA) value 89%, while if we replace Adam with SWA for the last 30 epochs, we can only get a smaller TA value 84%.We consider the graph node classification task using graph neural network (GNN) models, such as graph convolutional network (GCN) [25], GraphSAGE [26], and graph attention network (GAT) [27], using public open-source datasets Cora, Citeseer, and Pubmed.The parameter setting for the experiment is shown in Table 5 in Section 6.4.The TA comparison result is presented in Table 1.
The above experimental results for graph datasets show that using SWA does not always lead to better generalization than some advanced SGD optimizers like Adam.
On a text dataset termed Microsoft Research Paraphrase Corpus (MRPC), we use an SGD with momentum to finetune the pre-trained model RoBERTa for testing whether two sentences are semantically equivalent.See details about the experimental setting in Section 6.1.On average, SGD could give a TA value 87.98%, while SWA only achieves 87.50%.
We find that, even for image datasets, SWA does not always converge to a boundary of a wide basin region, especially when we remove the momentum module from its backbone SGD.For such cases, we find that SWA may converge to a deep loss valley, where the averaged gradients over mini-batch training samples are all close to zero.Then the products of such gradients and the learning rate are close to zero.In such cases, SWA fails to find a wider optimum with better generalization.See details of the experimental results in Section 6.3.

What is the real function of the WA operation to SWA?
To answer the question in the title of this section, we conduct ablation studies across different DNN models and datasets.As mentioned above, the SWA procedure consists of a CHC LRS based rewarmed SGD process, from which a set of NN weights are sampled, and a WA operation that yields the average of these sampled weights.We refer to the sampled weights as SWA samples in what follows.The momentum and weight decaying operations are removed here to provide a clean investigation.First, we consider image classification with DNN structures VGG16 [30], Preactivation ResNet-164 (PreResNet-164) [31], WideResNet-28-10 [32], using datasets CIFAR-{10,100}.For each model, SWA runs after a preceding SGD process being converged.The results are presented in Figure 1.The effect of using CHC LRS can be revealed by comparing TA values of SWA samples to that of SGD.The effect of WA can be checked by comparing the TA value of SWA and those of separate SWA samples.As is shown, neither using the CHC LRS nor performing WA brings a significantly clear benefit for increasing the TA value.This result coincides with that revealed in the above subsection.
We then consider cases in which the backbone SGD that runs preceding SWA converges to a bad optimum, corresponding to Case II in Section 6.1.In this case, we do not give enough budget for DNN training.The number of training epochs is only 30.The results based on CIFAR-{10,100} are presented in Figure 2. In this case, it is shown that, the application of the CHC LRS makes the resulting SWA samples produce striking greater TA values than the backbone SGD that runs before SWA.We also find that WA contributes an additional increase in the TA value.
Finally, we conduct an ablation study based on the Imagenet dataset.See the result in Section 6.2.It is shown that SWA samples provide much bigger TA values than SGD; and, except for VGG16, the WA operation provides an additional increase in the TA value.
For all cases aforementioned, the WA operation always outputs a TA value that is bigger than the smallest TA value given by separate SWA samples.That says the WA operation in SWA functions similarly as tail-averaging [1], in decreasing the variance of TA values associated with SWA samples.

On global geometric structure of the DNN loss landscape
As presented above, we find cases in which SWA is initialized by a backbone SGD that does not converge well and SWA performs strikingly better than its backbone SGD, while if the backbone SGD converges well, then the performance gap of SWA and SGD is reduced or even becomes indistinguishable.
As we know, when the backbone SGD converges well, then the NN weights employed by SWA shall center around a local optimum discovered by this SGD.Thus, SWA can only make use of a very local geometric structure around this local optimum.When the backbone SGD does not converge well, then NN samples fed to SWA shall span a much wider area.This motivates us to raise a hypothesis as follows: Is there any global geometric structure in the DNN loss landscape that can be encountered by an SGD at the early stage of its life cycle?If such a global structure exists, how to    exploit it for facilitating the discovery of higher quality local optima?
We propose a novel algorithm design, termed periodic SWA (PSWA) that starts at an early stage of its backbone SGD.PSWA exploits the aforementioned possible global structures via performing WA sequentially.We show experimental results in the following section, which demonstrate that PSWA outperforms its backbone SWA remarkably, thus provides evidence for the existence of such global geometric structures.
PSWA consists of a series of SWA procedures that run sequentially.The first SWA procedure is initialized by an NN weight given by the backbone SGD that runs preceding SWA.For each of the other SWA procedures, its starting weight seed is the output of its former SWA procedure.Different from the original SWA method, which is invoked when its preceding SGD is converged, PSWA is started when its preceding SGD is at a very early stage of its working period.In addition, PSWA uses a LRS that is totally the same as its backbone SGD, as shown in Figure 6.If the sequentially performed SWA procedures can continually bring performance gains compared with the backbone SGD, then it would indicate that PSWA has made use of some global structures of the loss landscape to search local optima.
Notably, the PSWA algorithm is a byproduct of our experimental findings in Section 3. The aim of our experiments here is to test the whether our hypothesis raised in subSection 4.1 holds.

On performance of PSWA
PSWA consists of a series of SWA procedures that run sequentially.The first SWA procedure is initialized by an NN weight given by the backbone SGD that runs preceding SWA.For each of the other SWA procedures, its starting weight seed is the output of its former SWA procedure.Different from the original SWA method, which is invoked when its preceding SGD is converged, PSWA is started when its preceding SGD is at a very early stage of its working period.In addition, PSWA uses a LRS that is totally the same as its backbone SGD, as shown in Figure 6.If the sequentially performed SWA procedures can continually bring performance gains compared with the backbone SGD, then it would indicate that PSWA has made use of some global structures of the loss landscape to search local optima.
Notably, the PSWA algorithm is a byproduct of our experimental findings in Section 3. The aim of our experiments here is to test the whether our hypothesis raised in subSection 4.1 holds.
We compare PSWA with the backbone SGD on datasets CIFAR-10 and CIFAR-100, based on DNN structures VGG16, PreResNet-164, and WideResNet-28-10.The momentum factor for SGD is 0.9, and the weight decaying parameter is set at 0.0005.PSWA starts after the 40th epoch with a period of 20 epochs.Within one period of PSWA, a full SWA procedure is conducted.In a SWA procedure, we sample one NN weight per epoch, then average the weights that have been sampled within this SWA procedure as the current output of PSWA.
The experimental results are shown in Fig. 3.We see that PSWA indeed provides a remarkable performance gain com- pared with its backbone SGD at the early stage of the training process.This provides experimental evidence for the existence of global geometric structures in the DNN loss landscape that can be encountered by an SGD process at the early stage of its working period, and demonstrates that such structures can be exploited by the WA operations for improving the backbone SGD.From an algorithmic perspective, we could not claim that PSWA is better than SGD, since qualities of their final outputs at the end of the training process are indistinguishable, while if the training budget can not support the whole process of training, then PSWA is clearly preferable to SGD, since it provides much better weight samples than SGD at the early stage of the training process.
We also compare two special examples of PSWA, termed double SWA (DSWA) and triple SWA (TSWA), to SWA.DSWA and TSWA consist of two and three sequentially performed SWA procedures, respectively.See the pseudo-codes to implement DSWA and TSWA in Algorithm 2 and Algorithm 3. To make a fair comparison, we let SWA, DSWA, and TSWA run the same number of iterations to guarantee that their computational budgets are almost the same.We do not use the momentum and weight decaying, to get rid of their influences on the comparison.Algorithm 2 Double Stochastic Weight Averaging (DSWA) Input: weights ŵ, LRS, cycle length c, number of iterations n (assumed to be multiples of 2) Output: w dswa 1: Run the SWA procedure (namely Algorithm 1) with input ŵ, c, n/2.Denote the output to be w swa .2: ŵ ← w swa .3: Run the SWA procedure again with input ŵ, c, n/2.Denote the output to be w dswa .4: return w dswa Algorithm 3 Triple Stochastic Weight Averaging (TSWA) Input: weights ŵ, LRS, cycle length c, number of iterations n (assumed to be multiples of 3) Output: w tswa 1: Run the SWA procedure (namely Algorithm 1) with input ŵ, c, n/3.Denote the output to be w swa .2: ŵ ← w swa .3: Run the SWA procedure again with input ŵ, c, n/3.Denote the output to be w dswa .4: ŵ ← w dswa .5: Run the SWA procedure again with input ŵ, c, n/3.Denote the output to be w tswa .6: return w tswa We find that, if the backbone SGD that runs preceding SWA is non-converged or converges to a bad local optimum, corresponding to Case II in Section 6.1, DSWA and TSWA indeed find flatter optima that lead to better generalization than SWA, see results in Tables 3, 4 and Figure 4.If the backbone SGD converges well, corresponding to Case I in Section 6.1, then DSWA and TSWA fail to find flatter optima than SWA, as shown in Figure 5.Note that Figures 4 and 5 are obtained in the same way as that used to obtain Figure 5 in [5].

Conclusions
In this paper, we investigated how the weight averaging operation and the cyclical or high constant learning rate scheduling each contribute to SWA.Through experiments on a broad range of NN architectures, we identified a link between SGD and the global loss landscape and developed a novel insight from a statistical as well as geometric perspective in regard to SWA.Specifically, we find that SWA works because it provides a mechanism to combine advantages of the WA operation and the CHC LRS.The CHC LRS contributes to discovering global scale geometric structures, and WA contributes to exploiting such structures.By leveraging SGD's early training phase behavior, we proposed a novel algorithm, periodic SWA, which is shown to be capable of finding high quality local optima much more quickly than SGD.Experiment setting for results reported in Section 3.1 For the graph classification task, we ran our experiments on a public open-source dataset MUTAG, which is commonly used for the graph classification task.See details about this dataset at https://paperswithcode.com/dataset/mutag.We use Adam [33] to train a GIN model for 300 epochs.We set the learning rate α at 0.01, and use the default parameter setting for the exponential decay rates β 1 and β 2 , namely let β 1 = 0.9 and β 2 = 0.999.For SWA, it starts at the 270th epoch, using a constant learning rate 0.02.
For experiments on the text dataset MRPC (see details about this dataset at https://paperswithcode.com/dataset/mrpc), the learning rate of SGD is fixed at 10 −4 during the first 20 epochs, then is linearly decreased to 10 −6 in the following 20 epochs, then is fixed at 10 −6 for the last 10 epochs.The momentum and the weight decaying factors are set at 0.9 and 0.01, respectively.SWA is started at the 45th epoch.For each epoch of SWA, the learning rate is linearly decreased from 10 −5 to 5 × 10 −6 .
Experiment setting for results reported in Section 3.2 For the image classification experiments presented in Section 3.2, we consider two major cases, termed Case I and Case II here, for each DNN architecture under consideration.In Case I, we run SWA after a converged SGD.In Case II, we run it after a non-converged SGD.We adopt the same type of LRS as used in [5].An example of LRS we use is shown in Figure 6.This LRS covers L = 160 epochs in total.The first half segment of this LRS takes a constant higher value C h , followed by a segment of LRS that consists of linearly decreased learning rate values.The ending segment of this LRS takes a constant lower value C l .For the LRS shown in Figure 6, C h = 0.05, and C l = 0.01.For Case I, we set the value of L to be big enough, and that of C l small enough to guarantee that the SGD process that runs before SWA is converged.For Case II, we set a small value like 30 to L, to ensure that the SGD that runs preceding SWA does not converge.For the SWA procedure, the cycle length c takes a value that makes a cycle equal to an epoch.We adopt the same CHC LRS as used in [5] for the SWA procedure.The mini-batch size is set at 128 for all experiments.

Experiment on Imagenet
We conduct the same ablation study as in Section 3.2 on the Imagenet dataset.We run SWA based on backbone DNN models VGG16, ResNet-50, ResNet-152, and DenseNet-161, which are contained in PyTorch.The results are presented in Figure 7.

Experiments with a toy CNN model
In this experiment, we remove the momentum module from SGD.We train a toy CNN model on CIFAR-10 and get an over-fitting result as shown in Figure 8.We collect the weight value at the end of each epoch and calculate its corresponding TA value.The maximum TA value of 0.683 appears at the 45th epoch.The TA corresponding to the last iterate of SGD is 0.680.We replace the last L = 5 iterations of SGD with the SWA procedure, then get a TA value w swa = 0.679.We change the value of L to be 20, getting w swa = 0.680.It is indicated that SWA does not lead to wider optima in this case.A similar phenomenon happens when we replace the toy CNN model with PreResNet-164.We use the code opensourced by [5], while closing off the momentum and the L2based weight regularization to remove their effects on SWA.
As is shown in Figure 9, SGD converges after about the 120th epoch with TA achieving 89.24%.We run SWA after the 140th epoch and get a TA of 89.17%, which is smaller than the TA given by the converged SGD.

Experiments with graph data
We show experimental settings associated with the graph data experiments presented in Section 3.1 in Tables 5-8.

Figure 1 :
Figure 1: Ablation study of the CHC LRS and the WA operation for DNNs that converge well.The legend "SGD" denotes the TA value associated with the NN weight given by the backbone SGD at the time point when SWA is started.The legend "SWA samples" denotes TA values associated with NN weights sampled during the SWA procedure.The legend "SWA" denotes the TA value associated with the mean of NN weights sampled during the SWA procedure.The sub-figures in the left/middle/right column correspond to VGG16/PreResNet-164/WideResNet-28-10.The sub-figures in the top/bottom row correspond to dataset CIFAR-10/100.

Figure 2 :
Figure 2: Ablation study on the CHC LRS and the WA operation for DNNs that does not converge well.The legends are defined in the same way as in Figure 1.The sub-figures in the left/middle/right column correspond to VGG16/PreResNet-164/WideResNet-28-10.The sub-figures in the top/bottom row correspond to dataset CIFAR-10/CIFAR-100.

Figure 4 :
Figure 4: Cross-entropy train loss and test error as a function of a point on the line connecting SWA and DSWA (or TSWA) solutions on CIFAR-100.DSWA and TSWA are initialized by a non-converged preceding SGD procedure.Left: PreResNet-164.Right: VGG16.

Figure 5 :
Figure 5: Cross-entropy train loss and test error as a function of a point on the line connecting SWA and DSWA (or TSWA) solutions on CIFAR-10.DSWA and TSWA are initialized by a converged preceding SGD procedure.Left: PreResNet-164.Right: VGG16.

Figure 6 :
Figure 6: An example of the learning rate schedules used in our experiments.

Figure 7 :
Figure 7: Ablation study using the Imagenet dataset.The legends are defined in the same way as in Figure 1.The top left, top right, bottom left and bottom right panels show results corresponding to VGG16, ResNet-50, ResNet-152, and DenseNet-161, respectively.Note that SWA begins based on such well pretrained models contained in Pytorch.So the horizontal axis label starts with epoch 1.

Figure 8 :
Figure 8: The over-fitting result obtained when training a toy CNN model on CIFAR-10.This CNN has 9 layers: the input layer, the convolution layer, a max-pooling layer, another convolution layer, another max-pooling layer, the flatten layer, 2 fully connected layers, and a softmax layer.

Table 5 :
The parameter setting for the GNN experiments.The baseline optimizer is Adam with weight decaying factor 0.0005.L denotes the total number of epochs, α the learning rate of the Adam optimizer, α SWA the constant learning rate used by SWA, and t SWA the starting point to launch SWA.

Table 6 :
The parameter setting for the graph classification task on dataset NCI1.The baseline optimizer is Adam with weight decaying factor 0.0005.L denotes the total number of epochs, α the learning rate of the Adam optimizer, α SWA the constant learning rate used by SWA, and t SWA the starting point to launch SWA.

Table 7 :
The parameter setting for the graph classification task on dataset D&D.The baseline optimizer is Adam with weight decaying factor 0.0005.L denotes the total number of epochs, α the learning rate of the Adam optimizer, α SWA the constant learning rate used by SWA, and t SWA the starting point to launch SWA.

Table 8 :
The parameter setting for the graph classification task on dataset PROTEINS.The baseline optimizer is Adam with weight decaying factor 0.0005.L denotes the total number of epochs, α the learning rate of the Adam optimizer, α SWA the constant learning rate used by SWA, and t SWA the starting point to launch SWA.