Source code for wormpose.commands.generate_training_data

#!/usr/bin/env python

"""
Generates the training and evaluation data from a dataset.
"""

import logging
import multiprocessing as mp
import os
import random
import shutil
import tempfile
import time
from argparse import Namespace

import numpy as np

from wormpose.commands import _log_parameters
from wormpose.config import default_paths
from wormpose.config.default_paths import (
    SYNTH_TRAIN_DATASET_NAMES,
    REAL_EVAL_DATASET_NAMES,
    CONFIG_FILENAME,
)
from wormpose.config.experiment_config import save_config, ExperimentConfig
from wormpose.dataset.image_processing.options import (
    add_image_processing_arguments,
    WORM_IS_LIGHTER,
)
from wormpose.dataset.loader import get_dataset_name
from wormpose.dataset.loader import load_dataset
from wormpose.dataset.loaders.resizer import add_resizing_arguments, ResizeOptions
from wormpose.machine_learning import eval_data_generator
from wormpose.machine_learning.synthetic_data_generator import SyntheticDataGenerator
from wormpose.machine_learning.tfrecord_file import TfrecordLabeledDataWriter
from wormpose.pose.postures_model import PosturesModel

logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def _parse_arguments(kwargs: dict):
    if kwargs.get("num_process") is None:
        kwargs["num_process"] = os.cpu_count()
    if kwargs.get("temp_dir") is None:
        kwargs["temp_dir"] = tempfile.gettempdir()
    if kwargs.get("num_train_samples") is None:
        kwargs["num_train_samples"] = int(5e5)
    if kwargs.get("num_eval_samples") is None:
        kwargs["num_eval_samples"] = int(1e4)
    if kwargs.get("work_dir") is None:
        kwargs["work_dir"] = default_paths.WORK_DIR
    if kwargs.get("postures_generation") is None:
        kwargs["postures_generation"] = PosturesModel().generate
    if kwargs.get("video_names") is None:
        kwargs["video_names"] = None
    if kwargs.get("random_seed") is None:
        kwargs["random_seed"] = None
    if kwargs.get(WORM_IS_LIGHTER) is None:
        kwargs[WORM_IS_LIGHTER] = False
    kwargs["temp_dir"] = tempfile.mkdtemp(dir=kwargs["temp_dir"])
    kwargs["resize_options"] = ResizeOptions(**kwargs)

    _log_parameters(logger.info, kwargs)

    return Namespace(**kwargs)


[docs]def generate(dataset_loader: str, dataset_path: str, **kwargs): """ Generate synthetic images (training data) and processed real images (evaluation data) and save them to TFrecord files using multiprocessing :param dataset_loader: Name of the dataset loader, for example "tierpsy" :param dataset_path: Root path of the dataset containing videos of worm """ _log_parameters(logger.info, {"dataset_loader": dataset_loader, "dataset_path": dataset_path}) args = _parse_arguments(kwargs) mp.set_start_method("spawn", force=True) random.seed(args.random_seed) np.random.seed(args.random_seed) # setup folders if not os.path.exists(args.work_dir): os.mkdir(args.work_dir) experiment_dir = os.path.join(args.work_dir, get_dataset_name(dataset_path)) os.makedirs(experiment_dir, exist_ok=True) tfrecords_dataset_root = os.path.join(experiment_dir, default_paths.TRAINING_DATA_DIR) if os.path.exists(tfrecords_dataset_root): shutil.rmtree(tfrecords_dataset_root) dataset = load_dataset( dataset_loader=dataset_loader, dataset_path=dataset_path, selected_video_names=args.video_names, **vars(args), ) start = time.time() synthetic_data_generator = SyntheticDataGenerator( num_process=args.num_process, temp_dir=args.temp_dir, dataset=dataset, postures_generation_fn=args.postures_generation, enable_random_augmentations=True, writer=TfrecordLabeledDataWriter, random_seed=args.random_seed, ) gen = synthetic_data_generator.generate( num_samples=args.num_train_samples, file_pattern=os.path.join(args.temp_dir, SYNTH_TRAIN_DATASET_NAMES), ) for progress in gen: yield progress yield 1.0 theta_dims = len(next(args.postures_generation())) num_eval_samples = eval_data_generator.generate( dataset=dataset, num_samples=args.num_eval_samples, theta_dims=theta_dims, file_pattern=os.path.join(args.temp_dir, REAL_EVAL_DATASET_NAMES), ) shutil.copytree(args.temp_dir, tfrecords_dataset_root) save_config( ExperimentConfig( dataset_loader=dataset_loader, image_shape=dataset.image_shape, theta_dimensions=theta_dims, num_train_samples=args.num_train_samples, num_eval_samples=num_eval_samples, resize_factor=args.resize_options.resize_factor, video_names=dataset.video_names, worm_is_lighter=getattr(args, WORM_IS_LIGHTER), ), os.path.join(experiment_dir, CONFIG_FILENAME), ) end = time.time() logger.info(f"Done generating training data in : {end - start:.1f}s")
def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("dataset_loader", type=str) parser.add_argument("dataset_path", type=str) parser.add_argument( "--video_names", type=str, nargs="+", help="Only generate training data for a subset of videos. " "If not set, will include all videos in dataset_path.", ) parser.add_argument("--num_train_samples", type=int, help="How many training samples to generate") parser.add_argument("--num_eval_samples", type=int, help="How many evaluation samples to generate") parser.add_argument("--temp_dir", type=str, help="Where to store temporary intermediate results") parser.add_argument("--work_dir", type=str, help="Root folder for all experiments") parser.add_argument("--num_process", type=int, help="How many worker processes") parser.add_argument("--random_seed", type=int, help="Optional random seed for deterministic results") add_resizing_arguments(parser) add_image_processing_arguments(parser) args = parser.parse_args() last_progress = None for progress in generate(**vars(args)): prog_percent = int(progress * 100) if prog_percent != last_progress: logger.info(f"Generating training data: {prog_percent}% done") last_progress = prog_percent if __name__ == "__main__": main()