Natural Gradient Decent
Dec 29, 2022 · 11 min · machine learning , math
Natural gradient decent minimizes the loss function in distribution space with KL-divergence as metric, instead of the usual parameter space with Euclidean metric.
Preliminary
Gradient Decent
Assume that is a loss function with first-order continuous partial derivatives, where . Now want to solve a unconstrained optimization problem:
Gradient descent is an iterative optimization algorithm. It works by starting at a random point on the objective function and iteratively updating . This process is repeated until the minimum is found or the algorithm converges. Formally, in th iteration, we have:
where is step size and is update direction. The first order Taylor series expansion of at is:
We want to find a steepest descent direction around the local neighbourhood of in the parameter space, that is to minimize:
where is the angle between and . Obviously, when , i.e. when and are in the opposite directions, and the loss decrease most greatly. This is why gradient decent moves in the direction of the negative gradient:
Now we can write Eq. 1 as:
where is learning rate.
KL Divergence
The Kullback-Leibler divergence, also known as the relative entropy, is a measure of the difference between current probability distribution and target distribution , which is defined as:
The KL divergence is zero if and only if the two distributions are equal. Note that KL divergence is non-symmetric, meaning that the KL divergence between and is not necessarily equal to that between and .
Fisher Information Matrix
I suggest having a look at this great article. To be short, the Fisher Information Matrix is:
-
the second moment of the first-order derivative of the log-likelihood function:
-
the negative expectation of the Hessian matrix of the log-likelihood function (proof):
KL Divergence & Fisher Information
Let , we have:
Proof
For convenience, we denote and as and , respectively.
Which means the Fisher Information Matrix defines the local curvature in distribution space for which KL-divergence is the metric.
Riemannian manifold
Manifold
A manifold is a topological space sharing the local properties of Euclidean spaces. Every point on a manifold has a small neighborhood around it that can be locally approximated by a tangent plane, which means that the curvature of this neighborhood is approximately zero. This allows us to use the Euclidean metric, which is based on the properties of flat space, to measure distance within this small neighborhood. To help your understanding, people on earth may experience their surroundings as being flat, but they are unable to perceive the curvature of the earth due to its large size.
More formally: A manifold of dimension is a topological space, such that every point has a neighbourhood which is homeomorphic to an open set in Euclidean space . This open set in Euclidean space is called tangent space, refered as .
Riemannian Metric
The distance between two points in an Euclidean space can be easily determined by taking the modulusof the vector connecting the points. However, manifolds are not linear spaces, so we need to use alternative methods to calculate lengths on a manifold. One possible approach is to consider a continuous, differentiable curve, represented by , on the manifold and compute the integral of the distance differential at each point along curve :
Thus for each point , we have to define a notion of distance in the tangent space at that point. We then use the notion to calculate the modulus of the tangent vector at the point and add up all of to obtain the total length of the curve.
To calculate :
The square of the modulus of a vector, also known as the Euclidean norm or the norm, is equal to the inner product of the vector with itself. This means that defines a metric, which is a way of measuring distance, over a tangent space. A metric that varies smoothly with respect to a point on a manifold is known as a Riemannian metric, and a manifold equipped with a Riemannian metric is called a Riemannian manifold.
As explained here (page 4):
Roughly, a Riemannian manifold is a smooth set with a smoothly-varying inner product on the tangent spaces.
Formally, for each ,a Riemannian metric satisfies:
- for all
- for all
- if and only if
Natural Gradient Decent
Constrained Optimization
Let’s start with looking at natural gradient decent from aspect of constrained optimization. In traditional gradient decent, the constrained optimization problem we want to solve is:
where the distance in parameter space is contrained and can be calculated using Euclidean metric.
As mentioned before, the steepest descent direction is the direction of the negative gradient:
When using gradient descent, the distribution of the parameters may change as the optimization process progresses. However, it is also important to ensure that the amount of change in the distribution is controlled, as large changes can lead to instability in the model. The Euclidean distance, which is based on the properties of flat space, may not be an appropriate measure of the amount of change in the distribution. Therefore, natural gradient decent works in distribution space and uses KL divergence to compare the current distribution of parameters with the target distribution . Now the constraint becomes:
We apply the Lagrange multiplier method to it:
To solve this minimization, we set its derivative to zero:
Finally, the optimal descent direction is (constant factor can be absorbed into the learning rate), called natural gradient.
Riemannian Manifold
Now we consider optimizing the objective function on a manifold. We can easily recognize that the descent direction depends on how to calculate , i.e., what’s the metric on the manifold.
The tranditional gradient decent works on parameter space (a Euclidean space), of which the metric is Euclidean metric. The natural gradient descent operates in distribution space (a Riemannian manifold), in which a point, represented by , can be thought of as a parameterized probability distribution. When the parameter changes to , the distance between and is . It can be seen that the Fisher information matrix serves as the Riemannian metric on this Riemannian manifold.
The Fisher information matrix reflects the local curvature of the likelihood probability distribution space, which means that it encodes information about how much the probability distribution changes as the parameters vary. By using the Fisher information matrix as a measure, we can obtain the descent direction that takes into account the curvature of the probability distribution space. This is important because the larger the curvature, the smaller the range of parameter values that can be used to maintain a given likelihood. Natural gradient descent ensures that the optimization process respects the underlying geometry of the probability distribution space and avoids making large, unstable changes to the parameters.
References
- Natural Gradient Descent. Agustinus Kristiadi.
- Differential Geometry. Ben Andrews. Australian National University.
- Differential Geometry. Nigel Hitchin. Oxford University.
- Riemannian Metrics. Claudio Gorodski. University of São Paulo.
- Geometric Deep Learning: Going beyond Euclidean Data. IEEE Signal Processing Magazine 2017. Michael M. Bronstein, et al.