最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - Trying to reconstruct original image from patches - Stack Overflow

programmeradmin2浏览0评论

I am running a weed detection pipeline which involves creating patches and running ML algorithm to detect region of weeds. When I try to now reconstruct the patches to create the complete weed map, the patches seem to be in disorder when recombined, see attached image .

here is the code that create patches:

class Patches:

    def __init__(self, im_list, msk_list, PATCH_SIZE, threshold=0.03):
        self.im_list = im_list
        self.msk_list = msk_list
        self.PATCH_SIZE = PATCH_SIZE
        self.threshold = threshold

    def image_to_patches(self, image, b_msk=False):
        slc_size = self.PATCH_SIZE
        x = int(np.ceil(image.shape[0] / slc_size))
        y = int(np.ceil(image.shape[1] / slc_size))
        padded_shape = (x * slc_size, y * slc_size)
        
        if not b_msk:
            padded_rgb_image = np.zeros((padded_shape[0], padded_shape[1], 3), dtype=np.uint8)
            padded_rgb_image[:image.shape[0], :image.shape[1]] = image
            patches = patchify(padded_rgb_image, (slc_size, slc_size, 3), step=slc_size)
        else:
            padded_rgb_image = np.zeros((padded_shape[0], padded_shape[1]), dtype=np.uint8)
            padded_rgb_image[:image.shape[0], :image.shape[1]] = image
            patches = patchify(padded_rgb_image, (slc_size, slc_size), step=slc_size)

        return patches, slc_size

    def load_image(self, path):
        """
        loads an image based on the path
        """
        rgb_image = skio.imread(path)
        return rgb_image

    def patchify_image_mask(self):
        imgs = []
        anns = []
        AREA = self.PATCH_SIZE * self.PATCH_SIZE
        f_AREA = int(self.threshold * AREA)
        print(f'Threshold: {self.threshold} * {AREA} = {f_AREA}')
        print(f"Patchyfying images and mask...")
        for im_path, msk_path in zip(self.im_list, self.msk_list):
            patches, _ = self.image_to_patches(self.load_image(im_path))
            masks, _ = self.image_to_patches(self.load_image(msk_path), b_msk=True)
            for i in range(patches.shape[0]):
                for j in range(patches.shape[1]):
                    patch = patches[i, j, :, :, :]
                    mask = masks[i, j, ::]
                    if mask.reshape(-1).sum() > f_AREA:
                        imgs.append(patch)
                        anns.append(mask)
        return np.array(imgs), np.array(anns)

and here is the test.py logic:

import sys

sys.path.append("/Users/i/Downloads/Deep-Weed-Segmentation-main/")

import numpy as np
import random
import matplotlib.pyplot as plt
from scripts.prepare_dataset import Prepare_Dataset
print("hallala")
import argparse
from scripts.model import Models
from tensorflow import keras
import segmentation_models as sm
import tensorflow_advanced_segmentation_models as tasm
import os
import cv2
from patchify import patchify, unpatchify

# /Users/icom/Downloads/CoFly-WeedDB

def main():
    parser = argparse.ArgumentParser(description='Testing Model')
    parser.add_argument('--network', help='Network Model type', default='custom')
    parser.add_argument('--backbone', help='Backbone', default='None')
    parser.add_argument('--weight_path', help='path to saved weights', default='./models')
    parser.add_argument('--image_path', help='path to image', default='./images')
    parser.add_argument('--patch_size', type=int, required=True, help='Size of the patches')
    parser.add_argument('--binary', help='Enable Binary Segmentation', default=False, action='store_true')
    parser.add_argument('--data_path', help='root path to dataset',
                        default='/Users/i/Downloads/CoFly-WeedDB')

    args = parser.parse_args()

    args = vars(args)
    args = {k: v for k, v in args.items() if v is not None}

    patch_size = int(args['patch_size'])

    # Get the dataset and original dimensions
    Y_train_cat, Y_test_cat, X_train, Y_test, X_test, p_weights, n_classes, original_height, original_width = Prepare_Dataset(
        patch_size, binary=args['binary'], backbone=args['backbone'], data_path=args['data_path']
    ).prepare_all()

    Test(args['network'], args['backbone'], args['weight_path'], args['image_path'], patch_size,
         args['data_path'], args['binary']).test()


class Test:
    def __init__(self, network, backbone, weight_path, input_image, PATCH_SIZE, data_path, binary):
        selfwork = network
        self.backbone = backbone
        self.PATCH_SIZE = self.size_(PATCH_SIZE)
        self.weights_path = weight_path
        self.input_image = input_image
        (self.Y_train_cat, self.Y_test_cat, self.X_train, self.Y_test, self.X_test, self.p_weights,
         self.n_classes, self.original_height, self.original_width) = Prepare_Dataset(self.PATCH_SIZE, binary=binary, backbone=backbone, data_path=data_path).prepare_all()
        self.total_loss = Prepare_Dataset(self.PATCH_SIZE).get_loss(p_weights=self.p_weights)
        self.binary = binary

    def size_(self, PATCH_SIZE):
        if selfwork == 'pspnet':
            if PATCH_SIZE % 48 != 0:
                print('Image size must be divisible by 48')
                PATCH_SIZE = int(PATCH_SIZE / 48) * 48 + 48
                print(f'New image size: {PATCH_SIZE}x{PATCH_SIZE}x3')
        return PATCH_SIZE

    def test(self):
        print('Testing...')

        metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
        LR = 0.0001
        optim = keras.optimizers.Adam(LR)

        print('Building Model...')
        # Initialize model
        if selfwork == 'custom':
            model = Models(self.n_classes, self.PATCH_SIZE, IMG_CHANNELS=3, model_name=selfwork,
                           backbone=self.backbone).simple_unet_model()
        elif selfwork == 'segnet':
            model, self.backbone = Models(self.n_classes, self.PATCH_SIZE, IMG_CHANNELS=3, model_name=selfwork,
                                          backbone=self.backbone).segnet_architecture()
        elif selfwork == 'unet' or selfwork == 'linknet' or selfwork == 'pspnet':
            model = Models(self.n_classes, self.PATCH_SIZE, IMG_CHANNELS=3, model_name=selfwork,
                           backbone=self.backbone).segmented_models()
        elif selfwork == 'deeplabv3':
            base_model, layers, layer_names = Models(self.n_classes, self.PATCH_SIZE, IMG_CHANNELS=3,
                                                     model_name=selfwork,
                                                     backbone=self.backbone).deeplabv3(name=self.backbone,
                                                                                       weights='imagenet',
                                                                                       height=self.PATCH_SIZE,
                                                                                       width=self.PATCH_SIZE)
            model = tasm.DeepLabV3plus(n_classes=self.n_classes, base_model=base_model, output_layers=layers,
                                       backbone_trainable=False)
            model.build((None, self.PATCH_SIZE, self.PATCH_SIZE, 3))
        else:
            print(f'{selfwork} network not available.')
            quit()
        # Compilation
        modelpile(optimizer=optim,
                      loss=self.total_loss,
                      metrics=[metrics])
        print(model.summary())
        path_to_model = os.path.join(self.weights_path,
                                     selfwork + '_' + self.backbone + '_' + str(
                                         self.PATCH_SIZE) + '_binary_' + str(self.binary) +
                                     '.weights.h5')
        try:
            model.load_weights(path_to_model)
        except Exception as err:
            print(err)
            quit()

        # Initialize an empty array for the reconstructed original image with overlays
        reconstructed_image_with_overlays = np.zeros((self.original_height, self.original_width, 3), dtype=np.uint8)

        counter = 0
        for img in range(len(self.X_test)):
            counter += 1
            test_img = self.X_test[img]
            ground_truth = self.Y_test[img]
            test_img_input = np.expand_dims(test_img, 0)
            prediction = model.predict(test_img_input)
            predicted_img = np.argmax(prediction, axis=3)[0, :, :]

            # Create an overlay for the predicted areas
            overlay = test_img.copy()
            for class_id in range(1, self.n_classes):  # Assuming class 0 is background
                mask = predicted_img == class_id
                overlay[mask] = [255, 0, 0]  # Example: Red overlay for predicted areas

            # Blend the overlay with the original test image
            blended_overlay = cv2.addWeighted(test_img, 0.7, overlay, 0.3, 0)

            # Save individual patch prediction
            plt.figure(figsize=(16, 8))
            plt.subplot(241)
            plt.title('Testing Image')
            plt.imshow(test_img)
            plt.subplot(242)
            plt.title('Testing Label')
            plt.imshow(ground_truth, cmap='jet')
            plt.subplot(243)
            plt.title('Prediction on test image')
            plt.imshow(predicted_img, cmap='jet')
            plt.subplot(244)
            plt.title('Overlay of Test Image and Prediction')
            plt.imshow(blended_overlay)
            plt.savefig(f'./plots/figure_{counter}_.png')
            plt.close()  # Close the figure to avoid memory issues
            print(f'{(counter/len(self.X_test))*100}% done.')

            # Calculate the starting row and column based on the patch index
            num_patches_per_row = self.original_width // self.PATCH_SIZE
            start_row = (img // num_patches_per_row) * self.PATCH_SIZE
            start_col = (img % num_patches_per_row) * self.PATCH_SIZE
            
            # Place the blended overlay patch in the reconstructed image
            if start_row + self.PATCH_SIZE <= self.original_height and start_col + self.PATCH_SIZE <= self.original_width:
                reconstructed_image_with_overlays[start_row:start_row + self.PATCH_SIZE, start_col:start_col + self.PATCH_SIZE] = blended_overlay

        # Save the reconstructed original image with overlays
        plt.figure(figsize=(12, 8))
        plt.title('Reconstructed Original Image with Overlays')
        plt.imshow(reconstructed_image_with_overlays)
        plt.axis('off')
        plt.savefig('./plots/reconstructed_image_with_overlays.png')

I am wondering why the patches are being reconstructed without order here, and what could be a suggestion?

Update: here is the input image that I was processing.

发布评论

评论列表(0)

  1. 暂无评论