Apollo  6.0
Open source self driving car software
cruise_mlp_evaluator.h
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright 2018 The Apollo Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *****************************************************************************/
16 
17 #pragma once
18 
19 #include <string>
20 #include <vector>
21 
22 #include "torch/script.h"
23 #include "torch/torch.h"
24 
26 
28 
29 namespace apollo {
30 namespace prediction {
31 
32 class CruiseMLPEvaluator : public Evaluator {
33  public:
38 
42  virtual ~CruiseMLPEvaluator() = default;
43 
49  bool Evaluate(Obstacle* obstacle_ptr,
50  ObstaclesContainer* obstacles_container) override;
51 
57  void ExtractFeatureValues(Obstacle* obstacle_ptr,
58  LaneSequence* lane_sequence_ptr,
59  std::vector<double>* feature_values);
60 
64  std::string GetName() override { return "CRUISE_MLP_EVALUATOR"; }
65 
66  void Clear();
67 
68  private:
74  void SetObstacleFeatureValues(const Obstacle* obstacle_ptr,
75  std::vector<double>* feature_values);
76 
84  void SetInteractionFeatureValues(Obstacle* obstacle_ptr,
85  ObstaclesContainer* obstacles_container,
86  LaneSequence* lane_sequence_ptr,
87  std::vector<double>* feature_values);
88 
95  void SetLaneFeatureValues(const Obstacle* obstacle_ptr,
96  const LaneSequence* lane_sequence_ptr,
97  std::vector<double>* feature_values);
98 
102  void LoadModels();
103 
104  void ModelInference(const std::vector<torch::jit::IValue>& torch_inputs,
105  torch::jit::script::Module torch_model_ptr,
106  LaneSequence* lane_sequence_ptr);
107 
108  private:
109  static const size_t OBSTACLE_FEATURE_SIZE = 23 + 5 * 9;
110  static const size_t INTERACTION_FEATURE_SIZE = 8;
111  static const size_t SINGLE_LANE_FEATURE_SIZE = 4;
112  static const size_t LANE_POINTS_SIZE = 20;
113 
114  torch::jit::script::Module torch_go_model_;
115  torch::jit::script::Module torch_cutin_model_;
116  torch::Device device_;
117 };
118 
119 } // namespace prediction
120 } // namespace apollo
Prediction obstacle.
Definition: obstacle.h:52
Definition: obstacles_container.h:39
PlanningContext is the runtime context in planning. It is persistent across multiple frames...
Definition: atomic_hash_map.h:25
void ExtractFeatureValues(Obstacle *obstacle_ptr, LaneSequence *lane_sequence_ptr, std::vector< double > *feature_values)
Extract feature vector.
virtual ~CruiseMLPEvaluator()=default
Destructor.
Definition: cruise_mlp_evaluator.h:32
bool Evaluate(Obstacle *obstacle_ptr, ObstaclesContainer *obstacles_container) override
Override Evaluate.
Obstacles container.
std::string GetName() override
Get the name of evaluator.
Definition: cruise_mlp_evaluator.h:64
Definition: evaluator.h:39
Define the data container base class.