最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - MobileNetV3-Large TFLite Model Works in Notebook But Always Predicts the Same Class in Flutter App - Stack Overflow

programmeradmin0浏览0评论

I trained a 46-class image classification model using MobileNetV3-Large with TensorFlow/Keras and saved it as a .keras model. I converted it to TFLite using tf.lite.TFLiteConverter.from_keras_model()

When testing the TFLite model on my notebook, it produces correct predictions. However, when I build my Flutter APK and run the model there, it always outputs the same class regardless of the input image. My previous model using MobileNetV2 worked correctly in Flutter. Here’s the training code and the conversion code I used. I suspect this might be related to differences in preprocessing or model conversion between MobileNetV2 and MobileNetV3-Large. Has anyone encountered this issue or have suggestions on what might be causing the constant output in the Flutter environment?

Below is the model training code

from tensorflow.keras.regularizers import l2
from tensorflow.keras.layers import Dropout
from tensorflow.keras.applications import MobileNetV3Large

l2_strength = 0.01

base_model = MobileNetV3Large(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
base_model.trainable = False  # Freeze the base model

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='swish', kernel_regularizer=l2(l2_strength))(x)
x = Dropout(0.5)(x)  # Add dropout
predictions = Dense(46, activation='softmax', kernel_regularizer=l2(l2_strength))(x)

model = Model(inputs=base_model.input, outputs=predictions)

modelpile(optimizer=Adam(learning_rate=0.0005), loss='categorical_crossentropy', metrics=['accuracy'])

from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.regularizers import l2
from tensorflow.keras.layers import Dense, Dropout

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-7)
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

history = model.fit(
    train_generator,
    steps_per_epoch=len(train_generator),
    validation_data=validation_generator,
    validation_steps=len(validation_generator),
    epochs=50,
    callbacks=[reduce_lr, early_stopping]
)
# Accuracy
test_loss, test_accuracy = model.evaluate(test_generator)
print(f'Test Accuracy: {test_accuracy * 100:.2f}%')
# Model saving
model.save('classifier_V3.keras')

Below is how I convert my keras model into tflite

import tensorflow as tf
from tensorflow.keras.models import load_model
model_path = "classifier_V3.keras"
tflite_model_path = "classifier_V3.tflite"

model = load_model(model_path)

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

with open(tflite_model_path, 'wb') as f:
    f.write(tflite_model)
print(f"TensorFlow Lite model saved to {tflite_model_path}")

If you have any clue please feel free to share

I am using tensorflow 2.18 and keras 3.9

I trained a 46-class image classification model using MobileNetV3-Large with TensorFlow/Keras and saved it as a .keras model. I converted it to TFLite using tf.lite.TFLiteConverter.from_keras_model()

When testing the TFLite model on my notebook, it produces correct predictions. However, when I build my Flutter APK and run the model there, it always outputs the same class regardless of the input image. My previous model using MobileNetV2 worked correctly in Flutter. Here’s the training code and the conversion code I used. I suspect this might be related to differences in preprocessing or model conversion between MobileNetV2 and MobileNetV3-Large. Has anyone encountered this issue or have suggestions on what might be causing the constant output in the Flutter environment?

Below is the model training code

from tensorflow.keras.regularizers import l2
from tensorflow.keras.layers import Dropout
from tensorflow.keras.applications import MobileNetV3Large

l2_strength = 0.01

base_model = MobileNetV3Large(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
base_model.trainable = False  # Freeze the base model

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='swish', kernel_regularizer=l2(l2_strength))(x)
x = Dropout(0.5)(x)  # Add dropout
predictions = Dense(46, activation='softmax', kernel_regularizer=l2(l2_strength))(x)

model = Model(inputs=base_model.input, outputs=predictions)

modelpile(optimizer=Adam(learning_rate=0.0005), loss='categorical_crossentropy', metrics=['accuracy'])

from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.regularizers import l2
from tensorflow.keras.layers import Dense, Dropout

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-7)
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

history = model.fit(
    train_generator,
    steps_per_epoch=len(train_generator),
    validation_data=validation_generator,
    validation_steps=len(validation_generator),
    epochs=50,
    callbacks=[reduce_lr, early_stopping]
)
# Accuracy
test_loss, test_accuracy = model.evaluate(test_generator)
print(f'Test Accuracy: {test_accuracy * 100:.2f}%')
# Model saving
model.save('classifier_V3.keras')

Below is how I convert my keras model into tflite

import tensorflow as tf
from tensorflow.keras.models import load_model
model_path = "classifier_V3.keras"
tflite_model_path = "classifier_V3.tflite"

model = load_model(model_path)

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

with open(tflite_model_path, 'wb') as f:
    f.write(tflite_model)
print(f"TensorFlow Lite model saved to {tflite_model_path}")

If you have any clue please feel free to share

I am using tensorflow 2.18 and keras 3.9

Share edited Mar 5 at 16:51 luz_de_guada asked Mar 5 at 16:48 luz_de_guadaluz_de_guada 13 bronze badges 2
  • Maybe first use print() (and print(type(...)), print(len(...)), etc.) to see which part of code is executed and what you really have in variables. It is called "print debugging" and it helps to see what code is really doing. – furas Commented Mar 6 at 5:33
  • maybe you could ask on similar portals DataScience, CrossValidated – furas Commented Mar 6 at 5:34
Add a comment  | 

1 Answer 1

Reset to default 0

After many trials and going through online documentation on Google, Flutter sides, I concluded that MobileNetV3 is not supported in Flutter.

Finally decided to pursue the process using ResNet50

与本文相关的文章

发布评论

评论列表(0)

  1. 暂无评论