Loss function for class imbalanced binary classifier in Tensor flow

Loss function for class imbalanced binary classifier in Tensor flow

Asked on December 18, 2018 in Tensorflow.
Add Comment


  • 3 Answer(s)

        By multiplying logits, add class weights to the loss function. The cross entropy loss is:

    loss(x, class) = -log(exp(x[class]) / (\sum_j exp(x[j])))
                   = -x[class] + log(\sum_j exp(x[j]))
    

        For weighted case:

    loss(x, class) = weights[class] * -x[class] + log(\sum_j exp(weights[class] * x[j]))
    

        Hence By class weight you are re-scaling predictions of each class, by multiplying logits.

    ratio = 31.0 / (500.0 + 31.0)
    class_weight = tf.constant([ratio, 1.0 - ratio])
    logits = ... # shape [batch_size, 2]
    weighted_logits = tf.mul(logits, class_weight) # shape [batch_size, 2]
    xent = tf.nn.softmax_cross_entropy_with_logits(
      weighted_logits, labels, name="xent_raw")
    

        It as an standard losses function which supports weights per batch:

    tf.losses.sparse_softmax_cross_entropy(labels=label, logits=logits, weights=weights)
    

        The weights are transformed from class weights to a weight per example. Check here.

    Answered on December 18, 2018.
    Add Comment

    Here is a solution for the problem:

    ratio = 31.0 / (500.0 + 31.0)
    class_weight = tf.constant([[ratio, 1.0 - ratio]])
    logits = ... # shape [batch_size, 2]
     
    weight_per_label = tf.transpose( tf.matmul(labels
                               , tf.transpose(class_weight)) ) #shape [1, batch_size]
    # this is the weight for each datapoint, depending on its label
     
    xent = tf.mul(weight_per_label
             , tf.nn.softmax_cross_entropy_with_logits(logits, labels, name="xent_raw") #shape [1, batch_size]
    loss = tf.reduce_mean(xent) #shape 1
    
    Answered on December 18, 2018.
    Add Comment

          To determine the Loss function for class imbalanced binary classifier use this tf.nn.weighted_cross_entropy_with_logits()and set pos_weight =1.

    Answered on December 18, 2018.
    Add Comment


  • Your Answer

    By posting your answer, you agree to the privacy policy and terms of service.