Skip to content

Commit f21842f

Browse files
author
Mogball
committed
[mlir] Optimize ThreadLocalCache by removing atomic bottleneck
The ThreadLocalCache implementation is used by the MLIRContext (among other things) to try to manage thread contention in the StorageUniquers. There is a bunch of fancy shared pointer/weak pointer setups that basically keeps everything alive across threads at the right time, but a huge bottleneck is the `weak_ptr::lock` call inside the `::get` method. This is because the `lock` method has to hit the atomic refcount several times, and this is bottlenecking performance across many threads. However, all this is doing is checking whether the storage is initialized. Importantly, when the `PerThreadInstance` goes out of scope, it does not remove all of its associated entries from the thread-local hash map (it contains dangling `PerThreadInstance *` keys). The `weak_ptr` also allows the thread local cache to synchronize with the `PerThreadInstance`'s destruction: 1. if `ThreadLocalCache` destructs, the `weak_ptr`s that reference its contained values are immediately invalidated 2. if `CacheType` destructs within a thread, any entries still live are removed from the owning `PerThreadInstance`, and it locks the `weak_ptr` first to ensure it's kept alive long enough for the removal. This PR changes the TLC entries to contain a `shared_ptr<ValueT*>` and a `weak_ptr<PerInstanceState>`. It gives the `PerInstanceState` entries a `weak_ptr<ValueT*>` on top of the `unique_ptr<ValueT>`. This enables `ThreadLocalCache::get` to check if the value is initialized by dereferencing the `shared_ptr<ValueT*>` and check if the contained pointer is null. When `PerInstanceState` destructs, the values inside the TLC are written to nullptr. The TLC uses the `weak_ptr<PerInstanceState>` to satisfy (2). (1) is no longer the case. When `ThreadLocalCache` begins destruction, the `weak_ptr<PerInstanceState>` are invalidated, but not the `shared_ptr<ValueT*>`. This is OK: because the overall object is being destroyed, `::get` cannot get called and because the `shared_ptr<PerInstanceState>` finishes destruction before freeing the pointer, it cannot get reallocated to another `ThreadLocalCache` during destruction. I.e. the values inside the TLC associated with a `PerInstanceState` cannot be read during destruction. The most important thing is to make sure destruction of the TLC doesn't race with the destructor of `PerInstanceState`. Because `PerInstanceState` carries `weak_ptr` references into the TLC, we guarantee to not have any use-after-frees.
1 parent ab7e6b6 commit f21842f

File tree

1 file changed

+76
-22
lines changed

1 file changed

+76
-22
lines changed

mlir/include/mlir/Support/ThreadLocalCache.h

+76-22
Original file line numberDiff line numberDiff line change
@@ -25,28 +25,80 @@ namespace mlir {
2525
/// cache has very large lock contention.
2626
template <typename ValueT>
2727
class ThreadLocalCache {
28+
struct PerInstanceState;
29+
30+
/// The "observer" is owned by a thread-local cache instance. It is
31+
/// constructed the first time a `ThreadLocalCache` instance is accessed by a
32+
/// thread, unless `perInstanceState` happens to get re-allocated to the same
33+
/// address as a previous one. A `thread_local` instance of this class is
34+
/// destructed when the thread in which it lives is destroyed.
35+
///
36+
/// This class is called the "observer" because while values cached in
37+
/// thread-local caches are owned by `PerInstanceState`, a reference is stored
38+
/// via this class in the TLC. With a double pointer, it knows when the
39+
/// referenced value has been destroyed.
40+
struct Observer {
41+
/// This is the double pointer, explicitly allocated because we need to keep
42+
/// the address stable if the TLC map re-allocates. It is owned by the
43+
/// observer and shared with the value owner.
44+
std::shared_ptr<ValueT *> ptr = std::make_shared<ValueT *>(nullptr);
45+
/// Because `Owner` living inside `PerInstanceState` contains a reference to
46+
/// the double pointer, and livkewise this class contains a reference to the
47+
/// value, we need to synchronize destruction of the TLC and the
48+
/// `PerInstanceState` to avoid racing. This weak pointer is acquired during
49+
/// TLC destruction if the `PerInstanceState` hasn't entered its destructor
50+
/// yet, and prevents it from happening.
51+
std::weak_ptr<PerInstanceState> keepalive;
52+
};
53+
54+
/// This struct owns the cache entries. It contains a reference back to the
55+
/// reference inside the cache so that it can be written to null to indicate
56+
/// that the cache entry is invalidated. It needs to do this because
57+
/// `perInstanceState` could get re-allocated to the same pointer and we don't
58+
/// remove entries from the TLC when it is deallocated. Thus, we have to reset
59+
/// the TLC entries to a starting state in case the `ThreadLocalCache` lives
60+
/// shorter than the threads.
61+
struct Owner {
62+
/// Save a pointer to the reference and write it to the newly created entry.
63+
Owner(Observer &observer)
64+
: value(std::make_unique<ValueT>()), ptrRef(observer.ptr) {
65+
*observer.ptr = value.get();
66+
}
67+
~Owner() {
68+
if (std::shared_ptr<ValueT *> ptr = ptrRef.lock())
69+
*ptr = nullptr;
70+
}
71+
72+
Owner(Owner &&) = default;
73+
Owner &operator=(Owner &&) = default;
74+
75+
std::unique_ptr<ValueT> value;
76+
std::weak_ptr<ValueT *> ptrRef;
77+
};
78+
2879
// Keep a separate shared_ptr protected state that can be acquired atomically
2980
// instead of using shared_ptr's for each value. This avoids a problem
3081
// where the instance shared_ptr is locked() successfully, and then the
3182
// ThreadLocalCache gets destroyed before remove() can be called successfully.
3283
struct PerInstanceState {
33-
/// Remove the given value entry. This is generally called when a thread
34-
/// local cache is destructing.
84+
/// Remove the given value entry. This is called when a thread local cache
85+
/// is destructing but still contains references to values owned by the
86+
/// `PerInstanceState`. Removal is required because it prevents writeback to
87+
/// a pointer that was deallocated.
3588
void remove(ValueT *value) {
3689
// Erase the found value directly, because it is guaranteed to be in the
3790
// list.
3891
llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
39-
auto it =
40-
llvm::find_if(instances, [&](std::unique_ptr<ValueT> &instance) {
41-
return instance.get() == value;
42-
});
92+
auto it = llvm::find_if(instances, [&](Owner &instance) {
93+
return instance.value.get() == value;
94+
});
4395
assert(it != instances.end() && "expected value to exist in cache");
4496
instances.erase(it);
4597
}
4698

4799
/// Owning pointers to all of the values that have been constructed for this
48100
/// object in the static cache.
49-
SmallVector<std::unique_ptr<ValueT>, 1> instances;
101+
SmallVector<Owner, 1> instances;
50102

51103
/// A mutex used when a new thread instance has been added to the cache for
52104
/// this object.
@@ -57,21 +109,22 @@ class ThreadLocalCache {
57109
/// instance of the non-static cache and a weak reference to an instance of
58110
/// ValueT. We use a weak reference here so that the object can be destroyed
59111
/// without needing to lock access to the cache itself.
60-
struct CacheType
61-
: public llvm::SmallDenseMap<PerInstanceState *, std::weak_ptr<ValueT>> {
112+
struct CacheType : public llvm::SmallDenseMap<PerInstanceState *, Observer> {
62113
~CacheType() {
63-
// Remove the values of this cache that haven't already expired.
64-
for (auto &it : *this)
65-
if (std::shared_ptr<ValueT> value = it.second.lock())
66-
it.first->remove(value.get());
114+
// Remove the values of this cache that haven't already expired. This is
115+
// required because if we don't remove them, they will contain a reference
116+
// back to the data here that is being destroyed.
117+
for (auto &[instance, observer] : *this)
118+
if (std::shared_ptr<PerInstanceState> state = observer.keepalive.lock())
119+
state->remove(*observer.ptr);
67120
}
68121

69122
/// Clear out any unused entries within the map. This method is not
70123
/// thread-safe, and should only be called by the same thread as the cache.
71124
void clearExpiredEntries() {
72125
for (auto it = this->begin(), e = this->end(); it != e;) {
73126
auto curIt = it++;
74-
if (curIt->second.expired())
127+
if (!*curIt->second.ptr)
75128
this->erase(curIt);
76129
}
77130
}
@@ -88,22 +141,23 @@ class ThreadLocalCache {
88141
ValueT &get() {
89142
// Check for an already existing instance for this thread.
90143
CacheType &staticCache = getStaticCache();
91-
std::weak_ptr<ValueT> &threadInstance = staticCache[perInstanceState.get()];
92-
if (std::shared_ptr<ValueT> value = threadInstance.lock())
144+
Observer &threadInstance = staticCache[perInstanceState.get()];
145+
if (ValueT *value = *threadInstance.ptr)
93146
return *value;
94147

95148
// Otherwise, create a new instance for this thread.
96-
llvm::sys::SmartScopedLock<true> threadInstanceLock(
97-
perInstanceState->instanceMutex);
98-
perInstanceState->instances.push_back(std::make_unique<ValueT>());
99-
ValueT *instance = perInstanceState->instances.back().get();
100-
threadInstance = std::shared_ptr<ValueT>(perInstanceState, instance);
149+
{
150+
llvm::sys::SmartScopedLock<true> threadInstanceLock(
151+
perInstanceState->instanceMutex);
152+
perInstanceState->instances.emplace_back(threadInstance);
153+
}
154+
threadInstance.keepalive = perInstanceState;
101155

102156
// Before returning the new instance, take the chance to clear out any used
103157
// entries in the static map. The cache is only cleared within the same
104158
// thread to remove the need to lock the cache itself.
105159
staticCache.clearExpiredEntries();
106-
return *instance;
160+
return **threadInstance.ptr;
107161
}
108162
ValueT &operator*() { return get(); }
109163
ValueT *operator->() { return &get(); }

0 commit comments

Comments
 (0)