I'm defining my own train() function - that does training and validation (perhaps the name of the function is not the most descriptive here)
Since I'm using a custom iterable dataset along with a dataloader, I'm trying to figure out how to calculate those metrics. The iterable dataset does not have a length() so I can't really figure another way out.
Here's my code so far:
import torch
from torch.utils.data import DataLoader
from new_data_loader import ProtobufIterableDataset, BatchShuffleDataset
from model import SimpleModel
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
# Constants
# Load Training Data
trainingDataset = ProtobufIterableDataset('train_examples', scaling=True)
trainingDataset = BatchShuffleDataset(trainingDataset, batch_size=SHUFFLE_BUFFER_SIZE)
trainingDataLoader = DataLoader(trainingDataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
# Load Validation Data
validationDataset = ProtobufIterableDataset('val_examples', scaling=True)
validationDataset = BatchShuffleDataset(validationDataset, batch_size=SHUFFLE_BUFFER_SIZE)
validationDataLoader = DataLoader(validationDataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
# Training Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleModel().to(device)
epochs = 10
learning_rate = 0.001
crossEntropyLoss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
def train(model, trainingDataLoader, validationDataLoader, optimizer, criterion, device, epochs=10):
# Create outer progress bar for epochs
epoch_pbar = tqdm(range(epochs), desc='Training Progress', position=0)
for epoch in epoch_pbar:
# Training Phase
train_loss = 0
train_correct = 0
train_total = 0
# Create progress bar for training batches
train_pbar = tqdm(trainingDataLoader,
desc=f'Training Epoch {epoch+1}/{epochs}',
# For every batch
for set_features, access_features, cache_features, labels in train_pbar:
# Combine features
cache_features_flat = cache_features.reshape(-1, 17*9)
combined_features = torch.cat([set_features, access_features, cache_features_flat], dim=1)
# Move to device
combined_features = combined_features.to(device)
labels = labels.to(device)
# Zero gradients
# Forward pass
outputs = model(combined_features)
loss = criterion(outputs, labels)
# Backward pass
# Training statistics
train_loss += loss.item()
_, predicted = torch.max(outputs, 1)
train_total += labels.size(0)
train_correct += (predicted == torch.max(labels, 1)[1]).sum().item()
# Update training progress bar with current loss and accuracy
'loss': f'{train_loss/train_total:.4f}',
'acc': f'{100 * train_correct/train_total:.2f}%'
# Validation Phase
val_loss = 0
val_correct = 0
val_total = 0
# Create progress bar for validation batches
val_pbar = tqdm(validationDataLoader,
desc=f'Validation Epoch {epoch+1}/{epochs}',
with torch.no_grad():
for set_features, access_features, cache_features, labels in val_pbar:
# Combine features
cache_features_flat = cache_features.reshape(-1, 17*9)
combined_features = torch.cat([set_features, access_features, cache_features_flat], dim=1)
# Move to device
combined_features = combined_features.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(combined_features)
loss = criterion(outputs, labels)
# Validation statistics
val_loss += loss.item()
_, predicted = torch.max(outputs, 1)
val_total += labels.size(0)
val_correct += (predicted == torch.max(labels, 1)[1]).sum().item()
# Update validation progress bar with current loss and accuracy
'loss': f'{val_loss/val_total:.4f}',
'acc': f'{100 * val_correct/val_total:.2f}%'
# Update epoch progress bar with final metrics
'train_loss': f'{train_loss/(trainingDataLoader.__len__()):.4f}',
'train_acc': f'{100 * train_correct/train_total:.2f}%',
'val_loss': f'{val_loss/(validationDataLoader.__len__()):.4f}',
'val_acc': f'{100 * val_correct/val_total:.2f}%'
# Print final statistics for the epoch
print(f'\nEpoch {epoch+1}/{epochs}:')
print(f'Training Loss: {train_loss/(trainingDataLoader.__len__()):.4f}, '
f'Training Accuracy: {100 * train_correct/train_total:.2f}%')
print(f'Validation Loss: {val_loss/(validationDataLoader.__len__()):.4f}, '
f'Validation Accuracy: {100 * val_correct/val_total:.2f}%\n')