"""
Adaptive Poisson solver using a residual-based energy-norm error
estimator

  eta_h**2 = sum_T eta_T**2

with

  eta_T**2 = h_T**2 ||R_T||_T**2 + 1/2 h_T ||R_dT||_dT**2

where

  R_T =  f + div grad u_h
  R_dT = jump(grad u_h * n)

and a maximal marking strategy (refining those cells for
which the error indicator is greater than a certain fraction of the
largest error indicator)

Adapted by Douglas Arnold from code of Marie Rognes
"""

from dolfin import *

# Define boundary for Dirichlet conditions
class Boundary(SubDomain):
    def inside(self, x, on_boundary):
        return on_boundary

# Error tolerance
tolerance = 0.1
max_iterations = 20

# Refinement fraction: refine elements with error indicator > fraction * max error indicator
fraction = 0.7

# Create initial mesh
mesh = Mesh("l-shape-mesh.xml")
mesh.order()

# SOLVE - ESTIMATE - MARK - REFINE loop
for i in range(max_iterations):

    # *** SOLVE step
    # Define variational problem and boundary condition
    V = FunctionSpace(mesh, "CG", 1)
    u = TrialFunction(V)
    v = TestFunction(V)
    f = Constant(mesh, 1.0)
    a = inner(grad(u), grad(v))*dx
    L = f*v*dx
    u0 = Constant(mesh, 0.0)
    bc = DirichletBC(V, u0, Boundary())
    # Solve variational problem on current mesh
    u_h = VariationalProblem(a, L, bc).solve()

    # *** ESTIMATE step
    # Define cell and edge residuals
    R_T = f + div(grad(u_h))
    # get the normal to the cells (in other versions of dolfin this uses FacetNormal)
    n = V.cell().n
    R_dT = dot(grad(u_h), n)
    # Will use space of constants to localize indicator form
    Constants = FunctionSpace(mesh, "DG", 0)
    w = TestFunction(Constants)
    h = CellSize(mesh)
    # Assemble squared error indicators, eta_T^2, and error indicators, eta_T
    eta2 = assemble(h**2*R_T**2*w*dx + avg(h)*avg(R_dT)**2*2*avg(w)*dS)
    eta = [sqrt(eta2_T) for eta2_T in eta2]
    # Compute global error estimate and stop if less than tolerance
    error_estimate = sqrt(sum(eta2_T for eta2_T in eta2))
    if error_estimate < tolerance:
        print "\nTolerance achieved.  Exiting."
        break

    # *** MARK step
    # Mark cells for refinement based on maximal marking strategy
    eta_max = max(eta)
    cell_markers = MeshFunction("bool", mesh, mesh.topology().dim())
    for c in cells(mesh):
        cell_markers[c] = eta[c.index()] > (fraction*eta_max)

    # *** REFINE step
    mesh.refine(cell_markers)
    plot(mesh)
    print "Mesh %g: %d triangles, %d vertices, hmax = %g, hmin = %g, errest = %g" % (i, mesh.num_cells(), mesh.num_vertices(), mesh.hmax(), mesh.hmin(), error_estimate)

interactive()
