使用自定义TFLITE的Firebase ML Kit对Android上的各种输出产生相同的推断

罗汉·波贾(Rohan Bojja)

我正在研究一种音频分类模型,该模型根据其类型对音频进行分类。

该模型吸收了一些音频特征,例如频谱质心等,并产生了诸如classic / rock / etc之类的输出。输入形状-> [1,26]这是一个多标签分类器。我有一个Keras模型,已将其转换为TFLite模型以用于移动平台。我已经测试了初始模型,并且它的准确性相当不错,当在我的PC上运行Python时,tflite模型也能正常工作。

当我将其部署到Firebase的ML Kit并与Android API结合使用时,它会生成单个标签/类作为各种输入的输出。我认为该模型没有问题,因为在我的Jupyter笔记本电脑上它可以正常工作。我不明白如何为相同的输入产生不同的推论?

硬模型:

#The test model
import tensorflow as tf
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras.layers import Dense, Dropout, Activation

model = models.Sequential()
model.add(layers.Dense(256, activation='relu', input_shape=(X_train.shape[1],)))
model.add(Dropout(0.5))
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.compile(optimizer='rmsprop',
              loss='sparse_categorical_crossentropy',
             metrics=['sparse_categorical_accuracy'])
history = model.fit(X_train,
                    y_train,
                    epochs=10)
#print(X_test[:1],y_test)
pred = model.predict_classes(X_test)
print(pred)
print(y_test)

转换码:

import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
                                       tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

输入/输出形状:

import tensorflow as tf
​
interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()
​
# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>
​
# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>
[ 1 26]
<class 'numpy.float32'>
[ 1 10]
<class 'numpy.float32'>

用于测试的演示Kotlin代码:

listenButton.setOnClickListener {
            incorrecttagButton.alpha = 1f
            incorrecttagButton.isClickable = true
            //Code for listening to music
           FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
               .addOnSuccessListener { isDownloaded ->
                   val options =
                       if (isDownloaded) {
                           FirebaseModelInterpreterOptions.Builder(remoteModel).build()
                       } else {
                           FirebaseModelInterpreterOptions.Builder(localModel).build()
                       }
                   Log.d("HUSKY","Downloaded? ${isDownloaded}")
                   val interpreter = FirebaseModelInterpreter.getInstance(options)
                   val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
                       .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 26))
                       .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1,10))
                       .build()
                   if(songNum==5){
                       songNum=0
                   }
                   val testSong = testsongs[songNum]
                   Log.d("HUSKY", "Song num = ${songNum} F = ${testSong} ")
                   val input = Array(1){FloatArray(26)}
                   val itr =  testSong.split(",").toTypedArray()
                   val preInput = itr.map { it.toFloat() }
                   var x = 0
                   preInput.forEach {
                       input[0][x] = preInput[x]
                       x+=1
                   }
                   //val input = preInput.toTypedArray()
                   Log.d("HUSKY", "${input[0][1]}")
                   val inputs = FirebaseModelInputs.Builder()
                       .add(input) // add() as many input arrays as your model requires
                       .build()

                   val labelArray = "blues classical country disco hiphop jazz metal pop reggae rock".split(" ").toTypedArray()
                   Log.d("HUSKY2", "GG")
                   interpreter?.run(inputs, inputOutputOptions)?.addOnSuccessListener { result ->
                       Log.d("HUSKY2", "GGWP")
                       val output = result.getOutput<Array<FloatArray>>(0)
                       val probabilities = output[0]
                       var bestMatch = 0f
                       var bestMatchIndex = 0
                       for (i in probabilities.indices){
                           if(probabilities[i]>bestMatch){
                               bestMatch = probabilities[i]
                               bestMatchIndex = i
                           }
                           Log.d("HUSKY2", "${labelArray[i]} ${probabilities[i]}")
                           genreLabel.text = labelArray[i]
                       }
                       genreLabel.text = labelArray[bestMatchIndex].capitalize()
                       confidenceLabel.text = probabilities[bestMatchIndex].toString()

                       // ...
                   }?.addOnFailureListener { e ->
                       // Task failed with an exception
                       // ...
                       Log.d("HUSKY2", "GGWP :( ${e.toString()}")
                   }

               }

我正在使用SongNum来增加字符串数组来更改歌曲。要素存储为字符串,并以逗号作为分隔符。

不管输入的功能如何(SongNum变量以更改歌曲[0-4]),输出都是相同的,并且对流行的置信度始终为1.0:

2020-02-25 00:11:21.014 17434-17434/com.rohanbojja.audient D/HUSKY: Downloaded? true
2020-02-25 00:11:21.015 17434-17434/com.rohanbojja.audient D/HUSKY: Song num = 0 F = 0.3595172803692916,0.04380025714635849,1365.710742222286,1643.935571084307,2725.445556640625,0.06513807508680555,-273.0061247040518,132.66331747988934,-31.86709317807114,44.21442952318603,4.335704872427025,32.32360339344842,-2.4662076330637714,20.458242724823684,-4.760171779927926,20.413702740993585,3.69545905318442,8.581128171784677,-15.601809275025104,5.295758930950924,-5.270195074271744,5.895109210872318,-6.1406603018722645,-2.9278519508415286,-1.9189588023091468,5.954495267889836 
2020-02-25 00:11:21.016 17434-17434/com.rohanbojja.audient D/HUSKY: 0.043800257
2020-02-25 00:11:21.016 17434-17434/com.rohanbojja.audient D/HUSKY2: GG
2020-02-25 00:11:21.021 17434-17434/com.rohanbojja.audient D/HUSKY2: GGWP
2020-02-25 00:11:21.021 17434-17434/com.rohanbojja.audient D/HUSKY2: blues 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: classical 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: country 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: disco 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: hiphop 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: jazz 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: metal 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: pop 1.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: reggae 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: rock 0.0

Jupyter Notebook上的输出如下:

(blues,)    (classical,)    (country,)  (disco,)    (hiphop,)   (jazz,) (metal,)    (pop,)  (reggae,)   (rock,)
0   0.257037    0.000705    0.429687    0.030933    0.009291    0.004909    1.734001e-03    0.000912    0.203305    0.061488

从我可以得出的结论来看,我搞砸了ML Kit API的用法吗?还是我传递输入数据或检索输出数据的方式?我是android开发的新手。

输出:“ pop”的置信度始终为1.0!预期的输出:每个流派都应该在[0-1.0]之间有一定的信心,而不是总是“流行”,就像我从Jupyter笔记本获得的结果一样。

抱歉,代码混乱。

任何帮助将不胜感激!

更新1:我用S型激活函数交换了relu,我注意到了区别。它仍然几乎总是“流行”,但有大约0.30的置信度。现在超级神秘。仅发生在ML Kit BTW上,还没有真正尝试在本机上实现它。

更新2:我不明白如何用相同的模型获得不同的推论。我迷路了。

罗汉·波贾(Rohan Bojja)

在预测阶段提取出的特征后,我还没有对其进行归一化处理,也就是说,提取出的特征不会进行转换。

我已经用

X = StandardScaler().fit_transform(np.array(data.iloc[:,1:-1]))

为了解决此问题,我必须转换功能:

scaler=StandardScaler().fit(np.array(data.iloc[:,1:-1]))
input_data = scaler.transform(input_data2)

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章

在AWS SageMaker上酝酿自定义ML模型

我使用TFLiteConvert post_training_quantize = True,但是我的模型仍然太大,无法托管在Firebase ML Kit的自定义服务器中

Firebase ML Kit proguard 问题

如何在Google Cloud ML上使用自定义预处理和数据文件进行推理

无法从ML Kit加载自定义模型:FirebaseMLException:加载任务失败

Core ML上具有两个参数功能的自定义层

在Azure ML中使用自定义docker

Firebase ML KIT无法识别古吉拉特语

iOS Firebase ML Kit 简单音频识别“无法为给定模型创建 TFLite 解释器”

如何使用Firebase ML Kit查找图像中的标记?

如何使用Firebase ML Kit识别条形码?

在{TF 2.0.0-beta1上使用tflite_convert时,“未知(自定义)损失函数”;Keras}模型

在 Android 中使用 ML Kit 的人脸置信度

自定义模型 [MLKit] - FirebaseMLException:执行 Firebase ML 任务时发生内部错误

在Cloud ML Engine上使用TPU

Firebase Android ML Kit:在QR代码上隐藏显示值的方法

Google Cloud ML随附的TPU自定义芯片

在PySpark ML中创建自定义变压器

在Azure ML Studio上部署自定义模型

在没有cocapods的iOS上使用Firebase ML Kit时,GoogleMobileVision中的链接器错误

使用 Azure ML 中的自定义筛选器将评级列转换为布尔列

如何使用自定义环境和管道对Azure ML工作区进行版本控制?

使用R在Azure ML Jupyter / iPython Notebook中下载自定义数据集

在 Google Cloud ML Engine 中使用自定义依赖项

添加 firebase ML-kit 时的依赖项冲突

Firebase ML Kit可以用作人脸验证吗

Firebase ML Kit 人脸检测,无法检索实例 ID

Tflite模型在Android(ml视觉)和Python中提供不同的输出

将ML Kit与NNAPI一起使用