22 #include "torch/script.h" 23 #include "torch/torch.h" 29 namespace prediction {
64 std::vector<double>* feature_values);
69 std::string
GetName()
override {
return "JUNCTION_MLP_EVALUATOR"; }
77 void SetObstacleFeatureValues(
Obstacle* obstacle_ptr,
78 std::vector<double>*
const feature_values);
86 void SetEgoVehicleFeatureValues(
Obstacle* obstacle_ptr,
88 std::vector<double>*
const feature_values);
95 void SetJunctionFeatureValues(
Obstacle* obstacle_ptr,
96 std::vector<double>*
const feature_values);
105 static const size_t OBSTACLE_FEATURE_SIZE = 4 + 2 * 5;
107 static const size_t EGO_VEHICLE_FEATURE_SIZE = 4;
109 static const size_t JUNCTION_FEATURE_SIZE = 12 * 8;
111 torch::jit::script::Module torch_model_;
112 torch::Device device_;
Definition: junction_mlp_evaluator.h:31
Prediction obstacle.
Definition: obstacle.h:52
void Clear()
Clear obstacle feature map.
Definition: obstacles_container.h:39
PlanningContext is the runtime context in planning. It is persistent across multiple frames...
Definition: atomic_hash_map.h:25
bool Evaluate(Obstacle *obstacle_ptr, ObstaclesContainer *obstacles_container) override
Override Evaluate.
void ExtractFeatureValues(Obstacle *obstacle_ptr, ObstaclesContainer *obstacles_container, std::vector< double > *feature_values)
Extract feature vector.
JunctionMLPEvaluator()
Constructor.
virtual ~JunctionMLPEvaluator()=default
Destructor.
Definition: evaluator.h:39
Define the data container base class.
std::string GetName() override
Get the name of evaluator.
Definition: junction_mlp_evaluator.h:69