Tracking Multiple Losses with Keras
Often we deal with networks that are optimized for multiple losses (e.g., VAE). In such scenarios, it is useful to keep track of each loss independently, for fine-tuning its contribution to the overall loss. This post details an example on how to do this with keras.
Let us look at an example model which needs to trained to minimize the sum of two losses, say mean square error (MSE) and mean absolute error (MAE). Let $\lambda_{mse}$ be the hyperparameter that controls the contribution of MSE to the toal loss. i.e., the total loss is MAE + $\lambda_{mse}$ * MSE. This loss can be implemented using:
import keras.backend as K
lambda_mse = 10 # hyperparameter to be adjusted
def joint_loss (y_true, y_pred):
# mse
mse_loss = K.mean(K.square(y_true - y_pred))
# mae
mae_loss = K.mean(K.abs(y_true - y_pred))
return mae_loss + (lambda_mse * mse_loss)
with the model compiled as:
model.compile(loss = joint_loss, optimizer='Adam')
However, when we run model.fit(...)
keras shows the progress something like this..
Epoch 1/30
19488/144615 [===>..........................] - ETA: 1:52:37 - loss: 0.4103
Keras shows only the joint loss and does not give the individual MSE and MAE losses which makes it difficult to track how they evolve over epochs and to adjust $\lambda_{mae}$ accordingly.
In order to track them, we will need to define individual losses as below.
import keras.backend as K
lambda_mse = 10 # hyperparameter to be adjusted
def joint_loss (y_true, y_pred):
# mse
mse_loss = K.mean(K.square(y_true - y_pred))
# mae
mae_loss = K.mean(K.abs(y_true - y_pred))
return mae_loss + (lambda_mse * mse_loss)
def mse_loss (y_true, y_pred):
return K.mean(K.square(y_true - y_pred))
def mae_loss (y_true, y_pred):
return K.mean(K.abs(y_true - y_pred))
Then we can use the metrics
parameter in the model.compile
to also track the MAE and MSE. This can be done by compiling the model using
model.compile(loss = joint_loss, optimizer='Adam', metrics=[mse_loss, mae_loss])
Notice that the model is still compiled to optimize for the joint loss, but it also returns the MAE and MSE losses. Executing model.metrics_names
will return three values, ['loss', 'mae_loss', 'mse_loss']
. Now the model.fit(...)
will show something like this
Epoch 1/30
26336/144615 [====>.........................] - ETA: 1:46:54 - loss: 0.4078 - mae_loss: 0.1891 - mse_loss: 0.0219
Now we can see the joint loss and the individual losses that contributed to it. We can also verify that the joint loss indeed is mae_loss + 10 * mse_loss, where 10 was the value chosen for $\lambda_{mse}$.
Similiarly, you can define your own loss terms and use the metrics
parameter in model.compile
to track them independently.