You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
76 lines
2.2 KiB
76 lines
2.2 KiB
#pragma once
|
|
|
|
#pragma once
|
|
|
|
#include <numeric>
|
|
#include <type_traits>
|
|
|
|
#include <container/stl_alias.hpp>
|
|
|
|
template <typename _Int = uint32_t>
|
|
class disjoint_set_union
|
|
{
|
|
public:
|
|
disjoint_set_union() noexcept = default;
|
|
|
|
disjoint_set_union(uint32_t size) noexcept { init(size); }
|
|
|
|
void init(uint32_t size) noexcept
|
|
{
|
|
parent.resize(size);
|
|
std::iota(parent.begin(), parent.end(), 0u);
|
|
}
|
|
|
|
template <typename _Other_Int, typename = std::enable_if_t<sizeof(_Other_Int) >= sizeof(_Int)>>
|
|
auto find(_Other_Int x) noexcept
|
|
{
|
|
while (parent[x] != x) {
|
|
parent[x] = parent[parent[x]];
|
|
x = parent[x];
|
|
}
|
|
return x;
|
|
}
|
|
|
|
template <typename _Other_Int, typename = std::enable_if_t<sizeof(_Other_Int) >= sizeof(_Int)>>
|
|
_Other_Int merge(_Other_Int x, _Other_Int y) noexcept
|
|
{
|
|
auto root_x = find(x), root_y = find(y);
|
|
return parent[root_y] = parent[root_x];
|
|
}
|
|
|
|
auto extract_disjoint_sets() noexcept -> std::pair<stl_vector_mp<stl_vector_mp<_Int>>, stl_vector_mp<_Int>>
|
|
{
|
|
std::pair<stl_vector_mp<stl_vector_mp<_Int>>, stl_vector_mp<_Int>> result{};
|
|
auto& [disjoint_sets, index_map] = result;
|
|
|
|
const auto num_disjoint_set = parent.size();
|
|
index_map.resize(num_disjoint_set, std::numeric_limits<_Int>::max());
|
|
uint32_t counter{};
|
|
// Assign each roots a unique index.
|
|
for (uint32_t i = 0; i < num_disjoint_set; i++) {
|
|
const auto root = find(i);
|
|
if (root == i) {
|
|
index_map[i] = counter;
|
|
counter++;
|
|
}
|
|
}
|
|
|
|
// Assign each element to its corresponding disjoint set.
|
|
for (uint32_t i = 0; i < num_disjoint_set; i++) {
|
|
const auto root = find(i);
|
|
assert(index_map[root] != std::numeric_limits<_Int>::max());
|
|
index_map[i] = index_map[root];
|
|
}
|
|
|
|
disjoint_sets.resize(counter);
|
|
for (uint32_t i = 0; i < num_disjoint_set; i++) {
|
|
const auto index = index_map[i];
|
|
disjoint_sets[index].emplace_back(i);
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
protected:
|
|
stl_vector_mp<_Int> parent{};
|
|
};
|