I have been training a CNN Autoencoder on binary images (pixels are either 0 or 1) of size 64x64. The model is shown below:
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNNAutoencoder(nn.Module):
"""
Input shape: (B, 1, 64, 64)
"""
def __init__(self, grid_size=64):
super().__init__()
self.grid_size = grid_size
# 1) Encoder layers
self.enc_conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.bn_enc1 = nn.BatchNorm2d(16)
self.enc_conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.bn_enc2 = nn.BatchNorm2d(32)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# After two pools => shape is (32, grid_size//4, grid_size//4)
flat_dim = 32 * (grid_size // 4) * (grid_size // 4)
latent_dim = 128
# For linear layers, we can use BatchNorm1d:
self.fc_enc = nn.Linear(flat_dim, latent_dim)
self.bn_fc_enc = nn.BatchNorm1d(latent_dim)
# 2) Decoder layers
self.fc_dec = nn.Linear(latent_dim, flat_dim)
self.bn_fc_dec = nn.BatchNorm1d(flat_dim)
self.dec_tconv1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
self.bn_dec1 = nn.BatchNorm2d(16)
self.dec_tconv2 = nn.ConvTranspose2d(16, 1, kernel_size=2, stride=2)
def encoder(self, x):
# x => (B,1,64,64)
x = self.enc_conv1(x) # => (B,16,64,64)
x = self.bn_enc1(x)
x = F.relu(x)
x = self.pool(x) # => (B,16,32,32)
x = self.enc_conv2(x) # => (B,32,32,32)
x = self.bn_enc2(x)
x = F.relu(x)
x = self.pool(x) # => (B,32,16,16)
# Flatten
x = x.view(x.size(0), -1) # => (B, flat_dim=32*16*16)
x = self.fc_enc(x) # => (B, latent_dim=128)
x = self.bn_fc_enc(x)
x = F.relu(x)
return x
def decoder(self, z):
# z => (B,128)
x = self.fc_dec(z) # => (B, flat_dim)
x = self.bn_fc_dec(x)
x = F.relu(x)
# Reshape to (B, 32, 16, 16)
B = x.size(0)
x = x.view(B, 32, self.grid_size // 4, self.grid_size // 4)
x = self.dec_tconv1(x) # => (B,16,32,32)
x = self.bn_dec1(x)
x = F.relu(x)
# Final upsample => (B,1,64,64)
x = self.dec_tconv2(x)
x = torch.sigmoid(x)
return x
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)
Train set includes 1079156 samples. I'm using batch size 128 and 7 epochs. Right now, 1 epoch takes approximately 3 hours to train. Does anyone have an idea what might be the problem?
I tried to switch between BatchNorm and LayerNorm/GroupNorm but the performance didn't change. I also tried early stopping, but this doesn't solve the problem of timing per epoch. For your reference, I've pasted my training code below:
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
# Model
model = CNNAutoencoder(grid_size=64).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.BCELoss()
best_val_loss = float('inf')
model_fname = f"best_{rt}_fold{fold}.pth"
global_step = 0
for ep in range(num_epochs):
if global_step >= max_steps:
print(f"Reached {max_steps} total steps; stopping early.")
break
t0 = time.time()
#############################
# Train Loop (per epoch)
#############################
model.train()
total_loss = 0.0
for x_in, x_tgt in train_loader:
x_in = x_in.to(device)
x_tgt = x_tgt.to(device)
optimizer.zero_grad()
out = model(x_in)
loss = loss_fn(out, x_tgt)
loss.backward()
optimizer.step()
total_loss += loss.item() * x_in.size(0)
global_step += 1 # increment step count
if global_step >= max_steps:
print(f"Reached {max_steps} total steps; stopping in mid-epoch.")
break
train_epoch_loss = total_loss / len(train_loader.dataset)
# If we already hit max_steps, break out
if global_step >= max_steps:
break
#############################
# Validation Loop
#############################
model.eval()
val_loss_sum = 0.0
with torch.no_grad():
for x_in, x_tgt in val_loader:
x_in, x_tgt = x_in.to(device), x_tgt.to(device)
out = model(x_in)
loss = loss_fn(out, x_tgt)
val_loss_sum += loss.item() * x_in.size(0)
current_val_loss = val_loss_sum / len(val_loader.dataset)
dt = time.time() - t0
print(f"Epoch [{ep+1}/{num_epochs}] => "
f"train_loss={train_epoch_loss:.4f}, val_loss={current_val_loss:.4f}, dt={dt:.2f}s, "
f"global_step={global_step}")
# Save best model
if current_val_loss < best_val_loss:
best_val_loss = current_val_loss
torch.save(model.state_dict(), model_fname)```
I have been training a CNN Autoencoder on binary images (pixels are either 0 or 1) of size 64x64. The model is shown below:
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNNAutoencoder(nn.Module):
"""
Input shape: (B, 1, 64, 64)
"""
def __init__(self, grid_size=64):
super().__init__()
self.grid_size = grid_size
# 1) Encoder layers
self.enc_conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.bn_enc1 = nn.BatchNorm2d(16)
self.enc_conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.bn_enc2 = nn.BatchNorm2d(32)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# After two pools => shape is (32, grid_size//4, grid_size//4)
flat_dim = 32 * (grid_size // 4) * (grid_size // 4)
latent_dim = 128
# For linear layers, we can use BatchNorm1d:
self.fc_enc = nn.Linear(flat_dim, latent_dim)
self.bn_fc_enc = nn.BatchNorm1d(latent_dim)
# 2) Decoder layers
self.fc_dec = nn.Linear(latent_dim, flat_dim)
self.bn_fc_dec = nn.BatchNorm1d(flat_dim)
self.dec_tconv1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
self.bn_dec1 = nn.BatchNorm2d(16)
self.dec_tconv2 = nn.ConvTranspose2d(16, 1, kernel_size=2, stride=2)
def encoder(self, x):
# x => (B,1,64,64)
x = self.enc_conv1(x) # => (B,16,64,64)
x = self.bn_enc1(x)
x = F.relu(x)
x = self.pool(x) # => (B,16,32,32)
x = self.enc_conv2(x) # => (B,32,32,32)
x = self.bn_enc2(x)
x = F.relu(x)
x = self.pool(x) # => (B,32,16,16)
# Flatten
x = x.view(x.size(0), -1) # => (B, flat_dim=32*16*16)
x = self.fc_enc(x) # => (B, latent_dim=128)
x = self.bn_fc_enc(x)
x = F.relu(x)
return x
def decoder(self, z):
# z => (B,128)
x = self.fc_dec(z) # => (B, flat_dim)
x = self.bn_fc_dec(x)
x = F.relu(x)
# Reshape to (B, 32, 16, 16)
B = x.size(0)
x = x.view(B, 32, self.grid_size // 4, self.grid_size // 4)
x = self.dec_tconv1(x) # => (B,16,32,32)
x = self.bn_dec1(x)
x = F.relu(x)
# Final upsample => (B,1,64,64)
x = self.dec_tconv2(x)
x = torch.sigmoid(x)
return x
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)
Train set includes 1079156 samples. I'm using batch size 128 and 7 epochs. Right now, 1 epoch takes approximately 3 hours to train. Does anyone have an idea what might be the problem?
I tried to switch between BatchNorm and LayerNorm/GroupNorm but the performance didn't change. I also tried early stopping, but this doesn't solve the problem of timing per epoch. For your reference, I've pasted my training code below:
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
# Model
model = CNNAutoencoder(grid_size=64).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.BCELoss()
best_val_loss = float('inf')
model_fname = f"best_{rt}_fold{fold}.pth"
global_step = 0
for ep in range(num_epochs):
if global_step >= max_steps:
print(f"Reached {max_steps} total steps; stopping early.")
break
t0 = time.time()
#############################
# Train Loop (per epoch)
#############################
model.train()
total_loss = 0.0
for x_in, x_tgt in train_loader:
x_in = x_in.to(device)
x_tgt = x_tgt.to(device)
optimizer.zero_grad()
out = model(x_in)
loss = loss_fn(out, x_tgt)
loss.backward()
optimizer.step()
total_loss += loss.item() * x_in.size(0)
global_step += 1 # increment step count
if global_step >= max_steps:
print(f"Reached {max_steps} total steps; stopping in mid-epoch.")
break
train_epoch_loss = total_loss / len(train_loader.dataset)
# If we already hit max_steps, break out
if global_step >= max_steps:
break
#############################
# Validation Loop
#############################
model.eval()
val_loss_sum = 0.0
with torch.no_grad():
for x_in, x_tgt in val_loader:
x_in, x_tgt = x_in.to(device), x_tgt.to(device)
out = model(x_in)
loss = loss_fn(out, x_tgt)
val_loss_sum += loss.item() * x_in.size(0)
current_val_loss = val_loss_sum / len(val_loader.dataset)
dt = time.time() - t0
print(f"Epoch [{ep+1}/{num_epochs}] => "
f"train_loss={train_epoch_loss:.4f}, val_loss={current_val_loss:.4f}, dt={dt:.2f}s, "
f"global_step={global_step}")
# Save best model
if current_val_loss < best_val_loss:
best_val_loss = current_val_loss
torch.save(model.state_dict(), model_fname)```
Share
Improve this question
asked Mar 17 at 19:18
SadnessAndSorrowSadnessAndSorrow
111 bronze badge
6
- Well you have a large training set. How long are expecting this to take? – Starship Remembers Shadow Commented Mar 17 at 19:20
- @Starship The set is large, but at the same time it only contains binary images. I just wanted to make sure that the training time looks reasonable. Does my model architecture and train/val procedure look fine? – SadnessAndSorrow Commented Mar 17 at 19:42
- What hardware are you running on? – xdurch0 Commented Mar 17 at 20:16
- @xdurch0 I'm running on Quadro RTX 6000. – SadnessAndSorrow Commented Mar 17 at 20:31
- 2 This is not necessarily a problem, only incorrect expectations. You should profile the code, else any answer would just be speculation. – Dr. Snoopy Commented Mar 18 at 7:26
1 Answer
Reset to default 2General advice would be to profile. It's hard to say anything without actual measurements.
But I'll try to break things down on how one could approach optimizing this or similar cases.
1. Use nvidia-smi to check GPU-utilization.
I typically do watch -n 0 nvidia-smi
, this will continuously poll nvidia-smi
and you will see your "GPU-utilization" percentage updating in real time. Ideally you want it to be close to 100% as much as possible. If it is not the case, you fail to saturate your GPU with workload.
This is a very basic, but useful and easy-to-get metric.
It is worth noting that even 100% all the time does not mean that GPU actually busy with useful work. The number simply indicates the percentage of time the kernels are active. With DDP training, 100% could be that it just waits for NCCL to transfer gradients. Sometimes low power usage can give you a hint that something isn't right.
2. Optimize data loading.
Your network seems to be very small. Its very likely that it is so small that data loading is the bottleneck, and not GPU.
2.1 Non-blocking transfer to GPU.
Try removing pin_memory=True
. Sometimes, using pinned memory might be slower. Generally, using pin_memory=True
and non-blocking transfer to GPU should be faster. However I do not see non-blocking transfer in your code, without it pin_memory=True
does not provide any benefit.
In order to do non-blocking transfer, you need to:
x_in = x_in.to(device, non_blocking=True)
x_tgt = x_tgt.to(device, non_blocking=True)
See: https://pytorch./tutorials/intermediate/pinmem_nonblock.html
2.2 Try to change number of workers in the dataloader.
Ideally, you would want as many workers as the degree of parallelism. With that exception when the individual worker runs multiple threads, which you would want to avoid.
2.3 Try to minimize the amount of data transferred from CPU to GPU.
You said that images are 0 or 1. Even if you load them as 8bit images, it is much better to return them as such from the dataset. The in your training code you would do:
x_in = x_in.to(device)
x_tgt = x_tgt.to(device)
out = model(x_in.float() / 255.0)
loss = loss_fn(out, x_tgt.float() / 255.0)
This way you would transfer 4 times less data to GPU.
2.4 Preload.
In your case, if that is actual dataset that you plan to use (not just some test case), I would recommend preload it to GPU memory. It isn't too large if bit-packed: (64 * 64) / 8 * 1e6 ~ 512MB
which should fit GPU memory. Basically, you would just need to remove the dataloader, generate a permutation of the dataset indices and iterate through it youself and unpack the images from bitmap to float tensors directly on GPU. This approach is significantly more involved, but would largely eliminate data loading issues.
If you find it too difficult, you can always do a simpler thing such as preloading the dataset to a CPU tensor and writing a simple Dataset that wraps it. Then you would just use the DataLoader as you do now. You would not avoid CPU -> GPU transfers, but at least you would avoid IO (and possibly decoding) latency.
3. Delete gradients, do not zero them out.
You should be doing optimizer.zero_grad(set_to_none=True)
, this will just delete the gradient tensors (which is a virtually free operation) instead of zeroing out them. However, this advice is only useful for older PyTorch versions, as the newer already switched to set_to_none been the default.
4. Avoid synchronization points.
Doing loss.item()
is a blocking call, which causes Pytorch to wait till all kernels finish to run so that it could transfer the result back to CPU. Try to avoid such operations. You could measure the loss once in a while, e.g. once in 100 iterations.
5. Try torchpile.
Using torchpile
in some cases can significantly speed up things by merging kernels. Since you network is quite small, the overhead of data movement due to non-fused kernels can be significant compared to the actual compute, which means that torchpile
should help if GPU is the bottleneck.
6. Profile.
Use torch.autograd.profiler
for profiling. For example:
from torch.profiler import profile, ProfilerActivity
def save_trace(profiler) -> None:
profiler.export_chrome_trace("trace.json")
...
def train_func(...):
...
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
profile_memory=False,
record_shapes=False,
schedule=torch.profiler.schedule(
skip_first=100, wait=1, warmup=5, active=1, repeat=1
),
on_trace_ready=save_trace,
) as profiler:
for ep in range(num_epochs):
if global_step >= max_steps:
print(f"Reached {max_steps} total steps; stopping early.")
break
t0 = time.time()
#############################
# Train Loop (per epoch)
#############################
model.train()
total_loss = 0.0
for x_in, x_tgt in train_loader:
x_in = x_in.to(device)
x_tgt = x_tgt.to(device)
optimizer.zero_grad()
out = model(x_in)
loss = loss_fn(out, x_tgt)
loss.backward()
optimizer.step()
profiler.step() # MAKE SURE YOU CALL step!!!
total_loss += loss.item() * x_in.size(0)
global_step += 1 # increment step count
if global_step >= max_steps:
print(f"Reached {max_steps} total steps; stopping in mid-epoch.")
break
Do not fet to call profiler.step()
! The resulting traces could be open in chrome using chrome://tracing
.
See https://pytorch./tutorials/recipes/recipes/profiler_recipe.html