Skip to content

Commit 1bb5283

Browse files
committed
Some minor changes and clarifications.
1 parent a6fe819 commit 1bb5283

File tree

8 files changed

+39
-79
lines changed

8 files changed

+39
-79
lines changed

core/solver/trs.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ void Trs<ValueType, IndexType>::generate(const LinOp *system_matrix,
7070
using CsrMatrix = matrix::Csr<ValueType, IndexType>;
7171
using Vector = matrix::Dense<ValueType>;
7272
GKO_ASSERT_IS_SQUARE_MATRIX(system_matrix);
73+
// This is needed because it does not make sense to call the copy and
74+
// convert if the existing matrix (if not CSR) is empty.
75+
if (dynamic_cast<const CsrMatrix *>(system_matrix) == nullptr) {
76+
GKO_ASSERT_IS_NON_EMPTY_MATRIX(system_matrix);
77+
}
7378
const auto exec = this->get_executor();
7479
csr_system_matrix_ = copy_and_convert_to<CsrMatrix>(exec, system_matrix);
7580
auto dense_b = as<const Vector>(b);

core/test/solver/trs.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3333
#include <ginkgo/core/solver/trs.hpp>
3434

3535

36-
#include <typeinfo>
37-
36+
#include <memory>
3837

3938
#include <gtest/gtest.h>
4039

@@ -56,7 +55,6 @@ class Trs : public ::testing::Test {
5655

5756
std::shared_ptr<const gko::Executor> exec;
5857
std::unique_ptr<Solver::Factory> trs_factory;
59-
std::unique_ptr<gko::LinOp> solver;
6058
};
6159

6260

dev_tools/scripts/todo_trs.txt

Lines changed: 0 additions & 53 deletions
This file was deleted.

include/ginkgo/core/base/exception_helpers.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,22 @@ inline dim<2> get_size(const dim<2> &size) { return size; }
134134
::gko::detail::get_size(_op1)[1], "expected square matrix"); \
135135
}
136136

137+
/**
138+
*Asserts that _op1 is a non-empty matrix.
139+
*
140+
*@throw DimensionMismatch if the number of rows of _op1 is different from the
141+
* number of columns of _op1.
142+
*/
143+
#define GKO_ASSERT_IS_NON_EMPTY_MATRIX(_op1) \
144+
if (::gko::detail::get_size(_op1)[0] == 0 && \
145+
::gko::detail::get_size(_op1)[1] == 0) { \
146+
throw ::gko::DimensionMismatch( \
147+
__FILE__, __LINE__, __func__, #_op1, \
148+
::gko::detail::get_size(_op1)[0], \
149+
::gko::detail::get_size(_op1)[1], #_op1, \
150+
::gko::detail::get_size(_op1)[0], \
151+
::gko::detail::get_size(_op1)[1], "expected non-empty matrix"); \
152+
}
137153

138154
/**
139155
* Asserts that _op1 can be applied to _op2.

include/ginkgo/core/solver/trs.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ class Trs : public EnableLinOp<Trs<ValueType, IndexType>>,
227227
: parameters_{factory->get_parameters()},
228228
EnableLinOp<Trs>(factory->get_executor(),
229229
transpose(args.system_matrix->get_size())),
230-
system_matrix_{std::move(args.system_matrix)}
230+
system_matrix_{std::move(args.system_matrix)},
231+
b_{std::move(args.b)}
231232
{
232233
if (parameters_.preconditioner) {
233234
preconditioner_ =
@@ -236,12 +237,13 @@ class Trs : public EnableLinOp<Trs<ValueType, IndexType>>,
236237
preconditioner_ = matrix::Identity<ValueType>::create(
237238
this->get_executor(), this->get_size()[0]);
238239
}
239-
this->generate(gko::lend(system_matrix_), gko::lend(args.b));
240+
this->generate(gko::lend(system_matrix_), gko::lend(b_));
240241
}
241242

242243

243244
private:
244245
std::shared_ptr<const LinOp> system_matrix_{};
246+
std::shared_ptr<const LinOp> b_{};
245247
std::shared_ptr<const matrix::Csr<ValueType, IndexType>>
246248
csr_system_matrix_{};
247249
std::shared_ptr<const LinOp> preconditioner_{};

omp/solver/trs_kernels.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ template <typename ValueType, typename IndexType>
5656
void generate(std::shared_ptr<const OmpExecutor> exec,
5757
const matrix::Csr<ValueType, IndexType> *matrix,
5858
const matrix::Dense<ValueType> *b)
59-
{}
59+
{
60+
// This generate kernel is here to allow for a more sophisticated
61+
// implementation as for the CUDA executor. This kernel would perform the
62+
// "analysis" phase for the triangular matrix.
63+
}
6064

6165

6266
template <typename ValueType, typename IndexType>

reference/solver/trs_kernels.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3838
#include <ginkgo/core/base/math.hpp>
3939
#include <ginkgo/core/base/types.hpp>
4040
#include <ginkgo/core/matrix/csr.hpp>
41-
#include <iostream>
41+
42+
4243
namespace gko {
4344
namespace kernels {
4445
namespace reference {
@@ -53,7 +54,11 @@ template <typename ValueType, typename IndexType>
5354
void generate(std::shared_ptr<const ReferenceExecutor> exec,
5455
const matrix::Csr<ValueType, IndexType> *matrix,
5556
const matrix::Dense<ValueType> *b)
56-
{}
57+
{
58+
// This generate kernel is here to allow for a more sophisticated
59+
// implementation as for the CUDA executor. This kernel would perform the
60+
// "analysis" phase for the triangular matrix.
61+
}
5762

5863
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_TRS_GENERATE_KERNEL);
5964

reference/test/solver/trs.cpp

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3232

3333
#include <ginkgo/core/solver/trs.hpp>
3434

35+
#include <memory>
3536

3637
#include <gtest/gtest.h>
3738

@@ -41,10 +42,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4142
#include <ginkgo/core/base/executor.hpp>
4243
#include <ginkgo/core/matrix/csr.hpp>
4344
#include <ginkgo/core/matrix/dense.hpp>
44-
#include <ginkgo/core/stop/combined.hpp>
45-
#include <ginkgo/core/stop/iteration.hpp>
46-
#include <ginkgo/core/stop/residual_norm_reduction.hpp>
47-
#include <ginkgo/core/stop/time.hpp>
4845

4946

5047
namespace {
@@ -83,20 +80,6 @@ class Trs : public ::testing::Test {
8380
}
8481
}
8582
}
86-
87-
static void assert_same_csr_matrices(const CsrMtx *m1, const CsrMtx *m2)
88-
{
89-
ASSERT_EQ(m1->get_size()[0], m2->get_size()[0]);
90-
ASSERT_EQ(m1->get_size()[1], m2->get_size()[1]);
91-
92-
for (gko::size_type i = 0; i < m1->get_size()[0] + 1; ++i) {
93-
EXPECT_EQ(m1->get_const_row_ptrs()[i], m2->get_const_row_ptrs()[i]);
94-
}
95-
for (gko::size_type i = 0; i < m1->get_num_stored_elements(); ++i) {
96-
EXPECT_EQ(m1->get_const_col_idxs()[i], m2->get_const_col_idxs()[i]);
97-
EXPECT_EQ(m1->get_const_values()[i], m2->get_const_values()[i]);
98-
}
99-
}
10083
};
10184

10285

0 commit comments

Comments
 (0)