Skip to content

Commit ff77fb1

Browse files
committed
Update with the new upstream changes.
1 parent 1e53dbe commit ff77fb1

File tree

3 files changed

+132
-53
lines changed

3 files changed

+132
-53
lines changed

core/solver/upper_trs_kernels.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ namespace kernels {
4848
namespace upper_trs {
4949

5050

51-
#define GKO_DECLARE_UPPER_TRS_CHECK_SHOULD_PERFORM_TRANSPOSE_KERNEL() \
51+
#define GKO_DECLARE_UPPER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL() \
5252
void should_perform_transpose(std::shared_ptr<const DefaultExecutor> exec, \
5353
bool &do_transpose)
5454

@@ -73,12 +73,12 @@ namespace upper_trs {
7373
const matrix::Dense<_vtype> *b, matrix::Dense<_vtype> *x)
7474

7575

76-
#define GKO_DECLARE_ALL_AS_TEMPLATES \
77-
GKO_DECLARE_UPPER_TRS_CHECK_SHOULD_PERFORM_TRANSPOSE_KERNEL(); \
78-
GKO_DECLARE_UPPER_TRS_INIT_STRUCT_KERNEL(); \
79-
template <typename ValueType, typename IndexType> \
80-
GKO_DECLARE_UPPER_TRS_SOLVE_KERNEL(ValueType, IndexType); \
81-
template <typename ValueType, typename IndexType> \
76+
#define GKO_DECLARE_ALL_AS_TEMPLATES \
77+
GKO_DECLARE_UPPER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL(); \
78+
GKO_DECLARE_UPPER_TRS_INIT_STRUCT_KERNEL(); \
79+
template <typename ValueType, typename IndexType> \
80+
GKO_DECLARE_UPPER_TRS_SOLVE_KERNEL(ValueType, IndexType); \
81+
template <typename ValueType, typename IndexType> \
8282
GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL(ValueType, IndexType)
8383

8484

cuda/solver/upper_trs_kernels.cu

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,28 +64,99 @@ namespace upper_trs {
6464

6565

6666
void should_perform_transpose(std::shared_ptr<const CudaExecutor> exec,
67-
bool &do_transpose) GKO_NOT_IMPLEMENTED;
67+
bool &do_transpose)
68+
{
69+
#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020))
70+
71+
72+
do_transpose = false;
73+
74+
75+
#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020))
76+
77+
78+
do_transpose = true;
79+
80+
81+
#endif
82+
}
6883

6984

7085
void init_struct(std::shared_ptr<const CudaExecutor> exec,
71-
std::shared_ptr<gko::solver::SolveStruct> &solve_struct)
86+
std::shared_ptr<solver::SolveStruct> &solve_struct)
7287
{
73-
const auto id = exec->get_device_id();
74-
device_guard g(id);
75-
solve_struct = std::shared_ptr<gko::solver::SolveStruct>(
76-
kernels::cuda::cusparse::init_trs_solve_struct(),
77-
[id](gko::solver::SolveStruct *solve_struct_) {
78-
device_guard g(id);
79-
kernels::cuda::cusparse::clear_trs_solve_struct(solve_struct_);
80-
});
88+
solve_struct =
89+
std::shared_ptr<solver::SolveStruct>(new solver::SolveStruct());
8190
}
8291

8392

8493
template <typename ValueType, typename IndexType>
8594
void generate(std::shared_ptr<const CudaExecutor> exec,
8695
const matrix::Csr<ValueType, IndexType> *matrix,
87-
solver::SolveStruct *solve_struct,
88-
const gko::size_type num_rhs) GKO_NOT_IMPLEMENTED;
96+
solver::SolveStruct *solve_struct, const gko::size_type num_rhs)
97+
{
98+
if (cusparse::is_supported<ValueType, IndexType>::value) {
99+
auto handle = exec->get_cusparse_handle();
100+
GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSetMatFillMode(
101+
solve_struct->factor_descr, CUSPARSE_FILL_MODE_UPPER));
102+
103+
104+
#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020))
105+
106+
107+
ValueType one = 1.0;
108+
109+
GKO_ASSERT_NO_CUSPARSE_ERRORS(
110+
cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_HOST));
111+
cusparse::buffer_size_ext(
112+
handle, solve_struct->algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE,
113+
CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], num_rhs,
114+
matrix->get_num_stored_elements(), &one, solve_struct->factor_descr,
115+
matrix->get_const_values(), matrix->get_const_row_ptrs(),
116+
matrix->get_const_col_idxs(), nullptr, num_rhs,
117+
solve_struct->solve_info, solve_struct->policy,
118+
&solve_struct->factor_work_size);
119+
120+
// allocate workspace
121+
if (solve_struct->factor_work_vec != nullptr) {
122+
GKO_ASSERT_NO_CUDA_ERRORS(cudaFree(solve_struct->factor_work_vec));
123+
}
124+
solve_struct->factor_work_vec =
125+
exec->alloc<void *>(solve_struct->factor_work_size);
126+
127+
cusparse::csrsm2_analysis(
128+
handle, solve_struct->algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE,
129+
CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], num_rhs,
130+
matrix->get_num_stored_elements(), &one, solve_struct->factor_descr,
131+
matrix->get_const_values(), matrix->get_const_row_ptrs(),
132+
matrix->get_const_col_idxs(), nullptr, num_rhs,
133+
solve_struct->solve_info, solve_struct->policy,
134+
solve_struct->factor_work_vec);
135+
GKO_ASSERT_NO_CUSPARSE_ERRORS(
136+
cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_DEVICE));
137+
138+
139+
#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020))
140+
141+
142+
GKO_ASSERT_NO_CUSPARSE_ERRORS(
143+
cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_HOST));
144+
cusparse::csrsm_analysis(
145+
handle, CUSPARSE_OPERATION_NON_TRANSPOSE, matrix->get_size()[0],
146+
matrix->get_num_stored_elements(), solve_struct->factor_descr,
147+
matrix->get_const_values(), matrix->get_const_row_ptrs(),
148+
matrix->get_const_col_idxs(), solve_struct->solve_info);
149+
GKO_ASSERT_NO_CUSPARSE_ERRORS(
150+
cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_DEVICE));
151+
152+
153+
#endif
154+
155+
156+
} else {
157+
GKO_NOT_IMPLEMENTED;
158+
}
159+
}
89160

90161
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
91162
GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL);

cuda/test/solver/upper_trs_kernels.cpp

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,12 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3333
#include <ginkgo/core/solver/upper_trs.hpp>
3434

3535

36-
#include <gtest/gtest.h>
37-
38-
3936
#include <memory>
4037
#include <random>
4138

4239

4340
#include <cuda.h>
41+
#include <gtest/gtest.h>
4442

4543

4644
#include <ginkgo/core/base/exception.hpp>
@@ -60,6 +58,7 @@ class UpperTrs : public ::testing::Test {
6058
protected:
6159
using CsrMtx = gko::matrix::Csr<double, gko::int32>;
6260
using Mtx = gko::matrix::Dense<>;
61+
6362
UpperTrs() : rand_engine(30) {}
6463

6564
void SetUp()
@@ -92,7 +91,32 @@ class UpperTrs : public ::testing::Test {
9291
std::normal_distribution<>(-1.0, 1.0), rand_engine, ref);
9392
}
9493

94+
void initialize_data(int m, int n)
95+
{
96+
mtx = gen_u_mtx(m, m);
97+
b = gen_mtx(m, n);
98+
x = gen_mtx(m, n);
99+
csr_mtx = CsrMtx::create(ref);
100+
mtx->convert_to(csr_mtx.get());
101+
d_csr_mtx = CsrMtx::create(cuda);
102+
d_x = Mtx::create(cuda);
103+
d_x->copy_from(x.get());
104+
d_csr_mtx->copy_from(csr_mtx.get());
105+
b2 = Mtx::create(ref);
106+
d_b2 = Mtx::create(cuda);
107+
d_b2->copy_from(b.get());
108+
b2->copy_from(b.get());
109+
}
95110

111+
std::shared_ptr<Mtx> b;
112+
std::shared_ptr<Mtx> b2;
113+
std::shared_ptr<Mtx> x;
114+
std::shared_ptr<Mtx> mtx;
115+
std::shared_ptr<CsrMtx> csr_mtx;
116+
std::shared_ptr<Mtx> d_b;
117+
std::shared_ptr<Mtx> d_b2;
118+
std::shared_ptr<Mtx> d_x;
119+
std::shared_ptr<CsrMtx> d_csr_mtx;
96120
std::shared_ptr<gko::ReferenceExecutor> ref;
97121
std::shared_ptr<const gko::CudaExecutor> cuda;
98122
std::ranlux48 rand_engine;
@@ -103,65 +127,49 @@ TEST_F(UpperTrs, CudaUpperTrsFlagCheckIsCorrect)
103127
{
104128
bool trans_flag = true;
105129
bool expected_flag = false;
106-
#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020))
107-
expected_flag = false;
108-
#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020))
130+
131+
132+
#if (defined(CUDA_VERSION) && (CUDA_VERSION < 9020))
133+
134+
109135
expected_flag = true;
136+
137+
110138
#endif
111-
gko::kernels::cuda::upper_trs::perform_transpose(cuda, trans_flag);
139+
140+
141+
gko::kernels::cuda::upper_trs::should_perform_transpose(cuda, trans_flag);
112142

113143
ASSERT_EQ(expected_flag, trans_flag);
114144
}
115145

116146

117147
TEST_F(UpperTrs, CudaSingleRhsApplyIsEquivalentToRef)
118148
{
119-
std::shared_ptr<Mtx> mtx = gen_u_mtx(50, 50);
120-
std::shared_ptr<Mtx> b = gen_mtx(50, 1);
121-
std::shared_ptr<Mtx> x = gen_mtx(50, 1);
122-
std::shared_ptr<CsrMtx> csr_mtx = CsrMtx::create(ref);
123-
mtx->convert_to(csr_mtx.get());
124-
std::shared_ptr<CsrMtx> d_csr_mtx = CsrMtx::create(cuda);
125-
auto d_x = Mtx::create(cuda);
126-
d_x->copy_from(x.get());
127-
d_csr_mtx->copy_from(csr_mtx.get());
128-
std::shared_ptr<Mtx> b2 = Mtx::create(ref);
129-
std::shared_ptr<Mtx> d_b2 = Mtx::create(cuda);
130-
d_b2->copy_from(b.get());
131-
b2->copy_from(b.get());
132-
149+
initialize_data(50, 1);
133150
auto upper_trs_factory = gko::solver::UpperTrs<>::build().on(ref);
134151
auto d_upper_trs_factory = gko::solver::UpperTrs<>::build().on(cuda);
135152
auto solver = upper_trs_factory->generate(csr_mtx);
136153
auto d_solver = d_upper_trs_factory->generate(d_csr_mtx);
154+
137155
solver->apply(b2.get(), x.get());
138156
d_solver->apply(d_b2.get(), d_x.get());
157+
139158
GKO_ASSERT_MTX_NEAR(d_x, x, 1e-14);
140159
}
141160

142161

143162
TEST_F(UpperTrs, CudaMultipleRhsApplyIsEquivalentToRef)
144163
{
145-
std::shared_ptr<Mtx> mtx = gen_u_mtx(50, 50);
146-
std::shared_ptr<Mtx> b = gen_mtx(50, 3);
147-
std::shared_ptr<Mtx> x = gen_mtx(50, 3);
148-
std::shared_ptr<CsrMtx> csr_mtx = CsrMtx::create(ref);
149-
mtx->convert_to(csr_mtx.get());
150-
std::shared_ptr<CsrMtx> d_csr_mtx = CsrMtx::create(cuda);
151-
auto d_x = Mtx::create(cuda);
152-
d_x->copy_from(x.get());
153-
d_csr_mtx->copy_from(csr_mtx.get());
154-
std::shared_ptr<Mtx> b2 = Mtx::create(ref);
155-
std::shared_ptr<Mtx> d_b2 = Mtx::create(cuda);
156-
d_b2->copy_from(b.get());
157-
b2->copy_from(b.get());
164+
initialize_data(50, 3);
158165

159166
auto upper_trs_factory =
160167
gko::solver::UpperTrs<>::build().with_num_rhs(3u).on(ref);
161168
auto d_upper_trs_factory =
162169
gko::solver::UpperTrs<>::build().with_num_rhs(3u).on(cuda);
163170
auto solver = upper_trs_factory->generate(csr_mtx);
164171
auto d_solver = d_upper_trs_factory->generate(d_csr_mtx);
172+
165173
solver->apply(b2.get(), x.get());
166174
d_solver->apply(d_b2.get(), d_x.get());
167175

0 commit comments

Comments
 (0)