Skip to content

Commit ab1910b

Browse files
committed
Add eigen backend
1 parent e009f4f commit ab1910b

File tree

8 files changed

+333
-78
lines changed

8 files changed

+333
-78
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "dependencies/eigen"]
2+
path = dependencies/eigen
3+
url = [email protected]:libeigen/eigen.git

CMakeLists.txt

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ project(simple-2d-constraint-solver)
77

88
set(CMAKE_CXX_STANDARD 11)
99

10+
option(USE_EIGEN ON)
11+
SET(USE_EIGEN 1)
12+
1013
# ========================================================
1114
# GTEST
12-
1315
include(FetchContent)
1416
FetchContent_Declare(
1517
googletest
@@ -26,11 +28,17 @@ set_property(TARGET gtest PROPERTY FOLDER "gtest")
2628
set_property(TARGET gtest_main PROPERTY FOLDER "gtest")
2729

2830
# ========================================================
31+
if(USE_EIGEN)
32+
include_directories(dependencies/eigen)
33+
add_definitions(-DATG_S2C_USE_EIGEN)
34+
endif(USE_EIGEN)
2935

3036
add_library(simple-2d-constraint-solver STATIC
37+
3138
# Source files
3239
src/rigid_body_system.cpp
33-
src/matrix.cpp
40+
src/matrix_custom.cpp
41+
src/matrix_eigen.cpp
3442
src/system_state.cpp
3543
src/utilities.cpp
3644
src/ode_solver.cpp
@@ -62,6 +70,8 @@ add_library(simple-2d-constraint-solver STATIC
6270
# Header files
6371
include/rigid_body_system.h
6472
include/matrix.h
73+
include/matrix_eigen.h
74+
include/matrix_custom.h
6575
include/utilities.h
6676
include/system_state.h
6777
include/ode_solver.h
@@ -93,10 +103,10 @@ add_library(simple-2d-constraint-solver STATIC
93103
)
94104

95105
# GTEST
96-
97106
enable_testing()
98107

99108
add_executable(simple-2d-constraint-solver-test
109+
100110
# Source files
101111
test/sanity_tests.cpp
102112
test/matrix_tests.cpp

dependencies/eigen

Submodule eigen added at e7248b2

include/matrix.h

Lines changed: 5 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,10 @@
11
#ifndef ATG_SIMPLE_2D_CONSTRAINT_SOLVER_MATRIX_H
22
#define ATG_SIMPLE_2D_CONSTRAINT_SOLVER_MATRIX_H
33

4-
#include <assert.h>
5-
6-
namespace atg_scs {
7-
class Matrix {
8-
public:
9-
Matrix();
10-
Matrix(int width, int height, double value = 0.0);
11-
~Matrix();
12-
13-
void initialize(int width, int height, double value);
14-
void initialize(int width, int height);
15-
void resize(int width, int height);
16-
void destroy();
17-
18-
void set(const double *data);
19-
20-
__forceinline void set(int column, int row, double value) {
21-
assert(column >= 0 && column < m_width);
22-
assert(row >= 0 && row < m_height);
23-
24-
m_matrix[row][column] = value;
25-
}
26-
27-
__forceinline void add(int column, int row, double value) {
28-
assert(column >= 0 && column < m_width);
29-
assert(row >= 0 && row < m_height);
30-
31-
m_matrix[row][column] += value;
32-
}
33-
34-
__forceinline double get(int column, int row) {
35-
assert(column >= 0 && column < m_width);
36-
assert(row >= 0 && row < m_height);
37-
38-
return m_matrix[row][column];
39-
}
40-
41-
void set(Matrix *reference);
42-
43-
void multiply(Matrix &b, Matrix *target);
44-
void componentMultiply(Matrix &b, Matrix *target);
45-
void transposeMultiply(Matrix &b, Matrix *target);
46-
void leftScale(Matrix &scale, Matrix *target);
47-
void rightScale(Matrix &scale, Matrix *target);
48-
void scale(double s, Matrix *target);
49-
void subtract(Matrix &b, Matrix *target);
50-
void add(Matrix &b, Matrix *target);
51-
void negate(Matrix *target);
52-
bool equals(Matrix &b, double err = 1e-6);
53-
double vectorMagnitudeSquared() const;
54-
double dot(Matrix &b) const;
55-
56-
void madd(Matrix &b, double s);
57-
void pmadd(Matrix &b, double s);
58-
59-
void transpose(Matrix *target);
60-
int getWidth() const { return m_width; }
61-
int getHeight() const { return m_height; }
62-
63-
__forceinline void fastRowSwap(int a, int b) {
64-
double *temp = m_matrix[a];
65-
m_matrix[a] = m_matrix[b];
66-
m_matrix[b] = temp;
67-
}
68-
69-
protected:
70-
double **m_matrix;
71-
double *m_data;
72-
int m_width;
73-
int m_height;
74-
int m_capacityWidth;
75-
int m_capacityHeight;
76-
};
77-
} /* namespace atg_scs */
4+
#ifdef ATG_S2C_USE_EIGEN
5+
#include "matrix_eigen.h"
6+
#else
7+
#include "matrix_custom.h"
8+
#endif
789

7910
#endif /* ATG_SIMPLE_2D_CONSTRAINT_SOLVER_MATRIX_H */

include/matrix_custom.h

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#ifndef ATG_SIMPLE_2D_CONSTRAINT_SOLVER_MATRIX_CUSTOM_H
2+
#define ATG_SIMPLE_2D_CONSTRAINT_SOLVER_MATRIX_CUSTOM_H
3+
4+
#include <assert.h>
5+
6+
namespace atg_scs {
7+
class Matrix {
8+
public:
9+
Matrix();
10+
Matrix(int width, int height, double value = 0.0);
11+
~Matrix();
12+
13+
void initialize(int width, int height, double value);
14+
void initialize(int width, int height);
15+
void resize(int width, int height);
16+
void destroy();
17+
18+
void set(const double *data);
19+
20+
__forceinline void set(int column, int row, double value) {
21+
assert(column >= 0 && column < m_width);
22+
assert(row >= 0 && row < m_height);
23+
24+
m_matrix[row][column] = value;
25+
}
26+
27+
__forceinline void add(int column, int row, double value) {
28+
assert(column >= 0 && column < m_width);
29+
assert(row >= 0 && row < m_height);
30+
31+
m_matrix[row][column] += value;
32+
}
33+
34+
__forceinline double get(int column, int row) {
35+
assert(column >= 0 && column < m_width);
36+
assert(row >= 0 && row < m_height);
37+
38+
return m_matrix[row][column];
39+
}
40+
41+
void set(Matrix *reference);
42+
43+
void multiply(Matrix &b, Matrix *target);
44+
void componentMultiply(Matrix &b, Matrix *target);
45+
void transposeMultiply(Matrix &b, Matrix *target);
46+
void leftScale(Matrix &scale, Matrix *target);
47+
void rightScale(Matrix &scale, Matrix *target);
48+
void scale(double s, Matrix *target);
49+
void subtract(Matrix &b, Matrix *target);
50+
void add(Matrix &b, Matrix *target);
51+
void negate(Matrix *target);
52+
bool equals(Matrix &b, double err = 1e-6);
53+
double vectorMagnitudeSquared() const;
54+
double dot(Matrix &b) const;
55+
56+
void madd(Matrix &b, double s);
57+
void pmadd(Matrix &b, double s);
58+
59+
void transpose(Matrix *target);
60+
int getWidth() const { return m_width; }
61+
int getHeight() const { return m_height; }
62+
63+
__forceinline void fastRowSwap(int a, int b) {
64+
double *temp = m_matrix[a];
65+
m_matrix[a] = m_matrix[b];
66+
m_matrix[b] = temp;
67+
}
68+
69+
protected:
70+
double **m_matrix;
71+
double *m_data;
72+
int m_width;
73+
int m_height;
74+
int m_capacityWidth;
75+
int m_capacityHeight;
76+
};
77+
} /* namespace atg_scs */
78+
79+
#endif /* ATG_SIMPLE_2D_CONSTRAINT_SOLVER_MATRIX_H */

include/matrix_eigen.h

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#ifndef ATG_SIMPLE_2D_CONSTRAINT_SOLVER_MATRIX_EIGEN_H
2+
#define ATG_SIMPLE_2D_CONSTRAINT_SOLVER_MATRIX_EIGEN_H
3+
4+
#include <assert.h>
5+
6+
#ifdef ATG_S2C_USE_EIGEN
7+
8+
#define EIGEN_NO_DEBUG
9+
#define EIGEN_NO_STATIC_ASSERT
10+
11+
#include "Eigen/Dense"
12+
13+
namespace atg_scs
14+
{
15+
class Matrix
16+
{
17+
typedef Eigen::MatrixXd MatrixType;
18+
19+
public:
20+
Matrix();
21+
Matrix(int width, int height, double value = 0.0);
22+
~Matrix();
23+
24+
void initialize(int width, int height, double value);
25+
void initialize(int width, int height);
26+
void resize(int width, int height);
27+
void destroy();
28+
29+
void set(const double *data);
30+
31+
__forceinline void set(int column, int row, double value)
32+
{
33+
m_matrix(row, column) = value;
34+
}
35+
36+
__forceinline void add(int column, int row, double value)
37+
{
38+
m_matrix(row, column) += value;
39+
}
40+
41+
__forceinline double get(int column, int row)
42+
{
43+
return m_matrix(row, column);
44+
}
45+
46+
void set(Matrix *reference);
47+
48+
void multiply(Matrix &b, Matrix *target);
49+
void componentMultiply(Matrix &b, Matrix *target);
50+
void transposeMultiply(Matrix &b, Matrix *target);
51+
void leftScale(Matrix &scale, Matrix *target);
52+
void rightScale(Matrix &scale, Matrix *target);
53+
void scale(double s, Matrix *target);
54+
void subtract(Matrix &b, Matrix *target);
55+
void add(Matrix &b, Matrix *target);
56+
void negate(Matrix *target);
57+
bool equals(Matrix &b, double err = 1e-6);
58+
double vectorMagnitudeSquared() const;
59+
double dot(Matrix &b) const;
60+
61+
void madd(Matrix &b, double s);
62+
void pmadd(Matrix &b, double s);
63+
64+
void transpose(Matrix *target);
65+
int getWidth() const { return m_matrix.cols(); }
66+
int getHeight() const { return m_matrix.rows(); }
67+
68+
__forceinline void fastRowSwap(int a, int b)
69+
{
70+
auto row_a = m_matrix.row(a);
71+
auto row_b = m_matrix.row(a);
72+
m_matrix.row(a) = row_b;
73+
m_matrix.row(b) = row_a;
74+
}
75+
76+
protected:
77+
MatrixType m_matrix;
78+
};
79+
} /* namespace atg_scs */
80+
81+
#endif
82+
83+
#endif /* ATG_SIMPLE_2D_CONSTRAINT_SOLVER_MATRIX_H */

src/matrix.cpp renamed to src/matrix_custom.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "../include/matrix.h"
1+
#include "../include/matrix_custom.h"
22

33
#include <algorithm>
44
#include <assert.h>

0 commit comments

Comments
 (0)