Source code for wormpose.demo.synthetic_simple_visualizer
#!/usr/bin/env python
"""
Visualizer for the synthetic images
"""
import random
from typing import Optional, Generator
import numpy as np
from wormpose.dataset.image_processing.options import add_image_processing_arguments
from wormpose.dataset.loader import load_dataset
from wormpose.dataset.loaders.resizer import add_resizing_arguments, ResizeOptions
from wormpose.images.synthetic import SyntheticDataset
from wormpose.pose.postures_model import PosturesModel
[docs]class SyntheticSimpleVisualizer(object):
"""
Utility class to visualize the synthetic images
"""
def __init__(
self,
dataset_loader: str,
dataset_path: str,
postures_generator: Optional[Generator] = None,
video_name: str = None,
**kwargs
):
resize_options = ResizeOptions(**kwargs)
dataset = load_dataset(dataset_loader, dataset_path, resize_options=resize_options, **kwargs)
if postures_generator is None:
postures_generator = PosturesModel().generate()
if video_name is None:
video_name = dataset.video_names[0]
features = dataset.features_dataset[video_name]
self.skeletons = features.skeletons
self.measurements = features.measurements
self.output_image_shape = dataset.image_shape
self.synthetic_dataset = SyntheticDataset(
frame_preprocessing=dataset.frame_preprocessing,
output_image_shape=self.output_image_shape,
enable_random_augmentations=False,
)
skel_is_not_nan = ~np.any(np.isnan(self.skeletons), axis=(1, 2))
self.labelled_indexes = np.where(skel_is_not_nan)[0]
if len(self.labelled_indexes) == 0:
raise ValueError("No template frames found in the dataset, can't generate synthetic images.")
self.frames_dataset = dataset.frames_dataset
self.video_name = video_name
self.postures_generator = postures_generator
def generate(self):
out_image = np.empty(self.output_image_shape, dtype=np.uint8)
with self.frames_dataset.open(self.video_name) as frames:
while True:
theta = next(self.postures_generator)
random_label_index = np.random.choice(self.labelled_indexes)
self.synthetic_dataset.generate(
theta=theta,
template_skeleton=self.skeletons[random_label_index],
template_frame=frames[random_label_index],
out_image=out_image,
template_measurements=self.measurements,
)
yield out_image, theta
def main():
import argparse
import cv2
parser = argparse.ArgumentParser()
parser.add_argument("dataset_loader", type=str)
parser.add_argument("dataset_path", type=str)
parser.add_argument("--video_name", type=str)
parser.add_argument("--random_seed", type=int, help="Optional random seed for deterministic results")
add_resizing_arguments(parser)
add_image_processing_arguments(parser)
args = parser.parse_args()
random.seed(args.random_seed)
np.random.seed(args.random_seed)
synth_visualizer_gen = SyntheticSimpleVisualizer(**vars(args)).generate()
while True:
synth_image, _ = next(synth_visualizer_gen)
cv2.imshow("synth_image", synth_image)
cv2.waitKey()
if __name__ == "__main__":
main()