Apollo  6.0
Open source self driving car software
onnx_obstacle_detector.h
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright 2020 The Apollo Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *****************************************************************************/
16 
17 #pragma once
18 
19 #include <iostream>
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "NvInfer.h"
27 #include "NvOnnxParser.h"
28 #include "NvInferVersion.h"
29 
32 
33 namespace apollo {
34 namespace perception {
35 namespace inference {
36 
37 using BlobPtr = std::shared_ptr<apollo::perception::base::Blob<float>>;
38 
39 // Logger for TensorRT info/warning/errors
40 class Logger : public nvinfer1::ILogger {
41  public:
42  explicit Logger(Severity severity = Severity::kWARNING)
43  : reportable_severity(severity) {}
44 
45  void log(Severity severity, const char* msg) override {
46  // suppress messages with severity enum value greater than the reportable
47  if (severity > reportable_severity) return;
48 
49  switch (severity) {
50  case Severity::kINTERNAL_ERROR:
51  std::cerr << "INTERNAL_ERROR: ";
52  break;
53  case Severity::kERROR:
54  std::cerr << "ERROR: ";
55  break;
56  case Severity::kWARNING:
57  std::cerr << "WARNING: ";
58  break;
59  case Severity::kINFO:
60  std::cerr << "INFO: ";
61  break;
62  default:
63  std::cerr << "UNKNOWN: ";
64  break;
65  }
66  std::cerr << msg << std::endl;
67  }
68 
70 };
71 
73  public:
74  OnnxObstacleDetector(const std::string &model_file,
75  const float score_threshold,
76  const std::vector<std::string> &outputs,
77  const std::vector<std::string> &inputs);
78 
79  OnnxObstacleDetector(const std::string &model_file,
80  const std::vector<std::string> &outputs,
81  const std::vector<std::string> &inputs);
82 
83  virtual ~OnnxObstacleDetector();
84 
91  void OnnxToTRTModel(const std::string& model_file,
92  nvinfer1::ICudaEngine** engine_ptr);
93 
94  void inference();
95 
96  bool Init(const std::map<std::string, std::vector<int>> &shapes) override;
97  void Infer() override;
98  BlobPtr get_blob(const std::string &name) override;
99 
100  private:
101  std::string model_file_;
102  float score_threshold_;
103  std::vector<std::string> output_names_;
104  std::vector<std::string> input_names_;
105  BlobMap blobs_;
106  nvinfer1::ICudaEngine* engine_;
107  nvinfer1::IExecutionContext* context_;
108  Logger g_logger_;
109 
110  int num_classes_;
111  int kBatchSize;
112 };
113 
114 } // namespace inference
115 } // namespace perception
116 } // namespace apollo
Severity reportable_severity
Definition: onnx_obstacle_detector.h:69
PlanningContext is the runtime context in planning. It is persistent across multiple frames...
Definition: atomic_hash_map.h:25
std::map< std::string, std::shared_ptr< apollo::perception::base::Blob< float > > > BlobMap
Definition: inference.h:34
Definition: onnx_obstacle_detector.h:40
Logger(Severity severity=Severity::kWARNING)
Definition: onnx_obstacle_detector.h:42
void log(Severity severity, const char *msg) override
Definition: onnx_obstacle_detector.h:45
std::shared_ptr< apollo::perception::base::Blob< float > > BlobPtr
Definition: caffe_net.h:33
bool Init(const char *binary_name)
Definition: onnx_obstacle_detector.h:72