diff --git a/include/cv/aggregator.hpp b/include/cv/aggregator.hpp index 003173a9..c10f8ee2 100644 --- a/include/cv/aggregator.hpp +++ b/include/cv/aggregator.hpp @@ -53,7 +53,9 @@ class CVAggregator { // gets the record of all cv results LockPtr> getCVRecord(); - + // Calculates the clustered center of each airdroptype in matched results, + // then assigns matched_results.matched_airdrop to the results + void calculateClusters(); void updateRecords(std::vector& new_values); private: diff --git a/src/cv/aggregator.cpp b/src/cv/aggregator.cpp index 1e61b098..346a0cb1 100644 --- a/src/cv/aggregator.cpp +++ b/src/cv/aggregator.cpp @@ -4,6 +4,7 @@ #include "utilities/lockptr.hpp" #include "utilities/locks.hpp" #include "utilities/logging.hpp" +#include "cv/clustering.hpp" CVAggregator::CVAggregator(Pipeline&& p) : pipeline(std::move(p)) { this->num_worker_threads = 0; @@ -123,3 +124,35 @@ std::vector CVAggregator::popAllRuns() { this->results->runs.clear(); return out; } + +void CVAggregator::calculateClusters() { + LockPtr> results = this->getCVRecord(); + std::vector> cluster_input; + for (const auto& pair : *results.data.get()) { + GPSProtoVec cords = pair.second.coordinates(); + auto types = pair.second.target_type(); + // add all types up to the current, in order to keep list format + for (int i = 0; i < cords.size(); i++) { + GPSCoord location = cords.at(i); + int type = types.at(i); + if (cluster_input.size() < type) { + for (int j = cluster_input.size(); j < type; j++) { + cluster_input.push_back(std::vector()); + } + } + cluster_input[type].push_back(location); + } + } + Clustering clustering; + std::vector clusterCenters = clustering.FindClustersCenter(cluster_input); + + LockPtr matched = this->getMatchedResults(); + std::unordered_map matched_clusters; + for (int i = 0; i < clusterCenters.size(); i++) { + AirdropTarget airdrop; + airdrop.set_index(static_cast(i)); + airdrop.mutable_coordinate()->CopyFrom(clusterCenters[i]); + matched_clusters.insert(std::pair(static_cast(i), std::move(airdrop))); + } + matched.data->matched_airdrop = std::move(matched_clusters); +} diff --git a/src/ticks/cv_loiter.cpp b/src/ticks/cv_loiter.cpp index ccd6d4e4..f16ebe35 100644 --- a/src/ticks/cv_loiter.cpp +++ b/src/ticks/cv_loiter.cpp @@ -6,6 +6,7 @@ #include "ticks/fly_search.hpp" #include "ticks/ids.hpp" #include "utilities/constants.hpp" +#include "cv/clustering.hpp" CVLoiterTick::CVLoiterTick(std::shared_ptr state) : Tick(state, TickID::CVLoiter) { this->status = CVLoiterTick::Status::None; @@ -39,9 +40,7 @@ Tick* CVLoiterTick::tick() { AirdropType::Water, AirdropType::Beacon, }; */ - - LockPtr results = state->getCV()->getResults(); - + state->getCV()->calculateClusters(); // for (const auto& bottle : ALL_BOTTLES) { // // contains will never be false but whatever // if (results.data->matches.contains(bottle)) { diff --git a/tests/integration/clustering.cpp b/tests/integration/clustering.cpp index 0bb638af..45ef9344 100644 --- a/tests/integration/clustering.cpp +++ b/tests/integration/clustering.cpp @@ -6,7 +6,7 @@ #define POINT_DISTANCE 10 #define NUM_OUTLIERS 5 #define NUM_RUNS 5 -//redeclaring method to not have fixed seed so the test is deterministic +//redeclaring method to not have fixed seed so the test is non-deterministic unsigned int seed = time(NULL); double random(double min, double max) { return min + static_cast(rand_r(&seed)) / RAND_MAX * (max - min); @@ -86,11 +86,14 @@ int main() scatter(x, y, std::vector(), c); hold(on); auto plot = scatter(cluster_x, cluster_y); + ::matplot::legend({"Cluster points", "Calculated cluster Center"}); plot->marker_style(line_spec::marker_style::diamond); + title("Clustering run " + std::to_string(run)); std::ostringstream stringStream; stringStream << "run" << run << ".png"; std::string copyOfStr = stringStream.str(); - // std::filesystem::remove(copyOfStr); + std::cout << "saved to /build/" << copyOfStr << std::endl; + //std::filesystem::remove(copyOfStr); save(copyOfStr); hold(off);