# -*- coding: utf-8 -*-
# Copyright (C) 2008-2016 Martin Sandve Alnæs
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later
#
# Modified by Massimiliano Leoni, 2016
from ufl.core.expr import Expr
from ufl.core.compute_expr_hash import compute_expr_hash
from ufl.utils.formatting import camel2underscore
# Make UFL type coercion available under the as_ufl name
# as_ufl = Expr._ufl_coerce_
[docs]def attach_operators_from_hash_data(cls):
"""Class decorator to attach ``__hash__``, ``__eq__`` and ``__ne__`` implementations.
These are implemented in terms of a ``._ufl_hash_data()`` method on the class,
which should return a tuple or hashable and comparable data.
"""
assert hasattr(cls, "_ufl_hash_data_")
def __hash__(self):
"__hash__ implementation attached in attach_operators_from_hash_data"
return hash(self._ufl_hash_data_())
cls.__hash__ = __hash__
def __eq__(self, other):
"__eq__ implementation attached in attach_operators_from_hash_data"
return type(self) == type(other) and self._ufl_hash_data_() == other._ufl_hash_data_()
cls.__eq__ = __eq__
def __ne__(self, other):
"__ne__ implementation attached in attach_operators_from_hash_data"
return type(self) != type(other) or self._ufl_hash_data_() != other._ufl_hash_data_()
cls.__ne__ = __ne__
return cls
[docs]def get_base_attr(cls, name):
"Return first non-``None`` attribute of given name among base classes."
for base in cls.mro():
if hasattr(base, name):
attr = getattr(base, name)
if attr is not None:
return attr
return None
[docs]def set_trait(cls, basename, value, inherit=False):
"""Assign a trait to class with namespacing ``_ufl_basename_`` applied.
If trait value is ``None``, optionally inherit it from the closest base class that has it.
"""
name = "_ufl_" + basename + "_"
if value is None and inherit:
value = get_base_attr(cls, name)
setattr(cls, name, value)
[docs]def determine_num_ops(cls, num_ops, unop, binop, rbinop):
"Determine number of operands for this type."
# Try to determine num_ops from other traits or baseclass, or
# require num_ops to be set for non-abstract classes if it cannot
# be determined automatically
if num_ops is not None:
return num_ops
elif cls._ufl_is_terminal_:
return 0
elif unop:
return 1
elif binop or rbinop:
return 2
else:
# Determine from base class
return get_base_attr(cls, "_ufl_num_ops_")
[docs]def check_is_terminal_consistency(cls):
"Check for consistency in ``is_terminal`` trait among superclasses."
if cls._ufl_is_terminal_ is None:
msg = ("Class {0.__name__} has not specified the is_terminal trait." +
" Did you forget to inherit from Terminal or Operator?")
raise TypeError(msg.format(cls))
base_is_terminal = get_base_attr(cls, "_ufl_is_terminal_")
if base_is_terminal is not None and cls._ufl_is_terminal_ != base_is_terminal:
msg = ("Conflicting given and automatic 'is_terminal' trait for class {0.__name__}." +
" Check if you meant to inherit from Terminal or Operator.")
raise TypeError(msg.format(cls))
[docs]def check_abstract_trait_consistency(cls):
"Check that the first base classes up to ``Expr`` are other UFL types."
for base in cls.mro():
if base is Expr:
break
if not issubclass(base, Expr) and base._ufl_is_abstract_:
msg = ("Base class {0.__name__} of class {1.__name__} "
"is not an abstract subclass of {2.__name__}.")
raise TypeError(msg.format(base, cls, Expr))
[docs]def check_has_slots(cls):
"""Check if type has ``__slots__`` unless it is marked as exception with
``_ufl_noslots_``."""
if "_ufl_noslots_" in cls.__dict__:
return
if "__slots__" not in cls.__dict__:
msg = ("Class {0.__name__} is missing the __slots__ "
"attribute and is not marked with _ufl_noslots_.")
raise TypeError(msg.format(cls))
# Check base classes for __slots__ as well, skipping object which is the last one
for base in cls.mro()[1:-1]:
if "__slots__" not in base.__dict__:
msg = ("Class {0.__name__} is has a base class "
"{1.__name__} with __slots__ missing.")
raise TypeError(msg.format(cls, base))
[docs]def check_type_traits_consistency(cls):
"Execute a variety of consistency checks on the ufl type traits."
# Check for consistency in global type collection sizes
assert Expr._ufl_num_typecodes_ == len(Expr._ufl_all_handler_names_)
assert Expr._ufl_num_typecodes_ == len(Expr._ufl_all_classes_)
assert Expr._ufl_num_typecodes_ == len(Expr._ufl_obj_init_counts_)
assert Expr._ufl_num_typecodes_ == len(Expr._ufl_obj_del_counts_)
# Check that non-abstract types always specify num_ops
if not cls._ufl_is_abstract_:
if cls._ufl_num_ops_ is None:
msg = "Class {0.__name__} has not specified num_ops."
raise TypeError(msg.format(cls))
# Check for non-abstract types that num_ops has the right type
if not cls._ufl_is_abstract_:
if not (isinstance(cls._ufl_num_ops_, int) or cls._ufl_num_ops_ == "varying"):
msg = 'Class {0.__name__} has invalid num_ops value {1} (integer or "varying").'
raise TypeError(msg.format(cls, cls._ufl_num_ops_))
# Check that num_ops is not set to nonzero for a terminal
if cls._ufl_is_terminal_ and cls._ufl_num_ops_ != 0:
msg = "Class {0.__name__} has num_ops > 0 but is terminal."
raise TypeError(msg.format(cls))
# Check that a non-scalar type doesn't have a scalar base class.
if not cls._ufl_is_scalar_:
if get_base_attr(cls, "_ufl_is_scalar_"):
msg = "Non-scalar class {0.__name__} is has a scalar base class."
raise TypeError(msg.format(cls))
[docs]def check_implements_required_methods(cls):
"""Check if type implements the required methods."""
if not cls._ufl_is_abstract_:
for attr in Expr._ufl_required_methods_:
if not hasattr(cls, attr):
msg = "Class {0.__name__} has no {1} method."
raise TypeError(msg.format(cls, attr))
elif not callable(getattr(cls, attr)):
msg = "Required method {1} of class {0.__name__} is not callable."
raise TypeError(msg.format(cls, attr))
[docs]def check_implements_required_properties(cls):
"Check if type implements the required properties."
if not cls._ufl_is_abstract_:
for attr in Expr._ufl_required_properties_:
if not hasattr(cls, attr):
msg = "Class {0.__name__} has no {1} property."
raise TypeError(msg.format(cls, attr))
elif callable(getattr(cls, attr)):
msg = "Required property {1} of class {0.__name__} is a callable method."
raise TypeError(msg.format(cls, attr))
[docs]def attach_implementations_of_indexing_interface(cls,
inherit_shape_from_operand,
inherit_indices_from_operand):
# Scalar or index-free? Then we can simplify the implementation of
# tensor properties by attaching them here.
if cls._ufl_is_scalar_:
cls.ufl_shape = ()
if cls._ufl_is_scalar_ or cls._ufl_is_index_free_:
cls.ufl_free_indices = ()
cls.ufl_index_dimensions = ()
# Automate direct inheriting of shape and indices from one of the
# operands. This simplifies refactoring because a lot of types do
# this.
if inherit_shape_from_operand is not None:
def _inherited_ufl_shape(self):
return self.ufl_operands[inherit_shape_from_operand].ufl_shape
cls.ufl_shape = property(_inherited_ufl_shape)
if inherit_indices_from_operand is not None:
def _inherited_ufl_free_indices(self):
return self.ufl_operands[inherit_indices_from_operand].ufl_free_indices
def _inherited_ufl_index_dimensions(self):
return self.ufl_operands[inherit_indices_from_operand].ufl_index_dimensions
cls.ufl_free_indices = property(_inherited_ufl_free_indices)
cls.ufl_index_dimensions = property(_inherited_ufl_index_dimensions)
[docs]def update_global_expr_attributes(cls):
"Update global ``Expr`` attributes, mainly by adding *cls* to global collections of ufl types."
Expr._ufl_all_classes_.append(cls)
Expr._ufl_all_handler_names_.add(cls._ufl_handler_name_)
if cls._ufl_is_terminal_modifier_:
Expr._ufl_terminal_modifiers_.append(cls)
# Add to collection of language operators. This collection is
# used later to populate the official language namespace.
# TODO: I don't think this functionality is fully completed, check
# it out later.
if not cls._ufl_is_abstract_ and hasattr(cls, "_ufl_function_"):
cls._ufl_function_.__func__.__doc__ = cls.__doc__
Expr._ufl_language_operators_[cls._ufl_handler_name_] = cls._ufl_function_
# Append space for counting object creation and destriction of
# this this type.
Expr._ufl_obj_init_counts_.append(0)
Expr._ufl_obj_del_counts_.append(0)
[docs]def ufl_type(is_abstract=False,
is_terminal=None,
is_scalar=False,
is_index_free=False,
is_shaping=False,
is_literal=False,
is_terminal_modifier=False,
is_in_reference_frame=False,
is_restriction=False,
is_evaluation=False,
is_differential=None,
use_default_hash=True,
num_ops=None,
inherit_shape_from_operand=None,
inherit_indices_from_operand=None,
wraps_type=None,
unop=None,
binop=None,
rbinop=None):
"""This decorator is to be applied to every subclass in the UFL ``Expr`` hierarchy.
This decorator contains a number of checks that are
intended to enforce uniform behaviour across UFL types.
The rationale behind the checks and the meaning of the
optional arguments should be sufficiently documented
in the source code below.
"""
def _ufl_type_decorator_(cls):
# Determine integer typecode by oncrementally counting all types
typecode = Expr._ufl_num_typecodes_
Expr._ufl_num_typecodes_ += 1
# Determine handler name by a mapping from "TypeName" to "type_name"
handler_name = camel2underscore(cls.__name__)
# is_scalar implies is_index_free
if is_scalar:
_is_index_free = True
else:
_is_index_free = is_index_free
# Store type traits
cls._ufl_class_ = cls
set_trait(cls, "handler_name", handler_name, inherit=False)
set_trait(cls, "typecode", typecode, inherit=False)
set_trait(cls, "is_abstract", is_abstract, inherit=False)
set_trait(cls, "is_terminal", is_terminal, inherit=True)
set_trait(cls, "is_literal", is_literal, inherit=True)
set_trait(cls, "is_terminal_modifier", is_terminal_modifier,
inherit=True)
set_trait(cls, "is_shaping", is_shaping, inherit=True)
set_trait(cls, "is_in_reference_frame", is_in_reference_frame,
inherit=True)
set_trait(cls, "is_restriction", is_restriction, inherit=True)
set_trait(cls, "is_evaluation", is_evaluation, inherit=True)
set_trait(cls, "is_differential", is_differential, inherit=True)
set_trait(cls, "is_scalar", is_scalar, inherit=True)
set_trait(cls, "is_index_free", _is_index_free, inherit=True)
# Number of operands can often be determined automatically
_num_ops = determine_num_ops(cls, num_ops, unop, binop, rbinop)
set_trait(cls, "num_ops", _num_ops)
# Attach builtin type wrappers to Expr
"""# These are currently handled in the as_ufl implementation in constantvalue.py
if wraps_type is not None:
if not isinstance(wraps_type, type):
msg = "Expecting a type, not a {0.__name__} for the wraps_type argument in definition of {1.__name__}."
raise TypeError(msg.format(type(wraps_type), cls))
def _ufl_from_type_(value):
return cls(value)
from_type_name = "_ufl_from_{0}_".format(wraps_type.__name__)
setattr(Expr, from_type_name, staticmethod(_ufl_from_type_))
"""
# Attach special function to Expr.
# Avoids the circular dependency problem of making
# Expr.__foo__ return a Foo that is a subclass of Expr.
"""# These are currently attached in exproperators.py
if unop:
def _ufl_expr_unop_(self):
return cls(self)
setattr(Expr, unop, _ufl_expr_unop_)
if binop:
def _ufl_expr_binop_(self, other):
try:
other = Expr._ufl_coerce_(other)
except:
return NotImplemented
return cls(self, other)
setattr(Expr, binop, _ufl_expr_binop_)
if rbinop:
def _ufl_expr_rbinop_(self, other):
try:
other = Expr._ufl_coerce_(other)
except:
return NotImplemented
return cls(other, self)
setattr(Expr, rbinop, _ufl_expr_rbinop_)
"""
# Make sure every non-abstract class has its own __hash__ and
# __eq__. Python 3 will set __hash__ to None if cls has
# __eq__, but we've implemented it in a separate function and
# want to inherit/use that for all types. Allow overriding by
# setting use_default_hash=False.
if use_default_hash:
cls.__hash__ = compute_expr_hash
# NB! This function conditionally adds some methods to the
# class! This approach significantly reduces the amount of
# small functions to implement across all the types but of
# course it's a bit more opaque.
attach_implementations_of_indexing_interface(cls,
inherit_shape_from_operand,
inherit_indices_from_operand)
# Update Expr
update_global_expr_attributes(cls)
# Apply a range of consistency checks to detect bugs in type
# implementations that Python doesn't check for us, including
# some checks that a static language compiler would do for us
check_abstract_trait_consistency(cls)
check_has_slots(cls)
check_is_terminal_consistency(cls)
check_implements_required_methods(cls)
check_implements_required_properties(cls)
check_type_traits_consistency(cls)
return cls
return _ufl_type_decorator_