Loss function for class imbalanced binary classifier in Tensor flow
By multiplying logits, add class weights to the loss function. The cross entropy loss is:
loss(x, class) = -log(exp(x[class]) / (\sum_j exp(x[j]))) = -x[class] + log(\sum_j exp(x[j]))
For weighted case:
loss(x, class) = weights[class] * -x[class] + log(\sum_j exp(weights[class] * x[j]))
Hence By class weight you are re-scaling predictions of each class, by multiplying logits.
ratio = 31.0 / (500.0 + 31.0) class_weight = tf.constant([ratio, 1.0 - ratio]) logits = ... # shape [batch_size, 2] weighted_logits = tf.mul(logits, class_weight) # shape [batch_size, 2] xent = tf.nn.softmax_cross_entropy_with_logits( weighted_logits, labels, name="xent_raw")
It as an standard losses function which supports weights per batch:
tf.losses.sparse_softmax_cross_entropy(labels=label, logits=logits, weights=weights)
The weights are transformed from class weights to a weight per example. Check here.