Source code for kitcar_ml.utils.data.labeled_dataset

import copy
import itertools
import math
import os
import random
from dataclasses import dataclass, field
from itertools import accumulate
from typing import Dict, List, Optional, Sequence, Tuple

from kitcar_utils.basics.init_options import InitOptions
from kitcar_utils.basics.save_options import SaveOptions
from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm

from kitcar_ml.utils.data.image_folder import find_images


[docs]@dataclass class LabeledDataset(Dataset, SaveOptions, InitOptions): """Dataset of images with labels.""" attributes: Optional[Sequence[str]] = None """Description of what each label means. Similar to headers in a table. """ classes: Dict[int, str] = field(default_factory=dict) """Description of what the class ids represent.""" labels: Dict[str, List[Sequence[str]]] = field(default_factory=dict) """Collection of all labels structured as a dictionary.""" _base_path: Optional[str] = None """Path to the root of the dataset. Only needs to be set if the dataset is used to load data. """ _cache = None """Helper variable for dataset caching.""" @property def available_files(self) -> List[str]: return [ os.path.basename(file) for file in find_images(self._base_path) if os.path.exists(file) and "debug" not in file ] def __getitem__(self, index: int) -> Tuple[Image.Image, List[str]]: """Return an image and it's label. Args: index: Index of returned datapoint. """ # Use cache if exists if self._cache is not None: return self._cache[index] key, label = list(self.labels.items())[index] # Get image path path = os.path.join(self._base_path, key) # Load image img = Image.open(path).convert("RGB") return img, label def __len__(self): """Return the total number of images in the dataset.""" return len(self.labels)
[docs] def cache(self, disable_tqdm: bool = False): """Cache this dataset.""" self._cache = [ data for data in tqdm( self, desc=f"Caching {self.__class__.__name__}", disable=disable_tqdm ) ]
[docs] def filter_labels(self): """Remove labels that have no corresponding image.""" all_files = self.available_files self.labels = {key: label for key, label in self.labels.items() if key in all_files}
[docs] def append_label(self, key: str, label: List[Sequence[str]]): """Add a new label to the dataset. A single image (or any abstract object) can have many labels. """ if key not in self.labels: self.labels[key] = [] self.labels[key].append(label)
[docs] def save_as_yaml(self, file_path: str): """Save the dataset to a yaml file. Override the default method to temporarily remove base_path and prevent writing it to the yaml file. Args: file_path: The output file. """ bp = self._base_path del self._base_path super().save_as_yaml(file_path) self._base_path = bp
[docs] def make_ids_continuous(self): """Reformat dataset to have continuous class ids.""" ids = sorted(self.classes.keys()) for new_id, old_id in enumerate(ids): self.replace_id(old_id, new_id)
[docs] def replace_id(self, search_id: int, replace_id: int): """Replace id (search) with another id (replace) in the whole dataset. Args: search_id: The id being searched for. replace_id: The replacement id that replaces the search ids """ # Replace in classes dict self.classes[replace_id] = self.classes.pop(search_id) # Replace in labels dict index = self.attributes.index("class_id") for label in itertools.chain(*self.labels.values()): if label[index] == search_id: label[index] = replace_id
[docs] def split(self, fractions: List[float], shuffle: bool = True) -> List["LabeledDataset"]: """Split this dataset into multiple.""" assert ( round(sum(fractions), 5) == 1.0 ), f"Fractions should sum up to 1 (not {round(sum(fractions),5)})" new_datasets = [ self.__class__( attributes=self.attributes, classes=self.classes, _base_path=self._base_path, ) for _ in fractions ] labels = list(self.labels.items()) if shuffle: random.shuffle(labels) counts = (int(math.ceil(len(self) * frac)) for frac in ([0] + fractions)) indices = list(accumulate(counts)) start_indices = indices[:-1] end_indices = indices[1:] # end_indices[-1] could be bigger than len(self) # But even if it is larger, Python would make sure # that all labels are used exactly once. # So we do not need a test for it. for dataset, from_index, to_index in zip(new_datasets, start_indices, end_indices): dataset.labels = dict(labels[from_index:to_index]) return new_datasets
[docs] @classmethod def from_yaml(cls, file_path: str) -> "LabeledDataset": """Load a Labeled Dataset from a yaml file. Args: file_path: The path to the yaml file to load """ instance = super().from_yaml(file_path) instance._base_path = os.path.dirname(file_path) return instance
[docs] @classmethod def split_file( cls, file: str, parts: Dict[str, float], shuffle: bool = True ) -> List["LabeledDataset"]: """Split a dataset file into multiple datasets. Args: file: The path to the yaml file which gets split parts: A dict of names and and fractions shuffle: Split the labels randomly """ # Read dataset and split it dataset = cls.from_yaml(file) new_datasets = dataset.split(list(parts.values()), shuffle) # Save the split yaml files for name, dataset in zip(parts.keys(), new_datasets): dataset.save_as_yaml(os.path.join(os.path.dirname(file), f"{name}.yaml")) return new_datasets
[docs] @classmethod def filter_file(cls, file: str) -> "LabeledDataset": """Filter broken file dependencies of a yaml file. Args: file: The path to the yaml file to filter """ labeled_dataset = cls.from_yaml(file) labeled_dataset.filter_labels() labeled_dataset.save_as_yaml(file) return labeled_dataset
[docs] def adjust_classes(self, other_classes: Dict[int, str], class_attribute="class_id"): """Change/add classes and class ids to the dataset. Args: other_classes: New mapping of classes. class_attribute: Optional name of the class attribute. """ assert set(other_classes.values()).issuperset( set(self.classes.values()) ), "There are classes the other dataset does not contain." assert class_attribute in self.attributes, "There's no class attribute." inv_other_classes = {v: k for k, v in other_classes.items()} # Mapping from old to new classes. classes_translation = { k: inv_other_classes[name] for k, name in self.classes.items() } cls_idx = self.attributes.index(class_attribute) # Adjust classes in labels. for label in (label for _, labels in self.labels.items() for label in labels): label[cls_idx] = classes_translation[label[cls_idx]] self.classes = other_classes
[docs] def prune_classes(self, classes: List[str]): """Keep only labels of given classes.""" if not set(classes).issubset(set(self.classes.values())): print("WARNING: The dataset does not contain all required classes.") class_idx = self.attributes.index("class_id") for key, labels in self.labels.items(): self.labels[key] = [ label for label in labels if self.classes[label[class_idx]] in classes ] self.classes = {k: v for k, v in self.classes.items() if v in classes}
[docs] def adjust_attributes(self, other_attributes: List[str]): """Change the order/add attributes to the dataset. Example: >>> from kitcar_ml.utils.data.labeled_dataset import LabeledDataset >>> ds = LabeledDataset() >>> ds.attributes = ["class_name", "x", "y"] >>> ds.append_label("img1", ["cat", 2, 4]) >>> ds.append_label("img2", ["dog", 4, 6]) >>> ds LabeledDataset(attributes=['class_name', 'x', 'y'], \ classes={}, labels={'img1': [['cat', 2, 4]], 'img2': [['dog', 4, 6]]}, _base_path=None) >>> ds.adjust_attributes(["class_name","color", "x", "y"]) >>> ds LabeledDataset(attributes=['class_name', 'color', 'x', 'y'], classes={}, \ labels={'img1': [['cat', None, 2, 4]], 'img2': [['dog', None, 4, 6]]}, _base_path=None) Args: other_attributes: New list of attributes. """ assert set(other_attributes).issuperset( set(self.attributes) ), "There are attributes the other dataset does not contain." new_attr_indices = { i: other_attributes.index(attr) for i, attr in enumerate(self.attributes) } def label_translation(label): """Turn old label into adjusted label.""" new_label = len(other_attributes) * [None] for i in range(len(self.attributes)): new_label[new_attr_indices[i]] = label[i] return new_label for labels in self.labels.values(): for i in range(len(labels)): labels[i] = label_translation(labels[i]) self.attributes = other_attributes
[docs] @classmethod def merge_datasets(cls, *datasets: "LabeledDataset") -> "LabeledDataset": """Merge multiple labeled datasets together into one. If you have folders with images and multiple yaml files declaring the datasets, this method can merge them into a single dataset. This doesn't copy or move any images! Args: datasets: Sequence of datasets that should be merged together. """ if len(datasets) > 2: # Use recursion to merge more than two datasets datasets = [ datasets[0], cls.merge_datasets(*datasets[1:]), ] elif len(datasets) == 1: return datasets[0] # Now: exactly two datasets left! assert len(datasets) == 2 ds1, ds2 = datasets """Adjust attributes.""" all_attributes = copy.deepcopy(ds1.attributes) # Why not just set(ds1.attrs + ds2.attrs)? # Using the list approach retains the order of the attributes for attr in ds2.attributes: if attr not in all_attributes: all_attributes.append(attr) ds1.adjust_attributes(all_attributes) ds2.adjust_attributes(all_attributes) """Adjust classes.""" if len(set(ds1.classes.values()).intersection(set(ds1.classes.values()))) == 0: # Warning: Classes are completely disjoint!! pass # Create list of both datasets classes all_classes = { i: c for i, c in enumerate( sorted(set(ds1.classes.values()).union(set(ds2.classes.values()))) ) } ds1.adjust_classes(all_classes) ds2.adjust_classes(all_classes) merged_dataset = cls() merged_dataset.attributes = copy.deepcopy(all_attributes) merged_dataset.classes = copy.deepcopy(all_classes) # Not so nice: # Use ds1 base bath merged_dataset._base_path = ds1._base_path """Adjust labels.""" merged_dataset.labels = copy.deepcopy(ds1.labels) # Adjust path to ds2 labels: diff_base_path = os.path.relpath(ds2._base_path, ds1._base_path) for file_path, labels in ds2.labels.items(): # New path by prepending relative path to new base path new_path = os.path.join(diff_base_path, file_path) for label in labels: merged_dataset.append_label(new_path, label) return merged_dataset