最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - Does order of transforms applied for data augmentation matter in Torchvision transforms? - Stack Overflow

programmeradmin0浏览0评论

I have the following Custom dataset class for an image segmentation task.

class LoadDataset(Dataset):
    def __init__(self, img_dir, mask_dir, apply_transforms = None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transforms = apply_transforms
        self.img_paths, self.mask_paths = self.__get_all_paths()
        self.__pil_to_tensor = transforms.PILToTensor()
        self.__float_tensor = transforms.ToDtype(torch.float32, scale = True)
        self.__grayscale = transforms.Grayscale()

    def __get_all_paths(self):
        img_paths = [os.path.join(self.img_dir, img_name.name) for img_name in os.scandir(self.img_dir) if os.path.isfile(img_name)]
        mask_paths = [os.path.join(self.mask_dir, mask_name.name) for mask_name in os.scandir(self.mask_dir) if os.path.isfile(mask_name)]
        img_paths = sorted(img_paths)
        mask_paths = sorted(mask_paths)
        return img_paths, mask_paths

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        img_path, mask_path = self.img_paths[index], self.mask_paths[index]
        img_PIL = Image.open(img_path)
        mask_PIL = Image.open(mask_path)
        img_tensor = self.__pil_to_tensor(img_PIL)
        img_tensor = self.__float_tensor(img_tensor)
        mask_tensor = self.__pil_to_tensor(mask_PIL)
        mask_tensor = self.__float_tensor(mask_tensor)
        mask_tensor = self.__grayscale(mask_tensor)
        if self.transforms:
            img_tensor, mask_tensor = self.transforms(img_tensor, mask_tensor)
        return img_tensor, mask_tensor

When I am applying the following transformation:

transforms.RandomHorizontalFlip()

either the image or the mask is being flipped. But if I change the order of the transformations in __getitem__ to the following, then it works fine.

def __getitem__(self, index):
    img_path, mask_path = self.img_paths[index], self.mask_paths[index]
    img_PIL = Image.open(img_path)
    mask_PIL = Image.open(mask_path)
    if self.transforms:
        img_PIL, mask_PIL = self.transforms(img_PIL, mask_PIL)
    img_tensor = self.__pil_to_tensor(img_PIL)
    mask_tensor = self.__pil_to_tensor(mask_PIL)
    img_tensor = self.__float_tensor(img_tensor)
    mask_tensor = self.__float_tensor(mask_tensor)
    mask_tensor = self.__grayscale(mask_tensor)
    return img_tensor, mask_tensor

Does the order transformation matter? I'm using torchvision.transforms.v2 for all the transformations.

I have the following Custom dataset class for an image segmentation task.

class LoadDataset(Dataset):
    def __init__(self, img_dir, mask_dir, apply_transforms = None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transforms = apply_transforms
        self.img_paths, self.mask_paths = self.__get_all_paths()
        self.__pil_to_tensor = transforms.PILToTensor()
        self.__float_tensor = transforms.ToDtype(torch.float32, scale = True)
        self.__grayscale = transforms.Grayscale()

    def __get_all_paths(self):
        img_paths = [os.path.join(self.img_dir, img_name.name) for img_name in os.scandir(self.img_dir) if os.path.isfile(img_name)]
        mask_paths = [os.path.join(self.mask_dir, mask_name.name) for mask_name in os.scandir(self.mask_dir) if os.path.isfile(mask_name)]
        img_paths = sorted(img_paths)
        mask_paths = sorted(mask_paths)
        return img_paths, mask_paths

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        img_path, mask_path = self.img_paths[index], self.mask_paths[index]
        img_PIL = Image.open(img_path)
        mask_PIL = Image.open(mask_path)
        img_tensor = self.__pil_to_tensor(img_PIL)
        img_tensor = self.__float_tensor(img_tensor)
        mask_tensor = self.__pil_to_tensor(mask_PIL)
        mask_tensor = self.__float_tensor(mask_tensor)
        mask_tensor = self.__grayscale(mask_tensor)
        if self.transforms:
            img_tensor, mask_tensor = self.transforms(img_tensor, mask_tensor)
        return img_tensor, mask_tensor

When I am applying the following transformation:

transforms.RandomHorizontalFlip()

either the image or the mask is being flipped. But if I change the order of the transformations in __getitem__ to the following, then it works fine.

def __getitem__(self, index):
    img_path, mask_path = self.img_paths[index], self.mask_paths[index]
    img_PIL = Image.open(img_path)
    mask_PIL = Image.open(mask_path)
    if self.transforms:
        img_PIL, mask_PIL = self.transforms(img_PIL, mask_PIL)
    img_tensor = self.__pil_to_tensor(img_PIL)
    mask_tensor = self.__pil_to_tensor(mask_PIL)
    img_tensor = self.__float_tensor(img_tensor)
    mask_tensor = self.__float_tensor(mask_tensor)
    mask_tensor = self.__grayscale(mask_tensor)
    return img_tensor, mask_tensor

Does the order transformation matter? I'm using torchvision.transforms.v2 for all the transformations.

Share Improve this question edited Jan 20 at 8:14 Shaido 28.3k25 gold badges74 silver badges81 bronze badges asked Jan 19 at 14:02 Amit SurAmit Sur 336 bronze badges
Add a comment  | 

1 Answer 1

Reset to default 0

Yes, the order of transformations matters. In this case, the transform to tensors makes the difference. When v2.RandomHorizontalFlip is given two tensors, the flip will be applied independently. However, when two PIL images are given, the same transform will be applied to both images, thus keeping the image and mask aligned.

For a more consistent handling, you can try using TVTensors for the data augmentation. Using these, you can specify the type of each data input before transforming them. For example:

from torchvision import tv_tensors

img_tensor = tv_tensors.Image(img_tensor)
mask_tensor= tv_tensors.Mask(mask_tensor)
发布评论

评论列表(0)

  1. 暂无评论