Source code for kitcar_ml.utils.data.test.test_labeled_data

import copy
import itertools
import os
import random
from typing import List

import hypothesis
import hypothesis.strategies as st

from kitcar_ml.utils.data.data_loader.utils import load_labeled_dataset, sample_generator
from kitcar_ml.utils.data.labeled_dataset import LabeledDataset

LABEL_FILE_PATH = os.path.dirname(__file__) + "/labeled_dir/labels.yaml"


[docs]def abs_path(rel): return os.path.join(os.path.dirname(__file__), rel)
[docs]def test_creating_labels(): label_dataset = LabeledDataset(_base_path=os.path.dirname(LABEL_FILE_PATH)) label_dataset.classes = {0: "class_0", 1: "class_1"} label_dataset.attributes = ["img_path", "class_id"] label_dataset.append_label("img0.png", ["img0.png", 0]) label_dataset.append_label("img0.png", ["img0.png", 1]) label_dataset.append_label("img1.png", ["img1.png", 1]) label_dataset.save_as_yaml(LABEL_FILE_PATH) loaded_label_dataset = LabeledDataset.from_yaml(LABEL_FILE_PATH) assert label_dataset == loaded_label_dataset
[docs]def test_labeled_dataset(): dataset = LabeledDataset.from_yaml(LABEL_FILE_PATH) assert dataset.__getitem__(0)[1] == [["img0.png", 0], ["img0.png", 1]] assert dataset.__getitem__(1)[1] == [["img1.png", 1]] assert len(dataset) == 2
[docs]def test_loading_labeled_dataset(): dataloader = load_labeled_dataset(LABEL_FILE_PATH, 1, 1, True, num_workers=0) assert len(dataloader) == 1 dataloader = load_labeled_dataset(LABEL_FILE_PATH, None, 1, True, num_workers=0) assert len(dataloader) == 2 sample_gen = sample_generator(dataloader, 5) assert len([0 for _ in sample_gen]) == 5 sample_gen = sample_generator(dataloader, 5) # Check if first result is correct _, labels = next(sample_gen) # Somewhat weird behavior of the dataloader: assert labels == [ [ ["img0.png", 0], ["img0.png", 1], ] ]
[docs]@hypothesis.given(st.lists(st.text("abcdef12345", min_size=1, max_size=20), max_size=100)) def test_adjust_classes(other_classes: List[str]): old_ds = LabeledDataset.from_yaml(abs_path("./datasets/dummy/labels.yaml")) new_ds = copy.deepcopy(old_ds) all_classes = list(set(new_ds.classes.values()).union(set(other_classes))) random.shuffle(all_classes) new_classes = {i: cls_name for i, cls_name in enumerate(all_classes)} new_ds.adjust_classes(new_classes) cls_idx = old_ds.attributes.index("class_id") assert set(old_ds.labels.keys()) == set(new_ds.labels.keys()) assert old_ds.attributes == new_ds.attributes assert set(new_ds.classes.values()).issuperset(set(old_ds.classes.values())) inv_new_classes = {v: k for k, v in new_ds.classes.items()} for key in new_ds.labels.keys(): assert len(new_ds.labels[key]) == len( old_ds.labels[key] ), f"Some labels have gone missing for key {key}..." # Expect that labels retain order... for new_label, old_label in zip(new_ds.labels[key], old_ds.labels[key]): old_class_name = old_ds.classes[old_label[cls_idx]] # Recreate old label exp_new_label = copy.deepcopy(old_label) exp_new_label[cls_idx] = inv_new_classes[old_class_name] assert ( new_label == exp_new_label ), f"There's no label corresponding to the old {old_label}"
[docs]@hypothesis.given(st.lists(st.text("abcdef12345", min_size=1, max_size=20), max_size=100)) def test_adjust_attributes(other_attributes: List[str]): old_ds = LabeledDataset.from_yaml(abs_path("./datasets/dummy/labels.yaml")) new_ds = copy.deepcopy(old_ds) new_attributes = list(set(new_ds.attributes).union(set(other_attributes))) random.shuffle(new_attributes) new_ds.adjust_attributes(new_attributes) assert set(old_ds.labels.keys()) == set(new_ds.labels.keys()) assert new_attributes == new_ds.attributes # inv_new_attributes = {v: k for k, v in new_ds.attributes.items()} for key in new_ds.labels.keys(): assert len(new_ds.labels[key]) == len( old_ds.labels[key] ), f"Some labels have gone missing for key {key}..." # Expect that labels retain order... for new_label, old_label in zip(new_ds.labels[key], old_ds.labels[key]): for i, entry in enumerate(new_label): new_attr = new_ds.attributes[i] if new_attr in old_ds.attributes: assert entry == old_label[old_ds.attributes.index(new_attr)] else: assert entry is None
[docs]def test_merging(): ds1 = LabeledDataset.from_yaml(abs_path("./datasets/dummy/labels.yaml")) ds2 = LabeledDataset.from_yaml(abs_path("./datasets/dummy1/labels.yaml")) ds3 = LabeledDataset.from_yaml(abs_path("./datasets/dummy2/labels.yaml")) ordered_datasets = [ds1, ds2, ds3] for datasets in itertools.permutations(ordered_datasets): md = LabeledDataset.merge_datasets(*datasets) all_attributes = set() all_classes = set() for ds in datasets: all_attributes = all_attributes.union(set(ds.attributes)) all_classes = all_classes.union(set(ds.classes.values())) assert set(md.classes.values()) == all_classes assert set(md.attributes) == all_attributes
[docs]def main(): test_creating_labels() test_labeled_dataset() test_loading_labeled_dataset() os.remove(LABEL_FILE_PATH) test_adjust_classes() test_adjust_attributes() test_merging()
if __name__ == "__main__": main()