-
Notifications
You must be signed in to change notification settings - Fork 27.4k
[feature request] Support soft target distribution in cross entropy loss #11959
Description
Currently our cross entropy loss (i.e., nn.CrossEntropyLoss) only supports a hard target class, i.e., wanting to maximize the output (log) probability of a particular class. But in many times training w.r.t. a soft target distribution (i.e., wanting the output to match a particular distribution) is quite useful too, e.g., preventing overfitting.
Math
Cross entropy loss operates on logits after softmax.
Denote the input vector as x. Log softmax computes a vector y of same length as x, where y_i = x_i - log( \sum_j exp(x_j) ), representing the log likelihood of each class.
-
In the hard target case, if the target clss is
c, the loss is simply negative log likelihood loss-y_c. -
In the soft target case, let the target distribution vector be
p(i.e.,p_iis the target probability for predicting classi). The loss is the KL divergenceD( softmax(x) || p) = \sum_i p_i (log p_i / softmax(x)_i) = -\sum_i p_i y_i + constantThe constant is independent of
xand thus discarded. Our loss formula is just-\sum_i p_i y_i.When
p_c = 1for some classc, this simplifies to the hard target class.The formula for gradient computation can be easily derived from this:
d l / d y_i = -p_i d y_i / d x_i = 1 - exp(x_i) / \sum_j exp(x_j) = 1 - exp(y_i) # suppose k != i d y_k / d_x_i = -exp(x_i) / \sum_j exp(x_j) = - exp(y_i) # so d l / d x_i = exp(y_i) (\sum p) - p_i = exp (y_i) - p_i (= softmax(x) - p_i).
Possible Implementation
Currently our cross entropy loss implementation takes in batched x of shape (N, C) and floating point dtype (N is the batch size and C is the number of classes), and a batched target class indices vector target of shape (N), where target[i] is the index of the desired output class, and dtype long (an integral type).
Since we want it to also take in soft target distribution as target, we can allow it to also take in target as a target batched distribution of shape (N, C), and detect whether we want soft target or hard target basing on shape and dtype.
cc @gchanan