import copy
import itertools
import os
import random
from typing import List
import hypothesis
import hypothesis.strategies as st
from kitcar_ml.utils.data.data_loader.utils import load_labeled_dataset, sample_generator
from kitcar_ml.utils.data.labeled_dataset import LabeledDataset
LABEL_FILE_PATH = os.path.dirname(__file__) + "/labeled_dir/labels.yaml"
[docs]def abs_path(rel):
return os.path.join(os.path.dirname(__file__), rel)
[docs]def test_creating_labels():
label_dataset = LabeledDataset(_base_path=os.path.dirname(LABEL_FILE_PATH))
label_dataset.classes = {0: "class_0", 1: "class_1"}
label_dataset.attributes = ["img_path", "class_id"]
label_dataset.append_label("img0.png", ["img0.png", 0])
label_dataset.append_label("img0.png", ["img0.png", 1])
label_dataset.append_label("img1.png", ["img1.png", 1])
label_dataset.save_as_yaml(LABEL_FILE_PATH)
loaded_label_dataset = LabeledDataset.from_yaml(LABEL_FILE_PATH)
assert label_dataset == loaded_label_dataset
[docs]def test_labeled_dataset():
dataset = LabeledDataset.from_yaml(LABEL_FILE_PATH)
assert dataset.__getitem__(0)[1] == [["img0.png", 0], ["img0.png", 1]]
assert dataset.__getitem__(1)[1] == [["img1.png", 1]]
assert len(dataset) == 2
[docs]def test_loading_labeled_dataset():
dataloader = load_labeled_dataset(LABEL_FILE_PATH, 1, 1, True, num_workers=0)
assert len(dataloader) == 1
dataloader = load_labeled_dataset(LABEL_FILE_PATH, None, 1, True, num_workers=0)
assert len(dataloader) == 2
sample_gen = sample_generator(dataloader, 5)
assert len([0 for _ in sample_gen]) == 5
sample_gen = sample_generator(dataloader, 5)
# Check if first result is correct
_, labels = next(sample_gen)
# Somewhat weird behavior of the dataloader:
assert labels == [
[
["img0.png", 0],
["img0.png", 1],
]
]
[docs]@hypothesis.given(st.lists(st.text("abcdef12345", min_size=1, max_size=20), max_size=100))
def test_adjust_classes(other_classes: List[str]):
old_ds = LabeledDataset.from_yaml(abs_path("./datasets/dummy/labels.yaml"))
new_ds = copy.deepcopy(old_ds)
all_classes = list(set(new_ds.classes.values()).union(set(other_classes)))
random.shuffle(all_classes)
new_classes = {i: cls_name for i, cls_name in enumerate(all_classes)}
new_ds.adjust_classes(new_classes)
cls_idx = old_ds.attributes.index("class_id")
assert set(old_ds.labels.keys()) == set(new_ds.labels.keys())
assert old_ds.attributes == new_ds.attributes
assert set(new_ds.classes.values()).issuperset(set(old_ds.classes.values()))
inv_new_classes = {v: k for k, v in new_ds.classes.items()}
for key in new_ds.labels.keys():
assert len(new_ds.labels[key]) == len(
old_ds.labels[key]
), f"Some labels have gone missing for key {key}..."
# Expect that labels retain order...
for new_label, old_label in zip(new_ds.labels[key], old_ds.labels[key]):
old_class_name = old_ds.classes[old_label[cls_idx]]
# Recreate old label
exp_new_label = copy.deepcopy(old_label)
exp_new_label[cls_idx] = inv_new_classes[old_class_name]
assert (
new_label == exp_new_label
), f"There's no label corresponding to the old {old_label}"
[docs]@hypothesis.given(st.lists(st.text("abcdef12345", min_size=1, max_size=20), max_size=100))
def test_adjust_attributes(other_attributes: List[str]):
old_ds = LabeledDataset.from_yaml(abs_path("./datasets/dummy/labels.yaml"))
new_ds = copy.deepcopy(old_ds)
new_attributes = list(set(new_ds.attributes).union(set(other_attributes)))
random.shuffle(new_attributes)
new_ds.adjust_attributes(new_attributes)
assert set(old_ds.labels.keys()) == set(new_ds.labels.keys())
assert new_attributes == new_ds.attributes
# inv_new_attributes = {v: k for k, v in new_ds.attributes.items()}
for key in new_ds.labels.keys():
assert len(new_ds.labels[key]) == len(
old_ds.labels[key]
), f"Some labels have gone missing for key {key}..."
# Expect that labels retain order...
for new_label, old_label in zip(new_ds.labels[key], old_ds.labels[key]):
for i, entry in enumerate(new_label):
new_attr = new_ds.attributes[i]
if new_attr in old_ds.attributes:
assert entry == old_label[old_ds.attributes.index(new_attr)]
else:
assert entry is None
[docs]def test_merging():
ds1 = LabeledDataset.from_yaml(abs_path("./datasets/dummy/labels.yaml"))
ds2 = LabeledDataset.from_yaml(abs_path("./datasets/dummy1/labels.yaml"))
ds3 = LabeledDataset.from_yaml(abs_path("./datasets/dummy2/labels.yaml"))
ordered_datasets = [ds1, ds2, ds3]
for datasets in itertools.permutations(ordered_datasets):
md = LabeledDataset.merge_datasets(*datasets)
all_attributes = set()
all_classes = set()
for ds in datasets:
all_attributes = all_attributes.union(set(ds.attributes))
all_classes = all_classes.union(set(ds.classes.values()))
assert set(md.classes.values()) == all_classes
assert set(md.attributes) == all_attributes
[docs]def main():
test_creating_labels()
test_labeled_dataset()
test_loading_labeled_dataset()
os.remove(LABEL_FILE_PATH)
test_adjust_classes()
test_adjust_attributes()
test_merging()
if __name__ == "__main__":
main()