# Ultralytics YOLO 🚀, GPL-3.0 license

import collections
from copy import deepcopy

from .augment import LetterBox


class MixAndRectDataset:
    """A wrapper of multiple images mixed dataset.

    Args:
        dataset (:obj:`BaseDataset`): The dataset to be mixed.
        transforms (Sequence[dict]): config dict to be composed.
    """

    def __init__(self, dataset):
        self.dataset = dataset
        self.imgsz = dataset.imgsz

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        labels = deepcopy(self.dataset[index])
        for transform in self.dataset.transforms.tolist():
            # mosaic and mixup
            if hasattr(transform, "get_indexes"):
                indexes = transform.get_indexes(self.dataset)
                if not isinstance(indexes, collections.abc.Sequence):
                    indexes = [indexes]
                mix_labels = [deepcopy(self.dataset[index]) for index in indexes]
                labels["mix_labels"] = mix_labels
            if self.dataset.rect and isinstance(transform, LetterBox):
                transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]]
            labels = transform(labels)
            if "mix_labels" in labels:
                labels.pop("mix_labels")
        return labels