-
Notifications
You must be signed in to change notification settings - Fork 5.9k
add memory switch mechanism in operator kernel switch #6991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
c901a18
7c3ad74
2f37231
b8ed9be
74f0f7f
41e11cc
aea3de0
f1d5fc4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -411,7 +411,34 @@ void OperatorWithKernel::Run(const Scope& scope, | |
| expected_kernel_key); | ||
| } | ||
|
|
||
| kernel_iter->second->Compute(ctx); | ||
| if (actual_kernel_key == expected_kernel_key) { | ||
| kernel_iter->second->Compute(ctx); | ||
| } else { | ||
| Scope& op_scope = scope.NewScope(); | ||
| auto input_vars = this->InputVars(); | ||
| for (auto var_name : input_vars) { | ||
| op_scope.Var(var_name); | ||
| } | ||
|
|
||
| // TODO(qijun) get appropriate DeviceContext from DeviceContext pool | ||
| platform::DeviceContext* trans_dev_ctx = nullptr; | ||
|
|
||
| // TODO(qijun) get appropriate DataTransformFn from global map | ||
|
||
| using DataTransformFn = std::function<void( | ||
| const Variable& in, Variable* out, platform::DeviceContext* ctx)>; | ||
|
||
| DataTransformFn trans_fun = nullptr; | ||
|
|
||
| for (auto var_name : input_vars) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here has a problem that maybe not all the input vars need to be transformed
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have not think out an elegant solution yet. Maybe we can make some hard codes before make data transform, just like |
||
| trans_fun(*(scope.FindVar(var_name)), op_scope.FindVar(var_name), | ||
| trans_dev_ctx); | ||
| } | ||
| // Wait for data transform finishing | ||
| trans_dev_ctx->Wait(); | ||
|
|
||
| // Create a new ExecutionContext | ||
| ExecutionContext op_ctx(*this, op_scope, *dev_ctx); | ||
| kernel_iter->second->Compute(op_ctx); | ||
| } | ||
| } | ||
|
|
||
| OpKernelType OperatorWithKernel::GetActualKernelType( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@reyoung @dzhwinter @jacquesqiao I find that we can not cache the transformed result variables in current scope in order to reduce the transform times. Following is an example:
The output of op1 is the input of op2 and op3.
If we make cache in current scope,
In the first batch training:
op2 runs first and creates a new variable (var_name + KernelType) and make data transform.
Then, op3 will check if this variable has been created or not. Since this new variable has been created by op2, op3 will directly use it and has no need to make data transform.
In the second batch training:
We have to make data transform again. But we still check if the new variable is created, the data transform will be skipped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I check the Executor, in every batch, the local scope will be deleted. So this problem will not happen. I will change the cache to local scope instead of op scope.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since each batch will create a new local_scope, add a cache seems can work for our framework.