Skip to content

[feature request] Support soft target distribution in cross entropy loss #11959

@ssnl

Description

@ssnl

Currently our cross entropy loss (i.e., nn.CrossEntropyLoss) only supports a hard target class, i.e., wanting to maximize the output (log) probability of a particular class. But in many times training w.r.t. a soft target distribution (i.e., wanting the output to match a particular distribution) is quite useful too, e.g., preventing overfitting.

Math

Cross entropy loss operates on logits after softmax.

Denote the input vector as x. Log softmax computes a vector y of same length as x, where y_i = x_i - log( \sum_j exp(x_j) ), representing the log likelihood of each class.

  • In the hard target case, if the target clss is c, the loss is simply negative log likelihood loss -y_c.

  • In the soft target case, let the target distribution vector be p (i.e., p_i is the target probability for predicting class i). The loss is the KL divergence

    D( softmax(x) || p) = \sum_i p_i (log p_i  / softmax(x)_i) = -\sum_i p_i y_i + constant
    

    The constant is independent of x and thus discarded. Our loss formula is just -\sum_i p_i y_i.

    When p_c = 1 for some class c, this simplifies to the hard target class.

    The formula for gradient computation can be easily derived from this:

    
    d l / d y_i = -p_i
    
    d y_i / d x_i = 1 - exp(x_i) / \sum_j exp(x_j) = 1 - exp(y_i)
    
    # suppose k != i
    d y_k / d_x_i = -exp(x_i) / \sum_j exp(x_j) = - exp(y_i)
    
    # so
    d l / d x_i = exp(y_i) (\sum p) - p_i = exp (y_i) - p_i (= softmax(x) - p_i).
    
    

Possible Implementation

Currently our cross entropy loss implementation takes in batched x of shape (N, C) and floating point dtype (N is the batch size and C is the number of classes), and a batched target class indices vector target of shape (N), where target[i] is the index of the desired output class, and dtype long (an integral type).

Since we want it to also take in soft target distribution as target, we can allow it to also take in target as a target batched distribution of shape (N, C), and detect whether we want soft target or hard target basing on shape and dtype.

cc @gchanan

Metadata

Metadata

Assignees

Labels

function requestA request for a new function or the addition of new arguments/modes to an existing function.module: lossProblem is related to loss functiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions