import copy
import itertools
import math
import os
import random
from dataclasses import dataclass, field
from itertools import accumulate
from typing import Dict, List, Optional, Sequence, Tuple
from kitcar_utils.basics.init_options import InitOptions
from kitcar_utils.basics.save_options import SaveOptions
from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm
from kitcar_ml.utils.data.image_folder import find_images
[docs]@dataclass
class LabeledDataset(Dataset, SaveOptions, InitOptions):
"""Dataset of images with labels."""
attributes: Optional[Sequence[str]] = None
"""Description of what each label means.
Similar to headers in a table.
"""
classes: Dict[int, str] = field(default_factory=dict)
"""Description of what the class ids represent."""
labels: Dict[str, List[Sequence[str]]] = field(default_factory=dict)
"""Collection of all labels structured as a dictionary."""
_base_path: Optional[str] = None
"""Path to the root of the dataset.
Only needs to be set if the dataset is used to load data.
"""
_cache = None
"""Helper variable for dataset caching."""
@property
def available_files(self) -> List[str]:
return [
os.path.basename(file)
for file in find_images(self._base_path)
if os.path.exists(file) and "debug" not in file
]
def __getitem__(self, index: int) -> Tuple[Image.Image, List[str]]:
"""Return an image and it's label.
Args:
index: Index of returned datapoint.
"""
# Use cache if exists
if self._cache is not None:
return self._cache[index]
key, label = list(self.labels.items())[index]
# Get image path
path = os.path.join(self._base_path, key)
# Load image
img = Image.open(path).convert("RGB")
return img, label
def __len__(self):
"""Return the total number of images in the dataset."""
return len(self.labels)
[docs] def cache(self, disable_tqdm: bool = False):
"""Cache this dataset."""
self._cache = [
data
for data in tqdm(
self, desc=f"Caching {self.__class__.__name__}", disable=disable_tqdm
)
]
[docs] def filter_labels(self):
"""Remove labels that have no corresponding image."""
all_files = self.available_files
self.labels = {key: label for key, label in self.labels.items() if key in all_files}
[docs] def append_label(self, key: str, label: List[Sequence[str]]):
"""Add a new label to the dataset.
A single image (or any abstract object) can have many labels.
"""
if key not in self.labels:
self.labels[key] = []
self.labels[key].append(label)
[docs] def save_as_yaml(self, file_path: str):
"""Save the dataset to a yaml file. Override the default method to temporarily
remove base_path and prevent writing it to the yaml file.
Args:
file_path: The output file.
"""
bp = self._base_path
del self._base_path
super().save_as_yaml(file_path)
self._base_path = bp
[docs] def make_ids_continuous(self):
"""Reformat dataset to have continuous class ids."""
ids = sorted(self.classes.keys())
for new_id, old_id in enumerate(ids):
self.replace_id(old_id, new_id)
[docs] def replace_id(self, search_id: int, replace_id: int):
"""Replace id (search) with another id (replace) in the whole dataset.
Args:
search_id: The id being searched for.
replace_id: The replacement id that replaces the search ids
"""
# Replace in classes dict
self.classes[replace_id] = self.classes.pop(search_id)
# Replace in labels dict
index = self.attributes.index("class_id")
for label in itertools.chain(*self.labels.values()):
if label[index] == search_id:
label[index] = replace_id
[docs] def split(self, fractions: List[float], shuffle: bool = True) -> List["LabeledDataset"]:
"""Split this dataset into multiple."""
assert (
round(sum(fractions), 5) == 1.0
), f"Fractions should sum up to 1 (not {round(sum(fractions),5)})"
new_datasets = [
self.__class__(
attributes=self.attributes,
classes=self.classes,
_base_path=self._base_path,
)
for _ in fractions
]
labels = list(self.labels.items())
if shuffle:
random.shuffle(labels)
counts = (int(math.ceil(len(self) * frac)) for frac in ([0] + fractions))
indices = list(accumulate(counts))
start_indices = indices[:-1]
end_indices = indices[1:]
# end_indices[-1] could be bigger than len(self)
# But even if it is larger, Python would make sure
# that all labels are used exactly once.
# So we do not need a test for it.
for dataset, from_index, to_index in zip(new_datasets, start_indices, end_indices):
dataset.labels = dict(labels[from_index:to_index])
return new_datasets
[docs] @classmethod
def from_yaml(cls, file_path: str) -> "LabeledDataset":
"""Load a Labeled Dataset from a yaml file.
Args:
file_path: The path to the yaml file to load
"""
instance = super().from_yaml(file_path)
instance._base_path = os.path.dirname(file_path)
return instance
[docs] @classmethod
def split_file(
cls, file: str, parts: Dict[str, float], shuffle: bool = True
) -> List["LabeledDataset"]:
"""Split a dataset file into multiple datasets.
Args:
file: The path to the yaml file which gets split
parts: A dict of names and and fractions
shuffle: Split the labels randomly
"""
# Read dataset and split it
dataset = cls.from_yaml(file)
new_datasets = dataset.split(list(parts.values()), shuffle)
# Save the split yaml files
for name, dataset in zip(parts.keys(), new_datasets):
dataset.save_as_yaml(os.path.join(os.path.dirname(file), f"{name}.yaml"))
return new_datasets
[docs] @classmethod
def filter_file(cls, file: str) -> "LabeledDataset":
"""Filter broken file dependencies of a yaml file.
Args:
file: The path to the yaml file to filter
"""
labeled_dataset = cls.from_yaml(file)
labeled_dataset.filter_labels()
labeled_dataset.save_as_yaml(file)
return labeled_dataset
[docs] def adjust_classes(self, other_classes: Dict[int, str], class_attribute="class_id"):
"""Change/add classes and class ids to the dataset.
Args:
other_classes: New mapping of classes.
class_attribute: Optional name of the class attribute.
"""
assert set(other_classes.values()).issuperset(
set(self.classes.values())
), "There are classes the other dataset does not contain."
assert class_attribute in self.attributes, "There's no class attribute."
inv_other_classes = {v: k for k, v in other_classes.items()}
# Mapping from old to new classes.
classes_translation = {
k: inv_other_classes[name] for k, name in self.classes.items()
}
cls_idx = self.attributes.index(class_attribute)
# Adjust classes in labels.
for label in (label for _, labels in self.labels.items() for label in labels):
label[cls_idx] = classes_translation[label[cls_idx]]
self.classes = other_classes
[docs] def prune_classes(self, classes: List[str]):
"""Keep only labels of given classes."""
if not set(classes).issubset(set(self.classes.values())):
print("WARNING: The dataset does not contain all required classes.")
class_idx = self.attributes.index("class_id")
for key, labels in self.labels.items():
self.labels[key] = [
label for label in labels if self.classes[label[class_idx]] in classes
]
self.classes = {k: v for k, v in self.classes.items() if v in classes}
[docs] def adjust_attributes(self, other_attributes: List[str]):
"""Change the order/add attributes to the dataset.
Example:
>>> from kitcar_ml.utils.data.labeled_dataset import LabeledDataset
>>> ds = LabeledDataset()
>>> ds.attributes = ["class_name", "x", "y"]
>>> ds.append_label("img1", ["cat", 2, 4])
>>> ds.append_label("img2", ["dog", 4, 6])
>>> ds
LabeledDataset(attributes=['class_name', 'x', 'y'], \
classes={}, labels={'img1': [['cat', 2, 4]], 'img2': [['dog', 4, 6]]}, _base_path=None)
>>> ds.adjust_attributes(["class_name","color", "x", "y"])
>>> ds
LabeledDataset(attributes=['class_name', 'color', 'x', 'y'], classes={}, \
labels={'img1': [['cat', None, 2, 4]], 'img2': [['dog', None, 4, 6]]}, _base_path=None)
Args:
other_attributes: New list of attributes.
"""
assert set(other_attributes).issuperset(
set(self.attributes)
), "There are attributes the other dataset does not contain."
new_attr_indices = {
i: other_attributes.index(attr) for i, attr in enumerate(self.attributes)
}
def label_translation(label):
"""Turn old label into adjusted label."""
new_label = len(other_attributes) * [None]
for i in range(len(self.attributes)):
new_label[new_attr_indices[i]] = label[i]
return new_label
for labels in self.labels.values():
for i in range(len(labels)):
labels[i] = label_translation(labels[i])
self.attributes = other_attributes
[docs] @classmethod
def merge_datasets(cls, *datasets: "LabeledDataset") -> "LabeledDataset":
"""Merge multiple labeled datasets together into one.
If you have folders with images and multiple yaml files declaring the datasets,
this method can merge them into a single dataset.
This doesn't copy or move any images!
Args:
datasets: Sequence of datasets that should be merged together.
"""
if len(datasets) > 2:
# Use recursion to merge more than two datasets
datasets = [
datasets[0],
cls.merge_datasets(*datasets[1:]),
]
elif len(datasets) == 1:
return datasets[0]
# Now: exactly two datasets left!
assert len(datasets) == 2
ds1, ds2 = datasets
"""Adjust attributes."""
all_attributes = copy.deepcopy(ds1.attributes)
# Why not just set(ds1.attrs + ds2.attrs)?
# Using the list approach retains the order of the attributes
for attr in ds2.attributes:
if attr not in all_attributes:
all_attributes.append(attr)
ds1.adjust_attributes(all_attributes)
ds2.adjust_attributes(all_attributes)
"""Adjust classes."""
if len(set(ds1.classes.values()).intersection(set(ds1.classes.values()))) == 0:
# Warning: Classes are completely disjoint!!
pass
# Create list of both datasets classes
all_classes = {
i: c
for i, c in enumerate(
sorted(set(ds1.classes.values()).union(set(ds2.classes.values())))
)
}
ds1.adjust_classes(all_classes)
ds2.adjust_classes(all_classes)
merged_dataset = cls()
merged_dataset.attributes = copy.deepcopy(all_attributes)
merged_dataset.classes = copy.deepcopy(all_classes)
# Not so nice:
# Use ds1 base bath
merged_dataset._base_path = ds1._base_path
"""Adjust labels."""
merged_dataset.labels = copy.deepcopy(ds1.labels)
# Adjust path to ds2 labels:
diff_base_path = os.path.relpath(ds2._base_path, ds1._base_path)
for file_path, labels in ds2.labels.items():
# New path by prepending relative path to new base path
new_path = os.path.join(diff_base_path, file_path)
for label in labels:
merged_dataset.append_label(new_path, label)
return merged_dataset