Source code for wormpose.commands.calibrate_dataset

#!/usr/bin/env python

"""
Calculates the image similarity on a random selection of labeled frames from a dataset.
"""

import logging
import os
import random
from argparse import Namespace
from typing import Tuple

import h5py
import numpy as np

from wormpose.commands import _log_parameters
from wormpose.config import default_paths
from wormpose.dataset import Dataset
from wormpose.dataset.image_processing.options import add_image_processing_arguments
from wormpose.dataset.loader import get_dataset_name
from wormpose.dataset.loader import load_dataset
from wormpose.images.scoring.centerline_accuracy_check import CenterlineAccuracyCheck
from wormpose.pose.centerline import skeletons_to_angles
from wormpose.dataset.loaders.resizer import add_resizing_arguments, ResizeOptions

logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


class _ScoresWriter(object):
    def __init__(self):
        self.all_scores = []

    def add(self, kwargs):
        self.all_scores.append(kwargs["score"])

    def write(self, results_file):
        with h5py.File(results_file, "a") as f:
            f.create_dataset("scores", data=self.all_scores)


class _ImagesAndScoresWriter(_ScoresWriter):
    def __init__(self):
        self.all_synth = []
        self.all_real = []
        super().__init__()

    def add(self, kwargs):
        centerline_accuracy: CenterlineAccuracyCheck = kwargs["centerline_accuracy"]
        self.all_real.append(np.array(centerline_accuracy.last_real_image))
        self.all_synth.append(np.array(centerline_accuracy.last_synth_image))
        super().add(kwargs)

    def write(self, results_file):
        with h5py.File(results_file, "a") as f:
            f.create_dataset("real_images", data=self.all_real)
            f.create_dataset("synth_images", data=self.all_synth)
        super().write(results_file)


class _Calibrator(object):
    def __init__(
        self,
        dataset: Dataset,
        results_dir: str,
        image_shape: Tuple[int, int],
        num_samples: int,
        theta_dims: int,
    ):
        self.dataset = dataset
        self.results_dir = results_dir
        self.image_shape = image_shape
        self.num_samples = num_samples
        self.theta_dims = theta_dims

    def __call__(self, video_name: str, writer: _ScoresWriter):
        """
        Evaluate image metric score on labelled frames.
        """
        features = self.dataset.features_dataset[video_name]
        labelled_thetas = skeletons_to_angles(features.skeletons, theta_dims=self.theta_dims)
        labelled_indexes = features.labelled_indexes

        centerline_accuracy = CenterlineAccuracyCheck(
            frame_preprocessing=self.dataset.frame_preprocessing,
            image_shape=self.image_shape,
        )

        with self.dataset.frames_dataset.open(video_name) as frames:
            frames_amount = min(self.num_samples, len(labelled_indexes))

            random_label_index = np.random.choice(labelled_indexes, frames_amount, replace=False)
            thetas = labelled_thetas[random_label_index]

            for theta, index in zip(thetas, random_label_index):
                cur_frame = frames[index]

                score, _ = centerline_accuracy(
                    theta=theta,
                    template_skeleton=features.skeletons[index],
                    template_measurements=features.measurements,
                    template_frame=cur_frame,
                    real_frame_orig=cur_frame,
                )
                writer.add(locals())

        results_file = os.path.join(self.results_dir, video_name + "_calibration.h5")
        if os.path.exists(results_file):
            os.remove(results_file)
        writer.write(results_file=results_file)

        logger.info(
            f"Evaluated known skeletons reconstruction for {video_name},"
            f" average score {np.mean(writer.all_scores):.4f}"
        )
        return results_file


def _parse_arguments(kwargs: dict):
    if kwargs.get("num_samples") is None:
        kwargs["num_samples"] = 500
    if kwargs.get("work_dir") is None:
        kwargs["work_dir"] = default_paths.WORK_DIR
    if kwargs.get("theta_dims") is None:
        kwargs["theta_dims"] = 100
    if kwargs.get("video_names") is None:
        kwargs["video_names"] = None
    if kwargs.get("save_images") is None:
        kwargs["save_images"] = False
    if kwargs.get("random_seed") is None:
        kwargs["random_seed"] = None
    kwargs["resize_options"] = ResizeOptions(**kwargs)

    _log_parameters(logger.info, kwargs)
    return Namespace(**kwargs)


[docs]def calibrate(dataset_loader: str, dataset_path: str, **kwargs): """ Calculate the image score for a certain number of labelled frames in the dataset, this will give an indication on choosing the image similarity threshold when predicting all frames in the dataset. :param dataset_loader: Name of the dataset loader, for example "tierpsy" :param dataset_path: Root path of the dataset containing videos of worm """ _log_parameters(logger.info, {"dataset_loader": dataset_loader, "dataset_path": dataset_path}) args = _parse_arguments(kwargs) random.seed(args.random_seed) np.random.seed(args.random_seed) dataset_name = get_dataset_name(dataset_path) experiment_dir = os.path.join(args.work_dir, dataset_name) calibration_results_dir = os.path.join(experiment_dir, default_paths.CALIBRATION_RESULTS_DIR) os.makedirs(calibration_results_dir, exist_ok=True) dataset = load_dataset( dataset_loader, dataset_path, selected_video_names=args.video_names, **vars(args), ) calibrator = _Calibrator( dataset=dataset, results_dir=calibration_results_dir, image_shape=dataset.image_shape, num_samples=args.num_samples, theta_dims=args.theta_dims, ) writer = _ImagesAndScoresWriter() if kwargs["save_images"] else _ScoresWriter() for video_name in dataset.video_names: results_file = calibrator(video_name=video_name, writer=writer) yield video_name, results_file
def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("dataset_loader", type=str) parser.add_argument("dataset_path", type=str) parser.add_argument( "--video_names", type=str, nargs="+", help="Only evaluate using a subset of videos. " "If not set, will include all videos in dataset_path.", ) parser.add_argument( "--num_samples", type=int, help="How many frames to perform the calibration in order to evaluate the image metric", ) parser.add_argument("--theta_dims", type=int) parser.add_argument("--work_dir", type=str, help="Root folder for all experiments") parser.add_argument( "--save_images", default=False, action="store_true", help="Also save the images used for the calibration" ) parser.add_argument("--random_seed", type=int, help="Optional random seed for deterministic results") add_resizing_arguments(parser) add_image_processing_arguments(parser) args = parser.parse_args() list(calibrate(**vars(args))) if __name__ == "__main__": main()