Source code for wormpose.machine_learning.loss

"""
Definition of the loss function for the network
"""

import tensorflow as tf


[docs]def angle_diff(a, b): """ Root Mean Square Error of the angle difference. The angle difference function takes into account the periodicity of angles """ diff = tf.atan2(tf.sin(a - b), tf.cos(a - b)) return tf.sqrt(tf.reduce_mean(tf.square(diff), axis=1))
[docs]def symmetric_angle_difference(y_true, y_pred): """ We calculate the angle difference between the prediction and the two possible labels, and pick the minimum of the two, we average the result on the batch """ dists = [angle_diff(y_pred, y_true[:, 0]), angle_diff(y_pred, y_true[:, 1])] mins = tf.reduce_min(dists, axis=0) loss = tf.reduce_mean(mins) return loss