如果我們得到了tflite文件,如何在python中使用?這裡可以在tensorflow庫的幫助下或者tflite_runtime庫的幫助下使用
tensorflow庫中有個lite子庫,是為tflite而設計的
給出示例代碼:
import tensorflow as tf
import cv2
import numpy as np
def preprocess(image): # 輸入圖像預處理
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = cv2.resize(image, (64, 64))
tensor = np.expand_dims(image, axis=[0, -1])
tensor = tensor.astype('float32')
return tensor
# API文檔:https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter#args_1
emotion_model_tflite = tf.lite.Interpreter("output.tflite") # 加載tflite模型
emotion_model_tflite.allocate_tensors() # 預先計劃張量分配以優化推理
tflife_input_details = emotion_model_tflite.get_input_details() # 獲取輸入節點的細節
tflife_output_details = emotion_model_tflite.get_output_details() # 獲取輸出節點的細節
# 加載並處理成輸入張量,和keras推理或者tensorflow推理的輸入張量一樣
img = cv2.imread("1fae49da5f2472cf260e3d0aa08d7e32.jpeg")
input_tensor = preprocess(img)
# 填入輸入tensor
emotion_model_tflite.set_tensor(tflife_input_details[0]['index'], input_tensor)
# 運行推理
emotion_model_tflite.invoke()
# 獲取推理結果
custom = emotion_model_tflite.get_tensor(tflife_output_details[0]['index'])
print(custom)
見名知意,tflite_runtime就是tflite的運行環境庫。因為tensorflow畢竟太大了,如果我們只是想使用tflite模型推理,那麼使用該庫是個不錯的選擇
首先在安裝 TensorFlow Lite 解釋器根據你的平台和python版本,下載對應的whl文件,然後使用pip安裝即可:pip install 下載的whl文件路徑
先給出代碼:
import tflite_runtime.interpreter as tflite # 改動一
import cv2
import numpy as np
def preprocess(image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = cv2.resize(image, (64, 64))
tensor = np.expand_dims(image, axis=[0, -1])
tensor = tensor.astype('float32')
return tensor
emotion_model_tflite = tflite.Interpreter("output.tflite") # 改動二
emotion_model_tflite.allocate_tensors()
tflife_input_details = emotion_model_tflite.get_input_details()
tflife_output_details = emotion_model_tflite.get_output_details()
img = cv2.imread("1fae49da5f2472cf260e3d0aa08d7e32.jpeg")
input_tensor = preprocess(img)
emotion_model_tflite.set_tensor(tflife_input_details[0]['index'], input_tensor)
emotion_model_tflite.invoke()
custom = emotion_model_tflite.get_tensor(tflife_output_details[0]['index'])
print(custom)