Skip to content

Commit ea80cdb

Browse files
committed
为TensorMeta添加stop_gradient属性
1 parent 69c0518 commit ea80cdb

File tree

1 file changed

+1
-0
lines changed
  • python/paddle/distributed/auto_parallel/pipelining

1 file changed

+1
-0
lines changed

python/paddle/distributed/auto_parallel/pipelining/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def __init__(self, tensor: paddle.Tensor):
123123
self._local_shape = None
124124
self.dtype = tensor.dtype
125125
self.placements = None if not tensor.is_dist() else tensor.placements
126+
self.stop_gradient = tensor.stop_gradient
126127

127128
def __repr__(self):
128129
return f"TensorMeta(global_shape={self.shape},local_shape={self._local_shape}, dtype={self.dtype}, placements={self.placements})"

0 commit comments

Comments
 (0)