最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - Discrepancies between tensorflow and pytorch while maintaining batch size, loss functions, architectures, metric, data,

programmeradmin4浏览0评论

It is a follow up question to Investigating discrepancies in TensorFlow and PyTorch performance

The author stated that there is a performance gap between Pytorch and Tensorflow while maintaining the architecture, loss function, optimizer, metric and data, and the reason behind it is because TensorFlow's model.fit default to mini-batching* (with a batch size of 32), while Pytorch is simply batching. However, I encountered the similar issue but batching with same batch size.

Can anyone help please?

Section below I will present 2 implementations using library (segmentation-model (sm), segmentation_models_pytorch as smp), I assumed the architectures to be the same, and I have adjusted the parameter for it to be the same.

both implementation will have the following:

  • batch_size:2
  • Epoch: 50
  • optimizer: Adam (lr =0.001)
  • loss: dice_loss
  • metric: accuracy, auc, iou (intersection over union)

Unfortunately, I could not provide data for reproducibility.

Tensorflow


**# iou**

def iou(y_true, y_pred, threshold=0.5):

    
    y_pred = tf.cast(y_pred > threshold, tf.float32)
    y_true = tf.cast(y_true, tf.float32)


    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3]) - intersection

    iou = tf.reduce_mean(intersection / (union + tf.keras.backend.epsilon()))

    return iou





**# dice_loss**
def dice_loss(y_true, y_pred):
    


    y_true_f = tf.cast(K.flatten(y_true), tf.float32)
    y_pred_f = tf.cast(K.flatten(y_pred), tf.float32)
    intersection = K.sum(y_true_f*y_pred_f)

    val = (2. * intersection + K.epsilon()) / (K.sum(y_true_f * y_true_f) + K.sum(y_pred_f * y_pred_f) + K.epsilon())
    return 1. - val



train_dataset = train_dataset.batch(config.BATCH_SIZE)
test_dataset = test_dataset.batch(config.BATCH_SIZE)


pretrained_base_model = sm.Unet(encoder_weights='imagenet', classes=1)  

# Input shape (for 4-channel input) and adjust it to 3 channels for the pretrained model
inp = keras.Input(shape=(None, None, 4))
l1 = keras.layers.Conv2D(3, (1, 1))(inp)  # Map 4-channel input to 3-channel input
out = pretrained_base_model(l1)

model = keras.Model(inp, out, name=pretrained_base_model.name)
modelpile(
    optimizer='Adam',
    loss=dice_loss,  # Custom dice_loss
    metrics=[
        iou,              # Custom IoU Score
        keras.metrics.AUC(),    
        keras.metrics.BinaryAccuracy() 
    ]
)

backup_callback = keras.callbacks.BackupAndRestore(
    backup_dir="./keras_backups" 
)
checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath= os.path.join(config.SAVE_MODEL_DIR, config.VERSION, "best_mdl.keras"),     
    monitor='val_iou',               
    save_best_only=True,                   
    mode='max',                           
    verbose=0                              
)


artery_history = model.fit(train_dataset, epochs=config.EPOCHS, callbacks=callbacks, validation_data=test_dataset)

Pytorch

import torch.nn as nn
channel_mapper = nn.Conv2d(in_channels=4, out_channels=3, kernel_size=1)


class WrappedModel(nn.Module):
    def __init__(self, base_model, channel_mapper):
        super(WrappedModel, self).__init__()
        self.channel_mapper = channel_mapper
        self.base_model = base_model

    def forward(self, x):
        x = self.channel_mapper(x)
        return self.base_model(x)


model = WrappedModel(
    base_model=smp.Unet(encoder_name="vgg16", encoder_weights="imagenet", in_channels=3, classes=1),
    channel_mapper=channel_mapper,
).to(device)


train_loader = DataLoader(train_dataset,batch_size=config.BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)


def iou(y_true, y_pred, threshold=0.5):
    
    
    if isinstance(y_true, np.ndarray):
        y_true = torch.tensor(y_true)
    if isinstance(y_pred, np.ndarray):
        y_pred = torch.tensor(y_pred)

    
    y_pred = (y_pred > threshold).float()
    y_true = y_true.float()

    
    intersection = torch.sum(y_true * y_pred)
    union = torch.sum(y_true) + torch.sum(y_pred) - intersection

    # IoU
    iou = intersection / (union + 1e-6)  # Add a small epsilon to prevent division by zero

    return iou.item()  # Return scalar IoU score


def dice_loss(y_true, y_pred):
    
    
    y_true_f = y_true.view(-1).float()
    y_pred_f = y_pred.view(-1).float()

    # intersection
    intersection = torch.sum(y_true_f * y_pred_f)

    # Dice score
    dice_score = (2. * intersection + 1e-6) / (torch.sum(y_true_f * y_true_f) + torch.sum(y_pred_f * y_pred_f) + 1e-6)

    # Dice loss
    return 1. - dice_score



optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = dice_loss


history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=test_loader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    num_epochs=config.EPOCHS,
    save_dir= os.path.join(config.SAVE_MODEL_DIR,config.VERSION),
    checkpoint_path="./checkpoints/checkpoint.pth",
    model_name= "unet.pth",
    device=device,
)

def train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    loss_fn,
    num_epochs,
    save_dir,
    device,
    checkpoint_path,
    model_name ,
    metrics=[ "iou", "accuracy", "auc"],# "precision", "recall", "f1",
):
    os.makedirs("./checkpoints", exist_ok=True)
    best_model_path = os.path.join(save_dir, model_name)

   
    model, optimizer, start_epoch, history = load_checkpoint(model, optimizer, checkpoint_path)
    

    sigmoid = torch.nn.Sigmoid()
    best_val_iou = max(history["val_metrics"]["iou"], default=0.0) if "iou" in history["val_metrics"] else 0.0

    for epoch in range(start_epoch, num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 40)

        # Training 
        model.train()
        train_loss = 0
        all_train_preds = []
        all_train_targets = []
        for images, masks in tqdm(train_loader, desc="Training"):
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            outputs = sigmoid(outputs)

            loss = loss_fn(outputs, masks)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            preds = (outputs > 0.5).float()
            all_train_preds.append(preds.cpu().numpy())
            all_train_targets.append(masks.cpu().numpy())

        train_loss /= len(train_loader)
        history["train_loss"].append(train_loss)

        # Compute training metrics
        all_train_preds = np.concatenate(all_train_preds).flatten()
        all_train_targets = np.concatenate(all_train_targets).flatten()
        train_metrics = {}

        # Update metrics 
        for metric in metrics:
            if metric == "iou":
                train_metrics["iou"] = iou(all_train_targets, all_train_preds)
            elif metric == "accuracy":
                train_metrics["accuracy"] = np.mean(all_train_preds == all_train_targets)
            elif metric == "auc":
                if len(np.unique(all_train_targets)) > 1:  
                    train_metrics["auc"] = roc_auc_score(all_train_targets, all_train_preds)


        for metric, value in train_metrics.items():
            history["train_metrics"].setdefault(metric, []).append(value)

        # Training Metrics
        print(f"Training Loss: {train_loss:.4f}")
        for metric, value in train_metrics.items():
            print(f"Training {metric.capitalize()}: {value:.4f}")

        # Validation 
        model.eval()
        val_loss = 0
        all_val_preds = []
        all_val_targets = []
        with torch.no_grad():
            for images, masks in tqdm(val_loader, desc="Validation"):
                images = images.to(device)
                masks = masks.to(device)

                outputs = model(images)
                outputs = sigmoid(outputs)
                loss = loss_fn(outputs, masks)
                val_loss += loss.item()

                preds = (outputs > 0.5).float()
                all_val_preds.append(preds.cpu().numpy())
                all_val_targets.append(masks.cpu().numpy())

        val_loss /= len(val_loader)
        history["val_loss"].append(val_loss)

        all_val_preds = np.concatenate(all_val_preds).flatten()
        all_val_targets = np.concatenate(all_val_targets).flatten()
        val_metrics = {}

        for metric in metrics:
            if metric == "iou":
                intersection = np.sum(all_val_preds * all_val_targets)
                union = np.sum(all_val_preds) + np.sum(all_val_targets) - intersection
                val_metrics["iou"] = intersection / (union + 1e-6)
            elif metric == "accuracy":
                val_metrics["accuracy"] = np.mean(all_val_preds == all_val_targets)
            elif metric == "auc":
                val_metrics["auc"] = roc_auc_score(all_val_targets, all_val_preds)

        for metric, value in val_metrics.items():
            history["val_metrics"].setdefault(metric, []).append(value)

        # Print Validaation Metrics
        print(f"Validation Loss: {val_loss:.4f}")
        for metric, value in val_metrics.items():
            print(f"Validation {metric.capitalize()}: {value:.4f}")

        # Save best model
        if val_metrics.get("iou", 0) > best_val_iou:
            best_val_iou = val_metrics["iou"]
            torch.save(model.state_dict(), best_model_path)
            print(f"Best model saved with IoU: {best_val_iou:.4f}")

        # Save checkpoint
        save_checkpoint(model, optimizer, epoch, history, checkpoint_path)
    
    shutil.rmtree("./checkpoints", ignore_errors=True)


    return history



与本文相关的文章

发布评论

评论列表(0)

  1. 暂无评论