最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - Conversion of model weights from old Keras version to Pytorch - Stack Overflow

programmeradmin1浏览0评论

I want to transfer pretrained weights from an old project on github :

The original Keras model code is:

def get_keras_autoencoder(self, input_size=256, nb_filter=96, k_size=5):

    input_img = Input(shape=(1, input_size, input_size))

    conv1 = Convolution2D(nb_filter, k_size, k_size, activation='relu', border_mode='same', name='conv1')(input_img)
    maxp1 = MaxPooling2D((2, 2), border_mode='same', name='maxp1')(conv1)
    
    conv2 = Convolution2D(nb_filter, k_size, k_size, activation='relu', border_mode='same', name='conv2')(maxp1)
    maxp2 = MaxPooling2D((2, 2), border_mode='same', name='maxp2')(conv2)
    
    conv3 = Convolution2D(nb_filter, k_size, k_size, activation='relu', border_mode='same', name='conv3')(maxp2)
    encoder = MaxPooling2D((2, 2), border_mode='same', name='encoder')(conv3)

    conv4 = Convolution2D(nb_filter, k_size, k_size, activation='relu', border_mode='same', name='conv4')(encoder)
    upsa1 = UpSampling2D((2, 2), name='upsa1')(conv4)
    
    conv4 = Convolution2D(nb_filter, k_size, k_size, activation='relu', border_mode='same', name='conv5')(upsa1)
    upsa2 = UpSampling2D((2, 2), name='upsa2')(conv4)
    
    conv5 = Convolution2D(nb_filter, k_size, k_size, activation='relu', border_mode='same', name='conv6')(upsa2)
    upsa3 = UpSampling2D((2, 2), name='upsa3')(conv5)
    
    decoder = Convolution2D(1, k_size, k_size, activation='sigmoid', border_mode='same')(upsa3)
    
    autoencoder = Model(input_img, decoder)
    
    return autoencoder

The ported PyTorch code I've done is:

class SAEModel(nn.Module):
    def __init__(self):
        super(SAEModel, self).__init__()

        # encoder
        self.conv1 = nn.Conv2d( 1, 96, kernel_size=5, stride=1, padding=2)
        self.maxp1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(96, 96, kernel_size=5, stride=1, padding=2)
        self.maxp2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv3 = nn.Conv2d(96, 96, kernel_size=5, stride=1, padding=2)
        self.encod = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        # decoder
        self.conv4 = nn.Conv2d(96, 96, kernel_size=5, stride=1, padding=2)
        self.upsa1 = nn.Upsample(scale_factor=2)
        self.conv5 = nn.Conv2d(96, 96, kernel_size=5, stride=1, padding=2)
        self.upsa2 = nn.Upsample(scale_factor=2)
        self.conv6 = nn.Conv2d(96, 96, kernel_size=5, stride=1, padding=2)
        self.upsa3 = nn.Upsample(scale_factor=2)
        self.decod = nn.Conv2d(96,  1, kernel_size=5, stride=1, padding=2)
        
    def forward(self, x):

        # encoder activations
        x = F.relu(self.conv1(x))
        x = self.maxp1(x)
        x = F.relu(self.conv2(x))
        x = self.maxp2(x)
        x = F.relu(self.conv3(x))
        x = self.encod(x)
        
        # decoder activations
        x = F.relu(self.conv4(x))
        x = self.upsa1(x)
        x = F.relu(self.conv5(x))
        x = self.upsa2(x)
        x = F.relu(self.conv6(x))
        x = self.upsa3(x)
        x = F.sigmoid(self.decod(x))

        return x

The summary dump of the Keras model:

Layer (type)                     Output Shape          Param #     Connected to
====================================================================================================
input_1 (InputLayer)             (None, 1, 256, 256)   0
____________________________________________________________________________________________________
conv1 (Convolution2D)            (None, 1, 256, 96)    2496        input_1[0][0]
____________________________________________________________________________________________________
maxp1 (MaxPooling2D)             (None, 1, 128, 96)    0           conv1[0][0]
____________________________________________________________________________________________________
conv2 (Convolution2D)            (None, 1, 128, 96)    230496      maxp1[0][0]
____________________________________________________________________________________________________
maxp2 (MaxPooling2D)             (None, 1, 64, 96)     0           conv2[0][0]
____________________________________________________________________________________________________
conv3 (Convolution2D)            (None, 1, 64, 96)     230496      maxp2[0][0]
____________________________________________________________________________________________________
encoder (MaxPooling2D)           (None, 1, 32, 96)     0           conv3[0][0]
____________________________________________________________________________________________________
conv4 (Convolution2D)            (None, 1, 32, 96)     230496      encoder[0][0]
____________________________________________________________________________________________________
upsa1 (UpSampling2D)             (None, 2, 64, 96)     0           conv4[0][0]
____________________________________________________________________________________________________
conv5 (Convolution2D)            (None, 2, 64, 96)     230496      upsa1[0][0]
____________________________________________________________________________________________________
upsa2 (UpSampling2D)             (None, 4, 128, 96)    0           conv5[0][0]
____________________________________________________________________________________________________
conv6 (Convolution2D)            (None, 4, 128, 96)    230496      upsa2[0][0]
____________________________________________________________________________________________________
upsa3 (UpSampling2D)             (None, 8, 256, 96)    0           conv6[0][0]
____________________________________________________________________________________________________
convolution2d_1 (Convolution2D)  (None, 8, 256, 1)     2401        upsa3[0][0]
====================================================================================================
Total params: 1,157,377
Trainable params: 1,157,377
Non-trainable params: 0

And the one for the ported PyTorch mode :

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 96, 256, 256]           2,496
         MaxPool2d-2         [-1, 96, 128, 128]               0
            Conv2d-3         [-1, 96, 128, 128]         230,496
         MaxPool2d-4           [-1, 96, 64, 64]               0
            Conv2d-5           [-1, 96, 64, 64]         230,496
         MaxPool2d-6           [-1, 96, 32, 32]               0
            Conv2d-7           [-1, 96, 32, 32]         230,496
          Upsample-8           [-1, 96, 64, 64]               0
            Conv2d-9           [-1, 96, 64, 64]         230,496
         Upsample-10         [-1, 96, 128, 128]               0
           Conv2d-11         [-1, 96, 128, 128]         230,496
         Upsample-12         [-1, 96, 256, 256]               0
           Conv2d-13          [-1, 1, 256, 256]           2,401
================================================================
Total params: 1,157,377
Trainable params: 1,157,377
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.25
Forward/backward pass size (MB): 158.00
Params size (MB): 4.42
Estimated Total Size (MB): 162.67

And finally the code to convert the weights :

import os
import argparse

from pathlib import Path

from keras.layers import Input, Convolution2D, MaxPooling2D, UpSampling2D
from keras.models import Model
from keras.models import load_model

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchsummary import summary

class Musica:

    def decompose_path(self, path):
        return Path(path).parent, Path(path).stem, Path(path).suffix

    def convert_keras_weights(self, weights_path):
        autoencoder = self.get_keras_autoencoder()
        autoencoder.load_weights(weights_path)
        directory, filename, extension = self.decompose_path(weights_path)
        autoencoder.save(os.path.join(directory, filename + '_model_included' + extension))

    def convert_keras_to_torch(self, full_model_path, print_models: bool = False, print_weights: bool = False):
        
        keras_model = load_model(full_model_path)
        if print_models:
            print(keras_model.summary())
        
        weights = keras_model.get_weights()

        torch_model = SAEModel()
        if print_models:
            summary(torch_model.cuda(), (1, 256, 256))

        if print_weights:
            print(f'Keras Weight {weights[0].shape} and Pytorch Weight {torch_model.conv1.weight.shape}')
            print(f'Keras Weight {weights[1].shape} and Pytorch Weight {torch_model.conv1.bias.shape}')
            print(f'Keras Weight {weights[2].shape} and Pytorch Weight {torch_model.conv2.weight.shape}')
            print(f'Keras Weight {weights[3].shape} and Pytorch Weight {torch_model.conv2.bias.shape}')
            print(f'Keras Weight {weights[4].shape} and Pytorch Weight {torch_model.conv3.weight.shape}')
            print(f'Keras Weight {weights[5].shape} and Pytorch Weight {torch_model.conv3.bias.shape}')
            print(f'Keras Weight {weights[6].shape} and Pytorch Weight {torch_model.conv4.weight.shape}')
            print(f'Keras Weight {weights[7].shape} and Pytorch Weight {torch_model.conv4.bias.shape}')
            print(f'Keras Weight {weights[8].shape} and Pytorch Weight {torch_model.conv5.weight.shape}')
            print(f'Keras Weight {weights[9].shape} and Pytorch Weight {torch_model.conv5.bias.shape}')
            print(f'Keras Weight {weights[10].shape} and Pytorch Weight {torch_model.conv6.weight.shape}')
            print(f'Keras Weight {weights[11].shape} and Pytorch Weight {torch_model.conv6.bias.shape}')
            print(f'Keras Weight {weights[12].shape} and Pytorch Weight {torch_model.decod.weight.shape}')
            print(f'Keras Weight {weights[13].shape} and Pytorch Weight {torch_model.decod.bias.shape}')

        # load keras weights into torch model
        torch_model.conv1.weight.data   = torch.from_numpy(weights[0])
        torch_model.conv1.bias.data     = torch.from_numpy(weights[1])
        torch_model.conv2.weight.data   = torch.from_numpy(weights[2])
        torch_model.conv2.bias.data     = torch.from_numpy(weights[3])
        torch_model.conv3.weight.data   = torch.from_numpy(weights[4])
        torch_model.conv3.bias.data     = torch.from_numpy(weights[5])
        torch_model.conv4.weight.data   = torch.from_numpy(weights[6])
        torch_model.conv4.bias.data     = torch.from_numpy(weights[7])
        torch_model.conv5.weight.data   = torch.from_numpy(weights[8])
        torch_model.conv5.bias.data     = torch.from_numpy(weights[9])
        torch_model.conv6.weight.data   = torch.from_numpy(weights[10])
        torch_model.conv6.bias.data     = torch.from_numpy(weights[11])
        torch_model.decod.weight.data   = torch.from_numpy(weights[12])
        torch_model.decod.bias.data     = torch.from_numpy(weights[13])

        directory, filename, extension = self.decompose_path(full_model_path)
        export_path = os.path.join(directory, filename + '_torch' + '.pth')

        torch.save(torch_model.state_dict(), export_path)

    def get_keras_autoencoder(self, input_size=256, nb_filter=96, k_size=5):

        input_img = Input(shape=(1, input_size, input_size))

        conv1 = Convolution2D(nb_filter, k_size, k_size, activation='relu', border_mode='same', name='conv1')(input_img)
        maxp1 = MaxPooling2D((2, 2), border_mode='same', name='maxp1')(conv1)
        
        conv2 = Convolution2D(nb_filter, k_size, k_size, activation='relu', border_mode='same', name='conv2')(maxp1)
        maxp2 = MaxPooling2D((2, 2), border_mode='same', name='maxp2')(conv2)
        
        conv3 = Convolution2D(nb_filter, k_size, k_size, activation='relu', border_mode='same', name='conv3')(maxp2)
        encoder = MaxPooling2D((2, 2), border_mode='same', name='encoder')(conv3)

        conv4 = Convolution2D(nb_filter, k_size, k_size, activation='relu', border_mode='same', name='conv4')(encoder)
        upsa1 = UpSampling2D((2, 2), name='upsa1')(conv4)
        
        conv4 = Convolution2D(nb_filter, k_size, k_size, activation='relu', border_mode='same', name='conv5')(upsa1)
        upsa2 = UpSampling2D((2, 2), name='upsa2')(conv4)
        
        conv5 = Convolution2D(nb_filter, k_size, k_size, activation='relu', border_mode='same', name='conv6')(upsa2)
        upsa3 = UpSampling2D((2, 2), name='upsa3')(conv5)
        
        decoder = Convolution2D(1, k_size, k_size, activation='sigmoid', border_mode='same')(upsa3)
        
        autoencoder = Model(input_img, decoder)
        
        return autoencoder

The weights transpose dump:

Keras Weight (96, 1, 5, 5) and Pytorch Weight torch.Size([96, 1, 5, 5])
Keras Weight (96,) and Pytorch Weight torch.Size([96])
Keras Weight (96, 96, 5, 5) and Pytorch Weight torch.Size([96, 96, 5, 5])
Keras Weight (96,) and Pytorch Weight torch.Size([96])
Keras Weight (96, 96, 5, 5) and Pytorch Weight torch.Size([96, 96, 5, 5])
Keras Weight (96,) and Pytorch Weight torch.Size([96])
Keras Weight (96, 96, 5, 5) and Pytorch Weight torch.Size([96, 96, 5, 5])
Keras Weight (96,) and Pytorch Weight torch.Size([96])
Keras Weight (96, 96, 5, 5) and Pytorch Weight torch.Size([96, 96, 5, 5])
Keras Weight (96,) and Pytorch Weight torch.Size([96])
Keras Weight (96, 96, 5, 5) and Pytorch Weight torch.Size([96, 96, 5, 5])
Keras Weight (96,) and Pytorch Weight torch.Size([96])
Keras Weight (1, 96, 5, 5) and Pytorch Weight torch.Size([1, 96, 5, 5])
Keras Weight (1,) and Pytorch Weight torch.Size([1])

All work well and I run the pytorch model using patched image (into 256x256 blocks). Unfortunatly, the model fail to do the job (sheet music staff-line remover) and got a slightly padded output (few pixels to right/bottom). I suspect a problem with the padding='same' that introduce some padding divergence.

Here is the input image :

Here is the same image binarized :

Here is the output image :

I precise that the code of the patching to process the image work well and well tested before, so the horizontal and vertical lines are not a patching problem but the result of the bad processing of the patch inside the model.

发布评论

评论列表(0)

  1. 暂无评论