Skip to content

Commit e198af8

Browse files
committed
add no_grad
1 parent ed475dd commit e198af8

File tree

3 files changed

+300
-0
lines changed

3 files changed

+300
-0
lines changed

python/paddle/jit/sot/opcode_translator/executor/variables/basic.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import operator
1919
import sys
2020
import types
21+
from contextlib import AbstractContextManager
2122
from enum import Enum
2223
from functools import cached_property, reduce
2324
from typing import TYPE_CHECKING, Any
@@ -2572,3 +2573,77 @@ def from_value(value: object, graph: FunctionGraph, tracker: Tracker):
25722573
class_var.tracker = GetAttrTracker(var, "__class__")
25732574
return var
25742575
return None
2576+
2577+
2578+
class ContextManagerVariable(VariableBase):
2579+
def __init__(
2580+
self,
2581+
value: AbstractContextManager,
2582+
graph: FunctionGraph,
2583+
tracker: Tracker,
2584+
):
2585+
super().__init__(graph=graph, tracker=tracker)
2586+
self.value = value
2587+
2588+
def get_py_type(self):
2589+
return self.value.__class__
2590+
2591+
def get_py_value(self, allow_tensor=False):
2592+
return self.value
2593+
2594+
@VariableFactory.register_from_value()
2595+
def from_value(value: object, graph: FunctionGraph, tracker: Tracker):
2596+
if isinstance(value, AbstractContextManager):
2597+
var = ContextManagerVariable(
2598+
value,
2599+
graph=graph,
2600+
tracker=tracker,
2601+
)
2602+
return var
2603+
return None
2604+
2605+
2606+
class NoGradContextManagerVariable(ContextManagerVariable):
2607+
def getattr(self, name: str, default=None):
2608+
from .callable import InternalFunctionVariable, MethodVariable
2609+
2610+
if name == "__enter__":
2611+
from ..function_graph import NoGradContext
2612+
2613+
def no_grap_enter(self):
2614+
self.graph.sir_builder._current_statement_ctxs.append(
2615+
NoGradContext()
2616+
)
2617+
return self
2618+
2619+
variable = InternalFunctionVariable(
2620+
no_grap_enter, self.graph, DummyTracker([])
2621+
)
2622+
return MethodVariable(
2623+
self, variable, self.graph, GetAttrTracker(self, "__enter__")
2624+
)
2625+
2626+
if name == "__exit__":
2627+
2628+
def no_grad_exit(self, exc, value, traceback):
2629+
self.graph.sir_builder._current_statement_ctxs.pop()
2630+
return ConstantVariable.wrap_literal(False, graph=self.graph)
2631+
2632+
variable = InternalFunctionVariable(
2633+
no_grad_exit, self.graph, DummyTracker([])
2634+
)
2635+
return MethodVariable(
2636+
self, variable, self.graph, GetAttrTracker(self, "__exit__")
2637+
)
2638+
return super().getattr(name, default)
2639+
2640+
@VariableFactory.register_from_value(successor="ContextManagerVariable")
2641+
def from_value(value: object, graph: FunctionGraph, tracker: Tracker):
2642+
if isinstance(value, paddle.no_grad):
2643+
var = NoGradContextManagerVariable(
2644+
value,
2645+
graph=graph,
2646+
tracker=tracker,
2647+
)
2648+
return var
2649+
return None

python/paddle/jit/sot/opcode_translator/executor/variables/callable.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,22 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
11911191
return None
11921192

11931193

1194+
class InternalFunctionVariable(FunctionVariable):
1195+
# Handles internal implementations specific for SOT, such as passing hook functions.
1196+
# Refer to NoGradContextManagerVariable.getattr for usage examples.
1197+
def __init__(
1198+
self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker
1199+
):
1200+
super().__init__(fn, graph, tracker)
1201+
self.fn = fn
1202+
1203+
def call_function(self, /, *args, **kwargs):
1204+
return self.fn(*args, **kwargs)
1205+
1206+
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
1207+
raise NotImplementedError
1208+
1209+
11941210
class ClassVariable(CallableVariable):
11951211
def __init__(self, class_: type, graph: FunctionGraph, tracker: Tracker):
11961212
super().__init__(graph, tracker)

test/sot/test_25_with.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# Copyright (c) 2025 paddlepaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from contextlib import contextmanager
17+
18+
from test_case_base import (
19+
TestCaseBase,
20+
)
21+
22+
import paddle
23+
from paddle import nn
24+
from paddle.jit.sot.opcode_translator.executor.opcode_executor import (
25+
ALREADY_SUPPORTED_EXCEPTION,
26+
)
27+
from paddle.jit.sot.psdb import check_no_breakgraph
28+
from paddle.jit.sot.utils import strict_mode_guard
29+
30+
31+
class Manager:
32+
def __init__(self):
33+
pass
34+
35+
def __enter__(self):
36+
return self
37+
38+
def __exit__(self, exc, value, traceback):
39+
pass
40+
41+
42+
class ManagerExitReturnFalse(Manager):
43+
def __exit__(self, *args):
44+
return False
45+
46+
47+
class ManagerExitReturnTrue(Manager):
48+
def __exit__(self, *args):
49+
return True
50+
51+
52+
TEST_WITH_STATEMENT_FLAG = False
53+
54+
55+
@contextmanager
56+
def my_context():
57+
global TEST_WITH_STATEMENT_FLAG
58+
try:
59+
TEST_WITH_STATEMENT_FLAG = True
60+
yield
61+
finally:
62+
TEST_WITH_STATEMENT_FLAG = False
63+
64+
65+
@check_no_breakgraph
66+
def with_manager_normal(x):
67+
with Manager() as mgr:
68+
x *= 2
69+
return x
70+
71+
72+
@check_no_breakgraph
73+
def with_manager_exit_true_raise_error(x):
74+
with ManagerExitReturnTrue() as mgr_true:
75+
x *= 3
76+
raise ValueError("test")
77+
x -= 4
78+
return x
79+
80+
81+
@check_no_breakgraph
82+
def with_manager_exit_true_zero_division(x):
83+
with ManagerExitReturnTrue() as mgr_true:
84+
x += 5
85+
# TODO(DrRyanHuang): Division by zero (x / 0) will raise an InnerError.
86+
# In the future, the actual Exception should be propagated rather than being wrapped as InnerError.
87+
1 / 0 # noqa: B018
88+
x *= 6
89+
return x
90+
91+
92+
@check_no_breakgraph
93+
def with_contextmanager_flag_behavior(x):
94+
global TEST_WITH_STATEMENT_FLAG
95+
with my_context():
96+
if TEST_WITH_STATEMENT_FLAG:
97+
x /= 7
98+
else:
99+
x *= 7
100+
101+
if not TEST_WITH_STATEMENT_FLAG:
102+
x += 8
103+
return x
104+
105+
106+
@check_no_breakgraph
107+
def with_manager_exit_false(x):
108+
try:
109+
with ManagerExitReturnFalse() as mgr_false:
110+
x *= 4
111+
1 / 0 # noqa: B018
112+
except ZeroDivisionError:
113+
x /= 4
114+
return x
115+
116+
117+
# TODO(DrRyanHuang): NoGradContextManagerVariable and UserDefinedContextManagerVariable will be implemented separately in the future.
118+
# The @strict_mode_guard decorator will be removed here to ensure that fallback is no longer permitted.
119+
@strict_mode_guard(False)
120+
def test_no_grad_behavior():
121+
x = paddle.rand([1, 2])
122+
p = paddle.rand([1, 2])
123+
p.stop_gradient = False
124+
x.stop_gradient = True
125+
with paddle.no_grad():
126+
y = (x * p).sum()
127+
y.backward()
128+
return x.grad, p.grad
129+
130+
131+
class TestWithStatement(TestCaseBase):
132+
def test_manager_normal(self):
133+
t = paddle.to_tensor(-10.0)
134+
self.assert_results(with_manager_normal, t)
135+
136+
def test_manager_exit_true_suppresses(self):
137+
t = paddle.to_tensor(-10.0)
138+
self.assert_results(with_manager_exit_true_raise_error, t)
139+
140+
def test_manager_exit_true_zero_division(self):
141+
t = paddle.to_tensor(-10.0)
142+
self.assert_results(with_manager_exit_true_zero_division, t)
143+
144+
def test_my_context_flag_behavior(self):
145+
t = paddle.to_tensor(-10.0)
146+
self.assert_results(with_contextmanager_flag_behavior, t)
147+
148+
def test_with_manager_exit_false(self):
149+
t = paddle.to_tensor(-10.0)
150+
self.assert_results(with_manager_exit_false, t)
151+
152+
def test_no_grad(self):
153+
self.assert_results(test_no_grad_behavior)
154+
155+
156+
class SimpleNet(nn.Layer):
157+
def __init__(self, input_dim, output_dim):
158+
super().__init__()
159+
self.layer = nn.Linear(input_dim, output_dim)
160+
161+
def forward(self, x):
162+
with paddle.static.amp.fp16_guard():
163+
return self._forward(x)
164+
165+
def _forward(self, x):
166+
return self.layer(x)
167+
168+
169+
@check_no_breakgraph
170+
def net_call(x: paddle.Tensor, net: nn.Layer):
171+
return net(x)
172+
173+
174+
def inner_no_grad_fn(x, y):
175+
with paddle.no_grad():
176+
return x * y + x**2
177+
178+
179+
@check_no_breakgraph
180+
def no_grad_fn_caller(x, y):
181+
z = inner_no_grad_fn(x * y, y)
182+
a = x * y + x**3 - 1
183+
return z + a
184+
185+
186+
class TestPaddleContextManager(TestCaseBase):
187+
# TODO(DrRyanHuang): Python 3.11 introduced a new opcode, BEFORE_WITH, which is not supported yet.
188+
# Therefore, for versions 3.11 and above, fallback is allowed for now.
189+
@strict_mode_guard(ALREADY_SUPPORTED_EXCEPTION)
190+
def test_fp16_guard(self):
191+
x = paddle.randn([4, 4])
192+
model = SimpleNet(4, 8)
193+
self.assert_results(net_call, x, model)
194+
195+
def test_no_grad(self):
196+
x = paddle.randn([12])
197+
y = paddle.randn([12])
198+
x.stop_gradient = False
199+
y.stop_gradient = False
200+
self.assert_results_with_grad(
201+
[x, y],
202+
no_grad_fn_caller,
203+
x,
204+
y,
205+
)
206+
207+
208+
if __name__ == '__main__':
209+
unittest.main()

0 commit comments

Comments
 (0)