Skip to content

Commit b6566e6

Browse files
committed
Allow empty creation from all matrix formats convertible to CSR.
1 parent dfbacd6 commit b6566e6

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

core/solver/trs.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4545

4646
#include "core/solver/trs_kernels.hpp"
4747

48+
4849
namespace gko {
4950
namespace solver {
5051

@@ -68,11 +69,14 @@ void Trs<ValueType, IndexType>::generate(const LinOp *system_matrix,
6869
GKO_ASSERT_IS_SQUARE_MATRIX(system_matrix);
6970
// This is needed because it does not make sense to call the copy and
7071
// convert if the existing matrix (if not CSR) is empty.
71-
if (dynamic_cast<const CsrMatrix *>(system_matrix) == nullptr) {
72-
GKO_ASSERT_IS_NON_EMPTY_MATRIX(system_matrix);
73-
}
7472
const auto exec = this->get_executor();
75-
csr_system_matrix_ = copy_and_convert_to<CsrMatrix>(exec, system_matrix);
73+
auto temp_cast = dynamic_cast<const CsrMatrix *>(system_matrix);
74+
if (!system_matrix->get_size()) {
75+
csr_system_matrix_ = CsrMatrix::create(exec);
76+
} else {
77+
csr_system_matrix_ =
78+
copy_and_convert_to<CsrMatrix>(exec, system_matrix);
79+
}
7680
auto dense_b = as<const Vector>(b);
7781
exec->run(trs::make_generate(gko::lend(csr_system_matrix_), dense_b));
7882
}

include/ginkgo/core/base/exception.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,9 @@ class BadDimension : public Error {
281281
* @param line The source code line number where the error occurred
282282
* @param func The function name where the error occurred
283283
* @param op_name The name of the operator
284-
* @param op_num_rows The output dimension of the operator
285-
* @param op_num_cols The input dimension of the operator
286-
* @param clarification An additional message describing the error further
284+
* @param op_num_rows The row dimension of the operator
285+
* @param op_num_cols The column dimension of the operator
286+
* @param clarification An additional message further describing the error
287287
*/
288288
BadDimension(const std::string &file, int line, const std::string &func,
289289
const std::string &op_name, size_type op_num_rows,

reference/test/solver/trs.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3939
#include <gtest/gtest.h>
4040

4141

42-
#include <core/test/utils/assertions.hpp>
4342
#include <ginkgo/core/base/exception.hpp>
4443
#include <ginkgo/core/base/executor.hpp>
4544
#include <ginkgo/core/matrix/csr.hpp>
4645
#include <ginkgo/core/matrix/dense.hpp>
4746

4847

48+
#include "core/test/utils/assertions.hpp"
49+
50+
4951
namespace {
5052

5153

@@ -97,7 +99,7 @@ TEST_F(Trs, TrsFactoryCreatesCorrectSolver)
9799

98100
TEST_F(Trs, CanBeCopied)
99101
{
100-
auto copy = Solver::build().on(exec)->generate(CsrMtx::create(exec),
102+
auto copy = Solver::build().on(exec)->generate(Mtx::create(exec),
101103
Mtx::create(exec));
102104

103105
copy->copy_from(lend(trs_solver));
@@ -112,7 +114,7 @@ TEST_F(Trs, CanBeCopied)
112114

113115
TEST_F(Trs, CanBeMoved)
114116
{
115-
auto copy = trs_factory->generate(CsrMtx::create(exec), Mtx::create(exec));
117+
auto copy = trs_factory->generate(Mtx::create(exec), Mtx::create(exec));
116118

117119
copy->copy_from(std::move(trs_solver));
118120

0 commit comments

Comments
 (0)