Wide neural networks of any depth evolve as linear models under gradient descent

  title={Wide neural networks of any depth evolve as linear models under gradient descent},
  author={Jaehoon Lee and Lechao Xiao and Samuel S. Schoenholz and Yasaman Bahri and Roman Novak and Jascha Narain Sohl-Dickstein and Jeffrey S. Pennington},
  journal={Journal of Statistical Mechanics: Theory and Experiment},
A longstanding goal in deep learning research has been to precisely characterize training and generalization. [] Key Result While these theoretical results are only exact in the infinite width limit, we nevertheless find excellent empirical agreement between the predictions of the original network and those of the linearized version even for finite practically-sized networks. This agreement is robust across different architectures, optimization methods, and loss functions.

Figures from this paper

Benefits of Jointly Training Autoencoders: An Improved Neural Tangent Kernel Analysis

This paper rigorously proves the linear convergence of gradient descent in two weakly-trained and jointly-trained regimes and indicates the considerable benefits of joint training over weak training in finding global optima, achieving a dramatic decrease in the required level of over-parameterization.

Learning Curves for Deep Neural Networks: A field theory perspective

A renormalization-group approach is used to show that noiseless GP inference using NTK, which lacks a good analytical handle, can be well approximated by noisy GP inference on a related kernel the authors call the renormalized NTK.

Deep learning versus kernel learning: an empirical study of loss landscape geometry and the time evolution of the Neural Tangent Kernel

A large-scale phenomenological analysis of training reveals a striking correlation between a diverse set of metrics over training time, governed by a rapid chaotic to stable transition in the first few epochs, that together poses challenges and opportunities for the development of more accurate theories of deep learning.

Disentangling trainability and generalization in deep learning

This paper discusses challenging issues in the context of wide neural networks at large depths and finds that there are large regions of hyperparameter space where networks can only memorize the training set in the sense they reach perfect training accuracy but completely fail to generalize outside the trainingSet.

Exploring the Uncertainty Properties of Neural Networks' Implicit Priors in the Infinite-Width Limit

This work uses the NNGP with a softmax link function to build a probabilistic model for multi-class classification and marginalize over the latent Gaussian outputs to sample from the posterior, leveraging recent theoretical advances that characterize the function-space prior of an ensemble of infinitely-wide NNs as a Gaussian process.

On the Optimization Dynamics of Wide Hypernetworks

This work partially solves an open problem and shows that the convergence rate of the r order term of the Taylor expansion of the cost function, along the optimization trajectories of SGD is n, improving upon the bound suggested by the conjecture of Dyer & Gur-Ari, while matching their empirical observations.

An analytic theory of shallow networks dynamics for hinge loss classification

This paper study in detail the training dynamics of a simple type of neural network: a single hidden layer trained to perform a classification task, and shows that in a suitable mean-field limit this case maps to a single-node learning problem with a time-dependent dataset determined self-consistently from the average nodes population.

Asymptotics of Wide Convolutional Neural Networks

It is found that the difference in performance between finite and infinite width models vanishes at a definite rate with respect to model width, consistent with finite width models generalizing either better or worse than their infinite width counterparts.

Learning Curves for Deep Neural Networks: A Gaussian Field Theory Perspective

This work constructs a versatile field-theory formalism for supervised deep learning, involving renormalization group, Feynmann diagrams, and replicas, and shows that this approach leads to highly accurate predictions of learning curves of truly deep DNNs trained on polynomial regression problems.

What can linearized neural networks actually say about generalization?

It is shown that the linear approximations can indeed rank the learning complexity of certain tasks for neural networks, even when they achieve very different performances, and that networks overfit to these tasks mostly due to the evolution of their kernel during training, thus, revealing a new type of implicit bias.



A Convergence Theory for Deep Learning via Over-Parameterization

This work proves why stochastic gradient descent can find global minima on the training objective of DNNs in $\textit{polynomial time}$ and implies an equivalence between over-parameterized neural networks and neural tangent kernel (NTK) in the finite (and polynomial) width setting.

Gradient descent optimizes over-parameterized deep ReLU networks

The key idea of the proof is that Gaussian random initialization followed by gradient descent produces a sequence of iterates that stay inside a small perturbation region centered at the initial weights, in which the training loss function of the deep ReLU networks enjoys nice local curvature properties that ensure the global convergence of gradient descent.

On Lazy Training in Differentiable Programming

This work shows that this "lazy training" phenomenon is not specific to over-parameterized neural networks, and is due to a choice of scaling that makes the model behave as its linearization around the initialization, thus yielding a model equivalent to learning with positive-definite kernels.

The Effect of Network Width on Stochastic Gradient Descent and Generalization: an Empirical Study

It is found that the optimal SGD hyper-parameters are determined by a "normalized noise scale," which is a function of the batch size, learning rate, and initialization conditions, and in the absence of batch normalization, the optimal normalized noise scale is directly proportional to width.

Scaling Limits of Wide Neural Networks with Weight Sharing: Gaussian Process Behavior, Gradient Independence, and Neural Tangent Kernel Derivation

This work opens a way toward design of even stronger Gaussian Processes, initialization schemes to avoid gradient explosion/vanishing, and deeper understanding of SGD dynamics in modern architectures.

Deep Neural Networks as Gaussian Processes

The exact equivalence between infinitely wide deep networks and GPs is derived and it is found that test performance increases as finite-width trained networks are made wider and more similar to a GP, and thus that GP predictions typically outperform those of finite- width networks.

Sensitivity and Generalization in Neural Networks: an Empirical Study

It is found that trained neural networks are more robust to input perturbations in the vicinity of the training data manifold, as measured by the norm of the input-output Jacobian of the network, and that it correlates well with generalization.

A Mean Field Theory of Batch Normalization

The theory shows that gradient signals grow exponentially in depth and that these exploding gradients cannot be eliminated by tuning the initial weight variances or by adjusting the nonlinear activation function, so vanilla batch-normalized networks without skip connections are not trainable at large depths for common initialization schemes.

A mean field view of the landscape of two-layer neural networks

A compact description of the SGD dynamics is derived in terms of a limiting partial differential equation that allows for “averaging out” some of the complexities of the landscape of neural networks and can be used to prove a general convergence result for noisy SGD.

On the Convergence Rate of Training Recurrent Neural Networks

It is shown when the number of neurons is sufficiently large, meaning polynomial in the training data size and in thelinear convergence rate, then SGD is capable of minimizing the regression loss in the linear convergence rate and gives theoretical evidence of how RNNs can memorize data.