Skip to content

Commit 5250aa4

Browse files
committed
Fix some final review comments.
+ Remove GKO_COMMA and params for gen thanks to Thomas.
1 parent 58e59f9 commit 5250aa4

File tree

4 files changed

+19
-29
lines changed

4 files changed

+19
-29
lines changed

core/solver/lower_trs.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,10 @@ GKO_REGISTER_OPERATION(solve, lower_trs::solve);
6161

6262

6363
template <typename ValueType, typename IndexType>
64-
void LowerTrs<ValueType, IndexType>::generate(
65-
const matrix::Csr<ValueType, IndexType> *system_matrix,
66-
const matrix::Dense<ValueType> *b)
64+
void LowerTrs<ValueType, IndexType>::generate()
6765
{
6866
this->get_executor()->run(
69-
lower_trs::make_generate(gko::lend(system_matrix), gko::lend(b)));
67+
lower_trs::make_generate(gko::lend(system_matrix_), gko::lend(b_)));
7068
}
7169

7270

@@ -95,7 +93,7 @@ void LowerTrs<ValueType, IndexType>::apply_impl(const LinOp *alpha,
9593
auto x_clone = dense_x->clone();
9694
this->apply(b, x_clone.get());
9795
dense_x->scale(beta);
98-
dense_x->add_scaled(alpha, x_clone.get());
96+
dense_x->add_scaled(alpha, gko::lend(x_clone));
9997
}
10098

10199

include/ginkgo/core/base/polymorphic_object.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3434
#define GKO_CORE_BASE_POLYMORPHIC_OBJECT_HPP_
3535

3636

37+
#include <memory>
38+
39+
3740
#include <ginkgo/core/base/executor.hpp>
3841
#include <ginkgo/core/base/utils.hpp>
3942
#include <ginkgo/core/log/logger.hpp>
@@ -484,7 +487,7 @@ std::unique_ptr<const R, std::function<void(const R *)>> copy_and_convert_to(
484487

485488
/**
486489
* Converts the object to R and places it on Executor exec. This is the version
487-
* that takes in the shared_ptr and returns a shared_ptr
490+
* that takes in the std::shared_ptr and returns a std::shared_ptr
488491
*
489492
* If the object is already of the requested type and on the requested executor,
490493
* the copy and conversion is avoided and a reference to the original object is

include/ginkgo/core/solver/lower_trs.hpp

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,7 @@ class LowerTrs : public EnableLinOp<LowerTrs<ValueType, IndexType>>,
223223
std::shared_ptr<const LinOpFactory> GKO_FACTORY_PARAMETER(
224224
preconditioner, nullptr);
225225
};
226-
#define GKO_COMMA ,
227-
GKO_ENABLE_LOWER_TRS_FACTORY(LowerTrs<ValueType GKO_COMMA IndexType>,
228-
parameters, Factory);
229-
#undef GKO_COMMA
226+
GKO_ENABLE_LOWER_TRS_FACTORY(LowerTrs, parameters, Factory);
230227
GKO_ENABLE_BUILD_METHOD(Factory);
231228

232229
protected:
@@ -237,16 +234,8 @@ class LowerTrs : public EnableLinOp<LowerTrs<ValueType, IndexType>>,
237234

238235
/**
239236
* Generates the solver.
240-
*
241-
* @param system_matrix the source matrix used to generate the
242-
* solver.
243-
* @param b the right hand side used to generate the solver.
244-
*
245-
* @note the system_matrix to be passed in has to be convertible to CSR.
246-
* Otherwise an exception is thrown.
247237
*/
248-
void generate(const matrix::Csr<ValueType, IndexType> *system_matrix,
249-
const matrix::Dense<ValueType> *b);
238+
void generate();
250239

251240
explicit LowerTrs(std::shared_ptr<const Executor> exec)
252241
: EnableLinOp<LowerTrs>(std::move(exec))
@@ -279,7 +268,7 @@ class LowerTrs : public EnableLinOp<LowerTrs<ValueType, IndexType>>,
279268
preconditioner_ = matrix::Identity<ValueType>::create(
280269
this->get_executor(), this->get_size()[0]);
281270
}
282-
this->generate(gko::lend(system_matrix_), gko::lend(b_));
271+
this->generate();
283272
}
284273

285274
private:

reference/test/solver/lower_trs.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class LowerTrs : public ::testing::Test {
6262
mtx(gko::initialize<Mtx>(
6363
{{2, 0.0, 0.0}, {3.0, 1, 0.0}, {1.0, 2.0, 3}}, exec)),
6464
b(gko::initialize<Mtx>({{2, 0.0, 0.0}}, exec)),
65-
csr_mtx(gko::copy_and_convert_to<CsrMtx>(exec, mtx.get())),
65+
csr_mtx(gko::copy_and_convert_to<CsrMtx>(exec, gko::lend(mtx))),
6666
lower_trs_factory(Solver::build().on(exec)),
6767
lower_trs_solver(lower_trs_factory->generate(mtx, b))
6868
{}
@@ -97,10 +97,10 @@ TEST_F(LowerTrs, CanBeCopied)
9797
copy->copy_from(gko::lend(lower_trs_solver));
9898

9999
ASSERT_EQ(copy->get_size(), gko::dim<2>(3, 3));
100-
auto copy_mtx = copy.get()->get_system_matrix();
100+
auto copy_mtx = gko::lend(copy)->get_system_matrix();
101101
auto d_copy_mtx = Mtx::create(exec);
102102
copy_mtx->convert_to(gko::lend(d_copy_mtx));
103-
auto copy_b = copy.get()->get_rhs();
103+
auto copy_b = gko::lend(copy)->get_rhs();
104104

105105
GKO_ASSERT_MTX_NEAR(d_copy_mtx, mtx, 0);
106106
GKO_ASSERT_MTX_NEAR(copy_b, b, 0);
@@ -114,10 +114,10 @@ TEST_F(LowerTrs, CanBeMoved)
114114
copy->copy_from(std::move(lower_trs_solver));
115115

116116
ASSERT_EQ(copy->get_size(), gko::dim<2>(3, 3));
117-
auto copy_mtx = copy.get()->get_system_matrix();
117+
auto copy_mtx = gko::lend(copy)->get_system_matrix();
118118
auto d_copy_mtx = Mtx::create(exec);
119119
copy_mtx->convert_to(gko::lend(d_copy_mtx));
120-
auto copy_b = copy.get()->get_rhs();
120+
auto copy_b = gko::lend(copy)->get_rhs();
121121

122122
GKO_ASSERT_MTX_NEAR(d_copy_mtx, mtx, 0);
123123
GKO_ASSERT_MTX_NEAR(copy_b, b, 0);
@@ -128,10 +128,10 @@ TEST_F(LowerTrs, CanBeCloned)
128128
{
129129
auto clone = lower_trs_solver->clone();
130130

131-
auto clone_mtx = clone.get()->get_system_matrix();
131+
auto clone_mtx = gko::lend(clone)->get_system_matrix();
132132
auto d_clone_mtx = Mtx::create(exec);
133133
clone_mtx->convert_to(gko::lend(d_clone_mtx));
134-
auto clone_b = clone.get()->get_rhs();
134+
auto clone_b = gko::lend(clone)->get_rhs();
135135

136136
ASSERT_EQ(clone->get_size(), gko::dim<2>(3, 3));
137137
GKO_ASSERT_MTX_NEAR(d_clone_mtx, mtx, 0);
@@ -143,8 +143,8 @@ TEST_F(LowerTrs, CanBeCleared)
143143
{
144144
lower_trs_solver->clear();
145145

146-
auto solver_mtx = lower_trs_solver.get()->get_system_matrix();
147-
auto solver_b = lower_trs_solver.get()->get_rhs();
146+
auto solver_mtx = gko::lend(lower_trs_solver)->get_system_matrix();
147+
auto solver_b = gko::lend(lower_trs_solver)->get_rhs();
148148

149149
ASSERT_EQ(lower_trs_solver->get_size(), gko::dim<2>(0, 0));
150150
ASSERT_EQ(solver_mtx, nullptr);

0 commit comments

Comments
 (0)