# -*- coding: utf-8 -*-
"This module defines classes representing constant values."
# Copyright (C) 2008-2016 Martin Sandve Alnæs
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later
#
# Modified by Anders Logg, 2011.
# Modified by Massimiliano Leoni, 2016.
from math import atan2
from ufl.log import error, UFLValueError
from ufl.core.expr import Expr
from ufl.core.terminal import Terminal
from ufl.core.multiindex import Index, FixedIndex
from ufl.core.ufl_type import ufl_type
# --- Helper functions imported here for compatibility---
from ufl.checks import is_python_scalar, is_ufl_scalar, is_true_ufl_scalar # noqa: F401
# Precision for float formatting
precision = None
# --- Base classes for constant types ---
[docs]@ufl_type(is_abstract=True)
class ConstantValue(Terminal):
__slots__ = ()
def __init__(self):
Terminal.__init__(self)
[docs] def is_cellwise_constant(self):
"Return whether this expression is spatially constant over each cell."
return True
[docs] def ufl_domains(self):
"Return tuple of domains related to this terminal object."
return ()
# --- Class for representing zero tensors of different shapes ---
# TODO: Add geometric dimension/domain and Argument dependencies to
# Zero?
[docs]@ufl_type(is_literal=True)
class Zero(ConstantValue):
"UFL literal type: Representation of a zero valued expression."
__slots__ = ("ufl_shape", "ufl_free_indices", "ufl_index_dimensions")
_cache = {}
def __getnewargs__(self):
return (self.ufl_shape, self.ufl_free_indices, self.ufl_index_dimensions)
def __new__(cls, shape=(), free_indices=(), index_dimensions=None):
if free_indices:
self = ConstantValue.__new__(cls)
else:
self = Zero._cache.get(shape)
if self is not None:
return self
self = ConstantValue.__new__(cls)
Zero._cache[shape] = self
self._init(shape, free_indices, index_dimensions)
return self
def __init__(self, shape=(), free_indices=(), index_dimensions=None):
pass
def _init(self, shape=(), free_indices=(), index_dimensions=None):
ConstantValue.__init__(self)
if not all(isinstance(i, int) for i in shape):
error("Expecting tuple of int.")
if not isinstance(free_indices, tuple):
error("Expecting tuple for free_indices, not %s" % str(free_indices))
self.ufl_shape = shape
if not free_indices:
self.ufl_free_indices = ()
self.ufl_index_dimensions = ()
elif all(isinstance(i, Index) for i in free_indices): # Handle old input format
if not (isinstance(index_dimensions, dict) and
all(isinstance(i, Index) for i in index_dimensions.keys())):
error("Expecting tuple of index dimensions, not %s" % str(index_dimensions))
self.ufl_free_indices = tuple(sorted(i.count() for i in free_indices))
self.ufl_index_dimensions = tuple(d for i, d in sorted(index_dimensions.items(), key=lambda x: x[0].count()))
else: # Handle new input format
if not all(isinstance(i, int) for i in free_indices):
error("Expecting tuple of integer free index ids, not %s" % str(free_indices))
if not (isinstance(index_dimensions, tuple) and
all(isinstance(i, int) for i in index_dimensions)):
error("Expecting tuple of integer index dimensions, not %s" % str(index_dimensions))
# Assuming sorted now to avoid this cost, enable for debugging:
# if sorted(free_indices) != list(free_indices):
# error("Expecting sorted input. Remove this check later for efficiency.")
self.ufl_free_indices = free_indices
self.ufl_index_dimensions = index_dimensions
[docs] def evaluate(self, x, mapping, component, index_values):
return 0.0
def __str__(self):
if self.ufl_shape == () and self.ufl_free_indices == ():
return "0"
if self.ufl_free_indices == ():
return "0 (shape %s)" % (self.ufl_shape,)
if self.ufl_shape == ():
return "0 (index labels %s)" % (self.ufl_free_indices,)
return "0 (shape %s, index labels %s)" % (self.ufl_shape, self.ufl_free_indices)
def __repr__(self):
r = "Zero(%s, %s, %s)" % (
repr(self.ufl_shape),
repr(self.ufl_free_indices),
repr(self.ufl_index_dimensions))
return r
def __eq__(self, other):
if isinstance(other, Zero):
if self is other:
return True
return (self.ufl_shape == other.ufl_shape and
self.ufl_free_indices == other.ufl_free_indices and
self.ufl_index_dimensions == other.ufl_index_dimensions)
elif isinstance(other, (int, float)):
return other == 0
else:
return False
def __neg__(self):
return self
def __abs__(self):
return self
def __bool__(self):
return False
__nonzero__ = __bool__
def __float__(self):
return 0.0
def __int__(self):
return 0
def __complex__(self):
return 0 + 0j
[docs]def zero(*shape):
"UFL literal constant: Return a zero tensor with the given shape."
if len(shape) == 1 and isinstance(shape[0], tuple):
return Zero(shape[0])
else:
return Zero(shape)
# --- Scalar value types ---
[docs]@ufl_type(is_abstract=True, is_scalar=True)
class ScalarValue(ConstantValue):
"A constant scalar value."
__slots__ = ("_value",)
def __init__(self, value):
ConstantValue.__init__(self)
self._value = value
[docs] def value(self):
return self._value
[docs] def evaluate(self, x, mapping, component, index_values):
return self._value
def __eq__(self, other):
"""This is implemented to allow comparison with python scalars.
Note that this will make IntValue(1) != FloatValue(1.0),
but ufl-python comparisons like
IntValue(1) == 1.0
FloatValue(1.0) == 1
can still succeed. These will however not have the same
hash value and therefore not collide in a dict.
"""
if isinstance(other, self._ufl_class_):
return self._value == other._value
elif isinstance(other, (int, float)):
# FIXME: Disallow this, require explicit 'expr ==
# IntValue(3)' instead to avoid ambiguities!
return other == self._value
else:
return False
def __str__(self):
return str(self._value)
def __float__(self):
return float(self._value)
def __int__(self):
return int(self._value)
def __complex__(self):
return complex(self._value)
def __neg__(self):
return type(self)(-self._value)
def __abs__(self):
return type(self)(abs(self._value))
[docs] def real(self):
return self._value.real
[docs] def imag(self):
return self._value.imag
[docs]@ufl_type(wraps_type=complex, is_literal=True)
class ComplexValue(ScalarValue):
"UFL literal type: Representation of a constant, complex scalar"
__slots__ = ()
def __getnewargs__(self):
return (self._value,)
def __new__(cls, value):
if value.imag == 0:
if value.real == 0:
return Zero()
else:
return FloatValue(value.real)
else:
return ConstantValue.__new__(cls)
def __init__(self, value):
ScalarValue.__init__(self, complex(value))
[docs] def modulus(self):
return abs(self.value())
[docs] def argument(self):
return atan2(self.value().imag, self.value().real)
def __repr__(self):
r = "%s(%s)" % (type(self).__name__, repr(self._value))
return r
def __float__(self):
raise TypeError("ComplexValues cannot be cast to float")
def __int__(self):
raise TypeError("ComplexValues cannot be cast to int")
[docs]@ufl_type(is_abstract=True, is_scalar=True)
class RealValue(ScalarValue):
"Abstract class used to differentiate real values from complex ones"
__slots__ = ()
[docs]@ufl_type(wraps_type=float, is_literal=True)
class FloatValue(RealValue):
"UFL literal type: Representation of a constant scalar floating point value."
__slots__ = ()
def __getnewargs__(self):
return (self._value,)
def __new__(cls, value):
if value == 0.0:
# Always represent zero with Zero
return Zero()
return ConstantValue.__new__(cls)
def __init__(self, value):
super(FloatValue, self).__init__(float(value))
def __repr__(self):
r = "%s(%s)" % (type(self).__name__, format_float(self._value))
return r
[docs]@ufl_type(wraps_type=int, is_literal=True)
class IntValue(RealValue):
"UFL literal type: Representation of a constant scalar integer value."
__slots__ = ()
_cache = {}
def __getnewargs__(self):
return (self._value,)
def __new__(cls, value):
if value == 0:
# Always represent zero with Zero
return Zero()
elif abs(value) < 100:
# Small numbers are cached to reduce memory usage
# (fly-weight pattern)
self = IntValue._cache.get(value)
if self is not None:
return self
self = RealValue.__new__(cls)
IntValue._cache[value] = self
else:
self = RealValue.__new__(cls)
self._init(value)
return self
def _init(self, value):
super(IntValue, self).__init__(int(value))
def __init__(self, value):
pass
def __repr__(self):
r = "%s(%s)" % (type(self).__name__, repr(self._value))
return r
# --- Identity matrix ---
[docs]@ufl_type()
class Identity(ConstantValue):
"UFL literal type: Representation of an identity matrix."
__slots__ = ("_dim", "ufl_shape")
def __init__(self, dim):
ConstantValue.__init__(self)
self._dim = dim
self.ufl_shape = (dim, dim)
[docs] def evaluate(self, x, mapping, component, index_values):
"Evaluates the identity matrix on the given components."
a, b = component
return 1 if a == b else 0
def __getitem__(self, key):
if len(key) != 2:
error("Size mismatch for Identity.")
if all(isinstance(k, (int, FixedIndex)) for k in key):
return IntValue(1) if (int(key[0]) == int(key[1])) else Zero()
return Expr.__getitem__(self, key)
def __str__(self):
return "I"
def __repr__(self):
r = "Identity(%d)" % self._dim
return r
def __eq__(self, other):
return isinstance(other, Identity) and self._dim == other._dim
# --- Permutation symbol ---
[docs]@ufl_type()
class PermutationSymbol(ConstantValue):
"""UFL literal type: Representation of a permutation symbol.
This is also known as the Levi-Civita symbol, antisymmetric symbol,
or alternating symbol."""
__slots__ = ("ufl_shape", "_dim")
def __init__(self, dim):
ConstantValue.__init__(self)
self._dim = dim
self.ufl_shape = (dim,) * dim
[docs] def evaluate(self, x, mapping, component, index_values):
"Evaluates the permutation symbol."
return self.__eps(component)
def __getitem__(self, key):
if len(key) != self._dim:
error("Size mismatch for PermutationSymbol.")
if all(isinstance(k, (int, FixedIndex)) for k in key):
return self.__eps(key)
return Expr.__getitem__(self, key)
def __str__(self):
return "eps"
def __repr__(self):
r = "PermutationSymbol(%d)" % self._dim
return r
def __eq__(self, other):
return isinstance(other, PermutationSymbol) and self._dim == other._dim
def __eps(self, x):
"""This function body is taken from
http://www.mathkb.com/Uwe/Forum.aspx/math/29865/N-integer-Levi-Civita
"""
result = IntValue(1)
for i, x1 in enumerate(x):
for j in range(i + 1, len(x)):
x2 = x[j]
if x1 > x2:
result = -result
elif x1 == x2:
return Zero()
return result
[docs]def as_ufl(expression):
"Converts expression to an Expr if possible."
if isinstance(expression, Expr):
return expression
elif isinstance(expression, complex):
return ComplexValue(expression)
elif isinstance(expression, float):
return FloatValue(expression)
elif isinstance(expression, int):
return IntValue(expression)
else:
raise UFLValueError("Invalid type conversion: %s can not be converted"
" to any UFL type." % str(expression))