25 #include <unordered_map> 28 #include "torch/script.h" 29 #include "torch/torch.h" 38 namespace prediction {
66 std::vector<double>* feature_values);
71 std::string
GetName()
override {
return "PEDESTRIAN_INTERACTION_EVALUATOR"; }
85 torch::Tensor GetSocialPooling();
88 std::unordered_map<int, LSTMState> obstacle_id_lstm_state_map_;
89 torch::jit::script::Module torch_position_embedding_;
90 torch::jit::script::Module torch_social_embedding_;
91 torch::jit::script::Module torch_single_lstm_;
92 torch::jit::script::Module torch_prediction_layer_;
93 torch::Device device_;
95 static const int kGridSize = 2;
96 static const int kEmbeddingSize = 64;
97 static const int kHiddenSize = 128;
Prediction obstacle.
Definition: obstacle.h:52
bool Evaluate(Obstacle *obstacle_ptr, ObstaclesContainer *obstacles_container) override
Override Evaluate.
Definition: obstacles_container.h:39
PlanningContext is the runtime context in planning. It is persistent across multiple frames...
Definition: atomic_hash_map.h:25
virtual ~PedestrianInteractionEvaluator()=default
Destructor.
bool ExtractFeatures(const Obstacle *obstacle_ptr, std::vector< double > *feature_values)
Extract features for learning model's input.
PedestrianInteractionEvaluator()
Constructor.
std::string GetName() override
Get the name of evaluator.
Definition: pedestrian_interaction_evaluator.h:71
Definition: evaluator.h:39
Definition: pedestrian_interaction_evaluator.h:40
Define the data container base class.