最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - Error When Using Batch Size Greater Than 1 in PyTorch - Stack Overflow

programmeradmin1浏览0评论

I'm building a neural network to predict how an image will be partitioned during compression using VVC (Versatile Video Coding). The model takes a single Y-frame from a YUV420 image as input and uses a CSV file containing the ground truth block positions and sizes for training.

Input and Ground Truth

  • Input: A 1-frame YUV420 10-bit image.

  • Ground Truth: A CSV file with block positions, sizes, and additional partitioning flags.

example(388016_320x480_37.yuv) enter image description here

example(388016_320x480_37.csv) enter image description here

Problem Description:

I implemented train.py and dataset.py, but I'm encountering an error when setting batch_size > 1 in the DataLoader. With a batch size of 1, the model works correctly, but increasing the batch size leads to runtime errors.

Code Summary:

Below is a simplified version of my custom_collate_fn and DataLoader setup:

def custom_collate_fn(batch):
    frames = [item[0] for item in batch]  # Y-frame tensors
    blocks = [item[1] for item in batch]  # Block information
    frames = torch.stack(frames, dim=0)  # Stacking frames along batch dimension
    return frames, blocks

dataloader = DataLoader(
    dataset,
    batch_size=batch_size, 
    shuffle=True,
    collate_fn=custom_collate_fn
)

Observations:

  • When batch_size = 1, the blocks_batch in the training loop is a list containing a single set of block data.

  • With batch_size > 1, it becomes a list of lists, causing errors when indexing.

for i, (frame, blocks_batch) in enumerate(dataloader):
    frame = frame.to(device)  # Shape: [batch_size, 1, H, W]
    blocks = blocks_batch[0]  # Works with batch_size=1 but fails with larger sizes

My Assumption:

It seems the issue arises from handling blocks_batch when batch_size > 1. The nested list structure makes it difficult to handle multiple batches.

Questions:

  • How can I adjust the custom_collate_fn or the training loop to handle batch sizes greater than 1 effectively?

  • If there's a better approach to batch-wise handling of variable-length block data, I'd appreciate any advice.

  File "C:\Users\Administrator\Documents\VVC_fast\test4\train.py", line 91, in <module>
    loss1 = criterion(out_split, target_split)
  File "C:\ProgramData\miniconda3\envs\newenv3\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\ProgramData\miniconda3\envs\newenv3\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\ProgramData\miniconda3\envs\newenv3\lib\site-packages\torch\nn\modules\loss.py", line 725, in forward
    return F.binary_cross_entropy_with_logits(input, target,
  File "C:\ProgramData\miniconda3\envs\newenv3\lib\site-packages\torch\nn\functional.py", line 3193, in binary_cross_entropy_with_logits
    raise ValueError(f"Target size ({target.size()}) must be the same as input size ({input.size()})")
ValueError: Target size (torch.Size([1, 1])) must be the same as input size (torch.Size([2, 1]))
发布评论

评论列表(0)

  1. 暂无评论