It is a follow up question to Investigating discrepancies in TensorFlow and PyTorch performance
The author stated that there is a performance gap between Pytorch and Tensorflow while maintaining the architecture, loss function, optimizer, metric and data, and the reason behind it is because TensorFlow's model.fit default to mini-batching* (with a batch size of 32), while Pytorch is simply batching. However, I encountered the similar issue but batching with same batch size.
Can anyone help please?
Section below I will present 2 implementations using library (segmentation-model (sm), segmentation_models_pytorch as smp), I assumed the architectures to be the same, and I have adjusted the parameter for it to be the same.
both implementation will have the following:
- batch_size:2
- Epoch: 50
- optimizer: Adam (lr =0.001)
- loss: dice_loss
- metric: accuracy, auc, iou (intersection over union)
Unfortunately, I could not provide data for reproducibility.
Tensorflow
**# iou**
def iou(y_true, y_pred, threshold=0.5):
y_pred = tf.cast(y_pred > threshold, tf.float32)
y_true = tf.cast(y_true, tf.float32)
intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3]) - intersection
iou = tf.reduce_mean(intersection / (union + tf.keras.backend.epsilon()))
return iou
**# dice_loss**
def dice_loss(y_true, y_pred):
y_true_f = tf.cast(K.flatten(y_true), tf.float32)
y_pred_f = tf.cast(K.flatten(y_pred), tf.float32)
intersection = K.sum(y_true_f*y_pred_f)
val = (2. * intersection + K.epsilon()) / (K.sum(y_true_f * y_true_f) + K.sum(y_pred_f * y_pred_f) + K.epsilon())
return 1. - val
train_dataset = train_dataset.batch(config.BATCH_SIZE)
test_dataset = test_dataset.batch(config.BATCH_SIZE)
pretrained_base_model = sm.Unet(encoder_weights='imagenet', classes=1)
# Input shape (for 4-channel input) and adjust it to 3 channels for the pretrained model
inp = keras.Input(shape=(None, None, 4))
l1 = keras.layers.Conv2D(3, (1, 1))(inp) # Map 4-channel input to 3-channel input
out = pretrained_base_model(l1)
model = keras.Model(inp, out, name=pretrained_base_model.name)
modelpile(
optimizer='Adam',
loss=dice_loss, # Custom dice_loss
metrics=[
iou, # Custom IoU Score
keras.metrics.AUC(),
keras.metrics.BinaryAccuracy()
]
)
backup_callback = keras.callbacks.BackupAndRestore(
backup_dir="./keras_backups"
)
checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath= os.path.join(config.SAVE_MODEL_DIR, config.VERSION, "best_mdl.keras"),
monitor='val_iou',
save_best_only=True,
mode='max',
verbose=0
)
artery_history = model.fit(train_dataset, epochs=config.EPOCHS, callbacks=callbacks, validation_data=test_dataset)
Pytorch
import torch.nn as nn
channel_mapper = nn.Conv2d(in_channels=4, out_channels=3, kernel_size=1)
class WrappedModel(nn.Module):
def __init__(self, base_model, channel_mapper):
super(WrappedModel, self).__init__()
self.channel_mapper = channel_mapper
self.base_model = base_model
def forward(self, x):
x = self.channel_mapper(x)
return self.base_model(x)
model = WrappedModel(
base_model=smp.Unet(encoder_name="vgg16", encoder_weights="imagenet", in_channels=3, classes=1),
channel_mapper=channel_mapper,
).to(device)
train_loader = DataLoader(train_dataset,batch_size=config.BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
def iou(y_true, y_pred, threshold=0.5):
if isinstance(y_true, np.ndarray):
y_true = torch.tensor(y_true)
if isinstance(y_pred, np.ndarray):
y_pred = torch.tensor(y_pred)
y_pred = (y_pred > threshold).float()
y_true = y_true.float()
intersection = torch.sum(y_true * y_pred)
union = torch.sum(y_true) + torch.sum(y_pred) - intersection
# IoU
iou = intersection / (union + 1e-6) # Add a small epsilon to prevent division by zero
return iou.item() # Return scalar IoU score
def dice_loss(y_true, y_pred):
y_true_f = y_true.view(-1).float()
y_pred_f = y_pred.view(-1).float()
# intersection
intersection = torch.sum(y_true_f * y_pred_f)
# Dice score
dice_score = (2. * intersection + 1e-6) / (torch.sum(y_true_f * y_true_f) + torch.sum(y_pred_f * y_pred_f) + 1e-6)
# Dice loss
return 1. - dice_score
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = dice_loss
history = train_model(
model=model,
train_loader=train_loader,
val_loader=test_loader,
optimizer=optimizer,
loss_fn=loss_fn,
num_epochs=config.EPOCHS,
save_dir= os.path.join(config.SAVE_MODEL_DIR,config.VERSION),
checkpoint_path="./checkpoints/checkpoint.pth",
model_name= "unet.pth",
device=device,
)
def train_model(
model,
train_loader,
val_loader,
optimizer,
loss_fn,
num_epochs,
save_dir,
device,
checkpoint_path,
model_name ,
metrics=[ "iou", "accuracy", "auc"],# "precision", "recall", "f1",
):
os.makedirs("./checkpoints", exist_ok=True)
best_model_path = os.path.join(save_dir, model_name)
model, optimizer, start_epoch, history = load_checkpoint(model, optimizer, checkpoint_path)
sigmoid = torch.nn.Sigmoid()
best_val_iou = max(history["val_metrics"]["iou"], default=0.0) if "iou" in history["val_metrics"] else 0.0
for epoch in range(start_epoch, num_epochs):
print(f"\nEpoch {epoch + 1}/{num_epochs}")
print("-" * 40)
# Training
model.train()
train_loss = 0
all_train_preds = []
all_train_targets = []
for images, masks in tqdm(train_loader, desc="Training"):
images = images.to(device)
masks = masks.to(device)
outputs = model(images)
outputs = sigmoid(outputs)
loss = loss_fn(outputs, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
preds = (outputs > 0.5).float()
all_train_preds.append(preds.cpu().numpy())
all_train_targets.append(masks.cpu().numpy())
train_loss /= len(train_loader)
history["train_loss"].append(train_loss)
# Compute training metrics
all_train_preds = np.concatenate(all_train_preds).flatten()
all_train_targets = np.concatenate(all_train_targets).flatten()
train_metrics = {}
# Update metrics
for metric in metrics:
if metric == "iou":
train_metrics["iou"] = iou(all_train_targets, all_train_preds)
elif metric == "accuracy":
train_metrics["accuracy"] = np.mean(all_train_preds == all_train_targets)
elif metric == "auc":
if len(np.unique(all_train_targets)) > 1:
train_metrics["auc"] = roc_auc_score(all_train_targets, all_train_preds)
for metric, value in train_metrics.items():
history["train_metrics"].setdefault(metric, []).append(value)
# Training Metrics
print(f"Training Loss: {train_loss:.4f}")
for metric, value in train_metrics.items():
print(f"Training {metric.capitalize()}: {value:.4f}")
# Validation
model.eval()
val_loss = 0
all_val_preds = []
all_val_targets = []
with torch.no_grad():
for images, masks in tqdm(val_loader, desc="Validation"):
images = images.to(device)
masks = masks.to(device)
outputs = model(images)
outputs = sigmoid(outputs)
loss = loss_fn(outputs, masks)
val_loss += loss.item()
preds = (outputs > 0.5).float()
all_val_preds.append(preds.cpu().numpy())
all_val_targets.append(masks.cpu().numpy())
val_loss /= len(val_loader)
history["val_loss"].append(val_loss)
all_val_preds = np.concatenate(all_val_preds).flatten()
all_val_targets = np.concatenate(all_val_targets).flatten()
val_metrics = {}
for metric in metrics:
if metric == "iou":
intersection = np.sum(all_val_preds * all_val_targets)
union = np.sum(all_val_preds) + np.sum(all_val_targets) - intersection
val_metrics["iou"] = intersection / (union + 1e-6)
elif metric == "accuracy":
val_metrics["accuracy"] = np.mean(all_val_preds == all_val_targets)
elif metric == "auc":
val_metrics["auc"] = roc_auc_score(all_val_targets, all_val_preds)
for metric, value in val_metrics.items():
history["val_metrics"].setdefault(metric, []).append(value)
# Print Validaation Metrics
print(f"Validation Loss: {val_loss:.4f}")
for metric, value in val_metrics.items():
print(f"Validation {metric.capitalize()}: {value:.4f}")
# Save best model
if val_metrics.get("iou", 0) > best_val_iou:
best_val_iou = val_metrics["iou"]
torch.save(model.state_dict(), best_model_path)
print(f"Best model saved with IoU: {best_val_iou:.4f}")
# Save checkpoint
save_checkpoint(model, optimizer, epoch, history, checkpoint_path)
shutil.rmtree("./checkpoints", ignore_errors=True)
return history