Loss function for class imbalanced binary classifier in Tensor flow

Loss function for class imbalanced binary classifier in Tensor flow

Asked on December 18, 2018 in Tensorflow.
Add Comment


  • 1 Answer(s)

        By multiplying logits, add class weights to the loss function. The cross entropy loss is:

    loss(x, class) = -log(exp(x[class]) / (\sum_j exp(x[j])))
                   = -x[class] + log(\sum_j exp(x[j]))
    

        For weighted case:

    loss(x, class) = weights[class] * -x[class] + log(\sum_j exp(weights[class] * x[j]))
    

        Hence By class weight you are re-scaling predictions of each class, by multiplying logits.

    ratio = 31.0 / (500.0 + 31.0)
    class_weight = tf.constant([ratio, 1.0 - ratio])
    logits = ... # shape [batch_size, 2]
    weighted_logits = tf.mul(logits, class_weight) # shape [batch_size, 2]
    xent = tf.nn.softmax_cross_entropy_with_logits(
      weighted_logits, labels, name="xent_raw")
    

        It as an standard losses function which supports weights per batch:

    tf.losses.sparse_softmax_cross_entropy(labels=label, logits=logits, weights=weights)
    

        The weights are transformed from class weights to a weight per example. Check here.

    Answered on December 18, 2018.
    Add Comment


  • Your Answer

    By posting your answer, you agree to the privacy policy and terms of service.