Phase diagram of early training dynamics in deep neural networks: effect of the learning rate, depth, and width
Dayal Singh Kalra, Maissam Barkeshli

TL;DR
This paper explores how learning rate, depth, and width influence the early training dynamics of deep neural networks, revealing distinct regimes and a sharpness reduction phase that depends on these parameters.
Contribution
It provides a systematic analysis of optimization regimes in DNNs, identifying critical parameters and phases, including a novel sharpness reduction phenomenon during early training.
Findings
Identification of four distinct training regimes based on Hessian eigenvalues.
Discovery of a sharpness reduction phase influenced by network depth and width.
Critical thresholds in learning rate and architecture parameters that alter training dynamics.
Abstract
We systematically analyze optimization dynamics in deep neural networks (DNNs) trained with stochastic gradient descent (SGD) and study the effect of learning rate , depth , and width of the neural network. By analyzing the maximum eigenvalue of the Hessian of the loss, which is a measure of sharpness of the loss landscape, we find that the dynamics can show four distinct regimes: (i) an early time transient regime, (ii) an intermediate saturation regime, (iii) a progressive sharpening regime, and (iv) a late time ``edge of stability" regime. The early and intermediate regimes (i) and (ii) exhibit a rich phase diagram depending on , , and . We identify several critical values of , which separate qualitatively distinct phenomena in the early time dynamics of training loss and sharpness. Notably, we discover the opening up…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
Taxonomy
TopicsStochastic Gradient Optimization Techniques · Machine Learning and ELM · Neural Networks and Applications
Phase diagram of early training dynamics in deep networks: effect of the learning rate, depth, and width
Dayal Singh Kalra
&Maissam Barkeshli 11footnotemark: 1
[email protected] Condensed Matter Theory Center, University of Maryland, College ParkInstitute for Physical Science and Technology, University of Maryland, College ParkDepartment of Physics, University of Maryland, College ParkJoint Quantum Institute, University of Maryland, College Park
Abstract
We systematically analyze optimization dynamics in deep neural networks (DNNs) trained with stochastic gradient descent (SGD) and study the effect of learning rate , depth , and width of the neural network. By analyzing the maximum eigenvalue of the Hessian of the loss, which is a measure of sharpness of the loss landscape, we find that the dynamics can show four distinct regimes: (i) an early time transient regime, (ii) an intermediate saturation regime, (iii) a progressive sharpening regime, and (iv) a late time “edge of stability" regime. The early and intermediate regimes (i) and (ii) exhibit a rich phase diagram depending on , , and . We identify several critical values of , which separate qualitatively distinct phenomena in the early time dynamics of training loss and sharpness. Notably, we discover the opening up of a “sharpness reduction" phase, where sharpness decreases at early times, as and are increased.
1 Introduction
The optimization dynamics of deep neural networks (DNNs) is a rich problem that is of great interest. Basic questions about how to choose learning rates and their effect on generalization error and training speed remain intensely studied research problems. Classical intuition from convex optimization has lead to the often made suggestion that in stochastic gradient descent (SGD), the learning rate should satisfy , where is the maximum eigenvalue of the Hessian of the loss, in order to ensure that the network reaches a minimum. However several recent studies have suggested that it is both possible and potentially preferable to have the learning rate *early in training *reach [66, 49, 72]. The idea is that such a choice will induce a temporary training instability, causing the network to ‘catapult’ out of a local basin into a flatter one with lower where training stabilizes. Indeed, during the early training phase, the local curvature of the loss landscape changes rapidly [42, 1, 37, 15], and the learning rate plays a crucial role in determining the convergence basin [37]. Flatter basins are believed to be preferable because they potentially lead to lower generalization error [31, 32, 42, 12, 39, 14] and allow larger learning rates leading to potentially faster training.
From a different perspective, the major theme of deep learning is that it is beneficial to increase the model size as much as possible. This has come into sharp focus with the discovery of scaling laws that show power law improvement in generalization error with model and dataset size [40]. This raises the fundamental question of how one can scale DNNs to arbitrarily large sizes while maintaining the ability to learn; in particular, how should initialization and optimization hyperparameters be chosen to maintain a similar quality of learning as the model size is taken to infinity [34, 47, 48, 11, 69, 58, 70, 68]?
Motivated by these ideas, we perform a systematic analysis of the training dynamics of SGD for DNNs as learning rate, depth, and width are tuned, across a variety of architectures and datasets. We monitor both the loss and sharpness () trajectories during early training, observing a number of qualitatively distinct phenomena summarized below.
1.1 Our contributions
We study SGD on fully connected networks (FCNs) with the same number of hidden units (width) in each layer, convolutional neural networks (CNNs), and ResNet architectures of varying width and depth with ReLU activation. For CNNs, the width corresponds to the number of channels. We focus on networks parameterized in Neural Tangent Parameterization (NTP) [34], and Standard Parameterization (SP) [62] initialized at criticality [55, 58], while other parameterizations and initializations may show different behavior. Further experimental details are provided in Appendix A. We study both mean-squared error (MSE) and cross-entropy loss functions and the datasets CIFAR-10, MNIST, Fashion-MNIST. Our findings apply to networks with , where depends on architecture class (e.g. for FCNs, ) and loss function, but is independent of , , and . Above this ratio, the dynamics becomes noise-dominated, and separating the underlying deterministic dynamics from random fluctuations becomes challenging, as shown in Appendix E. We use sharpness to refer to , the maximum eigenvalue of at time-step , and flatness refers to .
By monitoring the sharpness, we find four clearly separated, qualitatively distinct regimes throughout the training trajectory. Fig. 1 shows an example from a CNN architecture. The four observed regimes are: (i) an early time transient regime where loss and sharpness may drastically change and eventually settle down, (ii) an intermediate saturation regime where the sharpness has lowered and remains relatively constant, (iii) a progressive sharpening regime where sharpness steadily rises, and finally, (iv) a late time regime where the sharpness saturates around for MSE loss; whereas for cross-entropy loss, sharpness drops after reaching this maximum value while remaining less than [8]. Note the log scale in Figure 1 highlights the early regimes (i) and (ii); in absolute terms these are much shorter in time than regimes (iii) and (iv).
In this work, we focus on the early transient and intermediate saturation regimes. As learning rate, and are tuned, a clear picture emerges, leading to a rich phase diagram, as demonstrated in Section 2. Given the learning rate scaled as , we characterize four distinct behaviors in the training dynamics in the early transient regime (i):
Sharpness reduction phase () : *Both the loss and the sharpness *monotonically decrease during early training. There is a particularly significant drop in sharpness in the regime , which motivates us to refer to learning rates lower than as sub-critical and larger than as super-critical. We discuss in detail below. The regime opens up significantly with increasing and , which is a new result of this work.
Loss catapult phase () : The first few gradient steps take training to a flatter region but with a higher loss. Training eventually settles down in the flatter region as the loss starts to decrease again. The sharpness *monotonically decreases from initialization *in this early time transient regime.
Loss and sharpness catapult phase (): In this regime *both the loss and sharpness *initially start to increase, effectively catapulting to a different point where loss and sharpness can start to decrease again. Training eventually exhibits a significant reduction in sharpness by the end of the early training. The report of a *loss and sharpness catapult *is also new to this work.
Divergent phase (): The learning rate is too large for training and the loss diverges.
The critical values , , are random variables that depend on random initialization, SGD batch selection, and architecture. The averages of , , shown in the phase diagrams show strong systematic dependence on depth and width. In order to better understand the cause of the sharpness reduction during early training we study the effect of network output at initialization by (1) centering the network, (2) setting last layer weights to zero, or (3) tuning the overall scale of the output layer. We also analyze the linear connectivity of the loss landscape in the early transient regime and show that for a range of learning rates , no barriers exist from the initial state to the final point of the initial transient phase, even though training passes through regions with higher loss than initialization.
Next, we provide a quantitative analysis of the intermediate saturation regime. We find that sharpness during this time typically displays 3 distinct regimes as the learning rate is tuned, depicted in Fig. 5. By identifying an appropriate order parameter, we can extract a sharp peak corresponding to . For MSE loss , whereas for crossentropy loss, . For , the network is effectively in a lazy training regime, with increasing fluctuations as and/or are increased.
Finally, we show that a single hidden layer linear network – the model – displays the same phenomena discussed above and we analyze the phase diagram in this minimal model.
1.2 Related works
A significant amount of research has identified various training regimes using diverse criteria, e.g., [13, 1, 16, 37, 17, 45, 35, 8, 33]. Here we focus on studies that characterize training regimes with sharpness and learning rates. Several studies have analyzed sharpness at different training times [37, 15, 35, 8, 33]. Ref. [8] studied sharpness at late training times and showed how *large-batch *gradient descent shows progressive sharpening followed by the edge of stability, which has motivated various theoretical studies [9, 2, 3]. Ref. [37] studied the entire training trajectory of sharpness in models trained with SGD and cross-entropy loss and found that sharpness increases during the early stages of training, reaches a peak, and then decreases. In contrast, we find a sharpness-reduction phase, which becomes more prominent with increasing and , where sharpness only decreases during early training; this also occurs in the catapult phase , during which the loss initially increases before decreasing. This discrepancy is likely due to different initialization and learning rate scaling in their work [33].
Ref. [35] examined the effect of hyperparameters on sharpness at late training times. Ref. [20] studied the optimization dynamics of SGD with momentum using sharpness. Ref. [45] classify training into 2 different regimes using training loss, providing a significantly coarser description of training dynamics than provided here. Ref. [33] studied the scaling of the maximum learning rate with and during early training in FCNs and its relationship with sharpness at initialization.
Refs. [52, 71] present phase diagrams of shallow ReLU networks at infinite width under gradient flow. Previous studies such as [41, 69] show that is the maximum learning rate for convergence as . This limit results in the kernel regime as training time is restricted to in the limit of infinite width, resulting in a lazy training regime for learning rates less than and divergent training for larger learning rates. In contrast, we analyze optimization dynamics at training timescales that grow with width . Specifically, the end of the early time transient period occurs at .
Ref. [49] analyzed the training dynamics at large widths and training times, using the top eigenvalue of the neural tangent kernel (NTK) as a proxy for sharpness. They demonstrated the existence of a new early training phase, which they dubbed the “catapult" phase, , in wide networks trained with MSE loss using SGD, in which training converges after an initial increase in training loss. The existence of this new training regime was further extended to quadratic models with large widths by [72, 53]. Our work extends the above analysis by studying the combined effect of learning rate, depth, and width for both MSE and cross-entropy loss, demonstrating the opening of a sharpness-reduction phase, the refinement of the catapult phase into two phases depending on whether the sharpness also catapults, analyzing the phase boundaries as and is increased, analyzing linear mode connectivity in the catapult phase, examining different qualitative behaviors in the intermediate saturation regime (ii) mentioned above.
2 Phase diagram of early transient regime
For wide enough networks trained with MSE loss using SGD, training converges into a flatter region after an initial increase in the training loss for learning rates [49]. Fig. 2(a, b) shows the first steps of the loss and sharpness trajectories of a shallow ( and ) CNN trained on the CIFAR-10 dataset with MSE loss using SGD. For learning rates, , the loss catapults and training eventually converges into a flatter region, as measured by sharpness. Additionally, we observe that sharpness may also spike initially, similar to the training loss (see Fig. 2 (b)). However, this initial spike in sharpness occurs at relatively higher learning rates (), which we will examine along with the loss catapult. We refer to this spike in sharpness as ‘sharpness catapult.’
An important consideration is the degree to which this phenomenon changes with network depth and width. Interestingly, we found that the training loss in deep networks on average catapults at much larger learning rates than . Fig. 2(d, e) shows that for a deep () CNN, the loss and sharpness may catapult only near the maximum trainable learning rate. In this section, we characterize the properties of the early training dynamics of models with MSE loss. In Appendix F, we show that a similar picture emerges for cross-entropy loss, despite the dynamics being noisier.
2.1 Loss and sharpness catapult during early training
In this subsection, we characterize the effect of finite depth and width on the onset of the loss and sharpness catapult and training divergence. We begin by defining critical constants that correspond to the above phenomena.
Definition 1**.**
*()
For learning rate , let the training loss and sharpness at step be denoted by and . We define as minimum learning rates constants such that the loss (sharpness) increases during the initial transient period:*
[TABLE]
and as the maximum learning rate constant such that the loss does not diverge during the initial transient period: where is a fixed large constant.111 We use to estimate In all our experiments, (see Appendix A), which justifies the use of a fixed value..
Note that the definition of allows for more flexibility than previous studies [33] in order to investigate a wider range of phenomena occurring near the maximum learning rate. Here, , , and are random variables that depend on the random initialization and the SGD batch sequence, and we denote the average over this randomness using .
Fig. 3(a-c) illustrates the phase diagram of early training for three different architectures trained on various datasets with MSE loss using SGD. These phase diagrams show how the averaged values , , and are affected by width. The results show that the averaged values of all the critical constants increase significantly with (note the log scale). At large widths, the loss starts to catapult at . As increases, increases and eventually converges to at large . By comparison, sharpness starts to catapult at relatively large learning rates at small , with continuing to increase with while remaining between and . Similar results are observed for different depths as demonstrated in Appendix C. Phase diagrams obtained by varying are qualitatively similar to those obtained by varying , as shown in Figure 3(d-f). Comparatively, we observe that may increase or decrease with in different settings while consistently increasing with , as shown in Appendices H and F.
While we plotted the averaged quantities , , , we have observed that their variance also increases significantly with and ; in Appendix C we show standard deviations about the averages for different random initializations. Nevertheless, we have found that the inequality typically holds, for any given initialization and batch sequences, except for some outliers due to high fluctuations when the averaged critical curves start merging at large and . Fig. 4 shows evidence of this claim. The setup is the same as in Fig. 3. Appendix D presents extensive additional results across various architectures and datasets.
In Appendix F, we show that cross-entropy loss shows similar results with some notable differences. The loss catapults at a relatively higher value and consistently decreases with , while still satisfying .
2.2 Loss connectivity in the early transient period
In the previous subsection, we observed that training loss and sharpness might quickly increase before decreasing (“catapult") during early training for a range of depths and widths. A logical next step is to analyze the region in the loss landscape that the training reaches after the catapult. Several works have analyzed loss connectivity along the training trajectory [21, 51, 64]. Ref. [51] report that training traverses a barrier at large learning rates, aligning with the naive intuition of a barrier between the initial and final points of the loss catapult, as the loss increases during early training. In this section, we will test the credibility of this intuition in real-world models. Specifically, we linearly interpolate the loss between the initial and final point after the catapult and examine the effect of the learning rate, depth, and width. The linearly interpolated loss and barrier are defined as follows.
Definition 2**.**
* Let represent the initial set of parameters, and let represent the set of parameters at the end of the initial transient period, trained using a learning rate constant . Then, we define the linearly interpolated loss as , where is the interpolation parameter. The interpolated loss barrier is defined as the maximum value of the interpolated loss over the range of : . *
Here we subtracted the loss’s initial value such that a positive value indicates a barrier to the final point from initialization. Using the interpolated loss barrier, we define as follows.
Definition 3**.**
*()
Given the initial () and final parameters (), we define as the minimum learning rate constant such that there exists a barrier from to : *
Here, is also a random variable that depends on the initialization and SGD batch sequence. We denote the average over this randomness using as before. Fig. 2(c, f) shows the interpolated loss of CNNs trained on the CIFAR-10 dataset for steps. The experimental setup is the same as in Section 2. For the network with larger width, we observe a barrier emerging at , while the loss starts to catapult at . In comparison, we do not observe any barrier from initialization to the final point at large and . Fig. 3 shows the relationship between and for various models and datasets. We consistently observe that , suggesting that training traverses a barrier only when sharpness starts to catapult during early training. Similar results were observed on increasing instead of as shown in Appendix C. We chose not to characterize the phase diagram of early training using as we did for other critical ’s, as it is somewhat different in character than the other critical constants, which depend only on the sharpness and loss trajectories.
These observations call into question the intuition of catapulting out of a basin for a range of learning rates in between . These results show that for these learning rates, the final point after the catapult already lies in the same basin as initialization, and even connected through a linear path, revealing an inductive bias of the training process towards regions of higher loss during the early time transient regime.
3 Intermediate saturation regime
In the intermediate saturation regime, sharpness does not change appreciably and reflects the cumulative change that occurred during the initial transient period. This section analyzes sharpness in the intermediate saturation regime by studying how it changes with the learning rate, depth, and width of the model. Here, we show results for MSE loss, whereas cross-entropy results are shown in Appendix F.
We measure the sharpness at a time in the middle of the intermediate saturation regime. We choose so that .222time-step is in the middle of regime (ii) for the models studied. Normalizing by allows proper comparison for different learning rates. For further details on sharpness measurement, see Section I.1. Fig. 5(a) illustrates the relationship between and the learning rate for -layer deep CNNs trained on the CIFAR-10 dataset with varying widths. The results indicate that the dependence of on learning rate can be grouped into three distinct stages. (1) At small learning rates, remains relatively constant, with fluctuations increasing as and increase ( in Fig. 5(a)). (2) A crossover regime where is dropping significantly ( in Fig. 5(a)). (3) A saturation stage where stays small and constant with learning rate ( in Fig. 5(a)). In Appendix I, we show that these results are consistent across architectures and datasets for varying values of and . Additionally, the results reveal that in stage (1), where is sub-critical, decreases with increasing and . In other words, for small and in the intermediate saturation regime, the loss is locally flatter as and increase.
We can precisely extract a critical value of that separates stages (1) and (2), which corresponds to the onset of an abrupt reduction of sharpness . To do this, we consider the averaged normalized sharpness over initializations and denote it by . The first two derivatives of the averaged normalized sharpness, and , characterize the change in sharpness with learning rate. The extrema of quantitatively define the boundaries between the three stages described above. In particular, using the maximum of , we define , which marks the beginning of the sharp decrease in with the learning rate.
Definition 4**.**
*()
Given the averaged normalized sharpness measured at , we define to be the learning rate constant that minimizes its second derivative: *
Here, we use to denote that the critical constant is obtained from the averaged normalized sharpness. Fig. 5(b, c) show and obtained from the results in Fig. 5(a). We observe similar results across various architectures and datasets, as shown in Appendix I. Our results show that has slight fluctuations as and are changed but generally stay in the vicinity of . The peak in becomes wider as and increase, indicating that the transition between stages (1) and (2) becomes smoother, presumably due to larger fluctuations in the properties of the Hessian at initialization. In contrast to , increase with and , implying the opening of the sharpness reduction phase as and increase. In Appendix F, we show that cross-entropy loss shows qualitatively similar results, but with .
4 Effect of network output at initialization on early training
Here we discuss the effect of network output at initialization on the early training dynamics. is the input and denotes the set of parameters at time . We consider setting the network output to zero at initialization, , by either (1) considering the “centered" network: , or (2) setting the last layer weights to zero at initialization (for details, see Appendix G). Remarkably, both (1) and (2) remove the opening up of the sharpness reduction phase with as shown in Figure 6. The average onset of the loss catapult, diagnosed by , becomes independent of and .
We also empirically study the impact of the output scale [19, 5, 4] on early training dynamics. Given a network function , we define the scaled network as , where is a scalar, fixed throughout training. In Appendix H, we show that a large (resp. small) value of relative to the one-hot encodings of the labels causes the sharpness to decrease (resp. increase) during early training. Interestingly, we still observe an increase in with and , unlike the case of initializing network output to zero, highlighting the unique impact of output scale on the dynamics.
5 Insights from a simple model
Here we analyze a two-layer linear network [57, 60, 49], the model, which shows much of the phenomena presented above. Define , with . Here, are the trainable parameters, initialized using the normal distribution, for . The model is trained with MSE loss on a single training example , which simplifies the loss to , and which was also considered in Ref. [49]. Our choice of is motivated by the results of Sec. 4, which suggest that the empirical results of Sec. 2 are intimately related to the model having a large initial output scale relative to the output labels. We minimize the loss using gradient descent (GD) with learning rate . The early time phase diagram also shows similar features to those described in preceding sections (compare Fig. 7(a) and Fig. 3). Below we develop an understanding of this early time phase diagram in the model.
The update equations of the model in function space can be written in terms of the trace of the Hessian
[TABLE]
From the above equations, it is natural to scale the learning rate as . Note that . Also, we denote the critical constants in this scaling as , , and , where the definitions follow from Definitions 1 and 4 on replacing sharpness with trace and use to denote an average over initialization. Figure 7(b) shows the phase diagram of early training, with replaced with as the measure of sharpness and with the learning rate scaled as . Similar to Figure 7(a), we observe a new phase opening up at small width. However, we do not observe the loss-sharpness catapult phase as does not increase during training (see Equation 1). We also observe , independent of width.
In Appendix B.3, we show that the critical value of for which increases with , which explains why increases with . Combined with , this implies the opening up of the sharpness reduction phase as is decreased.
To understand the loss-sharpness catapult phase, we require some other measure as does not increase for . As is difficult to analyze, we consider the Frobenius norm as a proxy for sharpness. We define as the minimum learning rate such that increases during early training. Figure 7(c) shows the phase diagram of the model, with as the measure of sharpness, while the learning rate is scaled as . We observe the loss-sharpness catapult phase at small widths. In Appendix B.4, we show that the critical value of for which increases from as increases. This explains the opening up of the loss catapult phase at small in Fig. 7 (c).
Fig. 8 shows the training trajectories of the model with large () and small () widths in a two-dimensional slice of parameters defined by and weight correlation . The above figure reveals that the first few training steps of the small-width network take the system in a flatter direction (as measured by ) as compared to the wider network. This means that the small-width network needs a relatively larger learning rate to get to a point of increased loss (loss catapult). We thus have the opening up of a new regime , in which the loss and sharpness monotonically decrease during early training.
The loss landscape of the model shown in Fig. 8 reveals interesting insights into the loss landscape connectivity results in Section 2.2 and the presence of . Fig. 8 shows how even when there is a loss catapult, as long as the learning rate is not too large, the final point after the catapult can be reached from initialization by a linear path without increasing the loss and passing through a barrier. However if the learning rate becomes large enough, then the final point after the catapult may correspond to a region of large weight correlation, and there will be a barrier in the loss upon linear interpolation.
The model trained on an example with provides insights into the effect of network output at initialization observed in Section 4. In Appendix G, we show that setting and in the dynamical equations results in loss catapult at , implying , irrespective of .
6 Discussion
We have studied the effect of learning rate, depth, and width on the early training dynamics in DNNs trained using SGD with learning rate scaled as . We analyzed the early transient and intermediate saturation regimes and presented a rich phase diagram of early training with learning rate, depth, and width. We report two new phases, sharpness reduction and loss-sharpness catapult, which have not been reported previously. Furthermore, we empirically investigated the underlying cause of sharpness reduction during early training. Our findings show that setting the network output to zero at initialization effectively leads to the vanishing of sharpness reduction phase at supercritical learning rates. We further studied loss connectivity in the early transient regime and demonstrated the existence of a regime , in which the final point after the catapult lies in the same basin as initialization, connected through a linear path. Finally, we study these phenomena in a 2-layer linear network ( model), gaining insights into the opening of the sharpness reduction phase.
We performed a preliminary analysis on the effect of batch size on the presented results in Appendix J. The sharpness trajectories of models trained with a smaller batch size ( vs. ) show similar early training dynamics. In the early transient regime, we observe a qualitatively similar phase diagram. In the intermediate saturation regime, the effect of reducing the batch size is to broaden the transition around .
In Section 2, we noted that for cross-entropy loss, the loss starts to catapult around at large widths, as compared to for MSE loss. Previous work, such as [50], analyzed the catapult dynamics for the model with logistic loss and demonstrated that the loss catapult occurs above . We summarize the main intuition about their analysis in Section B.9. However, a complete understanding of the catapult phenomenon in the context of cross-entropy loss requires a more detailed examination.
The early training dynamics is sensitive to the initialization scheme and optimization algorithm used, and we leave it to future work to explore this dependence and its implications. In this work, we focused on models initialized at criticality [55] as it allows for proper gradient flow through ReLU networks at initialization [23, 58], and studied vanilla SGD for simplicity. However, other initializations [46], parameterizations [69, 70], and optimization procedures [22] may show dissimilarities with the reported phase diagram of early training.
Acknowledgments
We thank Andrey Gromov, Tianyu He, and Shubham Jain for discussions, and Paolo Glorioso, Sho Yaida, Daniel Roberts, and Darshil Doshi for detailed comments on the manuscript. We also express our gratitude to anonymous reviewers for their valuable feedback for improving the manuscript. This work is supported by an NSF CAREER grant (DMR1753240) and the Laboratory for Physical Sciences through the Condensed Matter Theory Center.
Appendix A Experimental details
Datasets:
We considered the MNIST [10], Fashion-MNIST [67], and CIFAR-10 [44] datasets. We standardized the images and used one-hot encoding for the labels.
Models:
We considered fully connected networks (FCNs), Myrtle family CNNs [61] and ResNets (version 1) [29] trained using the JAX [7], and Flax libraries [30]. We use and to denote the depth and width of the network. Below, we provide additional details of the models and clarify what width corresponds to for CNNs and ResNets.
FCNs: We considered ReLU FCNs with constant width in Neural Tangent Parameterization (NTP) / Standard Parameterization (SP), initialized at criticality [55]. The models do not include bias or normalization. The forward pass of the pre-activations from layer to is given by
[TABLE]
where is the ReLU activation and is a constant. For NTP, and the weights are initialized using normal distribution, i.e., . For SP, and the weights are initialized as . For the last layer, we have for NTP and for SP.
For , the dynamics is noisier, and it becomes challenging to separate the underlying deterministic dynamics from random fluctuations (see Appendix E). 2. 2.
CNNs: We considered Myrtle family ReLU CNNs [61] without any bias or normalization in Standard Parameterization (SP), initialized using He initialization [29]. The above model uses a fixed number of channels in each layer, which we refer to as the width of the network. In this case, the forward pass equations for the pre-activations from layer to layer are given by
[TABLE]
where label the spacial location. The weights are initialized as , where is the filter size. 3. 3.
ResNets: We considered version 1 ResNet [29] implementations from Flax examples without Batch Norm or regularization. For ResNets, width corresponds to the number of channels in the first block. For example, the standard ResNet-18 has four blocks with widths , with . We refer to as the width or the widening factor. We considered ResNet-18 and ResNet-34.
All the models are trained with the average loss over the batch , i.e., , where is the loss function. This normalization, along with initialization, ensures that the loss is at initialization.
Bias: Throughout this work, we have primarily focused on models without any bias for simplicity. In Appendix K, we demonstrate that bias does not have an appreciable impact on the results.
Batch size: We use a batch size of 512 and scale the learning rate as in all our experiments, unless specified. Appendix J shows results for a smaller batch size .
Learning rate: We scale the learning rate constant as , with in steps of . Here, is related to the maximum learning rate constant as .
Sharpness measurement: We measure sharpness using the power iteration method with iterations. We found that iterations suffice both for MSE and cross-entropy loss. For MSE loss, we use randomly selected training examples for evaluating sharpness at each step. In comparison, we found that cross-entropy requires a large number of training examples to obtain a good approximation of sharpness. Given the computational constraints, we use training examples to approximate sharpness for cross-entropy loss.
Averages over initialization and SGD runs: All the critical constants depend on both the random initializations and the SGD runs. In our experiments, we found that the fluctuations from initialization at large outweigh the randomness coming from different SGD runs. Thus, we focus on initialization averages in all our experiments.
A.1 Compute usage
We utilized different computational resources depending on the task complexity. For less demanding tasks, we performed computation for a total of hours, utilizing a seventh of an NVIDIA A100 GPU. For more computationally intensive tasks, we utilized a full NVIDIA A100 GPU for a total hours.
A.2 Reproducibility
The main results of this paper can be reproduced using the associated GitHub repository:https://github.com/dayal-kalra/early-training.
A.3 Details of Figures in the main text:
Figure 1: A shallow CNN (, ) in SP trained on the CIFAR-10 dataset with MSE loss for epochs using SGD with learning rates and batch size . We measure sharpness at every step for the first epoch, every epoch between and epochs, and every epochs beyond .
Figure 2: (top panel) A wide (, ) and (bottom panel) a deep CNN () in SP trained on the CIFAR-10 dataset with MSE loss for steps using vanilla SGD with learning rates and batch size .
Figure 3: Phase diagrams of early training of neural networks trained with MSE loss using SGD. Panels (a-c) show phase diagrams with width: (a) FCNs () trained on the MNIST dataset, (b) CNNs () trained on the Fashion-MNIST dataset, (c) ResNet () trained on the CIFAR-10 (without batch normalization). Panels (d-f) show phase diagrams with depth: FCNs trained on the Fashion-MNIST dataset for different widths. Each data point in the figure represents an average of ten distinct initializations, and the solid lines represent a two-degree polynomial fitted to the raw data points. Here, where , and can take on one of three values: and .
Figure 4: (a) FCNs in SP with and trained on the MNIST dataset, (b) CNNs in SP with and trained on the Fashion-MNIST dataset, (c) ResNet in SP with and trained on the CIFAR-10 dataset (without batch normalization).
Figure 6: Phase diagrams of layer FCNs trained on the CIFAR-10 dataset using MSE, demonstrating the effect of output scale at initialization: (a) vanilla network, (b) centered network, and (c) network initialized with the last layer set to zero. The values of widths are the same as in Figure fig. 3.
Figure 5: Normalized sharpness measured at against the learning rate constant for -layer CNNs in SP trained on the CIFAR-10 dataset, with . Each data point is an average over five random initialization. Smoothening details are provided in Appendix I.2.
Figure 7: The phase diagram of the model trained with MSE loss using gradient descent with (a) the top eigenvalue of Hessian , (b) the trace of Hessian and (c) the square of the Frobenius norm used as a measure of sharpness. In (a), the learning rate is scaled as , while in (b) and (c), the learning rate is scaled as . The vertical dashed line shows () for reference. Each data point is an average over random initializations.
Figure 8: Training trajectories of the model with (a, b) large () and (c, d) small () width, trained for training steps on a single example with MSE loss using vanilla gradient descent with learning rates (a, c) and (b, d) .
Appendix B Additional results for the model
B.1 Details of the model
Consider a two-layer linear network in (NTP) with unit input-output dimensions
[TABLE]
where . Here, are trainable parameters, with each element initialized using the normal distribution, for . The model is trained using MSE loss on a single training example , which simplifies the loss to
[TABLE]
The trace of the Hessian has a simple expression in terms of the norms of the weight vectors
[TABLE]
which is equivalent to the NTK for this model. The Frobenius norm of the Hessian can be written in terms of the loss and
[TABLE]
The gradient descent updates of the model trained using MSE loss on a single training example are given by
[TABLE]
The update equations in function space can be written in terms of the trace of the Hessian .
[TABLE]
Figure 9 shows the training trajectories of the model trained on using MSE loss for training steps. The model shows similar dynamics to those presented in Section 2. It is worth mentioning that the above equations have been analyzed in [49] at large width. In the following subsections, we extend their analysis by incorporating the higher-order terms to analyze the effect of finite width.
B.2 The intermediate saturation regime
The model trained on does not show the progressive sharpening and late-time regimes (iii) and (iv) described in Section 1. Hence, we can measure sharpness at the end of training to analyze how it is reduced upon increasing the learning rate and to compare it with the intermediate saturation regime results in Section 3.
Figure 10(a) shows the normalized sharpness measured at steps for various widths. This behavior reproduces the results observed in the intermediate saturation regime in Section 3. In particular, we can see stages (1) and (2), where starts off fairly independent of learning rate constant , and then dramatically reduces when ; stage (3), where plateaus at a small value as a function of is too close to the divergent phase in this model to be clearly observed. The corresponding derivatives of the averaged normalized sharpness, , and , are shown in Figure 10(b, c). The vertical dashed lines denote estimated for each width, using the maximum of . We observe that for all widths.
B.3 Opening of the sharpness reduction phase in the model
This section shows that terms in Equation 10 effectively lead to the opening of the sharpness reduction phase with in the model. In Appendix B.2, we demonstrated that for the model, for all values of widths. Hence, it suffices to show that increases from the value as increases. We do so by finding the smallest such that the averaged loss over initializations increases during early training.
It follows from Equation 10 that the averaged loss increases in the first training step if the following holds
[TABLE]
where denotes the average over initializations. On scaling the learning rate with trace as , we have
[TABLE]
The required two averages have the following expressions as shown in Appendix B.8.
[TABLE]
Inserting the above expressions in Equation 13, on average the loss increases in the very first step if the following inequality holds
[TABLE]
The graphical representation of the above inequality shown in Figure 11(a) is in excellent agreement with the experimental results presented in Figure 11(b).
Let us denote as the minimum learning rate constant such that the average loss increases in the first step. Similarly, let denote the learning rate constant if the loss increases in the first steps. Then, increases from the value as increases as shown in Figure 11(a). By comparison, the trace reduces at any step if . At initialization, this condition becomes . Hence, for , both the loss and trace monotonically decrease in the first training step. These arguments can be extended to later training steps, revealing that the loss and trace will continue to decrease for .
Next, let denote the learning rate corresponding to . Then, we have , implying
[TABLE]
Figure 11(c) shows that for all widths, implying . Hence, increases with as observed in Figure 7(a). In Appendix B.2, we demonstrated that for the model, for all values of widths. Incorporating this with increases with , we have sharpness reduction phase opening up as increases.
B.4 Opening of the loss catapult phase at finite width
In this section, we use the Frobenius norm of the Hessian as a proxy for the sharpness and demonstrate the emergence of the loss-sharpness catapult phase at finite width. In particular, We analyze the expectation value after the first training step near and show that , with the difference increasing with . First, we write in terms of and
[TABLE]
Next, using Equations 1, we write down the change in after the first training step in terms of and
[TABLE]
Next, we substitute to obtain the above equation as a function of
[TABLE]
Finally, we calculate the expectation value of
[TABLE]
by estimating using the approach demonstrated in the previous section
[TABLE]
Inserting in Equation LABEL:eqn:average_fnorm2 along with , we have
[TABLE]
At infinite width, the above equation reduces to , and hence, . For any finite width, for . At , , and therefore . In order for the sharpness to catapult, we require and therefore . As increases also increases, which means a higher value of is required to reach a point where . Thus increases with .
B.5 The early training trajectories
Figure 9 shows the early training trajectories of the model with large () and small () widths. The dynamics depicted show several similarities with early training dynamics of real-world models shown in Figure 2. At small widths, the loss catapults at relatively higher learning rates (specifically, at , which is significantly higher than the critical value of ).
B.6 Relationship between critical constants
Figure 12 shows the relationship between various critical constants for the model. The data show that the inequality holds for every random initialization of the model.
B.7 Phase diagrams with error bars
This section shows the variation in the phase diagram boundaries of the model shown in Figure 7(a, b). Figure 13 shows these phase diagrams. Each data point is an average of over initializations. The horizontal bars around each data point indicate the region between and quantile.
B.8 Derivation of the expectation values
Here, we provide the detailed derivation of the averages and . We begin by finding the average
[TABLE]
where denotes the norm of the vectors.
The above integral is non-zero only if . Hence, it is a sum of identical integrals. Without any loss of generality, we solve this integral for and multiply by to obtain the final result, i.e.,
[TABLE]
Consider a transformation of into dimensional spherical coordinates such that
[TABLE]
which yields,
[TABLE]
where denotes the dimensional solid angle element. Here, we denote the radial and angular integrals by and respectively. The radial integral is
[TABLE]
Let and with and , then we have
[TABLE]
where denotes the Gamma function. The angular integral is
[TABLE]
Plugging in Equations 32 and 35 into Equation 29, we obtain a very simple expression
[TABLE]
The other integral can be obtained by generalizing the above approach as described below
[TABLE]
The integral is zero if either and or , which we consider separately. Without loss of generality, we find the following integrals
[TABLE]
which have the following expressions
[TABLE]
where denotes the gamma function. On combining the expressions with their multiplicities, we obtain the final result
[TABLE]
B.9 Insights into the catapult effect in crossentropy loss using model
In this section, we summarize the main intuition behind the discrepancy in the values of for cross-entropy loss at large widths. We consider the model trained on a classification task using logistic loss, as presented in [50].
Consider the model trained on a binary classification task using the logistic loss on two training examples and . Then, the total loss is . Hence, the loss grows monotonically as the output function increases. The update equation of the function is given by:
[TABLE]
where is the learning rate and is the derivative of the loss. At large width, if the condition holds, then output function continues to decrease. Given that in the above case, this decrease persists for . This result provides some intuition behind the discrepancy.
Appendix C Phase diagrams of early training
This section describes experimental details and shows additional phase diagrams of early training. The results include (1) FCNs in NTP trained on MNIST, Fashion-MNIST, and CIFAR-10 datasets, (2) CNNs in SP trained on Fashion-MNIST and CIFAR-10, and (3) ResNets in SP trained on CIFAR-10 datasets using MSE loss. Figures 14 to 19 show these results. The depths and widths are the same as specified in Appendix A. Each data point is an average over initializations. The horizontal bars around the average data point indicate the region between and quantile. Phase diagrams for cross-entropy results are shown in Appendix F.
Additional experimental details
: We train each model for steps using SGD with learning rates and batch size of , where with in steps of . Here, is relatd to the maximum trainable learning rate constant as . We have considered random initializations for each model. As mentioned in Appendix A, we do not consider averages over SGD runs as the randomness from initialization outweighs it. Hence, we obtain values for each of the critical values in the following results. For each initialization, we compute the critical constants using Definitions 1 and 3. To avoid a random increase in loss and sharpness due to fluctuations, we round off the values of and to their second decimal places before comparing with 1. We denote the average values using data points and variation using horizontal bars around the average data points, which indicate the region between and quantile. The smooth curves are obtained by fitting a two-degree polynomial with and can take on one of three values: and .
Phase diagrams with depth
Figure 20: shows the phase diagrams with depth for FCNs in NTP trained on the CIFAR-10 dataset. The phase diagrams look qualitatively similar compared to the phase diagrams.
Appendix D Relationship between various critical constants
Figure 21 illustrates the relationship between the early training critical constants for models and datasets. The experimental setup is the same as in Appendix C. Typically, we find that holds true. However, there are some exceptions, which are observed at high values of (see 21 (d, e)), where the trends of the critical constants converge, and large fluctuations can cause deviations from the inequality.
Appendix E The effect of on the noise in dynamics
In this section, we demonstrate that for FCNs with , the dynamics becomes noise-dominated. This aspect makes it challenging to distringuish the underlying deterministic dynamics from random fluctuations. To demonstrate this, we consider FCNs trained on CIFAR-10 using MSE and cross-entropy loss and use training examples for estimating sharpness.
Figures 22 and 23 show the training loss and sharpness of FCNs with and varying widths, trained on CIFAR-10 using MSE loss. We observe that the sharpness dynamics becomes noisier for .
Figures 24 and 25 shows the training dynamics with loss switched to cross-entropy, while keeping the initialization and SGD batch sequence the same as in the MSE loss case. In comparison to MSE loss, the training loss and sharpness dynamics show a higher level of noise, especially for . As a result, it becomes difficult to characterize the training dynamics for .
Appendix F Crossentropy
In this section, we provide additional results for models trained with cross-entropy (xent) loss and compare them with MSE results. Broadly speaking, models trained with cross-entropy loss show similar characterstics to those trained with MSE loss, such as, (i) sharpness reduction during early training, (ii) an increase in critical constants , with and , (iii) . However, the dynamics of models trained with cross-entropy loss is noisier compared to MSE as shown in the previous section, and characterizing these dynamics can be more complex. In the following experiments, we consider models trained on the CIFAR-10 dataset and used training examples to estimate sharpness.
F.1 Phase diagrams
Figure 26 compares the phase diagrams of FCNs in SP trained on the CIFAR-10 dataset, using both MSE and cross-entropy loss. The estimated critical constants for cross-entropy loss are generally more noisy, as quantified by the confidence intervals. In comparison to phase diagrams of models trained with MSE loss, we observe a few notable differences. First, the loss starts to catapult at a value appreciably larger than at large widths. Primarly, . Additionally, generally decreases with . This decreasing trend becomes less sharp at large depths.
Despite these differences, the phase diagrams for both loss functions share various similarities. First, we observe sharpness reduces during early training for (see the first row of Figure 25). Next, we observe that the inequality generally holds for both loss functions as demonstrated in Figure 27, barring some exceptions.
Figure 28 shows the phase diagrams for CNNs and ResNets trained on the CIFAR-10 dataset using cross-entropy loss. The observed critical constants are much noisier as quantified by the confidence intervals. Nevertheless, the phase diagram shows similar trends as mentioned above. For large models, we found that progressive sharpening begins after training steps. For these cases, we only use the first steps to measure sharpness to avoid progressive sharpening. For CNNs, we observed that the dynamics becomes difficult to characterize for and , due to large fluctuations. Consequently, we’ve opted not to include these particular results.
F.2 Intemediate saturation regime
Figure 29 shows the normalized sharpness measured at for FCNs trained on CIFAR-10 using cross-entropy loss. 333The time step is in the middle of the intermediate saturation regime for most of the models. For further details on estimating sharpness, see Section I.1.
Similar to MSE loss, we observe an abrupt drop in sharpness at large learning rates. However, this abrupt drop occurs at . The estimated sharpness is noisier (compare with Figure 38), which hinders a reliable estimation of . We speculate that we require a large number of averages for a reliable estimation of for cross-entropy loss. We leave the precise characterization of for cross-entropy loss for future work.
Appendix G The effect of setting model output to zero at initialization
In this section, we demonstrate the effect of network output at initialization on the early training dynamics. In particular, we set the network output to zero at initialization, , by (1) ‘centering’ the network by its initial value or (2) setting the last layer weights to zero at initialization. We show that both (1) and (2) remove the opening of the sharpness reduction phase with . Resultantly, the average onset of loss catapult occurs at , independent of depth and width.
Throughout this section, we use ‘vanilla’ networks to refer to networks initialized in the standard way. For simplicity, we train FCNs using full batch gradient descent with MSE loss using a subset consisting of examples of the CIFAR-10 dataset.
G.1 The effect of centering networks
Given a network function , we define the centered network as
[TABLE]
where is the network output at intialization. By construction, the network output is zero at initialization. It is noteworthy that centering a network is an unusual way of training deep networks as it doubles the cost of training because of two forward passes.
Figure 30 compares the training loss and sharpness dynamics of vanilla networks and centered networks. Unlike vanilla networks, we do not observe a decrease in sharpness for during early training. Rather, we observe a slight increase in sharpness. To distinguish this slight increase from sharpness catapult, we introduce a threshold , comparing normalized sharpness with , to define a sharpness catapult.444In experiments, we set . We use the same threshold for zero-init networks. As demonstrated in Section G.3, the model trained on a single training example with sheds lights on this initial increase in sharpness.
Interestingly, irrespective of depth and width, we observe that loss catapults at , as demonstrated in the phase diagrams in Figure 31(a, b, c). These findings suggest a strong correlation between a large network output at initialization and the opening of the sharpness reduction phase discussed in Section 2.
G.2 The effect of setting the last layer to zero
An alternative way to train networks with is by setting the last layer to zero at initialization. The principle of criticality at initialization [55, 58, 68] does not put any constraints on the last layer weights. Hence, setting the last layer to zero does not affect signal/gradient propagation at initialization. Yet, setting the last layer to zero results in initialization in a flat curvature region at initialization, resulting in access to larger learning rates. We refer to these networks as ‘zero-init’ networks.
Figure 30 compares the training dynamics of zero-init networks with vanilla and centered networks. We observe that the dynamics is quite similar to the centered networks: (i) sharpness does not reduce for small learning rates and (ii) loss catapults , irrespective of depth and width. Figure 31(d, e, f) show the phase diagrams of networks with zero-initialized networks. Like centered networks, the critical constants do not scale with depth and width. Again, suggesting that a large network output at initialization is related to the opening of the sharpness reduction phase in the early training results shown in Section 2.
G.3 Insights from model trained on
In this section, we gain insights into the effect of setting network output to zero at initialization using model trained on an example . In particular, we show that loss catapults at and sharpness increases during early training.
Consider the model trained on a single training example with 555Note that for , the network is already at a minimum for .
[TABLE]
This simplifies the loss function to
[TABLE]
where is the residual. The trace of the Hessian is
[TABLE]
The Frobeinus norm can be written in terms of the trace and the network output
[TABLE]
The function and residual updates are given by
[TABLE]
Similarly, we can obtain the trace update equations
[TABLE]
Let us analyze them for the networks with zero output at initialization. The loss at the first step increases if
[TABLE]
Setting and scaling the learning rate as , we see that the loss increases at the first step if .
[TABLE]
Next, we analyze the change in trace during the first training step. Setting , we observe that the trace increases for all learning rates
[TABLE]
modulated by the learning rate and width. Finally, we analyze the change in Frobenius norm in the first training step at , which implies ,
[TABLE]
As increases in the first training step, also increases in the first training step.
Appendix H The effect of output scale on the training dynamics
Given a neural network function with depth and width , we define the scaled network as , where is referred to as the output scale. In this section, we empirically study the impact of the output scale on the early training dynamics. In particular, we show that a large (resp. small) value of relative to the one-hot encodings of the labels causes the sharpness to decrease (resp. increase) during early training. Interestingly, we still observe an increase in with and , unlike the case of initializing network output to zero, highlighting the unique impact of output scale on the dynamics. For simplicity, we train FCNs using gradient descent with MSE loss using a subset consisting of examples of the CIFAR-10 dataset, as in the previous section.
H.1 The effect of fixed output scale at initialization
In this section, we study the training dynamics of models trained with a fixed output scale at initialization. Given a network output function , we define the ‘scaled network’ as
[TABLE]
where is a scalar, fixed throughout training. By construction, the network output norm equals . For standard initialization, , where are the number of classes.
Figure 32 shows the training dynamics of FCNs for three different values of the output scale . The training dynamics of networks with and share qualitative similarities. In contrast, networks initialized with a smaller output scale () exhibit distinctly different dynamics. In particular, we observe that for large output scales () sharpness decreases during early training, while sharpness increases for small output scales 666We empirically observed that sharpness reduces for output scales as small as , which is relatively small compared to .. Furthermore, the training dynamics tends to be noisier at small output scales, making it difficult to characterize catapult dynamics amidst these fluctuations. In summary, the training dynamics of networks with small output scale deviate from the training dynamics discussed in the main text, particularly as the sharpness quickly increases during early training.
Figure 33 shows the trends of various critical constants with width for FCNs for three different values of . Similar to vanilla networks, we observe that increases with and . In comparison, sharpness decreases (increases) for large (small) values of . These experiments suggest that the output scale primarly influences the increase/decrease in sharpness during early training and does not affect the scaling of with depth and width.
Note that we do not generate phase diagrams for these experiments as the training dynamics of networks with small output scales at initialization deviate from the training dynamics disucssed in the main text.
H.2 Scaling the output scale with width
In this section, we study the training dynamics of models with an output scale scaled with width as , which is commonly used in the literature [19, 6, 4]. We consider three distinct values , where represents the lazy regime, corresponds to feature learning (rich) regime and correponds to standard (vanilla) initialization.
Figure 34 shows the training loss and sharpness trajectories of FCNs trained on for different values. We observe that the training trajectories in the lazy regime look identical to standard initialization. In comparison, the training trajectories in the feature learning regime is distinctly different. We observe that in the standard and lazy regimes, sharpness decreases during early training, whereas sharpness tends to increase in the feature learning regime and eventually oscillates around the edge of stability regime. Moreover, we observe that sharpness can catapult before the training loss in the feature learning regime (compare catapult peaks in 34(e, f)). These results are in parallel to the fixed output scale networks studied in the pervious section.
Figure 35 summarizes the early training dynamics of FCNs with different values. We observe similar results as in the previous section. The output scale affects the initial increase/decrease of sharpness but does not affect the scaling trend of with depth and width. Moreover, we observe a systematic pattern of scaling with width. In the lazy regime, we observe that increases with , while decreases with in the feature learning regime.
Appendix I Sharpness curves in the intermediate saturation regime
This section shows additional results for Section 3 for MSE loss. Cross-entropy results are shown in Appendix F. Figures 36 to 40 show the normalized sharpness curves for different depths and widths.
I.1 Estimating the sharpness
This paragraph describes the procedure for measuring the sharpness to study the effect of the learning rate, depth, and width in the intermediate saturation regime. We measure the sharpness at a time in the middle of the intermediate saturation regime. We choose so that , for learning rates , where in steps of . The value is chosen such that is in the middle of the intermediate saturation regime. Next, we measure sharpness over a range of steps and average over to reduce fluctuations. We repeat this process for various initializations and obtain the average sharpness.
I.2 Estimating the critical constant
This subsection explains how to estimate from sharpness measured at time . First, we normalize the sharpness with its initial value, and then average over random initializations. Next, we estimate the critical point using the second derivative of the order parameter curve. Even if the obtained averaged normalized sharpness curve is somewhat smooth, the second derivative may become extremely noisy as minor fluctuations amplify on taking derivatives. This can cause difficulties in obtaining . We resolve this issue by estimating the smooth derivatives of the averaged order parameter with the Savitzky–Golay filter [59] using its scipy implementation [63]. The estimated is shown by vertical lines in the sharpness curves in Figures 36 to 40.
Appendix J The effect of batch size on the reported results
J.1 The early transient regime
Figure 41 shows the phase diagrams of early training dynamics of FCNs with trained on the CIFAR-10 dataset using two different batch sizes. The phase diagram obtained is consistent with the findings presented in Section 2, except for one key difference. Specifically, we observe that when is small and small batch sizes are used for training, sharpness may increase from initialization at relatively smaller values of . This is reflected in Fig. 41 by moving to the left as is reduced from to . However, this initial increase in sharpness is small compared to the sharpness catapult observed at larger batch sizes. We found that this increase at small batch sizes is due to fluctuations in gradient estimation that can cause sharpness to increase above its initial value by chance.
J.2 The intermediate saturation regime
Figure 42 shows the normalized sharpness, measured at , and its derivatives for various widths and batch sizes. The results are consistent with those in Section 3, with a lowering in the peak heights of the derivatives and at small batch sizes. The lowering of the peak heights means the full width at half maximum increases, which implies a broadening of the transition around at smaller batch sizes.
Appendix K The effect of bias on the reported results
In this section, we show that FCNs with bias show similar results as presented in the main text. We considered FCNs in SP initialized with He initialization [29].
Figure 43 shows the phase diagrams of early training for FCNs with bias trained on the CIFAR-10 dataset. We observe a similar phase diagram compared to the no-bias case (compare with Figure 26).
The reference list from the paper itself. Each links out to its DOI / PubMed record.
- 1Achille et al. [2019] Alessandro Achille, Matteo Rovere, and Stefano Soatto. Critical learning periods in deep networks. In International Conference on Learning Representations , 2019. URL https://openreview.net/forum?id=Bke Sts Cc KQ .
- 2Agarwala et al. [2022] Atish Agarwala, Fabian Pedregosa, and Jeffrey Pennington. Second-order regression models exhibit progressive sharpening to the edge of stability. Ar Xiv , abs/2210.04860, 2022.
- 3Arora et al. [2022] Sanjeev Arora, Zhiyuan Li, and Abhishek Panigrahi. Understanding gradient descent on edge of stability in deep learning. In International Conference on Machine Learning , 2022.
- 4Atanasov et al. [2023] Alexander Atanasov, Blake Bordelon, Sabarish Sainathan, and Cengiz Pehlevan. The onset of variance-limited behavior for networks in the lazy and rich regimes. In The Eleventh International Conference on Learning Representations , 2023. URL https://openreview.net/forum?id=JLI Nx POV Th 7 .
- 5Bordelon and Pehlevan [2022] Blake Bordelon and Cengiz Pehlevan. Self-consistent dynamical field theory of kernel evolution in wide neural networks. In Alice H. Oh, Alekh Agarwal, Danielle Belgrave, and Kyunghyun Cho, editors, Advances in Neural Information Processing Systems , 2022. URL https://openreview.net/forum?id=sipwr P Cr IS .
- 6Bordelon and Pehlevan [2023] Blake Bordelon and Cengiz Pehlevan. Dynamics of finite width kernel and prediction fluctuations in mean field neural networks. Ar Xiv , abs/2304.03408, 2023.
- 7Bradbury et al. [2018] James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, Adam Paszke, Jake Vander Plas, Skye Wanderman-Milne, and Qiao Zhang. JAX: composable transformations of Python+Num Py programs, 2018. URL http://github.com/google/jax .
- 8Cohen et al. [2021] Jeremy Cohen, Simran Kaur, Yuanzhi Li, J Zico Kolter, and Ameet Talwalkar. Gradient descent on neural networks typically occurs at the edge of stability. In International Conference on Learning Representations , 2021. URL https://openreview.net/forum?id=jh-r Ttvk Ge M .
