Loss function for class imbalanced binary classifier in Tensor flow
Loss function for class imbalanced binary classifier in Tensor flow
By multiplying logits, add class weights to the loss function. The cross entropy loss is:
loss(x, class) = -log(exp(x[class]) / (\sum_j exp(x[j]))) = -x[class] + log(\sum_j exp(x[j]))
For weighted case:
loss(x, class) = weights[class] * -x[class] + log(\sum_j exp(weights[class] * x[j]))
Hence By class weight you are re-scaling predictions of each class, by multiplying logits.
ratio = 31.0 / (500.0 + 31.0) class_weight = tf.constant([ratio, 1.0 - ratio]) logits = ... # shape [batch_size, 2] weighted_logits = tf.mul(logits, class_weight) # shape [batch_size, 2] xent = tf.nn.softmax_cross_entropy_with_logits( weighted_logits, labels, name="xent_raw")
It as an standard losses function which supports weights per batch:
tf.losses.sparse_softmax_cross_entropy(labels=label, logits=logits, weights=weights)
The weights are transformed from class weights to a weight per example. Check here.
Here is a solution for the problem:
ratio = 31.0 / (500.0 + 31.0) class_weight = tf.constant([[ratio, 1.0 - ratio]]) logits = ... # shape [batch_size, 2] weight_per_label = tf.transpose( tf.matmul(labels , tf.transpose(class_weight)) ) #shape [1, batch_size] # this is the weight for each datapoint, depending on its label xent = tf.mul(weight_per_label , tf.nn.softmax_cross_entropy_with_logits(logits, labels, name="xent_raw") #shape [1, batch_size] loss = tf.reduce_mean(xent) #shape 1
To determine the Loss function for class imbalanced binary classifier use this tf.nn.weighted_cross_entropy_with_logits()
and set pos_weight =1.