Source code for kitcar_ml.utils.data.test.test_import_kitcar_xml_labels

import os
import shutil

from kitcar_ml.utils.data import import_kitcar_xml_labels
from kitcar_ml.utils.data.labeled_dataset import LabeledDataset


[docs]def test_xml_dataset_import(): print("Test xml dataset import.") import_kitcar_xml_labels.import_kitcar_xml_labels( f"{os.path.dirname(__file__)}/xml_labeled_test_data/input", f"{os.path.dirname(__file__)}/xml_labeled_test_data/output", ) outs = os.listdir(f"{os.path.dirname(__file__)}/xml_labeled_test_data/output") assert "labels.yaml" in outs for label_index in range(3): assert f"{label_index}.png" in outs dataset = LabeledDataset.from_yaml( f"{os.path.dirname(__file__)}/xml_labeled_test_data/output/labels.yaml" ) # Dataset should contain 3 Images assert len(dataset) == 4 # Dataset should have 4 different classes assert len(dataset.classes) == 4 # Labels should have at least the attributes x1,x2,y1,y2,class_id assert {"x1", "y1", "x2", "y2", "class_id"}.issubset(set(dataset.attributes)) # Expected Labels, extracted from the annotation files labels_0_expected = [ { "x1": 395, "x2": 426, "y1": 86, "y2": 113, "class_id": 0, }, { "x1": 530, "x2": 558, "y1": 41, "y2": 69, "class_id": 1, }, ] labels_1_expected = [ { "x1": 320, "x2": 358, "y1": 46, "y2": 61, "class_id": 2, } ] labels_2_expected = [ { "x1": 21, "x2": 187, "y1": 34, "y2": 121, "class_id": 2, }, { "x1": 1050, "x2": 1136, "y1": 41, "y2": 110, "class_id": 3, }, ] labels_3_expected = [] expected_labels = zip( ["0.png", "1.png", "2.png", "3.png"], [labels_0_expected, labels_1_expected, labels_2_expected, labels_3_expected], ) for img_name, labels in expected_labels: dataset_labels = dataset.labels[img_name] for label_index, label in enumerate(labels): for attribute_name, attribute_value in label.items(): index = dataset.attributes.index(attribute_name) assert dataset_labels[label_index][index] == attribute_value # Clear output dir shutil.rmtree(f"{os.path.dirname(__file__)}/xml_labeled_test_data/output")
[docs]def main(): test_xml_dataset_import()
if __name__ == "__main__": main()