diff --git a/fml/concurrent_message_loop.cc b/fml/concurrent_message_loop.cc index a116356632856..3d4ca470e41fe 100644 --- a/fml/concurrent_message_loop.cc +++ b/fml/concurrent_message_loop.cc @@ -26,6 +26,10 @@ ConcurrentMessageLoop::ConcurrentMessageLoop(size_t worker_count) WorkerMain(); }); } + + for (const auto& worker : workers_) { + worker_thread_ids_.emplace_back(worker.get_id()); + } } ConcurrentMessageLoop::~ConcurrentMessageLoop() { @@ -73,25 +77,43 @@ void ConcurrentMessageLoop::PostTask(const fml::closure& task) { void ConcurrentMessageLoop::WorkerMain() { while (true) { std::unique_lock lock(tasks_mutex_); - tasks_condition_.wait(lock, - [&]() { return tasks_.size() > 0 || shutdown_; }); + tasks_condition_.wait(lock, [&]() { + return tasks_.size() > 0 || shutdown_ || HasThreadTasksLocked(); + }); - if (tasks_.size() == 0) { - // This can only be caused by shutdown. - FML_DCHECK(shutdown_); - break; + // Shutdown cannot be read with the task mutex unlocked. + bool shutdown_now = shutdown_; + fml::closure task; + std::vector thread_tasks; + + if (tasks_.size() != 0) { + task = tasks_.front(); + tasks_.pop(); } - auto task = tasks_.front(); - tasks_.pop(); + if (HasThreadTasksLocked()) { + thread_tasks = GetThreadTasksLocked(); + FML_DCHECK(!HasThreadTasksLocked()); + } - // Don't hold onto the mutex while the task is being executed as it could - // itself try to post another tasks to this message loop. + // Don't hold onto the mutex while tasks are being executed as they could + // themselves try to post more tasks to the message loop. lock.unlock(); TRACE_EVENT0("flutter", "ConcurrentWorkerWake"); - // Execute the one tasks we woke up for. - task(); + // Execute the primary task we woke up for. + if (task) { + task(); + } + + // Execute any thread tasks. + for (const auto& thread_task : thread_tasks) { + thread_task(); + } + + if (shutdown_now) { + break; + } } } @@ -101,6 +123,31 @@ void ConcurrentMessageLoop::Terminate() { tasks_condition_.notify_all(); } +void ConcurrentMessageLoop::PostTaskToAllWorkers(fml::closure task) { + if (!task) { + return; + } + + std::scoped_lock lock(tasks_mutex_); + for (const auto& worker_thread_id : worker_thread_ids_) { + thread_tasks_[worker_thread_id].emplace_back(task); + } + tasks_condition_.notify_all(); +} + +bool ConcurrentMessageLoop::HasThreadTasksLocked() const { + return thread_tasks_.count(std::this_thread::get_id()) > 0; +} + +std::vector ConcurrentMessageLoop::GetThreadTasksLocked() { + auto found = thread_tasks_.find(std::this_thread::get_id()); + FML_DCHECK(found != thread_tasks_.end()); + std::vector pending_tasks; + std::swap(pending_tasks, found->second); + thread_tasks_.erase(found); + return pending_tasks; +} + ConcurrentTaskRunner::ConcurrentTaskRunner( std::weak_ptr weak_loop) : weak_loop_(std::move(weak_loop)) {} diff --git a/fml/concurrent_message_loop.h b/fml/concurrent_message_loop.h index a3487b6c119ee..6071f45cafa82 100644 --- a/fml/concurrent_message_loop.h +++ b/fml/concurrent_message_loop.h @@ -6,6 +6,7 @@ #define FLUTTER_FML_CONCURRENT_MESSAGE_LOOP_H_ #include +#include #include #include @@ -30,6 +31,8 @@ class ConcurrentMessageLoop void Terminate(); + void PostTaskToAllWorkers(fml::closure task); + private: friend ConcurrentTaskRunner; @@ -38,6 +41,8 @@ class ConcurrentMessageLoop std::mutex tasks_mutex_; std::condition_variable tasks_condition_; std::queue tasks_; + std::vector worker_thread_ids_; + std::map> thread_tasks_; bool shutdown_ = false; ConcurrentMessageLoop(size_t worker_count); @@ -46,6 +51,10 @@ class ConcurrentMessageLoop void PostTask(const fml::closure& task); + bool HasThreadTasksLocked() const; + + std::vector GetThreadTasksLocked(); + FML_DISALLOW_COPY_AND_ASSIGN(ConcurrentMessageLoop); }; diff --git a/runtime/dart_vm.cc b/runtime/dart_vm.cc index c4d0cf23e890c..3338b0952bd01 100644 --- a/runtime/dart_vm.cc +++ b/runtime/dart_vm.cc @@ -504,4 +504,8 @@ DartVM::GetConcurrentWorkerTaskRunner() const { return concurrent_message_loop_->GetTaskRunner(); } +std::shared_ptr DartVM::GetConcurrentMessageLoop() { + return concurrent_message_loop_; +} + } // namespace flutter diff --git a/runtime/dart_vm.h b/runtime/dart_vm.h index ac84fe77020d2..3a03f82dbb793 100644 --- a/runtime/dart_vm.h +++ b/runtime/dart_vm.h @@ -147,6 +147,19 @@ class DartVM { std::shared_ptr GetConcurrentWorkerTaskRunner() const; + //---------------------------------------------------------------------------- + /// @brief The concurrent message loop hosts threads that are used by the + /// engine to perform tasks long running background tasks. + /// Typically, to post tasks to this message loop, the + /// `GetConcurrentWorkerTaskRunner` method may be used. + /// + /// @see GetConcurrentWorkerTaskRunner + /// + /// @return The concurrent message loop used by this running Dart VM + /// instance. + /// + std::shared_ptr GetConcurrentMessageLoop(); + private: const Settings settings_; std::shared_ptr concurrent_message_loop_; diff --git a/shell/common/shell.h b/shell/common/shell.h index 39ffd6055435c..d768c0673b973 100644 --- a/shell/common/shell.h +++ b/shell/common/shell.h @@ -352,6 +352,14 @@ class Shell final : public PlatformView::Delegate, /// @brief Accessor for the disable GPU SyncSwitch std::shared_ptr GetIsGpuDisabledSyncSwitch() const; + //---------------------------------------------------------------------------- + /// @brief Get a pointer to the Dart VM used by this running shell + /// instance. + /// + /// @return The Dart VM pointer. + /// + DartVM* GetDartVM(); + private: using ServiceProtocolHandler = std::function rasterizer, std::unique_ptr io_manager); - DartVM* GetDartVM(); - void ReportTimings(); // |PlatformView::Delegate| diff --git a/shell/platform/embedder/embedder.cc b/shell/platform/embedder/embedder.cc index c638a6a8c1576..8fcd336a8ced7 100644 --- a/shell/platform/embedder/embedder.cc +++ b/shell/platform/embedder/embedder.cc @@ -1724,7 +1724,6 @@ FlutterEngineResult FlutterEnginePostDartObject( return kSuccess; } -FLUTTER_EXPORT FlutterEngineResult FlutterEngineNotifyLowMemoryWarning( FLUTTER_API_SYMBOL(FlutterEngine) raw_engine) { auto engine = reinterpret_cast(raw_engine); @@ -1747,3 +1746,27 @@ FlutterEngineResult FlutterEngineNotifyLowMemoryWarning( kInternalInconsistency, "Could not dispatch the low memory notification message."); } + +FlutterEngineResult FlutterEnginePostCallbackOnAllNativeThreads( + FLUTTER_API_SYMBOL(FlutterEngine) engine, + FlutterNativeThreadCallback callback, + void* user_data) { + if (engine == nullptr) { + return LOG_EMBEDDER_ERROR(kInvalidArguments, "Invalid engine handle."); + } + + if (callback == nullptr) { + return LOG_EMBEDDER_ERROR(kInvalidArguments, + "Invalid native thread callback."); + } + + return reinterpret_cast(engine) + ->PostTaskOnEngineManagedNativeThreads( + [callback, user_data](FlutterNativeThreadType type) { + callback(type, user_data); + }) + ? kSuccess + : LOG_EMBEDDER_ERROR(kInvalidArguments, + "Internal error while attempting to post " + "tasks to all threads."); +} diff --git a/shell/platform/embedder/embedder.h b/shell/platform/embedder/embedder.h index ebd7ec990f083..cab034129d500 100644 --- a/shell/platform/embedder/embedder.h +++ b/shell/platform/embedder/embedder.h @@ -934,6 +934,31 @@ typedef struct { }; } FlutterEngineDartObject; +/// This enum allows embedders to determine the type of the engine thread in the +/// FlutterNativeThreadCallback. Based on the thread type, the embedder may be +/// able to tweak the thread priorities for optimum performance. +typedef enum { + /// The Flutter Engine considers the thread on which the FlutterEngineRun call + /// is made to be the platform thread. There is only one such thread per + /// engine instance. + kFlutterNativeThreadTypePlatform, + /// This is the thread the Flutter Engine uses to execute rendering commands + /// based on the selected client rendering API. There is only one such thread + /// per engine instance. + kFlutterNativeThreadTypeRender, + /// This is a dedicated thread on which the root Dart isolate is serviced. + /// There is only one such thread per engine instance. + kFlutterNativeThreadTypeUI, + /// Multiple threads are used by the Flutter engine to perform long running + /// background tasks. + kFlutterNativeThreadTypeWorker, +} FlutterNativeThreadType; + +/// A callback made by the engine in response to +/// `FlutterEnginePostCallbackOnAllNativeThreads` on all internal thread. +typedef void (*FlutterNativeThreadCallback)(FlutterNativeThreadType type, + void* user_data); + typedef struct { /// The size of this struct. Must be sizeof(FlutterProjectArgs). size_t struct_size; @@ -1667,6 +1692,45 @@ FLUTTER_EXPORT FlutterEngineResult FlutterEngineNotifyLowMemoryWarning( FLUTTER_API_SYMBOL(FlutterEngine) engine); +//------------------------------------------------------------------------------ +/// @brief Schedule a callback to be run on all engine managed threads. +/// The engine will attempt to service this callback the next time +/// the message loop for each managed thread is idle. Since the +/// engine manages the entire lifecycle of multiple threads, there +/// is no opportunity for the embedders to finely tune the +/// priorities of threads directly, or, perform other thread +/// specific configuration (for example, setting thread names for +/// tracing). This callback gives embedders a chance to affect such +/// tuning. +/// +/// @attention This call is expensive and must be made as few times as +/// possible. The callback must also return immediately as not doing +/// so may risk performance issues (especially for callbacks of type +/// kFlutterNativeThreadTypeUI and kFlutterNativeThreadTypeRender). +/// +/// @attention Some callbacks (especially the ones of type +/// kFlutterNativeThreadTypeWorker) may be called after the +/// FlutterEngine instance has shut down. Embedders must be careful +/// in handling the lifecycle of objects associated with the user +/// data baton. +/// +/// @attention In case there are multiple running Flutter engine instances, +/// their workers are shared. +/// +/// @param[in] engine A running engine instance. +/// @param[in] callback The callback that will get called multiple times on +/// each engine managed thread. +/// @param[in] user_data A baton passed by the engine to the callback. This +/// baton is not interpreted by the engine in any way. +/// +/// @return Returns if the callback was successfully posted to all threads. +/// +FLUTTER_EXPORT +FlutterEngineResult FlutterEnginePostCallbackOnAllNativeThreads( + FLUTTER_API_SYMBOL(FlutterEngine) engine, + FlutterNativeThreadCallback callback, + void* user_data); + #if defined(__cplusplus) } // extern "C" #endif diff --git a/shell/platform/embedder/embedder_engine.cc b/shell/platform/embedder/embedder_engine.cc index 3b260834ea33c..23202052f2096 100644 --- a/shell/platform/embedder/embedder_engine.cc +++ b/shell/platform/embedder/embedder_engine.cc @@ -247,7 +247,34 @@ bool EmbedderEngine::RunTask(const FlutterTask* task) { task->task); } -const Shell& EmbedderEngine::GetShell() const { +bool EmbedderEngine::PostTaskOnEngineManagedNativeThreads( + std::function closure) const { + if (!IsValid() || closure == nullptr) { + return false; + } + + const auto trampoline = [closure](FlutterNativeThreadType type, + fml::RefPtr runner) { + runner->PostTask([closure, type] { closure(type); }); + }; + + // Post the task to all thread host threads. + const auto& task_runners = shell_->GetTaskRunners(); + trampoline(kFlutterNativeThreadTypeRender, task_runners.GetGPUTaskRunner()); + trampoline(kFlutterNativeThreadTypeWorker, task_runners.GetIOTaskRunner()); + trampoline(kFlutterNativeThreadTypeUI, task_runners.GetUITaskRunner()); + trampoline(kFlutterNativeThreadTypePlatform, + task_runners.GetPlatformTaskRunner()); + + // Post the task to all worker threads. + auto vm = shell_->GetDartVM(); + vm->GetConcurrentMessageLoop()->PostTaskToAllWorkers( + [closure]() { closure(kFlutterNativeThreadTypeWorker); }); + + return true; +} + +Shell& EmbedderEngine::GetShell() { FML_DCHECK(shell_); return *shell_.get(); } diff --git a/shell/platform/embedder/embedder_engine.h b/shell/platform/embedder/embedder_engine.h index 16525d9dc14fd..124f8af0cf9aa 100644 --- a/shell/platform/embedder/embedder_engine.h +++ b/shell/platform/embedder/embedder_engine.h @@ -80,7 +80,10 @@ class EmbedderEngine { bool RunTask(const FlutterTask* task); - const Shell& GetShell() const; + bool PostTaskOnEngineManagedNativeThreads( + std::function closure) const; + + Shell& GetShell(); private: const std::unique_ptr thread_host_; diff --git a/shell/platform/embedder/tests/embedder_unittests.cc b/shell/platform/embedder/tests/embedder_unittests.cc index cd0b536c2acc4..d63685c65eec2 100644 --- a/shell/platform/embedder/tests/embedder_unittests.cc +++ b/shell/platform/embedder/tests/embedder_unittests.cc @@ -3904,5 +3904,137 @@ TEST_F(EmbedderTest, CanSendLowMemoryNotification) { ASSERT_EQ(FlutterEngineNotifyLowMemoryWarning(engine.get()), kSuccess); } +TEST_F(EmbedderTest, CanPostTaskToAllNativeThreads) { + UniqueEngine engine; + size_t worker_count = 0; + fml::AutoResetWaitableEvent sync_latch; + + // One of the threads that the callback will be posted to is the platform + // thread. So we cannot wait for assertions to complete on the platform + // thread. Create a new thread to manage the engine instance and wait for + // assertions on the test thread. + auto platform_task_runner = CreateNewThread("platform_thread"); + + platform_task_runner->PostTask([&]() { + auto& context = GetEmbedderContext(); + + EmbedderConfigBuilder builder(context); + builder.SetSoftwareRendererConfig(); + + engine = builder.LaunchEngine(); + + ASSERT_TRUE(engine.is_valid()); + + worker_count = ToEmbedderEngine(engine.get()) + ->GetShell() + .GetDartVM() + ->GetConcurrentMessageLoop() + ->GetWorkerCount(); + + sync_latch.Signal(); + }); + + sync_latch.Wait(); + + ASSERT_GT(worker_count, 4u /* three base threads plus workers */); + const auto engine_threads_count = worker_count + 4u; + + struct Captures { + // Waits the adequate number of callbacks to fire. + fml::CountDownLatch latch; + + // Ensures that the expect number of distinct threads were serviced. + std::set thread_ids; + + size_t platform_threads_count = 0; + size_t render_threads_count = 0; + size_t ui_threads_count = 0; + size_t worker_threads_count = 0; + + Captures(size_t count) : latch(count) {} + }; + + Captures captures(engine_threads_count); + + platform_task_runner->PostTask([&]() { + ASSERT_EQ(FlutterEnginePostCallbackOnAllNativeThreads( + engine.get(), + [](FlutterNativeThreadType type, void* baton) { + auto captures = reinterpret_cast(baton); + + switch (type) { + case kFlutterNativeThreadTypeRender: + captures->render_threads_count++; + break; + case kFlutterNativeThreadTypeWorker: + captures->worker_threads_count++; + break; + case kFlutterNativeThreadTypeUI: + captures->ui_threads_count++; + break; + case kFlutterNativeThreadTypePlatform: + captures->platform_threads_count++; + break; + } + + captures->thread_ids.insert(std::this_thread::get_id()); + captures->latch.CountDown(); + }, + &captures), + kSuccess); + }); + + captures.latch.Wait(); + ASSERT_EQ(captures.thread_ids.size(), engine_threads_count); + ASSERT_EQ(captures.platform_threads_count, 1u); + ASSERT_EQ(captures.render_threads_count, 1u); + ASSERT_EQ(captures.ui_threads_count, 1u); + ASSERT_EQ(captures.worker_threads_count, worker_count + 1u /* for IO */); + + platform_task_runner->PostTask([&]() { + engine.reset(); + sync_latch.Signal(); + }); + sync_latch.Wait(); + + // The engine should have already been destroyed on the platform task runner. + ASSERT_FALSE(engine.is_valid()); +} + +TEST_F(EmbedderTest, CanPostTaskToAllNativeThreadsRecursively) { + EmbedderConfigBuilder builder(GetEmbedderContext()); + + builder.SetSoftwareRendererConfig(); + + static std::mutex engine_mutex; + static UniqueEngine engine; + static fml::AutoResetWaitableEvent event; + + std::unique_lock engine_lock(engine_mutex); + engine.reset(); + engine = builder.LaunchEngine(); + ASSERT_TRUE(engine.is_valid()); + ASSERT_EQ(FlutterEnginePostCallbackOnAllNativeThreads( + engine.get(), + [](FlutterNativeThreadType type, void* baton) { + // This should deadlock if the task mutex acquisition is + // busted. + std::scoped_lock engine_lock_inner(engine_mutex); + if (engine.is_valid()) { + ASSERT_EQ(FlutterEnginePostCallbackOnAllNativeThreads( + engine.get(), + [](FlutterNativeThreadType type, + void* baton) { event.Signal(); }, + nullptr), + kSuccess); + } + }, + &engine), + kSuccess); + engine_lock.unlock(); + event.Wait(); + engine.reset(); +} + } // namespace testing } // namespace flutter