You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
38 lines
1.7 KiB
Python
38 lines
1.7 KiB
Python
import numpy as np
|
|
import MNN
|
|
from loguru import logger
|
|
|
|
|
|
class MNNAdapter(object):
|
|
|
|
def __init__(self, model_path: str, input_shape: tuple,
|
|
dim_type: int = MNN.Tensor_DimensionType_Caffe, outputs_name=None, outputs_shape=None):
|
|
self.interpreter = MNN.Interpreter(model_path)
|
|
self.session = self.interpreter.createSession()
|
|
self.input_tensor = self.interpreter.getSessionInput(self.session)
|
|
self.input_shape = input_shape
|
|
self.dim_type = dim_type
|
|
self.outputs_name = outputs_name
|
|
self.outputs_shape = outputs_shape
|
|
|
|
def inference(self, tensor: np.ndarray) -> np.ndarray:
|
|
tensor = tensor.astype(np.float32)
|
|
tmp_input = MNN.Tensor(self.input_shape, MNN.Halide_Type_Float, tensor, self.dim_type)
|
|
self.input_tensor.copyFrom(tmp_input)
|
|
self.interpreter.runSession(self.session)
|
|
output_tensor = list()
|
|
if self.outputs_name:
|
|
if self.outputs_shape:
|
|
for idx, shape in enumerate(self.outputs_shape):
|
|
tmp_output = MNN.Tensor(shape, MNN.Halide_Type_Float, np.ones(shape).astype(np.float32), self.dim_type)
|
|
tmp_tensor = self.interpreter.getSessionOutput(self.session, self.outputs_name[idx])
|
|
tmp_tensor.copyToHostTensor(tmp_output)
|
|
output_tensor.append(np.asarray(tmp_output.getData()))
|
|
else:
|
|
output_tensor = [np.asarray(self.interpreter.getSessionOutput(self.session, name).getData()) for name in self.outputs_name]
|
|
else:
|
|
output_tensor.append(self.interpreter.getSessionOutput(self.session).getData())
|
|
res = np.asarray(output_tensor)
|
|
|
|
return res
|