最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - How to Build a More Efficient DataLoader to Load Large Image Datasets? - Stack Overflow

programmeradmin1浏览0评论

I am trying to train a deep learning model on a very large image dataset. The model input requires a pair of images (A and B). Because my image sizes are quite large, I have resized each of them to a torch.Tensor of shape (3x224x224) and stored each pair as a separate file on my disk. Same pairs share the same index.

However, when using a dataset and DataLoader to load these files into memory, I encountered the following issues:

  1. CPU Memory Issue: When setting the number of workers to 12, the 200GB memory is quickly exhausted. I tried setting prefetch_factor=1, but it did not help.
  2. Slow Initialization: Before the training starts, it takes a long time to initialize before each epoch. I read in previous posts that this could be due to the overhead of initializing. I set persistent_workers=True, but it also didn’t help.
  3. GPU and Batch Size: I am using 4 GPUs with DDP training, and my current batch size is 1024.

Are there any suggestions on how I can improve the efficiency of my dataset or DataLoader?


normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225])

augmentation = transforms.Compose([
    transforms.RandomApply([transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)], p=0.8),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0))], p=0.5),
    transforms.RandomGrayscale(p=0.1),
    transforms.RandomVerticalFlip(p=0.5),
    normalize,
])


class ImagePairDataset(Dataset):
    def __init__(self, data_save_folder, dataset_name, num_samples, transform=None):
        """
        Args:
            data_save_folder (str): Path to the folder containing the data files.
            dataset_name (str): One of 'train', 'val', or 'test' indicating the dataset split.
            num_samples (int): The number of samples in the dataset split. (train: 3000000, val: 10000, test: 10000)
            transform (callable, optional): Optional transform to be applied on an image tensor.
        """
        self.data_save_folder = data_save_folder
        self.dataset_name = dataset_name
        self.num_samples = num_samples
        self.transform = transform

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        
        # Construct file paths based on idx
        A_image_path = f"{self.data_save_folder}/{self.dataset_name}_A_images_{idx}.pt"
        B_image_path = f"{self.data_save_folder}/{self.dataset_name}_B_images_{idx}.pt"
        label_path = f"{self.data_save_folder}/{self.dataset_name}_labels_{idx}.pt"
        
        # Load the tensors from file paths
        A_image = torch.load(A_image_path)
        B_image = torch.load(B_image_path)
        label = torch.load(label_path)
        
        # Apply transformation if available
        if self.transform:
            A_image = self.transform(A_image)
            B_image = self.transform(B_image)
        
        return A_image, B_image, label

    

class ImagePairDataModule(pl.LightningDataModule):
    
    def __init__(self, data_save_folder, train_samples, val_samples, test_samples, batch_size=32, num_workers=4):
        super().__init__()
        self.data_save_folder = data_save_folder
        self.train_samples = train_samples
        self.val_samples = val_samples
        self.test_samples = test_samples
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_transform = augmentation
        self.eval_transform = normalize # Only normalization for validation and test

    def setup(self, stage=None):
        
        self.train_dataset = ImagePairDataset(self.data_save_folder, 'train', self.train_samples, transform=self.train_transform)
        self.val_dataset = ImagePairDataset(self.data_save_folder, 'val', self.val_samples, transform=self.eval_transform)
        self.test_dataset = ImagePairDataset(self.data_save_folder, 'test', self.test_samples, transform=self.eval_transform)
        
    def train_dataloader(self): #prefetch_factor=1, , persistent_workers=True
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) 

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)



# Initialize the DataModule
data_module = ImagePairDataModule(
    data_save_folder=args.data_save_folder,
    train_samples=train_samples,
    val_samples=val_samples,
    test_samples=test_samples,
    batch_size=args.batch_size,
    num_workers=12,
)
发布评论

评论列表(0)

  1. 暂无评论