Apollo  6.0
Open source self driving car software
pedestrian_interaction_evaluator.h
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright 2019 The Apollo Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *****************************************************************************/
16 
22 #pragma once
23 
24 #include <string>
25 #include <unordered_map>
26 #include <vector>
27 
28 #include "torch/script.h"
29 #include "torch/torch.h"
30 
32 
37 namespace apollo {
38 namespace prediction {
39 
41  public:
46 
50  virtual ~PedestrianInteractionEvaluator() = default;
51 
57  bool Evaluate(Obstacle* obstacle_ptr,
58  ObstaclesContainer* obstacles_container) override;
59 
65  bool ExtractFeatures(const Obstacle* obstacle_ptr,
66  std::vector<double>* feature_values);
67 
71  std::string GetName() override { return "PEDESTRIAN_INTERACTION_EVALUATOR"; }
72 
73  private:
74  struct LSTMState {
75  double timestamp;
76  torch::Tensor ct;
77  torch::Tensor ht;
78  int frame_count = 0;
79  };
80 
81  // void Clear();
82 
83  void LoadModel();
84 
85  torch::Tensor GetSocialPooling();
86 
87  private:
88  std::unordered_map<int, LSTMState> obstacle_id_lstm_state_map_;
89  torch::jit::script::Module torch_position_embedding_;
90  torch::jit::script::Module torch_social_embedding_;
91  torch::jit::script::Module torch_single_lstm_;
92  torch::jit::script::Module torch_prediction_layer_;
93  torch::Device device_;
94 
95  static const int kGridSize = 2;
96  static const int kEmbeddingSize = 64;
97  static const int kHiddenSize = 128;
98 };
99 
100 } // namespace prediction
101 } // namespace apollo
Prediction obstacle.
Definition: obstacle.h:52
bool Evaluate(Obstacle *obstacle_ptr, ObstaclesContainer *obstacles_container) override
Override Evaluate.
Definition: obstacles_container.h:39
PlanningContext is the runtime context in planning. It is persistent across multiple frames...
Definition: atomic_hash_map.h:25
virtual ~PedestrianInteractionEvaluator()=default
Destructor.
bool ExtractFeatures(const Obstacle *obstacle_ptr, std::vector< double > *feature_values)
Extract features for learning model&#39;s input.
std::string GetName() override
Get the name of evaluator.
Definition: pedestrian_interaction_evaluator.h:71
Definition: evaluator.h:39
Definition: pedestrian_interaction_evaluator.h:40
Define the data container base class.