22 #include "torch/script.h" 23 #include "torch/torch.h" 30 namespace prediction {
58 LaneSequence* lane_sequence_ptr,
59 std::vector<double>* feature_values);
64 std::string
GetName()
override {
return "CRUISE_MLP_EVALUATOR"; }
74 void SetObstacleFeatureValues(
const Obstacle* obstacle_ptr,
75 std::vector<double>* feature_values);
84 void SetInteractionFeatureValues(
Obstacle* obstacle_ptr,
86 LaneSequence* lane_sequence_ptr,
87 std::vector<double>* feature_values);
95 void SetLaneFeatureValues(
const Obstacle* obstacle_ptr,
96 const LaneSequence* lane_sequence_ptr,
97 std::vector<double>* feature_values);
104 void ModelInference(
const std::vector<torch::jit::IValue>& torch_inputs,
105 torch::jit::script::Module torch_model_ptr,
106 LaneSequence* lane_sequence_ptr);
109 static const size_t OBSTACLE_FEATURE_SIZE = 23 + 5 * 9;
110 static const size_t INTERACTION_FEATURE_SIZE = 8;
111 static const size_t SINGLE_LANE_FEATURE_SIZE = 4;
112 static const size_t LANE_POINTS_SIZE = 20;
114 torch::jit::script::Module torch_go_model_;
115 torch::jit::script::Module torch_cutin_model_;
116 torch::Device device_;
CruiseMLPEvaluator()
Constructor.
Prediction obstacle.
Definition: obstacle.h:52
Definition: obstacles_container.h:39
PlanningContext is the runtime context in planning. It is persistent across multiple frames...
Definition: atomic_hash_map.h:25
void ExtractFeatureValues(Obstacle *obstacle_ptr, LaneSequence *lane_sequence_ptr, std::vector< double > *feature_values)
Extract feature vector.
virtual ~CruiseMLPEvaluator()=default
Destructor.
Definition: cruise_mlp_evaluator.h:32
bool Evaluate(Obstacle *obstacle_ptr, ObstaclesContainer *obstacles_container) override
Override Evaluate.
std::string GetName() override
Get the name of evaluator.
Definition: cruise_mlp_evaluator.h:64
Definition: evaluator.h:39
Define the data container base class.