05 Jul 2022
In the previous post we showed that cross-entropy (spoiler alert if you haven’t read it) is convex. However, training a model using cross-entropy loss doesn’t have to be a convex optimization problem.
In case of cross-entropy, the total loss $L$ of a model is
\[L = \sum_{i=1}^{n}L(y_i, \hat{\boldsymbol{y}}_i) = \sum_{i=1}^{n}L(y_i, f(\boldsymbol{x}_i, \theta))\]where $\hat{\boldsymbol{y}}_i$ is the model prediction and $y_i$ is the true label of the $i$-th sample respectively. The model $\hat{\boldsymbol{y}}_i = f(\boldsymbol{x}_i, \theta)$ computes predictions based on its parameters $\theta$ and sample features $\boldsymbol{x}_i$. Training the model means finding a set of parameters $\theta$ in the parameter space $\Theta$ that minimize the total loss
\[\arg\min_{\theta \in \Theta}L(\theta)\]Minimizing $L(\theta)$ with respect to $\theta$ involves the function $f$ which doesn’t have to be convex.
We can test this in practice by defining a model and evaluating its Hessian matrix at some point $\theta$. If the matrix is not positive semi-definite, then $L(\theta)$ is non-convex.
We will use JAX to do the difficult part of computing the Hessian. JAX offers jacrev
that produces
the Jacobian of a given function using reverse-mode auto differentiation (backpropagation). Making
Hessians with jacrev
is easy:
Let’s start with a simple example to see if it actually works. We want to get the Hessian of a function $g(x_1, x_2) = 2x_1^3 + 4x_2^2$ which is
\[\mathrm{H}_g = \begin{pmatrix} \frac{\partial^2 g}{\partial x_1^2} & \frac{\partial^2 g}{\partial x_1 \partial x_2 } \\ \frac{\partial^2 g}{\partial x_1 \partial x_2} & \frac{\partial^2 g}{\partial x_2^2} \\ \end{pmatrix} = \begin{pmatrix} 12x_1 & 0 \\ 0 & 8 \\ \end{pmatrix}\]If we create the same function g
in Python and evaluate its
Hessian at x = [2, 4]
,
the result is [[24, 0], [0, 8]]
. Exactly as expected.
We now apply the same hessian
function to a real model to investigate its convexity.
The model in question is a simple neural network classifier consiting of two dense layers.
It is implemented in the forward_pass
function in the following snippet.
The first layer is using gaussian error linear unit (GELU) as its activation function which is defined as
\[\mathrm{GELU}(x) = x\Phi(x) = \frac{x}{\sqrt{2\pi}}\int_{-\infty}^{x}e^{\frac{-t^2}{2}}\mathrm{d}t\] \[\mathrm{GELU}(x) \approx \frac{x}{2}\left(1 + \tanh\left[\sqrt{\frac{2}{\pi}}\left(x + 0.044715x^3\right)\right]\right)\]where $\Phi(x)$ is the cumulative distribution function of the standard normal distribution $\mathcal{N}(0, 1)$. GELU is somewhat similar to rectified linear unit $\mathrm{ReLU}(x) = \max(x, 0)$ but it is actually non-convex.
The model output $\boldsymbol{z}$ is passed to softmax function
\[\sigma(\boldsymbol{z})_j = \frac{e^{z_j}}{\sum_{k=1}^{m}e^{z_k}}\]producing a probability vector \(\hat{\boldsymbol{y}}_i\) (multinoulli distribution) where \(\hat{y}_{ij} = \sigma(\boldsymbol{z}_i)_j\) is the predicted probability that the $i$-th sample belongs to label $j$.
Now it’s time to link the model with its loss function. In the last post we saw that the cross-entropy loss is just negative sum of true label log-probabilities \(L = -\sum_{i=1}^{n}\log s_{i}\) where \(s_i \in \hat{\boldsymbol{y}}_i\) is the predicted probability of assigning the true label to the $i$-th sample.
In the snippet, the array y
contains the true labels (one-hot encoded) of each sample
and y_hat
is the prediction. Example of 3 samples and 2 labels is
Next we create a wrapper around cross_entropy_loss
so that the model params
can be
passed as a flat array. This will make our life easier when working with the Hessian
matrices computed by JAX.
Now everything is almost ready, we just need to generate random data X
, y
and initialize the model parameters.
Computing and evaluating the Hessian is just few lines of code.
The output H
is a 82x82 matrix because our model has 82 parameters and we are passing
them in flat shape. If we look at the eigenvalues of the matrix jnp.linalg.eig(H)[0][:4]
,
we can see that they are complex64
and their real component is often negative. This shows
that the Hessian at flat_params
is not positive semi-definitve and the loss is non-convex.
If this has not convinced you, we can try a more intuitive approach. Non-convex functions can have multiple local minima. Can we find them in case of our $L(\theta)$? A neat trick mentioned in this lecture on model optimization is to define a plane by setting up 3 points $\theta_1, \theta_2, \theta_3$ in the parameter space and plotting $L(\theta_1 + \alpha\theta_2 + \beta\theta_3)$ for any $\alpha, \beta$ values in a selected range.
After playing with random seeds for a while, I found this region with more than one local minimum clearly illustrating non-convexity of $L(\theta)$.
Notebook containing all the code is available on my GitHub.