Apollo  6.0
Open source self driving car software
libtorch_obstacle_detector.h
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright 2020 The Apollo Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *****************************************************************************/
16 
17 #pragma once
18 
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 #include <torch/script.h>
25 #include <torch/torch.h>
26 
28 
29 namespace apollo {
30 namespace perception {
31 namespace inference {
32 
33 using BlobPtr = std::shared_ptr<apollo::perception::base::Blob<float>>;
34 
35 class ObstacleDetector : public Inference {
36  public:
37  ObstacleDetector(const std::string &net_file, const std::string &model_file,
38  const std::vector<std::string> &outputs);
39 
40  ObstacleDetector(const std::string &net_file, const std::string &model_file,
41  const std::vector<std::string> &outputs,
42  const std::vector<std::string> &inputs);
43 
44  virtual ~ObstacleDetector() {}
45 
46  bool Init(const std::map<std::string, std::vector<int>> &shapes) override;
47 
48  void Infer() override;
49  BlobPtr get_blob(const std::string &name) override;
50 
51  protected:
52  bool shape(const std::string &name, std::vector<int> *res);
53  torch::jit::script::Module net_;
54 
55  private:
56  std::string net_file_;
57  std::string model_file_;
58  std::vector<std::string> output_names_;
59  std::vector<std::string> input_names_;
60  BlobMap blobs_;
61 
62  torch::DeviceType device_type_;
63  int device_id_ = 0;
64 };
65 
66 } // namespace inference
67 } // namespace perception
68 } // namespace apollo
Definition: libtorch_obstacle_detector.h:35
PlanningContext is the runtime context in planning. It is persistent across multiple frames...
Definition: atomic_hash_map.h:25
std::map< std::string, std::shared_ptr< apollo::perception::base::Blob< float > > > BlobMap
Definition: inference.h:34
BlobPtr get_blob(const std::string &name) override
ObstacleDetector(const std::string &net_file, const std::string &model_file, const std::vector< std::string > &outputs)
std::shared_ptr< apollo::perception::base::Blob< float > > BlobPtr
Definition: caffe_net.h:33
bool Init(const std::map< std::string, std::vector< int >> &shapes) override
torch::jit::script::Module net_
Definition: libtorch_obstacle_detector.h:53
bool shape(const std::string &name, std::vector< int > *res)
virtual ~ObstacleDetector()
Definition: libtorch_obstacle_detector.h:44