Source code for wormpose.machine_learning.best_models_saver
"""
Implements a Keras callback to save the top N best models on evaluation data,
"""
import heapq
import json
import logging
import os
import tensorflow as tf
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class _BestModelsHeap(object):
"""
Keeps track of the top N named objects with the smallest value
uses a max-heap (heapq and inverting the sign of the value)
"""
def __init__(self, values, names, maxlen):
self.maxlen = maxlen
self._queue = []
for val, name in zip(values, names):
heapq.heappush(self._queue, (-val, name))
def append(self, value, name):
if len(self._queue) < self.maxlen:
heapq.heappush(self._queue, (-value, name))
else:
heapq.heappushpop(self._queue, (-value, name))
def get_sorted(self):
sorted_queue = list(reversed(sorted(self._queue)))
return [-x[0] for x in sorted_queue], [x[1] for x in sorted_queue]
_NAME_PATTERN = "model.{epoch:02d}-{val_loss:.2f}.hdf5"
[docs]class BestModels(tf.keras.callbacks.Callback):
def __init__(self, models_dir, models_to_keep=5):
self.models_name_pattern = os.path.join(models_dir, _NAME_PATTERN)
self.models_dir = models_dir
self.backup_filepath = os.path.join(self.models_dir, "models_info.json")
try:
with open(self.backup_filepath, "r") as f:
backup = json.load(f)
val_loss = backup["best_val_loss"]
names = backup["best_models"]
epoch = backup["epoch"]
last = backup["last_model"]
except Exception:
val_loss = []
names = []
epoch = 0
last = ""
self.epoch = epoch
self.last_model_path = os.path.join(models_dir, last)
self.best_model_path = os.path.join(models_dir, names[0]) if len(names) > 0 else self.last_model_path
self.models_heap = _BestModelsHeap(values=val_loss, names=names, maxlen=models_to_keep)
super().__init__()
[docs] def on_epoch_end(self, epoch, logs=None):
"""
At the end of each epoch, update the best models and save to json
"""
val_loss = logs["val_loss"]
last_model_path = os.path.basename(self.models_name_pattern.format(epoch=epoch + 1, val_loss=val_loss))
self.models_heap.append(val_loss, last_model_path)
best_val_loss, best_model_names = self.models_heap.get_sorted()
for f in os.listdir(self.models_dir):
if f not in best_model_names and f != os.path.basename(self.backup_filepath) and f != last_model_path:
os.remove(os.path.join(self.models_dir, f))
with open(self.backup_filepath, "w") as f:
json.dump(
{
"epoch": epoch + 1,
"best_val_loss": best_val_loss,
"best_models": best_model_names,
"last_model": last_model_path,
},
f,
indent=4,
)