diff --git a/shared_module/container/hashed_refcount_hive.hpp b/shared_module/container/hashed_refcount_hive.hpp new file mode 100644 index 0000000..7e769d5 --- /dev/null +++ b/shared_module/container/hashed_refcount_hive.hpp @@ -0,0 +1,132 @@ +#pragma once + +#include "hive.hpp" +#include "hashmap.hpp" + +namespace detail +{ +template +struct hasher { + size_t operator()(const K& k) const; +}; + +// template +// struct eq_compare +// { +// bool operator()(const K& lhs, const K& rhs) const; +// }; + +template +struct default_elem_ctor { + V operator()(const K& k) const; +}; + +template > +struct hashed_refcount_hive { + using iterator = typename hive::iterator; + +public: + std::pair acquire(const K& key) + { + size_t hash_key = hasher()(key); + auto iter = refcount_data_map.find(hash_key); + if (iter == refcount_data_map.end()) { + auto data_iter = data.emplace(default_elem_ctor{}(key)); + refcount_data_map.emplace(hash_key, std::make_pair(1, data_iter)); + return {data_iter, true}; + } else { + iter->second.first += 1; + return {iter->second.second, false}; + } + } + + void release(const K& key) + { + size_t hash_key = hasher()(key); + auto iter = refcount_data_map.find(hash_key); + if (iter == refcount_data_map.end()) throw std::runtime_error("Key not found in refcount map."); + if (--iter->second.first == 0) { + data.erase(iter->second.second); + refcount_data_map.erase(iter); + } + } + + void release(const iterator& ptr) + { + for (auto iter = refcount_data_map.begin(); iter != refcount_data_map.end(); ++iter) { + if (iter->second.second == ptr) { + if (--iter->second.first == 0) { + data.erase(iter->second.second); + refcount_data_map.erase(iter); + } + return; + } + } + throw std::runtime_error("Pointer not found in refcount map."); + } + +protected: + hive data{}; + // manually calculate hash for the key, so that we do not need to store the key + flat_hash_map> refcount_data_map{}; +}; + +template +struct tagged_hasher; + +template > +struct tagged_hashed_refcount_hive { + using iterator = typename hive::iterator; + +public: + std::pair acquire(const K& key) + { + size_t hash_key = tagged_hasher()(key); + auto iter = refcount_data_map.find(hash_key); + if (iter == refcount_data_map.end()) { + auto data_iter = data.emplace(default_elem_ctor{}(key)); + refcount_data_map.emplace(hash_key, std::make_pair(1, data_iter)); + return {data_iter, true}; + } else { + iter->second.first += 1; + return {iter->second.second, false}; + } + } + + void release(const K& key) + { + size_t hash_key = tagged_hasher()(key); + auto iter = refcount_data_map.find(hash_key); + if (iter == refcount_data_map.end()) throw std::runtime_error("Key not found in refcount map."); + if (--iter->second.first == 0) { + data.erase(iter->second.second); + refcount_data_map.erase(iter); + } + } + + void release(const iterator& ptr) + { + for (auto iter = refcount_data_map.begin(); iter != refcount_data_map.end(); ++iter) { + if (iter->second.second == ptr) { + if (--iter->second.first == 0) { + data.erase(iter->second.second); + refcount_data_map.erase(iter); + } + return; + } + } + throw std::runtime_error("Pointer not found in refcount map."); + } + +protected: + hive data{}; + // manually calculate hash for the key, so that we do not need to store the key + flat_hash_map> refcount_data_map{}; +}; +} // namespace detail + +template +using hashed_refcount_hive_mp = detail::hashed_refcount_hive>; + +template +using tagged_hashed_refcount_hive_mp = detail::tagged_hashed_refcount_hive>; \ No newline at end of file diff --git a/shared_module/container/wrapper/object_map.hpp b/shared_module/container/wrapper/object_map.hpp index 66333ea..bf65df0 100644 --- a/shared_module/container/wrapper/object_map.hpp +++ b/shared_module/container/wrapper/object_map.hpp @@ -1,4 +1,5 @@ #include +#include #include namespace detail @@ -32,13 +33,10 @@ struct flat_object_map { auto [iter, is_new] = data.try_emplace(key, object_with_refcount{}); auto& value = iter->second; if (is_new) { - value.refcount = 1; - V* new_value = static_cast(mi_aligned_alloc(alignof(V), sizeof(V))); - value.object_pointer = make_pointer_wrapper(new_value); - if constexpr (std::is_same_v) - *value.object_pointer = key; - else - *value.object_pointer = std::move(default_elem_ctor{}(key)); + value.refcount = 1; + V* new_value = static_cast(mi_aligned_alloc(alignof(V), sizeof(V))); + value.object_pointer = make_pointer_wrapper(new_value); + *value.object_pointer = std::move(default_elem_ctor{}(key)); return {value.object_pointer, true}; } else { value.refcount++; @@ -61,7 +59,54 @@ protected: value_pointer_t object_pointer{}; }; - flat_hash_map, Allocator>> data{}; + flat_hash_map, Allocator>> data{}; +}; + +template typename Allocator> +struct flat_object_map, Allocator> { + using value_pointer_t = pointer_wrapper; + + std::pair acquire(const K& key) + { + size_t hash_key = Hasher()(key); + auto [iter, is_new] = refcount_data_map.try_emplace(hash_key, object_with_refcount{}); + auto& [refcount, object_iter] = iter->second; + if (is_new) { + refcount = 1; + object_iter = data.emplace(key); + return {make_pointer_wrapper(object_iter.operator->()), true}; + } else { + refcount++; + return {make_pointer_wrapper(object_iter.operator->()), false}; + } + } + + void release(const K& key) + { + size_t hash_key = Hasher()(key); + if (auto iter = refcount_data_map.find(hash_key); iter == refcount_data_map.end()) { + throw std::runtime_error("Key not found in refcount map."); + } else if (--iter->second.refcount == 0) { + data.erase(iter->second.object_iter); + refcount_data_map.erase(iter); + } + } + +protected: + hive> data{}; + using iterator = typename decltype(data)::iterator; + + struct object_with_refcount { + size_t refcount{}; + iterator object_iter{}; + }; + + flat_hash_map, + eq_compare, + Allocator>> + refcount_data_map{}; }; } // namespace detail