Computing Hessians with JAX

05 Jul 2022

In the previous post we showed that cross-entropy (spoiler alert if you haven’t read it) is convex. However, training a model using cross-entropy loss doesn’t have to be a convex optimization problem.

In case of cross-entropy, the total loss $L$ of a model is

\[L = \sum_{i=1}^{n}L(y_i, \hat{\boldsymbol{y}}_i) = \sum_{i=1}^{n}L(y_i, f(\boldsymbol{x}_i, \theta))\]

where $\hat{\boldsymbol{y}}_i$ is the model prediction and $y_i$ is the true label of the $i$-th sample respectively. The model $\hat{\boldsymbol{y}}_i = f(\boldsymbol{x}_i, \theta)$ computes predictions based on its parameters $\theta$ and sample features $\boldsymbol{x}_i$. Training the model means finding a set of parameters $\theta$ in the parameter space $\Theta$ that minimize the total loss

\[\arg\min_{\theta \in \Theta}L(\theta)\]

Minimizing $L(\theta)$ with respect to $\theta$ involves the function $f$ which doesn’t have to be convex.

JAX Framework

We can test this in practice by defining a model and evaluating its Hessian matrix at some point $\theta$. If the matrix is not positive semi-definite, then $L(\theta)$ is non-convex.

We will use JAX to do the difficult part of computing the Hessian. JAX offers jacrev that produces the Jacobian of a given function using reverse-mode auto differentiation (backpropagation). Making Hessians with jacrev is easy:

from jax import jacrev

def hessian(f):
    return jacrev(jacrev(f))

Let’s start with a simple example to see if it actually works. We want to get the Hessian of a function $g(x_1, x_2) = 2x_1^3 + 4x_2^2$ which is

\[\mathrm{H}_g = \begin{pmatrix} \frac{\partial^2 g}{\partial x_1^2} & \frac{\partial^2 g}{\partial x_1 \partial x_2 } \\ \frac{\partial^2 g}{\partial x_1 \partial x_2} & \frac{\partial^2 g}{\partial x_2^2} \\ \end{pmatrix} = \begin{pmatrix} 12x_1 & 0 \\ 0 & 8 \\ \end{pmatrix}\]

If we create the same function g in Python and evaluate its Hessian at x = [2, 4],

import jax.numpy as jnp

def g(x):
    return 2*x[0]**3 + 4*x[1]**2

x = jnp.array([2., 4.])
hessian(g)(x)

the result is [[24, 0], [0, 8]]. Exactly as expected.

The Model

We now apply the same hessian function to a real model to investigate its convexity. The model in question is a simple neural network classifier consiting of two dense layers. It is implemented in the forward_pass function in the following snippet.

def gelu(x):
    a = jnp.sqrt(2 / jnp.pi)
    b = (x + 0.044715 * x ** 3)
    return 0.5 * x * (1 + jnp.tanh(a * b))

def softmax(x):
    exp_x = jnp.exp(x)
    return exp_x / jnp.sum(exp_x)

def forward_pass(params, X):
    W_0, b_0 = params["W_0"], params["b_0"]
    activations = gelu(jnp.dot(X, W_0) + b_0)  # first layer
    
    W_1, b_1 = params["W_1"], params["b_1"]
        
    output = jnp.dot(activations, W_1) + b_1  # second layer
    return jnp.apply_along_axis(softmax, 1, output)

The first layer is using gaussian error linear unit (GELU) as its activation function which is defined as

\[\mathrm{GELU}(x) = x\Phi(x) = \frac{x}{\sqrt{2\pi}}\int_{-\infty}^{x}e^{\frac{-t^2}{2}}\mathrm{d}t\] \[\mathrm{GELU}(x) \approx \frac{x}{2}\left(1 + \tanh\left[\sqrt{\frac{2}{\pi}}\left(x + 0.044715x^3\right)\right]\right)\]

where $\Phi(x)$ is the cumulative distribution function of the standard normal distribution $\mathcal{N}(0, 1)$. GELU is somewhat similar to rectified linear unit $\mathrm{ReLU}(x) = \max(x, 0)$ but it is actually non-convex.

GELU

The model output $\boldsymbol{z}$ is passed to softmax function

\[\sigma(\boldsymbol{z})_j = \frac{e^{z_j}}{\sum_{k=1}^{m}e^{z_k}}\]

producing a probability vector \(\hat{\boldsymbol{y}}_i\) (multinoulli distribution) where \(\hat{y}_{ij} = \sigma(\boldsymbol{z}_i)_j\) is the predicted probability that the $i$-th sample belongs to label $j$.

Now it’s time to link the model with its loss function. In the last post we saw that the cross-entropy loss is just negative sum of true label log-probabilities \(L = -\sum_{i=1}^{n}\log s_{i}\) where \(s_i \in \hat{\boldsymbol{y}}_i\) is the predicted probability of assigning the true label to the $i$-th sample.

def cross_entropy_loss(params, X, y):
    y_hat = forward_pass(params, X)
    true_label_probs = y_hat[y]

    return -jnp.sum(jnp.log(true_label_probs))

In the snippet, the array y contains the true labels (one-hot encoded) of each sample and y_hat is the prediction. Example of 3 samples and 2 labels is

y = [[0, 1], [1, 0], [1, 0]]
y_hat = [[0.1, 0.9], [0.4, 0.6], [0.7, 0.3]]

Next we create a wrapper around cross_entropy_loss so that the model params can be passed as a flat array. This will make our life easier when working with the Hessian matrices computed by JAX.

def size(shape):
    try:
        return shape[0] * shape[1]
    except IndexError:
        return shape[0]

def reshape_params(flat_params, shapes):
    
    params = {}
    start = 0
    
    for name, shape in shapes.items():
        slice_ = flat_params[start:start + size(shape)]
        params[name] = slice_.reshape(*shape)
        start += size(shape)
        
    return params

def cross_entropy_loss_flat(flat_params, shapes, X, y):
    params = reshape_params(flat_params, shapes)
    return cross_entropy_loss(params, X, y)

Now everything is almost ready, we just need to generate random data X, y and initialize the model parameters.

n_samples = 40
n_features = 5
n_labels = 2

shapes = {
    "W_0": (n_features, 10),
    "b_0": (10,),
    "W_1": (10, n_labels),
    "b_1": (n_labels,)
}

# total_params = 82
total_params = sum(size(v) for v in shapes.values())

np.random.seed(0)
flat_params = np.random.randn(total_params)
X = np.random.randn(n_samples, n_features)
y = np.random.rand(n_samples).round().astype("int32")
y = pd.get_dummies(y).values

Computing and evaluating the Hessian is just few lines of code.

from functools import partial

loss = partial(
    cross_entropy_loss_flat, shapes=shapes, X=X, y=y
)

H = hessian(loss)(flat_params)

The output H is a 82x82 matrix because our model has 82 parameters and we are passing them in flat shape. If we look at the eigenvalues of the matrix jnp.linalg.eig(H)[0][:4],

DeviceArray(
    [
        657.44336+0.j, 402.86493+0.j, -318.71735+0.j,
        -287.853+0.j, -294.13333+0.j, -155.14163+0.j
    ],
    dtype=complex64
)

we can see that they are complex64 and their real component is often negative. This shows that the Hessian at flat_params is not positive semi-definitve and the loss is non-convex.

Visualizing $L(\theta)$

If this has not convinced you, we can try a more intuitive approach. Non-convex functions can have multiple local minima. Can we find them in case of our $L(\theta)$? A neat trick mentioned in this lecture on model optimization is to define a plane by setting up 3 points $\theta_1, \theta_2, \theta_3$ in the parameter space and plotting $L(\theta_1 + \alpha\theta_2 + \beta\theta_3)$ for any $\alpha, \beta$ values in a selected range.

flat_params_2 = np.random.randn(total_params)
flat_params_3 = np.random.randn(total_params)

lims = np.linspace(-2, 2, 100)
losses = [
    [
        loss(
            flat_params
            + alpha * flat_params_2
            + beta * flat_params_3
        )
        for alpha in lims
    ] for beta in lims
]

After playing with random seeds for a while, I found this region with more than one local minimum clearly illustrating non-convexity of $L(\theta)$.

Loss

Notebook containing all the code is available on my GitHub.