I'm trying to develop an Android app that can detect text in manga, to later allow the user to click on the text and get a definition for each word.
I'm on the OCR part and trying to use the OCR model of manga-ocr in Android. Through search, I found that I can use the onnx.model directly in Android, but I'm stuck when using it.
object ONNXModelHelper {
private fun loadModelFromAssets(context: Context, filename: String): File {
val file = File(context.filesDir, filename)
if (!file.exists()) {
context.assets.open(filename).use { inputStream ->
file.outputStream().use { outputStream ->
inputStream.copyTo(outputStream)
}
}
}
return file
}
private fun loadONNXModel(context: Context): OrtSession {
val env = OrtEnvironment.getEnvironment()
val modelFile = loadModelFromAssets(context, "model2.onnx")
val options = OrtSession.SessionOptions()
return env.createSession(modelFile.absolutePath, options)
}
@Synchronized
fun init(context: Context) {
if (session == null) {
session = loadONNXModel(context)
}
}
private var session: OrtSession? = null
fun process(bitmap: Bitmap) {
val session = session ?: return
// Convert the image to a tensor
val inputTensor = convertToTensor(bitmap)
// Create the decoder_input_ids tensor
val dummyDecoderInputIds = createDecoderInputIds()
val inputs = mutableMapOf<String, OnnxTensor>()
inputs["pixel_values"] = inputTensor
inputs["decoder_input_ids"] = dummyDecoderInputIds
val result = session.run(inputs)
val output = result.get(0).value as FloatArray
println("Inference Output: ${output.joinToString()}")
}
fun createDecoderInputIds(): OnnxTensor {
val startTokenId = 2L
val inputData = LongBuffer.wrap(longArrayOf(startTokenId))
return OnnxTensor.createTensor(OrtEnvironment.getEnvironment(), inputData, longArrayOf(1, 1))
}
private val channels = 3 // For RGB
private val width = 384
private var height = 384
private fun preprocessImage(bitmap: Bitmap): Pair<FloatBuffer, LongArray> {
val resizedBitmap = Bitmap.createScaledBitmap(bitmap, width, height, true)
val pixels = IntArray(width * height)
resizedBitmap.getPixels(pixels, 0, width, 0, 0, width, height)
val buffer = FloatBuffer.allocate(width * height * channels)
for (pixel in pixels) {
val r = (pixel shr 16 and 0xFF) / 255f
val g = (pixel shr 8 and 0xFF) / 255f
val b = (pixel and 0xFF) / 255f
buffer.put(r)
buffer.put(g)
buffer.put(b)
}
buffer.rewind()
return buffer to longArrayOf(1, channels.toLong(), height.toLong().toLong(), width.toLong())
}
private fun convertToTensor(bitmap: Bitmap): OnnxTensor {
val floatArray = preprocessImage(bitmap)
// Create a tensor with shape [1, 3, 384, 384]
return OnnxTensor.createTensor(
OrtEnvironment.getEnvironment(),
floatArray.first,
floatArray.second
)
}
}
If somebody has already done something similar or knows how to do it, could you provide me with some help?