TFLite 使用训练自己的数据集并部署到手机上

TFLite部分

第一步:下载代码

打开命令行工具输入以下命令行,从google 的codelabs下载代码:

git clone https://github.com/googlecodelabs/tensorflow-for-poets-2

下载完后,会生成一个叫“tensorflow-for-poets”的文件夹。

代码下载

文件夹内容的组成如下:

  • scripts-----包含机器学习的python代码文件
  • tf_files-----包含输出文件,比如graph.pb和labels.txt
  • android-----包含安卓app项目,又分为tfmobile和TFLite
  • iOS----包含ios App的项目,需要使用xCode

第二步:下载数据集

下载链接

点击下载链接,下载约200MB的公开数据集,该数据集包含五种分类的花:Rose(玫瑰花)、Daisy(雏菊)、Dandelion(蒲公英)、Sunflower(向日葵)

解压缩到tf_files > flowes_photos目录

第三步:重新训练模型

在tensorflow-for-poets-2目录中打开命令行工具,输入:

python scripts/retrain.py --output_graph=tf_files/retrained_graph.pb --output_labels=tf_files/retrained_labels.txt --image_dir=tf_files/flower_photos --architecture=mobilenet_1.0_224 --summaries_dir tf_files/training_summaries/mobilenet_1.0_224

然后开始下载预训练的Mobilenet_1.0_224 的 frozen graph ;并且在tf_files目录中生成 retrained_graph.pbretrained_labels.txt 文件

第四步:打开Tensorboard(可跳过)

在Tensorboard中可以观察准确度和交叉熵损失函数的变化。

tensorboard --logdir=tf_files/training_summaries/mobilenet_1.0_244

第五步:确认模型的有效性

从互联网随机下载一张花的图片,放入工作目录中,查看模型识别结果

python scripts/label_image --graph=tf_files/retrained_graph.pb --image=new_rose.jpg

第六步:将模型转成TFLIte格式

系统要求:Ubuntu

Toco 使用来将pb文件文件转成.lite格式文件的转换器,更多细节可以使用 toco --help 查看说明。

IMAGE_SIZE=224
toco --graph_def_file=tf_files/retrained_graph.pb --output_file=tf_files/optimized_graph.lite --output_format=TFLITE --input_shape=1,${IMAGE_SIZE},${IMAGE_SIZE},3 --input_array=input \
--output_array=final_result --inference_type=FLOAT --inference_input_type=FLOAT

上述命令会在tf_files目录中生成optimized_graph.lite 文件

tips: 1)--input_file 已经更新成 --graph_def_file 2)--input_format 对于mobile_net 的计算图没有必要性

移动端部分

Android

第一步:模型和标签的替换

tf_files 中生成的 optimized_graph.literetrained_labels.txt,复制到android >tflite项目的assets中,并替换原有的graph.litelabels.txt 文件

cp tf_files/optimized_graph.lite android/tflite/app/src/main/assets/graph.lite

cp tf_files/retrained_labels.txt android/tflite/app/src/main/assets/labels.txt

第二步:生成app

打开Android Studio,并打开已有项目,选中android/tflite目录,点击BUILD-->Bulid APK, app-debug.apk文件就会产生,然后安装到安卓手机上。

IOS

  1. 安装Xcode

    xcode-select --install
    
  2. 安装Cocoapods

    sudo gem install cocoapods
    
  3. 安装 TFlite Cocoapod

    pod install --project-directory=ios/tflite/
    
  4. 替换模型和文件

    cp tf_files/optimized_graph.lite ios/tflite/data/graph.lite
    cp tf_files/retrained_labels.txt ios/tflite/data/labels.txt
    
  5. 打开模拟器,运行项目,查看结果。