Source code for wormpose.dataset.loader

"""
The Dataset loader: instantiates the FramesDataset, FeaturesDataset, FramePreprocessing, ResultsExporter (optional).
Also handles the resizing options.
"""

import os
from typing import List, Optional, Tuple

import pkg_resources

from wormpose.dataset.base_dataset import (
    BaseFramesDataset,
    BaseFramePreprocessing,
    BaseResultsExporter,
)
from wormpose.dataset.features import Features, calculate_crop_window_size, FeaturesDict
from wormpose.dataset.image_processing.options import WORM_IS_LIGHTER
from wormpose.dataset.loaders.resizer import (
    frames_dataset_resizer,
    features_dataset_resizer,
    ResizeOptions,
)


class Dataset(object):
    def __init__(
        self,
        video_names: List[str],
        frames_dataset: BaseFramesDataset,
        features_dataset: FeaturesDict,
        frame_preprocessing: BaseFramePreprocessing,
        image_shape: Tuple[int, int],
        results_exporter: BaseResultsExporter,
    ):
        self.video_names = video_names
        self.frames_dataset = frames_dataset
        self.features_dataset = features_dataset
        self.frame_preprocessing = frame_preprocessing
        self.image_shape = image_shape
        self.results_exporter = results_exporter

    def num_frames(self, video_name):
        with self.frames_dataset.open(video_name) as frames:
            return len(frames)


[docs]def get_dataset_name(dataset_path: str) -> str: """ Each dataset gets assigned a name: the root folder of the dataset :param dataset_path: Full path of the dataset :return: Name identifier for the dataset, simply use the basename of the path Each different dataset must have a unique basename in order to process several at once """ return os.path.basename(os.path.normpath(dataset_path))
class _DummyResultsExporter(BaseResultsExporter): """ Does nothing """ def export(self, video_name: str, **kwargs): pass def load_dataset( dataset_loader: str, dataset_path: str, selected_video_names: Optional[List[str]] = None, resize_options: ResizeOptions = None, **kwargs, ) -> Dataset: for entry_point in pkg_resources.iter_entry_points("worm_dataset_loaders"): if entry_point.name == dataset_loader: module = entry_point.load() frames_dataset_class = module.FramesDataset features_dataset_class = module.FeaturesDataset frame_preprocessing_class = module.FramePreprocessing frames_dataset, features_dataset, video_names = _load_dataset( frames_dataset_class, features_dataset_class, dataset_path, selected_video_names, ) is_foreground_lighter_than_background = kwargs.get(WORM_IS_LIGHTER, False) frame_preprocessing = frame_preprocessing_class(is_foreground_lighter_than_background) image_shape = calculate_crop_window_size(features_dataset) if resize_options is not None: resize_options.update_resize_factor(features_dataset) if resize_options.resize_factor != 1.0: # reload frames and features dataset after resizing, also get new image_shape frames_dataset_class = frames_dataset_resizer( frames_dataset_class, resize_factor=resize_options.resize_factor ) features_dataset_class = features_dataset_resizer( features_dataset_class, resize_factor=resize_options.resize_factor, ) frames_dataset, features_dataset, video_names = _load_dataset( frames_dataset_class, features_dataset_class, dataset_path, selected_video_names, ) image_shape = resize_options.get_image_shape(features_dataset) results_exporter = ( module.ResultsExporter(dataset_path) if hasattr(module, "ResultsExporter") else _DummyResultsExporter() ) return Dataset( video_names=video_names, features_dataset=features_dataset, frames_dataset=frames_dataset, frame_preprocessing=frame_preprocessing, image_shape=image_shape, results_exporter=results_exporter, ) raise NotImplementedError(f"Dataset loader: '{dataset_loader}' not found in the package entry points.") def _load_dataset( frames_dataset_class, features_dataset_class, dataset_path: str, selected_video_names, ): frames_dataset = frames_dataset_class(dataset_path) video_names = _resolve_video_names(frames_dataset, selected_video_names) raw_features_dataset = features_dataset_class(dataset_path, video_names) features_dataset = {} for video_name in video_names: features_dataset[video_name] = Features(raw_features_dataset.get_features(video_name)) return frames_dataset, features_dataset, video_names def _resolve_video_names(frames_dataset: BaseFramesDataset, selected_video_names: Optional[List[str]]): video_names = frames_dataset.video_names() if selected_video_names is not None: for video_name in selected_video_names: if video_name not in video_names: raise ValueError(f"Requested video '{video_name}' not found in dataset ({video_names})") video_names = selected_video_names _validate_video_names(video_names) return video_names def _validate_video_names(video_names: List[str]): if len(video_names) == 0: raise ValueError(f"Video names list is empty, no video to analyze") if not all(isinstance(n, str) for n in video_names): raise NotImplementedError("Only video names as list of strings supported type.") if len(video_names) > len(set(video_names)): raise ValueError(f"Video names must be unique but duplicates were found: {video_names}.")