Is Cross-Entropy Loss Convex?

27 Mar 2022

The other day I saw a list of interesting data science problems with one of them being “show that cross-entropy loss is convex”. Let’s look into it!

Self-Information, Entropy and Cross-Entropy

If an outcome $x$ occurs with probability $p(x)$, its self-information is $-\mathrm{log}_2 p(x)$. Because of the negative logarithm, high probability means low self-information and vice versa. The basic idea is that frequent outcomes are not very interesting and do not provide us with much information. If you know that a loaded dice always yields six, then throwing the dice and observing the result will give you zero new information.

Entropy $H(X)$ of a random discrete variable $X$ with $m$ possible outcomes is

\[H(X) = -\sum_{j=1}^m p(x_j)\mathrm{log}_2 p(x_j)\]

From the definition we can see that $H(X)$ is the average self-information of $X$. Similarly to variance and standard devitation, entropy is a measure of dispersion. Uniform distribution where each outcome has the same probability results in highest entropy.

Entropy

Cross-entropy is a measure of mismatch between two probability distributions. Measuring such mismatch is useful when building machine learning models because it can tell us how close our model is to the ground truth. Imagine we are training a model to classify whether there is an apple, orange, or banana in a picture. We give the model a picture of an apple and it returns

{"apple": 0.95, "orange": 0.04, "banana": 0.01}

Our model thinks there is an apple in the picture with $95\%$ probability. The actual ground truth is

{"apple": 1.00, "orange": 0.00, "banana": 0.00}

Both the prediction $q$ and ground truth $p$ are multinoulli probability distributions with 3 possible outcomes. We can compute their cross-entropy $H(p, q)$ as

\[H(p, q) = -\sum_{j=1}^m p(x_j)\mathrm{log}_2 q(x_j)\] \[H(p, q) = -1 \cdot \mathrm{log}_2 0.95 + 0 \cdot \mathrm{log}_2 0.04 + 0 \cdot \mathrm{log}_2 0.01 = -\mathrm{log}_2 0.95 = 0.07\]

When training a model, the ground truth distribution always contains one outcome with $100\%$ probability (the correct label) which means its entropy is $H(p) = 0$. If we achieved perfect prediction, $p$ and $q$ distributions would be identical and by definition $H(p, q) = H(p) = H(q) = 0$. Training is thus a matter of getting as close to zero as possible.

Cross-entropy is always equal or greater than the entropy of the ground truth distribution $H(p)$, the difference $H(p, q) - H(p) = D_{KL}(p \Vert q)$ is called Kullback–Leibler divergence.

Because the ground truth probabilities of incorrect labels are always zero, we can simplify the cross-entropy loss $L_i$ for training sample $i$ as $L_i(s_i) = -\mathrm{log}_2 s_i$ where $s_i$ is the probability of the true label as defined by the predicted distribution $q_i$. The total training loss $L$ is then

\[L(\boldsymbol{s}) = -\sum_{i=1}^n \mathrm{log}_2 s_i\]

Convexity

Our loss $L(\boldsymbol{s})$ is a function of $n$ variables $s_1,\dotsc, s_n$ as there are $n$ training samples. A function of multiple variables is convex if its Hessian matrix is positive semi-definite for all values $\boldsymbol{s}$. Hessian of a function is a square matrix of second partial derivatives. For our loss function $L$, the Hessian is

\[\mathrm{H}_L = \begin{pmatrix} \frac{\partial^2 L}{\partial s_1^2} & \dots & \frac{\partial^2 L}{\partial s_1 \partial s_n } \\ \vdots & \ddots & \vdots \\ \frac{\partial^2 L}{\partial s_n \partial s_1} & \dots & \frac{\partial^2 L}{\partial s_n^2} \\ \end{pmatrix} = \begin{pmatrix} \frac{1}{s_1^2} & \dots & 0 \\ \vdots & \ddots & \vdots \\ 0 & \dots & \frac{1}{s_n^2} \end{pmatrix}\]

Real-valued Hessian is positive semi-definite if all its eigenvalues $\lambda_i$ are zero or positive. Our $\mathrm{H}_L$ is a diagonal matrix where the eigenvalues are equal to the diagonal elements $\lambda_j = s_i^{-2}$ which are indeed always greater than zero. This shows that $\mathrm{H}_L$ is positive semi-definite and cross-entropy loss is a convex function.

Convex

We like convex functions because they are easier to optimize, any local minimum of a convex function is a global minimum.