import os
import shutil
from kitcar_ml.utils.data import import_kitcar_xml_labels
from kitcar_ml.utils.data.labeled_dataset import LabeledDataset
[docs]def test_xml_dataset_import():
print("Test xml dataset import.")
import_kitcar_xml_labels.import_kitcar_xml_labels(
f"{os.path.dirname(__file__)}/xml_labeled_test_data/input",
f"{os.path.dirname(__file__)}/xml_labeled_test_data/output",
)
outs = os.listdir(f"{os.path.dirname(__file__)}/xml_labeled_test_data/output")
assert "labels.yaml" in outs
for label_index in range(3):
assert f"{label_index}.png" in outs
dataset = LabeledDataset.from_yaml(
f"{os.path.dirname(__file__)}/xml_labeled_test_data/output/labels.yaml"
)
# Dataset should contain 3 Images
assert len(dataset) == 4
# Dataset should have 4 different classes
assert len(dataset.classes) == 4
# Labels should have at least the attributes x1,x2,y1,y2,class_id
assert {"x1", "y1", "x2", "y2", "class_id"}.issubset(set(dataset.attributes))
# Expected Labels, extracted from the annotation files
labels_0_expected = [
{
"x1": 395,
"x2": 426,
"y1": 86,
"y2": 113,
"class_id": 0,
},
{
"x1": 530,
"x2": 558,
"y1": 41,
"y2": 69,
"class_id": 1,
},
]
labels_1_expected = [
{
"x1": 320,
"x2": 358,
"y1": 46,
"y2": 61,
"class_id": 2,
}
]
labels_2_expected = [
{
"x1": 21,
"x2": 187,
"y1": 34,
"y2": 121,
"class_id": 2,
},
{
"x1": 1050,
"x2": 1136,
"y1": 41,
"y2": 110,
"class_id": 3,
},
]
labels_3_expected = []
expected_labels = zip(
["0.png", "1.png", "2.png", "3.png"],
[labels_0_expected, labels_1_expected, labels_2_expected, labels_3_expected],
)
for img_name, labels in expected_labels:
dataset_labels = dataset.labels[img_name]
for label_index, label in enumerate(labels):
for attribute_name, attribute_value in label.items():
index = dataset.attributes.index(attribute_name)
assert dataset_labels[label_index][index] == attribute_value
# Clear output dir
shutil.rmtree(f"{os.path.dirname(__file__)}/xml_labeled_test_data/output")
[docs]def main():
test_xml_dataset_import()
if __name__ == "__main__":
main()