05 Jul 2022
In the previous post we showed that cross-entropy (spoiler alert if you haven’t read it) is convex. However, training a model using cross-entropy loss doesn’t have to be a convex optimization problem.
In case of cross-entropy, the total loss $L$ of a model is
\[L = \sum_{i=1}^{n}L(y_i, \hat{\boldsymbol{y}}_i) = \sum_{i=1}^{n}L(y_i, f(\boldsymbol{x}_i, \theta))\]where $\hat{\boldsymbol{y}}_i$ is the model prediction and $y_i$ is the true label of the $i$-th sample respectively. The model $\hat{\boldsymbol{y}}_i = f(\boldsymbol{x}_i, \theta)$ computes predictions based on its parameters $\theta$ and sample features $\boldsymbol{x}_i$. Training the model means finding a set of parameters $\theta$ in the parameter space $\Theta$ that minimize the total loss
\[\arg\min_{\theta \in \Theta}L(\theta)\]Minimizing $L(\theta)$ with respect to $\theta$ involves the function $f$ which doesn’t have to be convex.
We can test this in practice by defining a model and evaluating its Hessian matrix at some point $\theta$. If the matrix is not positive semi-definite, then $L(\theta)$ is non-convex.
We will use JAX to do the difficult part of computing the Hessian. JAX offers jacrev
that produces
the Jacobian of a given function using reverse-mode auto differentiation (backpropagation). Making
Hessians with jacrev
is easy:
from jax import jacrev
def hessian(f):
return jacrev(jacrev(f))
Let’s start with a simple example to see if it actually works. We want to get the Hessian of a function $g(x_1, x_2) = 2x_1^3 + 4x_2^2$ which is
\[\mathrm{H}_g = \begin{pmatrix} \frac{\partial^2 g}{\partial x_1^2} & \frac{\partial^2 g}{\partial x_1 \partial x_2 } \\ \frac{\partial^2 g}{\partial x_1 \partial x_2} & \frac{\partial^2 g}{\partial x_2^2} \\ \end{pmatrix} = \begin{pmatrix} 12x_1 & 0 \\ 0 & 8 \\ \end{pmatrix}\]If we create the same function g
in Python and evaluate its
Hessian at x = [2, 4]
,
import jax.numpy as jnp
def g(x):
return 2*x[0]**3 + 4*x[1]**2
x = jnp.array([2., 4.])
hessian(g)(x)
the result is [[24, 0], [0, 8]]
. Exactly as expected.
We now apply the same hessian
function to a real model to investigate its convexity.
The model in question is a simple neural network classifier consiting of two dense layers.
It is implemented in the forward_pass
function in the following snippet.
def gelu(x):
a = jnp.sqrt(2 / jnp.pi)
b = (x + 0.044715 * x ** 3)
return 0.5 * x * (1 + jnp.tanh(a * b))
def softmax(x):
exp_x = jnp.exp(x)
return exp_x / jnp.sum(exp_x)
def forward_pass(params, X):
W_0, b_0 = params["W_0"], params["b_0"]
activations = gelu(jnp.dot(X, W_0) + b_0) # first layer
W_1, b_1 = params["W_1"], params["b_1"]
output = jnp.dot(activations, W_1) + b_1 # second layer
return jnp.apply_along_axis(softmax, 1, output)
The first layer is using gaussian error linear unit (GELU) as its activation function which is defined as
\[\mathrm{GELU}(x) = x\Phi(x) = \frac{x}{\sqrt{2\pi}}\int_{-\infty}^{x}e^{\frac{-t^2}{2}}\mathrm{d}t\] \[\mathrm{GELU}(x) \approx \frac{x}{2}\left(1 + \tanh\left[\sqrt{\frac{2}{\pi}}\left(x + 0.044715x^3\right)\right]\right)\]where $\Phi(x)$ is the cumulative distribution function of the standard normal distribution $\mathcal{N}(0, 1)$. GELU is somewhat similar to rectified linear unit $\mathrm{ReLU}(x) = \max(x, 0)$ but it is actually non-convex.
The model output $\boldsymbol{z}$ is passed to softmax function
\[\sigma(\boldsymbol{z})_j = \frac{e^{z_j}}{\sum_{k=1}^{m}e^{z_k}}\]producing a probability vector \(\hat{\boldsymbol{y}}_i\) (multinoulli distribution) where \(\hat{y}_{ij} = \sigma(\boldsymbol{z}_i)_j\) is the predicted probability that the $i$-th sample belongs to label $j$.
Now it’s time to link the model with its loss function. In the last post we saw that the cross-entropy loss is just negative sum of true label log-probabilities \(L = -\sum_{i=1}^{n}\log s_{i}\) where \(s_i \in \hat{\boldsymbol{y}}_i\) is the predicted probability of assigning the true label to the $i$-th sample.
def cross_entropy_loss(params, X, y):
y_hat = forward_pass(params, X)
true_label_probs = y_hat[y]
return -jnp.sum(jnp.log(true_label_probs))
In the snippet, the array y
contains the true labels (one-hot encoded) of each sample
and y_hat
is the prediction. Example of 3 samples and 2 labels is
y = [[0, 1], [1, 0], [1, 0]]
y_hat = [[0.1, 0.9], [0.4, 0.6], [0.7, 0.3]]
Next we create a wrapper around cross_entropy_loss
so that the model params
can be
passed as a flat array. This will make our life easier when working with the Hessian
matrices computed by JAX.
def size(shape):
try:
return shape[0] * shape[1]
except IndexError:
return shape[0]
def reshape_params(flat_params, shapes):
params = {}
start = 0
for name, shape in shapes.items():
slice_ = flat_params[start:start + size(shape)]
params[name] = slice_.reshape(*shape)
start += size(shape)
return params
def cross_entropy_loss_flat(flat_params, shapes, X, y):
params = reshape_params(flat_params, shapes)
return cross_entropy_loss(params, X, y)
Now everything is almost ready, we just need to generate random data X
, y
and initialize the model parameters.
n_samples = 40
n_features = 5
n_labels = 2
shapes = {
"W_0": (n_features, 10),
"b_0": (10,),
"W_1": (10, n_labels),
"b_1": (n_labels,)
}
# total_params = 82
total_params = sum(size(v) for v in shapes.values())
np.random.seed(0)
flat_params = np.random.randn(total_params)
X = np.random.randn(n_samples, n_features)
y = np.random.rand(n_samples).round().astype("int32")
y = pd.get_dummies(y).values
Computing and evaluating the Hessian is just few lines of code.
from functools import partial
loss = partial(
cross_entropy_loss_flat, shapes=shapes, X=X, y=y
)
H = hessian(loss)(flat_params)
The output H
is a 82x82 matrix because our model has 82 parameters and we are passing
them in flat shape. If we look at the eigenvalues of the matrix jnp.linalg.eig(H)[0][:4]
,
DeviceArray(
[
657.44336+0.j, 402.86493+0.j, -318.71735+0.j,
-287.853+0.j, -294.13333+0.j, -155.14163+0.j
],
dtype=complex64
)
we can see that they are complex64
and their real component is often negative. This shows
that the Hessian at flat_params
is not positive semi-definitve and the loss is non-convex.
If this has not convinced you, we can try a more intuitive approach. Non-convex functions can have multiple local minima. Can we find them in case of our $L(\theta)$? A neat trick mentioned in this lecture on model optimization is to define a plane by setting up 3 points $\theta_1, \theta_2, \theta_3$ in the parameter space and plotting $L(\theta_1 + \alpha\theta_2 + \beta\theta_3)$ for any $\alpha, \beta$ values in a selected range.
flat_params_2 = np.random.randn(total_params)
flat_params_3 = np.random.randn(total_params)
lims = np.linspace(-2, 2, 100)
losses = [
[
loss(
flat_params
+ alpha * flat_params_2
+ beta * flat_params_3
)
for alpha in lims
] for beta in lims
]
After playing with random seeds for a while, I found this region with more than one local minimum clearly illustrating non-convexity of $L(\theta)$.
Notebook containing all the code is available on my GitHub.