Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions test/dygraph_to_static/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ endif()
list(REMOVE_ITEM TEST_OPS test_build_strategy)

if(NOT WITH_GPU)
# TODO(SigureMo): Temporarily disable train step on Windows CPU CI.
# We should remove this after fix the performance issue.
list(REMOVE_ITEM TEST_OPS test_train_step_resnet18_adam)
list(REMOVE_ITEM TEST_OPS test_train_step_resnet18_sgd)
# disable some model test on CPU to avoid timeout
list(REMOVE_ITEM TEST_OPS test_resnet)
list(REMOVE_ITEM TEST_OPS test_bert)
Expand Down Expand Up @@ -56,8 +52,6 @@ if(APPLE)
endif()

if(WITH_GPU)
set_tests_properties(test_train_step_resnet18_sgd PROPERTIES TIMEOUT 240)
set_tests_properties(test_train_step_resnet18_adam PROPERTIES TIMEOUT 240)
set_tests_properties(test_bert PROPERTIES TIMEOUT 240)
set_tests_properties(test_transformer PROPERTIES TIMEOUT 240)
set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 240)
Expand Down
89 changes: 5 additions & 84 deletions test/dygraph_to_static/dygraph_to_static_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import importlib
import inspect
import logging
import os
import sys
import unittest
from contextlib import contextmanager
Expand All @@ -28,7 +27,7 @@
from typing_extensions import TypeAlias

import paddle
from paddle import get_flags, set_flags, static
from paddle import set_flags
from paddle.jit.api import sot_mode_guard
from paddle.jit.dy2static.utils import (
ENV_ENABLE_CINN_IN_DY2ST,
Expand Down Expand Up @@ -114,12 +113,9 @@ def lower_case_name(self):
DEFAULT_TO_STATIC_MODE = (
ToStaticMode.AST | ToStaticMode.SOT | ToStaticMode.SOT_MGS10
)
DEFAULT_IR_MODE = IrMode.PT | IrMode.PIR
DEFAULT_IR_MODE = IrMode.PIR
DEFAULT_BACKEND_MODE = BackendMode.PHI | BackendMode.CINN
VALID_MODES = [
# For `.pd_model` export, we still need test AST+PT / AST+LEGACY_IR
(ToStaticMode.AST, IrMode.LEGACY_IR, BackendMode.PHI),
(ToStaticMode.AST, IrMode.PT, BackendMode.PHI),
(ToStaticMode.AST, IrMode.PIR, BackendMode.PHI),
(ToStaticMode.SOT, IrMode.PIR, BackendMode.PHI),
(ToStaticMode.SOT_MGS10, IrMode.PIR, BackendMode.PHI),
Expand All @@ -138,9 +134,7 @@ def lower_case_name(self):

DISABLED_IR_TEST_FILES = {
IrMode.LEGACY_IR: [],
IrMode.PT: [
"test_tensor_hook",
],
IrMode.PT: [],
IrMode.PIR: [],
}
DISABLED_BACKEND_TEST_FILES = {
Expand All @@ -158,15 +152,6 @@ def pir_dygraph_guard():
yield


@contextmanager
def legacy_ir_dygraph_guard():
in_dygraph_mode = paddle.in_dynamic_mode()
with paddle.pir_utils.OldIrGuard():
if in_dygraph_mode:
paddle.disable_static()
yield


def to_ast_test(fn):
"""
convert run AST
Expand Down Expand Up @@ -220,45 +205,11 @@ def sot_mgs10_impl(*args, **kwargs):


def to_legacy_ir_test(fn):
@wraps(fn)
def legacy_ir_impl(*args, **kwargs):
logger.info("[LEGACY_IR] running legacy ir")
with legacy_ir_dygraph_guard():
pt_in_dy2st_flag = ENV_ENABLE_PIR_WITH_PT_IN_DY2ST.name
original_flag_value = get_flags(pt_in_dy2st_flag)[pt_in_dy2st_flag]
with EnvironmentVariableGuard(
ENV_ENABLE_PIR_WITH_PT_IN_DY2ST, False
):
try:
set_flags({pt_in_dy2st_flag: False})
return fn(*args, **kwargs)
finally:
set_flags({pt_in_dy2st_flag: original_flag_value})

return legacy_ir_impl
raise NotImplementedError("Legacy IR is not supported")


def to_pt_test(fn):
@wraps(fn)
def pt_impl(*args, **kwargs):
logger.info("[PT] running PT")
with legacy_ir_dygraph_guard():
pt_in_dy2st_flag = ENV_ENABLE_PIR_WITH_PT_IN_DY2ST.name
original_flag_value = get_flags(pt_in_dy2st_flag)[pt_in_dy2st_flag]
if os.environ.get('FLAGS_use_stride_kernel', False):
return
with (
static.scope_guard(static.Scope()),
static.program_guard(static.Program()),
EnvironmentVariableGuard(ENV_ENABLE_PIR_WITH_PT_IN_DY2ST, True),
):
try:
set_flags({pt_in_dy2st_flag: True})
return fn(*args, **kwargs)
finally:
set_flags({pt_in_dy2st_flag: original_flag_value})

return pt_impl
raise NotImplementedError("PT is not supported")


def to_pir_test(fn):
Expand Down Expand Up @@ -484,41 +435,11 @@ def test_sot_only(fn):
return fn


def test_legacy_only(fn):
fn = set_ir_mode(IrMode.LEGACY_IR)(fn)
return fn


def test_pt_only(fn):
fn = set_ir_mode(IrMode.PT)(fn)
return fn


def test_pir_only(fn):
fn = set_ir_mode(IrMode.PIR)(fn)
return fn


def test_legacy_and_pt(fn):
fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PT)(fn)
return fn


def test_pt_and_pir(fn):
fn = set_ir_mode(IrMode.PT | IrMode.PIR)(fn)
return fn


def test_legacy_and_pir(fn):
fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR)(fn)
return fn


def test_legacy_and_pt_and_pir(fn):
fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PT | IrMode.PIR)(fn)
return fn


def test_phi_only(fn):
fn = set_backend_mode(BackendMode.PHI)(fn)
return fn
Expand Down
2 changes: 0 additions & 2 deletions test/dygraph_to_static/test_decorator_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_pt_only,
)

import paddle
Expand Down Expand Up @@ -197,7 +196,6 @@ def test_deco_transform(self):
np.testing.assert_allclose(outs[7], np.array(10), rtol=1e-05)

@test_ast_only
@test_pt_only
def test_contextmanager_warning(self):
paddle.disable_static()
with warnings.catch_warnings(record=True) as w:
Expand Down
2 changes: 0 additions & 2 deletions test/dygraph_to_static/test_for_enumerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_legacy_and_pir,
)

import paddle
Expand Down Expand Up @@ -560,7 +559,6 @@ def test_for_zip_error(self):
model_path,
)

@test_legacy_and_pir
def test_for_zip(self):
model_path = os.path.join(self.temp_dir.name, 'for_zip')
paddle.jit.save(
Expand Down
8 changes: 0 additions & 8 deletions test/dygraph_to_static/test_ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@

import numpy as np
from dygraph_to_static_utils import (
BackendMode,
Dy2StTestBase,
IrMode,
ToStaticMode,
disable_test_case,
enable_to_static_guard,
test_ast_only,
test_pir_only,
Expand Down Expand Up @@ -105,10 +101,6 @@ def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = dyfunc_with_if_else2

# TODO(dev): fix AST mode
@disable_test_case(
(ToStaticMode.AST, IrMode.PT, BackendMode.PHI | BackendMode.CINN)
)
def test_ast_to_func(self):
np.testing.assert_allclose(
self._run_dygraph(), self._run_static(), atol=1e-7, rtol=1e-7
Expand Down
13 changes: 0 additions & 13 deletions test/dygraph_to_static/test_len.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
Dy2StTestBase,
static_guard,
test_ast_only,
test_pir_only,
test_pt_only,
)

import paddle
Expand Down Expand Up @@ -165,17 +163,6 @@ def setUp(self):
)

@test_ast_only
@test_pt_only
def test_len_legacy(self):
with static_guard():
(
selected_rows_var_len,
var_tensor_len,
) = legacy_len_with_selected_rows(self.place)
self.assertEqual(selected_rows_var_len, var_tensor_len)

@test_ast_only
@test_pir_only
def test_len(self):
with static_guard():
selected_rows_var_len, var_tensor_len = len_with_selected_rows(
Expand Down
7 changes: 0 additions & 7 deletions test/dygraph_to_static/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@

import numpy as np
from dygraph_to_static_utils import (
BackendMode,
Dy2StTestBase,
IrMode,
ToStaticMode,
disable_test_case,
test_ast_only,
)

Expand Down Expand Up @@ -297,9 +293,6 @@ def train(self, to_static=False):
res = self.dygraph_func(self.input, self.iter_num)
return self.result_to_numpy(res)

@disable_test_case(
(ToStaticMode.AST, IrMode.PT, BackendMode.PHI | BackendMode.CINN)
)
def test_transformed_static_result(self):
self.compare_transformed_static_result()

Expand Down
Loading