# -*- coding: utf-8 -*-
"""This module defines the single index types and some internal index utilities."""
# Copyright (C) 2008-2016 Martin Sandve Alnæs and Anders Logg
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later
#
# Modified by Massimiliano Leoni, 2016.
from ufl.log import error
from ufl.utils.counted import counted_init
from ufl.core.ufl_type import ufl_type
from ufl.core.terminal import Terminal
# Export list for ufl.classes
__all_classes__ = ["IndexBase", "FixedIndex", "Index"]
[docs]class IndexBase(object):
"""Base class for all indices."""
__slots__ = ()
def __init__(self):
pass
[docs]class FixedIndex(IndexBase):
"""UFL value: An index with a specific value assigned."""
__slots__ = ("_value", "_hash")
_cache = {}
def __getnewargs__(self):
return (self._value,)
def __new__(cls, value):
self = FixedIndex._cache.get(value)
if self is None:
if not isinstance(value, int):
error("Expecting integer value for fixed index.")
self = IndexBase.__new__(cls)
self._init(value)
FixedIndex._cache[value] = self
return self
def _init(self, value):
IndexBase.__init__(self)
self._value = value
self._hash = hash(("FixedIndex", self._value))
def __init__(self, value):
pass
def __hash__(self):
return self._hash
def __eq__(self, other):
return isinstance(other, FixedIndex) and (self._value == other._value)
def __int__(self):
return self._value
def __str__(self):
return "%d" % self._value
def __repr__(self):
r = "FixedIndex(%d)" % self._value
return r
[docs]class Index(IndexBase):
"""UFL value: An index with no value assigned.
Used to represent free indices in Einstein indexing notation."""
__slots__ = ("_count",)
_globalcount = 0
def __init__(self, count=None):
IndexBase.__init__(self)
counted_init(self, count, Index)
[docs] def count(self):
return self._count
def __hash__(self):
return hash(("Index", self._count))
def __eq__(self, other):
return isinstance(other, Index) and (self._count == other._count)
def __str__(self):
c = str(self._count)
if len(c) > 1:
c = "{%s}" % c
return "i_%s" % c
def __repr__(self):
r = "Index(%d)" % self._count
return r
[docs]@ufl_type()
class MultiIndex(Terminal):
"Represents a sequence of indices, either fixed or free."
__slots__ = ("_indices",)
_cache = {}
def __getnewargs__(self):
return (self._indices,)
def __new__(cls, indices):
if not isinstance(indices, tuple):
error("Expecting a tuple of indices.")
if all(isinstance(ind, FixedIndex) for ind in indices):
# Cache multiindices consisting of purely fixed indices
# (aka flyweight pattern)
key = tuple(ind._value for ind in indices)
self = MultiIndex._cache.get(key)
if self is not None:
return self
self = Terminal.__new__(cls)
MultiIndex._cache[key] = self
else:
# Create a new object if we have any free indices (too
# many combinations to cache)
if not all(isinstance(ind, IndexBase) for ind in indices):
error("Expecting only Index and FixedIndex objects.")
self = Terminal.__new__(cls)
# Initialize here instead of in __init__ to avoid overwriting
# self._indices from cached objects
self._init(indices)
return self
def __init__(self, indices):
pass
def _init(self, indices):
Terminal.__init__(self)
self._indices = indices
[docs] def indices(self):
"Return tuple of indices."
return self._indices
def _ufl_compute_hash_(self):
return hash(("MultiIndex",) + tuple(hash(ind) for ind in self._indices))
def __eq__(self, other):
return isinstance(other, MultiIndex) and \
self._indices == other._indices
[docs] def evaluate(self, x, mapping, component, index_values):
"Evaluate index."
# Build component from index values
component = []
for i in self._indices:
if isinstance(i, FixedIndex):
component.append(i._value)
elif isinstance(i, Index):
component.append(index_values[i])
return tuple(component)
@property
def ufl_shape(self):
"This shall not be used."
error("Multiindex has no shape (it is not a tensor expression).")
@property
def ufl_free_indices(self):
"This shall not be used."
error("Multiindex has no free indices (it is not a tensor expression).")
@property
def ufl_index_dimensions(self):
"This shall not be used."
error("Multiindex has no free indices (it is not a tensor expression).")
[docs] def is_cellwise_constant(self):
"Always True."
return True
[docs] def ufl_domains(self):
"Return tuple of domains related to this terminal object."
return ()
# --- Adding multiindices ---
def __add__(self, other):
if isinstance(other, tuple):
return MultiIndex(self._indices + other)
elif isinstance(other, MultiIndex):
return MultiIndex(self._indices + other._indices)
return NotImplemented
def __radd__(self, other):
if isinstance(other, tuple):
return MultiIndex(other + self._indices)
elif isinstance(other, MultiIndex):
return MultiIndex(other._indices + self._indices)
return NotImplemented
# --- String formatting ---
def __str__(self):
return ", ".join(str(i) for i in self._indices)
def __repr__(self):
r = "MultiIndex(%s)" % repr(self._indices)
return r
# --- Iteration protocol ---
def __len__(self):
return len(self._indices)
def __getitem__(self, i):
return self._indices[i]
def __iter__(self):
return iter(self._indices)
[docs]def as_multi_index(ii, shape=None):
"Return a ``MultiIndex`` version of *ii*."
if isinstance(ii, MultiIndex):
return ii
elif not isinstance(ii, tuple):
ii = (ii,)
return MultiIndex(ii)
[docs]def indices(n):
"UFL value: Return a tuple of :math:`n` new Index objects."
return tuple(Index() for i in range(n))