Skip to content

Commit 60bb28e

Browse files
committed
Test multigrid for HCT/HCT-red
1 parent f00196b commit 60bb28e

File tree

2 files changed

+50
-14
lines changed

2 files changed

+50
-14
lines changed

firedrake/mg/embedded.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,15 @@
1313

1414

1515
native_families = frozenset(["Lagrange", "Discontinuous Lagrange", "Real", "Q", "DQ"])
16-
non_native_integral_variants = frozenset(["integral", "fdm"])
16+
alfeld_families = frozenset(["Hsieh-Clough-Tocher", "Reduced Hsieh-Clough-Tocher", "Johnson-Mercier"])
17+
non_native_variants = frozenset(["integral", "fdm"])
18+
19+
20+
def get_embedding_element(element):
21+
dg_element = get_embedding_dg_element(element)
22+
if element.family() in alfeld_families:
23+
dg_element = dg_element.reconstruct(variant="alfeld")
24+
return dg_element
1725

1826

1927
class Op(IntEnum):
@@ -28,7 +36,7 @@ class Cache(object):
2836
2937
:arg element: The element to use for the caching."""
3038
def __init__(self, element):
31-
self.embedding_element = get_embedding_dg_element(element)
39+
self.embedding_element = get_embedding_element(element)
3240
self._dat_versions = {}
3341
self._V_DG_mass = {}
3442
self._DG_inv_mass = {}
@@ -59,7 +67,7 @@ def is_native(self, element):
5967
return True
6068
if isinstance(element.cell, ufl.TensorProductCell) and len(element.sub_elements) > 0:
6169
return reduce(and_, map(self.is_native, element.sub_elements))
62-
return element.family() in native_families and not element.variant() in non_native_integral_variants
70+
return element.family() in native_families and not element.variant() in non_native_variants
6371

6472
def _native_transfer(self, element, op):
6573
try:

tests/macro/test_macro_multigrid.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,20 @@ def test_macro_grid_transfer(hierarchy, space, degrees, variant, transfer_type):
127127
run_prolongation(hierarchy, space, degrees, variant)
128128

129129

130+
mg_params = {
131+
"mat_type": "matfree",
132+
"ksp_type": "cg",
133+
"ksp_monitor": None,
134+
"pc_type": "mg",
135+
"mg_levels_ksp_type": "chebyshev",
136+
"mg_levels_pc_type": "jacobi",
137+
"mg_coarse_pc_type": "python",
138+
"mg_coarse_pc_python_type": "firedrake.AssembledPC",
139+
}
140+
141+
130142
@pytest.mark.parametrize("degree", (1,))
131-
def test_macro_multigrid(hierarchy, degree, variant):
143+
def test_macro_multigrid_poisson(hierarchy, degree, variant):
132144
mesh = hierarchy[-1]
133145
V = FunctionSpace(mesh, "CG", degree, variant=variant)
134146
u = TrialFunction(V)
@@ -138,19 +150,35 @@ def test_macro_multigrid(hierarchy, degree, variant):
138150
bcs = [DirichletBC(V, 0, "on_boundary")]
139151

140152
uh = Function(V)
141-
sp = {
142-
"mat_type": "matfree",
143-
"ksp_type": "cg",
144-
"pc_type": "mg",
145-
"mg_levels_ksp_type": "chebyshev",
146-
"mg_levels_pc_type": "jacobi",
147-
"mg_coarse_pc_type": "python",
148-
"mg_coarse_pc_python_type": "firedrake.AssembledPC",
149-
}
150153
problem = LinearVariationalProblem(a, L, uh, bcs=bcs)
151-
solver = LinearVariationalSolver(problem, solver_parameters=sp)
154+
solver = LinearVariationalSolver(problem, solver_parameters=mg_params)
152155
solver.solve()
153156
expected = 10
154157
if mesh.geometric_dimension() == 3 and variant == "alfeld":
155158
expected = 14
156159
assert solver.snes.ksp.getIterationNumber() <= expected
160+
161+
162+
@pytest.fixture()
163+
def square_hierarchy():
164+
refine = 4
165+
base = UnitSquareMesh(3, 3)
166+
return MeshHierarchy(base, refine)
167+
168+
169+
@pytest.mark.parametrize("family", ("HCT-red", "HCT"))
170+
def test_macro_multigrid_biharmonic(square_hierarchy, family):
171+
mesh = square_hierarchy[-1]
172+
V = FunctionSpace(mesh, family, 3)
173+
u = TrialFunction(V)
174+
v = TestFunction(V)
175+
a = inner(div(grad(u)), div(grad(v))) * dx
176+
L = inner(Constant(1), v) * dx
177+
bcs = [DirichletBC(V, 0, "on_boundary")]
178+
179+
uh = Function(V)
180+
problem = LinearVariationalProblem(a, L, uh, bcs=bcs)
181+
solver = LinearVariationalSolver(problem, solver_parameters=mg_params)
182+
solver.solve()
183+
expected = 16
184+
assert solver.snes.ksp.getIterationNumber() <= expected

0 commit comments

Comments
 (0)