# -*- coding: utf-8 -*-
"""A collection of utility algorithms for printing
of UFL objects in the DOT graph visualization language,
mostly intended for debugging purposers."""
# Copyright (C) 2008-2016 Martin Sandve Alnæs
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later
from ufl.log import error
from ufl.core.expr import Expr
from ufl.form import Form
from ufl.variable import Variable
from ufl.algorithms.multifunction import MultiFunction
[docs]class ReprLabeller(MultiFunction):
def __init__(self):
MultiFunction.__init__(self)
[docs] def terminal(self, e):
return repr(e)
[docs] def operator(self, e):
return e._ufl_class_.__name__.split(".")[-1]
[docs]class CompactLabeller(ReprLabeller):
def __init__(self, function_mapping=None):
ReprLabeller.__init__(self)
self._function_mapping = function_mapping
# Terminals:
[docs] def scalar_value(self, e):
return repr(e._value)
[docs] def zero(self, e):
return "0"
[docs] def identity(self, e):
return "Id"
[docs] def multi_index(self, e):
return str(e)
[docs] def geometric_quantity(self, e):
return str(e)
# Operators:
[docs] def sum(self, e):
return "+"
[docs] def product(self, e):
return "*"
[docs] def division(self, e):
return "/"
[docs] def power(self, e):
return "**"
[docs] def math_function(self, e):
return e._name
[docs] def index_sum(self, e):
return "∑"
[docs] def indexed(self, e):
return "[]"
[docs] def component_tensor(self, e): # TODO: Understandable short notation for this?
return "]["
[docs] def negative_restricted(self, e):
return "[-]"
[docs] def positive_restricted(self, e):
return "[+]"
[docs] def cell_avg(self, e): # TODO: Understandable short notation for this?
return "_K_"
[docs] def facet_avg(self, e): # TODO: Understandable short notation for this?
return "_F_"
[docs] def conj(self, e):
return "conj"
[docs] def real(self, e):
return "real"
[docs] def imag(self, e):
return "imag"
[docs] def inner(self, e):
return "inner"
[docs] def dot(self, e):
return "dot"
[docs] def outer(self, e):
return "outer"
[docs] def transposed(self, e):
return "transp."
[docs] def determinant(self, e):
return "det"
[docs] def trace(self, e):
return "tr"
[docs] def dev(self, e):
return "dev"
[docs] def skew(self, e):
return "skew"
[docs] def grad(self, e):
return "grad"
[docs] def div(self, e):
return "div"
[docs] def curl(self, e):
return "curl"
[docs] def nabla_grad(self, e):
return "nabla_grad"
[docs] def nabla_div(self, e):
return "nabla_div"
[docs] def diff(self, e):
return "diff"
# Make this class like the ones above to use fancy math symbol labels
class2label = {"IndexSum": "∑",
"Sum": "∑",
"Product": "∏",
"Division": "/",
"Inner": ":",
"Dot": "⋅",
"Outer": "⊗",
"Grad": "grad",
"Div": "div",
"NablaGrad": "∇⊗",
"NablaDiv": "∇⋅",
"Curl": "∇×", }
[docs]class FancyLabeller(CompactLabeller):
pass
[docs]def build_entities(e, nodes, edges, nodeoffset, prefix="", labeller=None):
# TODO: Maybe this can be cleaner written using the graph
# utilities.
# TODO: To collapse equal nodes with different objects, do not use
# id as key. Make this an option?
# Cutoff if we have handled e before
if id(e) in nodes:
return
if labeller is None:
labeller = ReprLabeller()
# Special-case Variable instances
if isinstance(e, Variable): # FIXME: Is this really necessary?
ops = (e._expression,)
label = "variable %d" % e._label._count
else:
ops = e.ufl_operands
label = labeller(e)
# Create node for parent e
nodename = "%sn%04d" % (prefix, len(nodes) + nodeoffset)
nodes[id(e)] = (nodename, label)
# Handle all children recursively
n = len(ops)
if n == 2:
oplabels = ["L", "R"]
elif n > 2:
oplabels = ["op%d" % i for i in range(n)]
else:
oplabels = [None] * n
for i, o in enumerate(ops):
# Handle entire subtree for expression o
build_entities(o, nodes, edges, nodeoffset, prefix, labeller)
# Add edge between e and child node o
edges.append((id(e), id(o), oplabels[i]))
integralgraphformat = """ %(node)s [label="%(label)s"]
form_%(formname)s -> %(node)s ;
%(node)s -> %(root)s ;
%(entities)s"""
exprgraphformat = """ digraph ufl_expression
{
%s
}"""
[docs]def ufl2dot(expression, formname="a", nodeoffset=0, begin=True, end=True,
labeling="repr", object_names=None):
if labeling == "repr":
labeller = ReprLabeller()
elif labeling == "compact":
labeller = CompactLabeller(object_names or {})
print(object_names)
if isinstance(expression, Form):
form = expression
subgraphs = []
k = 0
for itg in form.integrals():
prefix = "itg%d_" % k
integralkey = "%s%s" % (itg.integral_type(), itg.subdomain_id())
integrallabel = "%s %s" % (itg.integral_type().capitalize().replace("_", " "), "integral")
integrallabel += " %s" % (itg.subdomain_id(),)
integrand = itg.integrand()
nodes = {}
edges = []
build_entities(integrand, nodes, edges, nodeoffset, prefix,
labeller)
rootnode = nodes[id(integrand)][0]
entitylist = format_entities(nodes, edges)
integralnode = "%s_%s" % (formname, integralkey)
subgraphs.append(integralgraphformat % {
'node': integralnode,
'label': integrallabel,
'formname': formname,
'root': rootnode,
'entities': entitylist, })
nodeoffset += len(nodes)
s = ""
if begin:
s += 'digraph ufl_form\n{\n node [shape="box"] ;\n'
s += ' form_%s [label="Form %s"] ;' % (formname, formname)
s += "\n".join(subgraphs)
if end:
s += "\n}"
elif isinstance(expression, Expr):
nodes = {}
edges = []
build_entities(expression, nodes, edges, nodeoffset, '', labeller)
entitylist = format_entities(nodes, edges)
s = exprgraphformat % entitylist
nodeoffset += len(nodes)
else:
error("Invalid object type %s" % type(expression))
return s, nodeoffset