Apollo  6.0
Open source self driving car software
trajectory_imitation_libtorch_inference.h
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright 2020 The Apollo Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *****************************************************************************/
16 
22 #pragma once
23 
24 #include <string>
25 
26 #include "torch/extension.h"
27 #include "torch/script.h"
28 
30 
31 namespace apollo {
32 namespace planning {
33 
35  public:
40  const LearningModelInferenceTaskConfig& config);
41 
45  virtual ~TrajectoryImitationLibtorchInference() = default;
46 
50  std::string GetName() override { return "TRAJECTORY_IMITATION_INFERENCE"; };
51 
55  bool LoadModel() override;
56 
61  bool DoInference(LearningDataFrame* const learning_data_frame) override;
62 
63  private:
67  bool LoadCNNModel();
68 
72  bool LoadCNNLSTMModel();
73 
78  bool DoCNNMODELInference(LearningDataFrame* const learning_data_frame);
79 
84  bool DoCNNLSTMMODELInference(LearningDataFrame* const learning_data_frame);
85 
89  void output_postprocessing(const at::Tensor& torch_output_tensor,
90  LearningDataFrame* const learning_data_frame);
91 
92  torch::jit::script::Module model_;
93  torch::Device device_;
94 };
95 
96 } // namespace planning
97 } // namespace apollo
PlanningContext is the runtime context in planning. It is persistent across multiple frames...
Definition: atomic_hash_map.h:25
TrajectoryImitationLibtorchInference(const LearningModelInferenceTaskConfig &config)
Constructor.
Definition: trajectory_imitation_libtorch_inference.h:34
Definition: model_inference.h:32
Planning module main class. It processes GPS and IMU as input, to generate planning info...
Define the model inference base class.
std::string GetName() override
Get the name of model inference.
Definition: trajectory_imitation_libtorch_inference.h:50
virtual ~TrajectoryImitationLibtorchInference()=default
Destructor.
bool LoadModel() override
load a learned model
bool DoInference(LearningDataFrame *const learning_data_frame) override
inference a learned model