@@ -37,7 +37,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
3737 const std::string &loss_var_name,
3838 const std::unordered_set<std::string> ¶ms,
3939 const std::vector<Scope *> &local_scopes,
40- platform::NCCLContextMap *nccl_ctxs, bool skip_scale_loss ,
40+ platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale ,
4141 bool use_nccl_allreduce)
4242 : loss_var_name_(loss_var_name),
4343 places_(places),
@@ -50,7 +50,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
5050 const std::vector<platform::Place> &places,
5151 const std::string &loss_var_name,
5252 const std::unordered_set<std::string> ¶ms,
53- const std::vector<Scope *> &local_scopes, bool skip_scale_loss ,
53+ const std::vector<Scope *> &local_scopes, bool use_default_grad_scale ,
5454 bool use_nccl_allreduce)
5555 : loss_var_name_ (loss_var_name),
5656 places_ (places),
@@ -60,7 +60,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
6060 for (auto &p : params) {
6161 grad_names_.insert (GradVarName (p));
6262 }
63- skip_scale_loss_ = skip_scale_loss ;
63+ use_default_grad_scale_ = use_default_grad_scale ;
6464}
6565
6666void MultiDevSSAGraphBuilder::CreateOpHandleIOs (SSAGraph *result,
@@ -141,8 +141,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
141141 } else if (IsDistTrainOp (*op, send_op)) {
142142 CreateComputationalOps (&result, *op, 1 );
143143 } else if (IsScaleLossOp (*op)) {
144- // user can customize loss@grad if skip_scale_loss_
145- if (!skip_scale_loss_ ) {
144+ // user can customize loss@grad if not use_default_grad_scale_
145+ if (use_default_grad_scale_ ) {
146146 CreateScaleLossGradOp (&result);
147147 }
148148 is_forwarding = false ;
0 commit comments