@@ -33,14 +33,12 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
33
#include < ginkgo/core/solver/upper_trs.hpp>
34
34
35
35
36
- #include < gtest/gtest.h>
37
-
38
-
39
36
#include < memory>
40
37
#include < random>
41
38
42
39
43
40
#include < cuda.h>
41
+ #include < gtest/gtest.h>
44
42
45
43
46
44
#include < ginkgo/core/base/exception.hpp>
@@ -60,6 +58,7 @@ class UpperTrs : public ::testing::Test {
60
58
protected:
61
59
using CsrMtx = gko::matrix::Csr<double , gko::int32>;
62
60
using Mtx = gko::matrix::Dense<>;
61
+
63
62
UpperTrs () : rand_engine(30 ) {}
64
63
65
64
void SetUp ()
@@ -92,7 +91,32 @@ class UpperTrs : public ::testing::Test {
92
91
std::normal_distribution<>(-1.0 , 1.0 ), rand_engine, ref);
93
92
}
94
93
94
+ void initialize_data (int m, int n)
95
+ {
96
+ mtx = gen_u_mtx (m, m);
97
+ b = gen_mtx (m, n);
98
+ x = gen_mtx (m, n);
99
+ csr_mtx = CsrMtx::create (ref);
100
+ mtx->convert_to (csr_mtx.get ());
101
+ d_csr_mtx = CsrMtx::create (cuda);
102
+ d_x = Mtx::create (cuda);
103
+ d_x->copy_from (x.get ());
104
+ d_csr_mtx->copy_from (csr_mtx.get ());
105
+ b2 = Mtx::create (ref);
106
+ d_b2 = Mtx::create (cuda);
107
+ d_b2->copy_from (b.get ());
108
+ b2->copy_from (b.get ());
109
+ }
95
110
111
+ std::shared_ptr<Mtx> b;
112
+ std::shared_ptr<Mtx> b2;
113
+ std::shared_ptr<Mtx> x;
114
+ std::shared_ptr<Mtx> mtx;
115
+ std::shared_ptr<CsrMtx> csr_mtx;
116
+ std::shared_ptr<Mtx> d_b;
117
+ std::shared_ptr<Mtx> d_b2;
118
+ std::shared_ptr<Mtx> d_x;
119
+ std::shared_ptr<CsrMtx> d_csr_mtx;
96
120
std::shared_ptr<gko::ReferenceExecutor> ref;
97
121
std::shared_ptr<const gko::CudaExecutor> cuda;
98
122
std::ranlux48 rand_engine;
@@ -103,65 +127,49 @@ TEST_F(UpperTrs, CudaUpperTrsFlagCheckIsCorrect)
103
127
{
104
128
bool trans_flag = true ;
105
129
bool expected_flag = false ;
106
- #if (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020))
107
- expected_flag = false ;
108
- #elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020))
130
+
131
+
132
+ #if (defined(CUDA_VERSION) && (CUDA_VERSION < 9020))
133
+
134
+
109
135
expected_flag = true ;
136
+
137
+
110
138
#endif
111
- gko::kernels::cuda::upper_trs::perform_transpose (cuda, trans_flag);
139
+
140
+
141
+ gko::kernels::cuda::upper_trs::should_perform_transpose (cuda, trans_flag);
112
142
113
143
ASSERT_EQ (expected_flag, trans_flag);
114
144
}
115
145
116
146
117
147
TEST_F (UpperTrs, CudaSingleRhsApplyIsEquivalentToRef)
118
148
{
119
- std::shared_ptr<Mtx> mtx = gen_u_mtx (50 , 50 );
120
- std::shared_ptr<Mtx> b = gen_mtx (50 , 1 );
121
- std::shared_ptr<Mtx> x = gen_mtx (50 , 1 );
122
- std::shared_ptr<CsrMtx> csr_mtx = CsrMtx::create (ref);
123
- mtx->convert_to (csr_mtx.get ());
124
- std::shared_ptr<CsrMtx> d_csr_mtx = CsrMtx::create (cuda);
125
- auto d_x = Mtx::create (cuda);
126
- d_x->copy_from (x.get ());
127
- d_csr_mtx->copy_from (csr_mtx.get ());
128
- std::shared_ptr<Mtx> b2 = Mtx::create (ref);
129
- std::shared_ptr<Mtx> d_b2 = Mtx::create (cuda);
130
- d_b2->copy_from (b.get ());
131
- b2->copy_from (b.get ());
132
-
149
+ initialize_data (50 , 1 );
133
150
auto upper_trs_factory = gko::solver::UpperTrs<>::build ().on (ref);
134
151
auto d_upper_trs_factory = gko::solver::UpperTrs<>::build ().on (cuda);
135
152
auto solver = upper_trs_factory->generate (csr_mtx);
136
153
auto d_solver = d_upper_trs_factory->generate (d_csr_mtx);
154
+
137
155
solver->apply (b2.get (), x.get ());
138
156
d_solver->apply (d_b2.get (), d_x.get ());
157
+
139
158
GKO_ASSERT_MTX_NEAR (d_x, x, 1e-14 );
140
159
}
141
160
142
161
143
162
TEST_F (UpperTrs, CudaMultipleRhsApplyIsEquivalentToRef)
144
163
{
145
- std::shared_ptr<Mtx> mtx = gen_u_mtx (50 , 50 );
146
- std::shared_ptr<Mtx> b = gen_mtx (50 , 3 );
147
- std::shared_ptr<Mtx> x = gen_mtx (50 , 3 );
148
- std::shared_ptr<CsrMtx> csr_mtx = CsrMtx::create (ref);
149
- mtx->convert_to (csr_mtx.get ());
150
- std::shared_ptr<CsrMtx> d_csr_mtx = CsrMtx::create (cuda);
151
- auto d_x = Mtx::create (cuda);
152
- d_x->copy_from (x.get ());
153
- d_csr_mtx->copy_from (csr_mtx.get ());
154
- std::shared_ptr<Mtx> b2 = Mtx::create (ref);
155
- std::shared_ptr<Mtx> d_b2 = Mtx::create (cuda);
156
- d_b2->copy_from (b.get ());
157
- b2->copy_from (b.get ());
164
+ initialize_data (50 , 3 );
158
165
159
166
auto upper_trs_factory =
160
167
gko::solver::UpperTrs<>::build ().with_num_rhs (3u ).on (ref);
161
168
auto d_upper_trs_factory =
162
169
gko::solver::UpperTrs<>::build ().with_num_rhs (3u ).on (cuda);
163
170
auto solver = upper_trs_factory->generate (csr_mtx);
164
171
auto d_solver = d_upper_trs_factory->generate (d_csr_mtx);
172
+
165
173
solver->apply (b2.get (), x.get ());
166
174
d_solver->apply (d_b2.get (), d_x.get ());
167
175
0 commit comments