Source code for wormpose.machine_learning.eval_data_generator
"""
Generates evaluation data: random real processed images with labels and save them to a Tfrecord file
"""
import csv
import logging
import os
import numpy as np
from wormpose.dataset import Dataset
from wormpose.images.real_dataset import RealDataset
from wormpose.machine_learning import tfrecord_file
from wormpose.pose.centerline import skeleton_to_angle, flip_theta
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
[docs]def generate(
dataset: Dataset,
num_samples: int,
theta_dims: int,
file_pattern: str,
) -> int:
"""
Generates evaluation dataset composed of processed real images, saved to a .TFrecord
:param dataset: WormPose Dataset
:param num_samples: How many images to generate
:param theta_dims: Dimensions of theta for the labels
:param file_pattern: Path of the output files with "index" variable
example: "path_to_out/eval_{index}.tfrecord"
:return: How many samples where actually generated (if there is less data than the requested num_samples)
"""
labelled_frames = {}
for video_name in dataset.video_names:
skel_is_not_nan = ~np.any(np.isnan(dataset.features_dataset[video_name].skeletons), axis=(1, 2))
labelled_indexes = np.where(skel_is_not_nan)[0]
if len(labelled_indexes) > 0:
labelled_frames[video_name] = labelled_indexes
if len(labelled_frames) == 0:
raise RuntimeError("Can't create evaluation data because couldn't find any labelled frame in the dataset.")
len_labelled_frames = int(np.sum([len(x) for x in labelled_frames.values()]))
if len_labelled_frames < num_samples:
logging.warning(
f"Not enough labelled frames in the dataset "
f"to create an evaluation set of {num_samples} unique samples, "
f"using all available {len_labelled_frames} samples instead."
)
num_samples = len_labelled_frames
real_dataset = RealDataset(dataset.frame_preprocessing, dataset.image_shape)
tfrecord_filename = file_pattern.format(index=0)
csv_infos_filename = os.path.splitext(tfrecord_filename)[0] + ".csv"
# get num_samples total random labelled frames from the videos
eval_frames = _populate_eval_frames(labelled_frames, num_samples)
# write the eval.tfrecord file with the images and the labels, save also the source infos in a separate eval.csv
# the frames are not shuffled by video, all the frames from one video are consecutive in the file
with tfrecord_file.Writer(tfrecord_filename) as record_writer, open(csv_infos_filename, "w") as csv_file:
csv_writer = csv.writer(csv_file, delimiter=" ", quotechar="|", quoting=csv.QUOTE_MINIMAL)
for video_name, cur_video_eval_indexes in eval_frames.items():
with dataset.frames_dataset.open(video_name) as frames:
for eval_frame_index in cur_video_eval_indexes:
image_data, _ = real_dataset.process_frame(frames[eval_frame_index])
cur_skel = dataset.features_dataset[video_name].skeletons[eval_frame_index]
cur_theta = skeleton_to_angle(cur_skel, theta_dims=theta_dims)
cur_theta_flipped = flip_theta(cur_theta)
record_writer.write(image_data, cur_theta, cur_theta_flipped)
csv_writer.writerow([video_name, int(eval_frame_index)])
return num_samples
def _populate_eval_frames(labelled_frames: dict, num_samples: int):
cur_index = 0
eval_frames = {}
while cur_index < num_samples:
# pick randomly a video name and a valid frame index for that video, without repetition
chosen_video_name = np.random.choice(list(labelled_frames.keys()))
valid_index_in_video = labelled_frames[chosen_video_name]
pick_index_at = np.random.randint(0, len(valid_index_in_video))
chosen_frame_index = valid_index_in_video[pick_index_at]
# remove chosen frame from the available frames
labelled_frames[chosen_video_name] = np.delete(labelled_frames[chosen_video_name], pick_index_at)
if len(labelled_frames[chosen_video_name]) == 0:
del labelled_frames[chosen_video_name]
if chosen_video_name in eval_frames:
eval_frames[chosen_video_name].append(chosen_frame_index)
else:
eval_frames[chosen_video_name] = [chosen_frame_index]
cur_index += 1
# sort for faster read access
for video_name in eval_frames:
eval_frames[video_name].sort()
return eval_frames