最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - Warning when loading the model because of Adam optimizer - Stack Overflow

programmeradmin1浏览0评论

I am implementing the following multi-task model with the following layers:

@tf.keras.utils.register_keras_serializable()
class Bcb_block(Layer):
    def __init__(self, mid_chan, out_chan, data_format = 'channels_first', **kwargs):
        super(Bcb_block, self).__init__(**kwargs)
        self.mid_chan = mid_chan
        self.out_chan = out_chan
        self.data_format = data_format

        self.conv2d_3x3_1= Conv2D(filters = self.mid_chan, kernel_size = 3 , strides = 1, padding='same',data_format = self.data_format , use_bias = True )
        self.ReLU_1 = ReLU()

        self.conv2d_3x3_2= Conv2D(filters =self.out_chan, kernel_size = 3 , strides = 1, padding='same',data_format = self.data_format, use_bias = True )
        self.ReLU_2 = ReLU()


    def call(self, input_tensor):

        x = self.conv2d_3x3_1(input_tensor)
        x = self.ReLU_1(x)

        x = self.conv2d_3x3_2(x)
        x = self.ReLU_2(x)

        return x
        
    def build(self, input_shape):

        
        # Call the parent build method to finalize the building process
        super(Bcb_block, self).build(input_shape) 
        
    def get_config(self):
        config = super(Bcb_block, self).get_config()
        config.update({
            "mid_chan": self.mid_chan,
            "out_chan": self.out_chan,
            "data_format": self.data_format,
        })
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

@tf.keras.utils.register_keras_serializable()
class Upsampler2D(layers.Layer):
    def __init__(self, size = (2, 2), data_format='channels_first', **kwargs):
        super(Upsampler2D, self).__init__(**kwargs)
        self.size = size
        self.data_format = data_format
        self.upsample = UpSampling2D(size=size, interpolation='bilinear', data_format='channels_last')

    def call(self, inputs):
          # Step 1: Permute the input tensor to channels_last
          if self.data_format == 'channels_first':
              inputs = tf.transpose(inputs, perm=[0, 2, 3, 1])  # (batch, height, width, channels)
          # Step 2: Apply the UpSampling2D layer
          upsampled = self.upsample(inputs)
          # Step 3: Permute the output tensor back to channels_first
          if self.data_format == 'channels_first':
              upsampled = tf.transpose(upsampled, perm=[0, 3, 1, 2])  # (batch, channels, height, width)

          return upsampled
        
    def build(self, input_shape):

        
        # Call the parent build method to finalize the building process
        super(Upsampler2D, self).build(input_shape)    

    def get_config(self):
        """
        This method returns the configuration of the layer that will be used for serialization
        """
        config = super(Upsampler2D, self).get_config()
        config.update({
            'size': self.size,
            'data_format': self.data_format
        })
        return config

    @classmethod
    def from_config(cls, config):
        """
        This method is used for deserializing the configuration back into a layer.
        """
        return cls(**config)

Here is the model:

@tf.keras.utils.register_keras_serializable()
class UNet_plus_plus(Model):
    def __init__(self, data_format = 'channels_first', output_ch=1, **kwargs):
        super(UNet_plus_plus, self).__init__(**kwargs)
        self.data_format = data_format
        self.output_ch = output_ch
        self.axis = 1 if self.data_format == 'channels_first' else -1
        n1 = 32
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        self.sigmoid_1 = tf.keras.layers.Activation('sigmoid', name="mask_1_output")
        self.sigmoid_2 = tf.keras.layers.Activation('sigmoid', name="mask_2_output")
        self.Up = Upsampler2D(size=(2, 2), data_format='channels_first')
        self.pool = MaxPool2D(pool_size=(2, 2), strides=2, data_format= self.data_format)
        # Encoder BCB BLOCKS
        self.Bcb_block_0_0 = Bcb_block(mid_chan=filters[0], out_chan=filters[0], data_format='channels_first')
        self.Bcb_block_1_0 = Bcb_block(mid_chan=filters[1], out_chan=filters[1], data_format='channels_first')
        self.Bcb_block_2_0 = Bcb_block(mid_chan=filters[2], out_chan=filters[2], data_format='channels_first')
        self.Bcb_block_3_0 = Bcb_block(mid_chan=filters[3], out_chan=filters[3], data_format='channels_first')
        self.Bcb_block_4_0 = Bcb_block(mid_chan=filters[4], out_chan=filters[4], data_format='channels_first')
        # Level 1 BCB blocks of each skip connection
        self.Bcb_block_0_1 = Bcb_block(mid_chan=filters[0], out_chan=filters[0], data_format='channels_first')
        self.Bcb_block_1_1 = Bcb_block(mid_chan=filters[1], out_chan=filters[1], data_format='channels_first')
        self.Bcb_block_2_1 = Bcb_block(mid_chan=filters[2], out_chan=filters[2], data_format='channels_first')
        self.Bcb_block_3_1 = Bcb_block(mid_chan=filters[3], out_chan=filters[3], data_format='channels_first')
        self.Bcb_block_3_11 = Bcb_block(mid_chan= filters[3], out_chan = filters[3],data_format='channels_first')
        # Level 2 BCB blocks of each skip connection
        self.Bcb_block_0_2 = Bcb_block(mid_chan=filters[0], out_chan=filters[0], data_format='channels_first')
        self.Bcb_block_1_2 = Bcb_block(mid_chan=filters[1], out_chan=filters[1], data_format='channels_first')
        self.Bcb_block_2_2 = Bcb_block(mid_chan=filters[2], out_chan=filters[2], data_format='channels_first')
        self.Bcb_block_2_21 = Bcb_block(mid_chan=filters[2], out_chan=filters[2], data_format='channels_first')
        # Level 3 BCB blocks of each skip connection
        self.Bcb_block_0_3 = Bcb_block(mid_chan=filters[0], out_chan=filters[0], data_format='channels_first')
        self.Bcb_block_1_3 = Bcb_block(mid_chan=filters[1], out_chan=filters[1], data_format='channels_first')
        self.Bcb_block_1_31 = Bcb_block(mid_chan=filters[1], out_chan=filters[1], data_format='channels_first')
        # Level 4 BCB block of each skip connection
        self.Bcb_block_0_4 = Bcb_block(mid_chan=filters[0], out_chan=filters[0], data_format='channels_first')
        self.Bcb_block_0_41 = Bcb_block(mid_chan=filters[0], out_chan=filters[0], data_format='channels_first')

        
        self.final_0 = Conv2D(data_format = self.data_format, filters=output_ch, kernel_size=1, strides=1, padding='valid')
        self.final_1 = Conv2D(data_format = self.data_format, filters=output_ch, kernel_size=1, strides=1, padding='valid')


    def call(self,x):

      # Encoder

      x0_0 = self.Bcb_block_0_0(x)
      x1_0 = self.Bcb_block_1_0(self.pool(x0_0))
      x0_1 = self.Bcb_block_0_1(Concatenate(axis = self.axis)([x0_0,self.Up(x1_0)]))

      x2_0 = self.Bcb_block_2_0(self.pool(x1_0))
      x1_1 = self.Bcb_block_1_1(Concatenate(axis = self.axis)([x1_0 ,self.Up(x2_0)]))
      x0_2 = self.Bcb_block_0_2(Concatenate(axis = self.axis)([x0_0,x0_1 ,self.Up(x1_1)]))


      x3_0 = self.Bcb_block_3_0(self.pool(x2_0))
      x2_1 = self.Bcb_block_2_1(Concatenate(axis = self.axis)([x2_0 ,self.Up(x3_0)]))
      x1_2 = self.Bcb_block_1_2(Concatenate(axis = self.axis)([x1_0, x1_1, self.Up(x2_1)]))
      x0_3 = self.Bcb_block_0_3(Concatenate(axis = self.axis)([x0_0, x0_1, x0_2, self.Up(x1_2)]))

      x4_0 = self.Bcb_block_4_0(self.pool(x3_0))

     


      # Decoder 0


      x3_1 = self.Bcb_block_3_1(Concatenate(axis = self.axis)([ x3_0 ,self.Up(x4_0)]))

      x2_2 = self.Bcb_block_2_2(Concatenate(axis = self.axis)([x2_0, x2_1, self.Up(x3_1)]))
      

      x1_3 =  self.Bcb_block_1_3(Concatenate(axis = self.axis)([x1_0, x1_1, x1_2, self.Up(x2_2)]))

      x0_4 = self.Bcb_block_0_4(Concatenate(axis = self.axis)([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3)]))

      output0 = self.final_0(x0_4)

      output0 = self.sigmoid_1(output0)

      # Decoder 1

      x3_11 = self.Bcb_block_3_11(Concatenate(axis = self.axis)([ x3_0 ,self.Up(x4_0)]))

      x2_21 = self.Bcb_block_2_21(Concatenate(axis = self.axis)([x2_0, x2_1, self.Up(x3_11)]))
      

      x1_31 =  self.Bcb_block_1_31(Concatenate(axis = self.axis)([x1_0, x1_1, x1_2, self.Up(x2_21)]))

      x0_41 = self.Bcb_block_0_41(Concatenate(axis = self.axis)([x0_0, x0_1, x0_2, x0_3, self.Up(x1_31)]))
      
      output1 = self.final_1(x0_41)

      output1 = self.sigmoid_2(output1)


      return {"mask_1_output": output0, "mask_2_output": output1}

    def build(self, input_shape):

        
        # Call the parent build method to finalize the building process
        super(UNet_plus_plus, self).build(input_shape)
    def get_config(self):
        """Returns the configuration dictionary for model serialization."""
        config = super().get_config()
        config.update({
            "data_format": self.data_format,
            "output_ch": self.output_ch,
        })
        return config

    @classmethod
    def from_config(cls, config):
        """Creates a model instance from a configuration dictionary."""
        return cls(**config)

Also, I am using a custom loss function:

@register_keras_serializable()
def soft_dice_loss(y_true, y_pred):
    epsilon = 1e-8  # Add small epsilon to avoid division by zero
    
    # Calculate the numerator and denominator 
    numerator_dice_coef= 2 * tf.reduce_sum(y_true * y_pred, axis=(2, 3)) + epsilon
    
    den_dice_coef = (tf.reduce_sum(y_true * y_true, axis=(2, 3))) + (tf.reduce_sum(y_pred * y_pred, axis=(2, 3))) + epsilon
    
    # Dice coefficient per image in the batch
    dice_coef =  numerator_dice_coef / den_dice_coef
    
    # Average Dice coefficient over the batch
    #mean_dice_coef = tf.reduce_mean(dice_coef)

    # Average Dice coefficient over all channels for each sample
    mean_dice_coef_per_sample = tf.reduce_mean(dice_coef, axis=1)  # Shape: [B]
    
    
    return 1 -  mean_dice_coef_per_sample

When I train the model:

task_1_weight = 0.3  #  weight for task 1
task_2_weight = 0.7  #  weight for task 2
        
# Initialize the model
model= UNet_plus_plus(data_format='channels_first', output_ch=1)
adam = Adam(learning_rate=1e-4,beta_1 = 0.999, beta_2 = 0.999)
modelpile(optimizer=adam,
                            loss={'mask_1_output': soft_dice_loss,

                                  'mask_2_output': soft_dice_loss} ,
                            loss_weights={

                                  'mask_1_output': task_1_weight,

                                  'mask_2_output': task_2_weight}
                 
                            
                        )
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=30)
mc = ModelCheckpoint(f'/kaggle/working/best_model_iteration_{fold_num}.keras', 
                     monitor='val_loss', mode='min', verbose=1, save_best_only=True)
history = model.fit(x = train_dataset,
                        validation_data = val_dataset,
                        epochs= 3,
                        steps_per_epoch = steps_per_epoch,
                        callbacks = [es,mc]
                        )

After that I load the best model for evaluating on the test set:

custom_objects = {
'UNet_plus_plus': UNet_plus_plus,  # Add TokenTransformer class
'Upsampler2D': Upsampler2D,  # Add NeuralNet class 
'Bcb_block':Bcb_block,
'soft_dice_loss':soft_dice_loss    

}
model = load_model(f'/kaggle/working/best_model_iteration_{fold_num}.keras',custom_objects=custom_objects)

However, I always get this warning:

/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py:713: UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 162 variables. 
  saveable.load_own_variables(weights_store.get(inner_path))

I am not sure why I am getting this error and what I should do to fix it? I searched for this warning but non of the found solutions fixed my problem.

Update 1: I noted that if I set compile = False in load_model the error disappears, however the values for the testing metrics are very low : but if I dont load the best model and just test the model directly after training it ( where I just set the steps per epoch to 1 to have a fast training processes to just debug the model) I get a higher values for the metric:

So definitely this means that there is something wrong while loading the model.

Update 2: When I remove the build method in the 'UNet_plus_plus' the problem is fixed. Consequently, I am getting the warning

/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py:391: UserWarning: `build()` was called on layer 'u_net_plus_plus', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn( 

I don't know why this fixes the problem. Any ideas please ? Thanks in advance.

发布评论

评论列表(0)

  1. 暂无评论