I'm building a neural network to predict how an image will be partitioned during compression using VVC (Versatile Video Coding). The model takes a single Y-frame from a YUV420 image as input and uses a CSV file containing the ground truth block positions and sizes for training.
Input and Ground Truth
Input: A 1-frame YUV420 10-bit image.
Ground Truth: A CSV file with block positions, sizes, and additional partitioning flags.
example(388016_320x480_37.yuv) enter image description here
example(388016_320x480_37.csv) enter image description here
Problem Description:
I implemented train.py
and dataset.py
, but I'm encountering an error when setting batch_size > 1
in the DataLoader
. With a batch size of 1, the model works correctly, but increasing the batch size leads to runtime errors.
Code Summary:
Below is a simplified version of my custom_collate_fn
and DataLoader
setup:
def custom_collate_fn(batch):
frames = [item[0] for item in batch] # Y-frame tensors
blocks = [item[1] for item in batch] # Block information
frames = torch.stack(frames, dim=0) # Stacking frames along batch dimension
return frames, blocks
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=custom_collate_fn
)
Observations:
When
batch_size = 1
, theblocks_batch
in the training loop is a list containing a single set of block data.With
batch_size > 1
, it becomes a list of lists, causing errors when indexing.
for i, (frame, blocks_batch) in enumerate(dataloader):
frame = frame.to(device) # Shape: [batch_size, 1, H, W]
blocks = blocks_batch[0] # Works with batch_size=1 but fails with larger sizes
My Assumption:
It seems the issue arises from handling blocks_batch
when batch_size > 1
. The nested list structure makes it difficult to handle multiple batches.
Questions:
How can I adjust the
custom_collate_fn
or the training loop to handle batch sizes greater than 1 effectively?If there's a better approach to batch-wise handling of variable-length block data, I'd appreciate any advice.
File "C:\Users\Administrator\Documents\VVC_fast\test4\train.py", line 91, in <module>
loss1 = criterion(out_split, target_split)
File "C:\ProgramData\miniconda3\envs\newenv3\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\ProgramData\miniconda3\envs\newenv3\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "C:\ProgramData\miniconda3\envs\newenv3\lib\site-packages\torch\nn\modules\loss.py", line 725, in forward
return F.binary_cross_entropy_with_logits(input, target,
File "C:\ProgramData\miniconda3\envs\newenv3\lib\site-packages\torch\nn\functional.py", line 3193, in binary_cross_entropy_with_logits
raise ValueError(f"Target size ({target.size()}) must be the same as input size ({input.size()})")
ValueError: Target size (torch.Size([1, 1])) must be the same as input size (torch.Size([2, 1]))