程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
您现在的位置: 程式師世界 >> 編程語言 >  >> 更多編程語言 >> Python

獲取 onnx 模型的輸入輸出信息 Python 腳本

編輯:Python

直接將以下腳本onnx模型路徑 onnx_path 修改即可:

from pprint import pprint
import onnxruntime
onnx_path = "originalpool/output.onnx"
# onnx_path = "custompool/output.onnx"
provider = "CPUExecutionProvider"
onnx_session = onnxruntime.InferenceSession(onnx_path, providers=[provider])
print("----------------- 輸入部分 -----------------")
input_tensors = onnx_session.get_inputs() # 該 API 會返回列表
for input_tensor in input_tensors: # 因為可能有多個輸入,所以為列表
input_info = {

"name" : input_tensor.name,
"type" : input_tensor.type,
"shape": input_tensor.shape,
}
pprint(input_info)
print("----------------- 輸出部分 -----------------")
output_tensors = onnx_session.get_outputs() # 該 API 會返回列表
for output_tensor in output_tensors: # 因為可能有多個輸出,所以為列表
output_info = {

"name" : output_tensor.name,
"type" : output_tensor.type,
"shape": output_tensor.shape,
}
pprint(output_info)

值得說明的是,如果onnx模型的輸入shape是固定的,該腳本的輸出是:

'----------------- 輸入部分 -----------------'
{'name': 'x',
'shape': [1, 3, 224, 224],
'type': 'tensor(float)'}
'----------------- 輸出部分 -----------------'
{'name': 'bilinear_interp_v2_7.tmp_0',
'shape': [1, 2, 224, 224],
'type': 'tensor(float)'}

值得說明的是,如果onnx模型的輸入shape是非固定的,該腳本的輸出是:

----------------- 輸入部分 -----------------
{'name': 'x',
'shape': [None, 3, None, None],
'type': 'tensor(float)'}
----------------- 輸出部分 -----------------
{'name': 'bilinear_interp_v2_7.tmp_0',
'shape': [None, 2, None, None],
'type': 'tensor(float)'}

shape 中有 None


  1. 上一篇文章:
  2. 下一篇文章:
Copyright © 程式師世界 All Rights Reserved