I am implementing the following multi-task model with the following layers:
@tf.keras.utils.register_keras_serializable()
class Bcb_block(Layer):
def __init__(self, mid_chan, out_chan, data_format = 'channels_first', **kwargs):
super(Bcb_block, self).__init__(**kwargs)
self.mid_chan = mid_chan
self.out_chan = out_chan
self.data_format = data_format
self.conv2d_3x3_1= Conv2D(filters = self.mid_chan, kernel_size = 3 , strides = 1, padding='same',data_format = self.data_format , use_bias = True )
self.ReLU_1 = ReLU()
self.conv2d_3x3_2= Conv2D(filters =self.out_chan, kernel_size = 3 , strides = 1, padding='same',data_format = self.data_format, use_bias = True )
self.ReLU_2 = ReLU()
def call(self, input_tensor):
x = self.conv2d_3x3_1(input_tensor)
x = self.ReLU_1(x)
x = self.conv2d_3x3_2(x)
x = self.ReLU_2(x)
return x
def build(self, input_shape):
# Call the parent build method to finalize the building process
super(Bcb_block, self).build(input_shape)
def get_config(self):
config = super(Bcb_block, self).get_config()
config.update({
"mid_chan": self.mid_chan,
"out_chan": self.out_chan,
"data_format": self.data_format,
})
return config
@classmethod
def from_config(cls, config):
return cls(**config)
@tf.keras.utils.register_keras_serializable()
class Upsampler2D(layers.Layer):
def __init__(self, size = (2, 2), data_format='channels_first', **kwargs):
super(Upsampler2D, self).__init__(**kwargs)
self.size = size
self.data_format = data_format
self.upsample = UpSampling2D(size=size, interpolation='bilinear', data_format='channels_last')
def call(self, inputs):
# Step 1: Permute the input tensor to channels_last
if self.data_format == 'channels_first':
inputs = tf.transpose(inputs, perm=[0, 2, 3, 1]) # (batch, height, width, channels)
# Step 2: Apply the UpSampling2D layer
upsampled = self.upsample(inputs)
# Step 3: Permute the output tensor back to channels_first
if self.data_format == 'channels_first':
upsampled = tf.transpose(upsampled, perm=[0, 3, 1, 2]) # (batch, channels, height, width)
return upsampled
def build(self, input_shape):
# Call the parent build method to finalize the building process
super(Upsampler2D, self).build(input_shape)
def get_config(self):
"""
This method returns the configuration of the layer that will be used for serialization
"""
config = super(Upsampler2D, self).get_config()
config.update({
'size': self.size,
'data_format': self.data_format
})
return config
@classmethod
def from_config(cls, config):
"""
This method is used for deserializing the configuration back into a layer.
"""
return cls(**config)
Here is the model:
@tf.keras.utils.register_keras_serializable()
class UNet_plus_plus(Model):
def __init__(self, data_format = 'channels_first', output_ch=1, **kwargs):
super(UNet_plus_plus, self).__init__(**kwargs)
self.data_format = data_format
self.output_ch = output_ch
self.axis = 1 if self.data_format == 'channels_first' else -1
n1 = 32
filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
self.sigmoid_1 = tf.keras.layers.Activation('sigmoid', name="mask_1_output")
self.sigmoid_2 = tf.keras.layers.Activation('sigmoid', name="mask_2_output")
self.Up = Upsampler2D(size=(2, 2), data_format='channels_first')
self.pool = MaxPool2D(pool_size=(2, 2), strides=2, data_format= self.data_format)
# Encoder BCB BLOCKS
self.Bcb_block_0_0 = Bcb_block(mid_chan=filters[0], out_chan=filters[0], data_format='channels_first')
self.Bcb_block_1_0 = Bcb_block(mid_chan=filters[1], out_chan=filters[1], data_format='channels_first')
self.Bcb_block_2_0 = Bcb_block(mid_chan=filters[2], out_chan=filters[2], data_format='channels_first')
self.Bcb_block_3_0 = Bcb_block(mid_chan=filters[3], out_chan=filters[3], data_format='channels_first')
self.Bcb_block_4_0 = Bcb_block(mid_chan=filters[4], out_chan=filters[4], data_format='channels_first')
# Level 1 BCB blocks of each skip connection
self.Bcb_block_0_1 = Bcb_block(mid_chan=filters[0], out_chan=filters[0], data_format='channels_first')
self.Bcb_block_1_1 = Bcb_block(mid_chan=filters[1], out_chan=filters[1], data_format='channels_first')
self.Bcb_block_2_1 = Bcb_block(mid_chan=filters[2], out_chan=filters[2], data_format='channels_first')
self.Bcb_block_3_1 = Bcb_block(mid_chan=filters[3], out_chan=filters[3], data_format='channels_first')
self.Bcb_block_3_11 = Bcb_block(mid_chan= filters[3], out_chan = filters[3],data_format='channels_first')
# Level 2 BCB blocks of each skip connection
self.Bcb_block_0_2 = Bcb_block(mid_chan=filters[0], out_chan=filters[0], data_format='channels_first')
self.Bcb_block_1_2 = Bcb_block(mid_chan=filters[1], out_chan=filters[1], data_format='channels_first')
self.Bcb_block_2_2 = Bcb_block(mid_chan=filters[2], out_chan=filters[2], data_format='channels_first')
self.Bcb_block_2_21 = Bcb_block(mid_chan=filters[2], out_chan=filters[2], data_format='channels_first')
# Level 3 BCB blocks of each skip connection
self.Bcb_block_0_3 = Bcb_block(mid_chan=filters[0], out_chan=filters[0], data_format='channels_first')
self.Bcb_block_1_3 = Bcb_block(mid_chan=filters[1], out_chan=filters[1], data_format='channels_first')
self.Bcb_block_1_31 = Bcb_block(mid_chan=filters[1], out_chan=filters[1], data_format='channels_first')
# Level 4 BCB block of each skip connection
self.Bcb_block_0_4 = Bcb_block(mid_chan=filters[0], out_chan=filters[0], data_format='channels_first')
self.Bcb_block_0_41 = Bcb_block(mid_chan=filters[0], out_chan=filters[0], data_format='channels_first')
self.final_0 = Conv2D(data_format = self.data_format, filters=output_ch, kernel_size=1, strides=1, padding='valid')
self.final_1 = Conv2D(data_format = self.data_format, filters=output_ch, kernel_size=1, strides=1, padding='valid')
def call(self,x):
# Encoder
x0_0 = self.Bcb_block_0_0(x)
x1_0 = self.Bcb_block_1_0(self.pool(x0_0))
x0_1 = self.Bcb_block_0_1(Concatenate(axis = self.axis)([x0_0,self.Up(x1_0)]))
x2_0 = self.Bcb_block_2_0(self.pool(x1_0))
x1_1 = self.Bcb_block_1_1(Concatenate(axis = self.axis)([x1_0 ,self.Up(x2_0)]))
x0_2 = self.Bcb_block_0_2(Concatenate(axis = self.axis)([x0_0,x0_1 ,self.Up(x1_1)]))
x3_0 = self.Bcb_block_3_0(self.pool(x2_0))
x2_1 = self.Bcb_block_2_1(Concatenate(axis = self.axis)([x2_0 ,self.Up(x3_0)]))
x1_2 = self.Bcb_block_1_2(Concatenate(axis = self.axis)([x1_0, x1_1, self.Up(x2_1)]))
x0_3 = self.Bcb_block_0_3(Concatenate(axis = self.axis)([x0_0, x0_1, x0_2, self.Up(x1_2)]))
x4_0 = self.Bcb_block_4_0(self.pool(x3_0))
# Decoder 0
x3_1 = self.Bcb_block_3_1(Concatenate(axis = self.axis)([ x3_0 ,self.Up(x4_0)]))
x2_2 = self.Bcb_block_2_2(Concatenate(axis = self.axis)([x2_0, x2_1, self.Up(x3_1)]))
x1_3 = self.Bcb_block_1_3(Concatenate(axis = self.axis)([x1_0, x1_1, x1_2, self.Up(x2_2)]))
x0_4 = self.Bcb_block_0_4(Concatenate(axis = self.axis)([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3)]))
output0 = self.final_0(x0_4)
output0 = self.sigmoid_1(output0)
# Decoder 1
x3_11 = self.Bcb_block_3_11(Concatenate(axis = self.axis)([ x3_0 ,self.Up(x4_0)]))
x2_21 = self.Bcb_block_2_21(Concatenate(axis = self.axis)([x2_0, x2_1, self.Up(x3_11)]))
x1_31 = self.Bcb_block_1_31(Concatenate(axis = self.axis)([x1_0, x1_1, x1_2, self.Up(x2_21)]))
x0_41 = self.Bcb_block_0_41(Concatenate(axis = self.axis)([x0_0, x0_1, x0_2, x0_3, self.Up(x1_31)]))
output1 = self.final_1(x0_41)
output1 = self.sigmoid_2(output1)
return {"mask_1_output": output0, "mask_2_output": output1}
def build(self, input_shape):
# Call the parent build method to finalize the building process
super(UNet_plus_plus, self).build(input_shape)
def get_config(self):
"""Returns the configuration dictionary for model serialization."""
config = super().get_config()
config.update({
"data_format": self.data_format,
"output_ch": self.output_ch,
})
return config
@classmethod
def from_config(cls, config):
"""Creates a model instance from a configuration dictionary."""
return cls(**config)
Also, I am using a custom loss function:
@register_keras_serializable()
def soft_dice_loss(y_true, y_pred):
epsilon = 1e-8 # Add small epsilon to avoid division by zero
# Calculate the numerator and denominator
numerator_dice_coef= 2 * tf.reduce_sum(y_true * y_pred, axis=(2, 3)) + epsilon
den_dice_coef = (tf.reduce_sum(y_true * y_true, axis=(2, 3))) + (tf.reduce_sum(y_pred * y_pred, axis=(2, 3))) + epsilon
# Dice coefficient per image in the batch
dice_coef = numerator_dice_coef / den_dice_coef
# Average Dice coefficient over the batch
#mean_dice_coef = tf.reduce_mean(dice_coef)
# Average Dice coefficient over all channels for each sample
mean_dice_coef_per_sample = tf.reduce_mean(dice_coef, axis=1) # Shape: [B]
return 1 - mean_dice_coef_per_sample
When I train the model:
task_1_weight = 0.3 # weight for task 1
task_2_weight = 0.7 # weight for task 2
# Initialize the model
model= UNet_plus_plus(data_format='channels_first', output_ch=1)
adam = Adam(learning_rate=1e-4,beta_1 = 0.999, beta_2 = 0.999)
modelpile(optimizer=adam,
loss={'mask_1_output': soft_dice_loss,
'mask_2_output': soft_dice_loss} ,
loss_weights={
'mask_1_output': task_1_weight,
'mask_2_output': task_2_weight}
)
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=30)
mc = ModelCheckpoint(f'/kaggle/working/best_model_iteration_{fold_num}.keras',
monitor='val_loss', mode='min', verbose=1, save_best_only=True)
history = model.fit(x = train_dataset,
validation_data = val_dataset,
epochs= 3,
steps_per_epoch = steps_per_epoch,
callbacks = [es,mc]
)
After that I load the best model for evaluating on the test set:
custom_objects = {
'UNet_plus_plus': UNet_plus_plus, # Add TokenTransformer class
'Upsampler2D': Upsampler2D, # Add NeuralNet class
'Bcb_block':Bcb_block,
'soft_dice_loss':soft_dice_loss
}
model = load_model(f'/kaggle/working/best_model_iteration_{fold_num}.keras',custom_objects=custom_objects)
However, I always get this warning:
/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py:713: UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 162 variables.
saveable.load_own_variables(weights_store.get(inner_path))
I am not sure why I am getting this error and what I should do to fix it? I searched for this warning but non of the found solutions fixed my problem.
Update 1: I noted that if I set compile = False
in load_model
the error disappears, however the values for the testing metrics are very low :
but if I dont load the best model and just test the model directly after training it ( where I just set the steps per epoch to 1 to have a fast training processes to just debug the model) I get a higher values for the metric:
So definitely this means that there is something wrong while loading the model.
Update 2: When I remove the build method in the 'UNet_plus_plus' the problem is fixed. Consequently, I am getting the warning
/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py:391: UserWarning: `build()` was called on layer 'u_net_plus_plus', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
warnings.warn(
I don't know why this fixes the problem. Any ideas please ? Thanks in advance.