Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1528,8 +1528,12 @@ void HandleForCudaGraphOp(
auto cuda_graph_op = op_item->dyn_cast<CudaGraphOp>();
std::vector<pir::Type> new_outputs;
for (size_t i = 0; i < cuda_graph_op.num_results(); ++i) {
new_outputs.push_back(
ConvertOpTypeToKernelType(ctx, cuda_graph_op.result(i).type(), place));
// Here, we set place as an undefined type to avoid unnecessary memcpy
// operations that may occur if place is fixed to a specific device (e.g.,
// GPU) too early. The real output place will be inferred later in
// `ProcessBlock` and then assigned to the outputs of new_cg_op.
new_outputs.push_back(ConvertOpTypeToKernelType(
ctx, cuda_graph_op.result(i).type(), phi::Place()));
}
auto new_cg_op = builder.Build<CudaGraphOp>(std::move(new_outputs));

Expand All @@ -1540,7 +1544,24 @@ void HandleForCudaGraphOp(
ctx,
map_op_pair,
map_value_pair,
true);
/*for_if_block=*/false);

PADDLE_ENFORCE_EQ(new_cg_op.block()->back().isa<::pir::YieldOp>(),
true,
common::errors::PreconditionNotMet(
"CudaGraphOp's block should end with YieldOp"));

auto yield_op = new_cg_op.block()->back().dyn_cast<::pir::YieldOp>();

PADDLE_ENFORCE_EQ(
yield_op.num_operands(),
new_cg_op.num_results(),
common::errors::PreconditionNotMet(
"CudaGraphOp's num_operands must equal to its YieldOp's"));

for (size_t i = 0; i < yield_op.num_operands(); ++i) {
new_cg_op->result(i).set_type(yield_op.operand_type(i));
}

// update map
(*map_op_pair)[op_item] = new_cg_op;
Expand Down Expand Up @@ -1879,18 +1900,16 @@ void HandleForSpecialOp(
if (op_item->isa<::pir::CombineOp>()) {
// Copy op inputs
std::vector<pir::Type> vec_inner_types;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
auto new_in = GetNewInput(
cur_in, *map_value_pair, static_cast<int>(i), op_item->name());
vec_inputs.push_back(new_in);
vec_inner_types.push_back(new_in.type());
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
auto new_in = GetNewInput(
cur_in, *map_value_pair, static_cast<int>(i), op_item->name());
vec_inputs.push_back(new_in);
vec_inner_types.push_back(new_in.type());
}
// Copy op output type

Expand Down
100 changes: 100 additions & 0 deletions test/dygraph_to_static/test_cudagraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from contextlib import contextmanager

import numpy as np

import paddle
from paddle.jit.dy2static.utils import CUDAGraphState

SEED = 2025
np.random.seed(2025)


class Dy2StCudaGraphManager:
def __init__(self):
self.state = CUDAGraphState.DISABLE
self.captured_batch_size = set()
self.batch_size = -1

def run_impl(self, original_run_impl, inputs, parameters, attrs):
prog_attrs, cuda_graph_attrs = attrs
if self.state == CUDAGraphState.REPLAY:
if self.batch_size not in self.captured_batch_size:
self.state = CUDAGraphState.DISABLE
elif self.state == CUDAGraphState.CAPTURE:
self.captured_batch_size.add(self.batch_size)

cuda_graph_attrs |= {
"cuda_graph_state": self.state,
"cuda_graph_dispatch_key": self.batch_size
if self.state != CUDAGraphState.DISABLE
else 0,
}
return original_run_impl(
inputs, parameters, (prog_attrs, cuda_graph_attrs)
)

@contextmanager
def run_impl_guard(self):
with paddle.jit.dy2static.pir_partial_program.replace_run_impl_guard(
self.run_impl,
):
yield


class CudaGraphRunner:
def __init__(self, runnable):
self.runnable = runnable
self.captured = False
self.cuda_graph_manager = Dy2StCudaGraphManager()

def run_static_model(self, x):
if not self.captured:
# Capture
self.cuda_graph_manager.state = CUDAGraphState.CAPTURE
self.cuda_graph_manager.batch_size = x.shape[0]
self.captured = True
with self.cuda_graph_manager.run_impl_guard():
self.runnable(x)
return

# Replay
self.cuda_graph_manager.state = CUDAGraphState.REPLAY
self.cuda_graph_manager.batch_size = x.shape[0]
with self.cuda_graph_manager.run_impl_guard():
return self.runnable(x)


class TestCUDAGraph(unittest.TestCase):
def get_function(self):
return lambda x: x + x

def test_cuda_graph(self):
x = paddle.rand([32, 64])
fn = self.get_function()
runner = CudaGraphRunner(fn)
# Captured
runner.run_static_model(x)
# Replay
y_cg = runner.run_static_model(x)
y_dy = fn(x)

self.assertTrue(paddle.allclose(y_dy, y_cg))


if __name__ == "__main__":
unittest.main()
Loading