| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 | 
							- import onnxruntime
 
- import numpy as np
 
- class OnnxModel(object):
 
-     def __init__(self, model_path):
 
-         sess_options = onnxruntime.SessionOptions()
 
-         # # Set graph optimization level to ORT_ENABLE_EXTENDED to enable bert optimization.
 
-         # sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
 
-         # # Use OpenMP optimizations. Only useful for CPU, has little impact for GPUs.
 
-         # sess_options.intra_op_num_threads = multiprocessing.cpu_count()
 
-         onnx_gpu = (onnxruntime.get_device() == 'GPU')
 
-         providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if onnx_gpu else ['CPUExecutionProvider']
 
-         self.sess = onnxruntime.InferenceSession(model_path, sess_options, providers=providers)
 
-         self._input_names = [item.name for item in self.sess.get_inputs()]
 
-         self._output_names = [item.name for item in self.sess.get_outputs()]
 
-         
 
-     @property
 
-     def input_names(self):
 
-         return self._input_names
 
-         
 
-     @property
 
-     def output_names(self):
 
-         return self._output_names
 
-         
 
-     def forward(self, inputs):
 
-         to_list_flag = False
 
-         if not isinstance(inputs, (tuple, list)):
 
-             inputs = [inputs]
 
-             to_list_flag = True
 
-         input_feed = {name: input for name, input in zip(self.input_names, inputs)}
 
-         outputs = self.sess.run(self.output_names, input_feed)
 
-         if (len(self.output_names) == 1) and to_list_flag:
 
-             return outputs[0]
 
-         else:
 
-             return outputs
 
-             
 
- def check_image_dtype_and_shape(image):
 
-     if not isinstance(image, np.ndarray):
 
-         raise Exception(f'image is not np.ndarray!')
 
-     if isinstance(image.dtype, (np.uint8, np.uint16)):
 
-         raise Exception(f'Unsupported image dtype, only support uint8 and uint16, got {image.dtype}!')
 
-     if image.ndim not in {2, 3}:
 
-         raise Exception(f'Unsupported image dimension number, only support 2 and 3, got {image.ndim}!')
 
-     if image.ndim == 3:
 
-         num_channels = image.shape[-1]
 
-         if num_channels not in {1, 3, 4}:
 
-             raise Exception(f'Unsupported image channel number, only support 1, 3 and 4, got {num_channels}!')
 
 
  |