Source code for ufl.index_combination_utils

"""Utilities for analysing and manipulating free index tuples."""
# Copyright (C) 2008-2016 Martin Sandve Alnæs
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier:    LGPL-3.0-or-later

from ufl.core.multiindex import FixedIndex, Index, indices

# FIXME: Some of these might be merged into one function, some might
# be optimized


[docs] def unique_sorted_indices(indices): """Get unique sorted indices. Given a list of (id, dim) tuples already sorted by id, return a unique list with duplicates removed. Also checks that the dimensions of duplicates are matching. """ newindices = [] prev = (None, None) for i in indices: if i[0] != prev[0]: newindices.append(i) prev = i else: if i[1] != prev[1]: raise ValueError("Nonmatching dimensions for free indices with same id!") return tuple(newindices)
[docs] def merge_unique_indices(afi, afid, bfi, bfid): """Merge two pairs of (index ids, index dimensions) sequences into one pair without duplicates. The id tuples afi, bfi are assumed already sorted by id. Given a list of (id, dim) tuples already sorted by id, return a unique list with duplicates removed. Also checks that the dimensions of duplicates are matching. """ na = len(afi) nb = len(bfi) if na == 0: return bfi, bfid elif nb == 0: return afi, afid ak = 0 bk = 0 fi = [] fid = [] while True: if afi[ak] < bfi[bk]: fi.append(afi[ak]) fid.append(afid[ak]) ak += 1 elif afi[ak] > bfi[bk]: fi.append(bfi[bk]) fid.append(bfid[bk]) bk += 1 else: fi.append(afi[ak]) fid.append(afid[ak]) ak += 1 bk += 1 if ak == na: if bk != nb: fi.extend(bfi[bk:]) fid.extend(bfid[bk:]) break elif bk == nb: fi.extend(afi[ak:]) fid.extend(afid[ak:]) break return tuple(fi), tuple(fid)
[docs] def remove_indices(fi, fid, rfi): """Remove indices.""" if not rfi: return fi, fid rfip = sorted((r, p) for p, r in enumerate(rfi)) nrfi = len(rfi) nfi = len(fi) shape = [None] * nrfi k = 0 pos = 0 newfiid = [] while pos < nfi: rk = rfip[k][0] # Keep while fi[pos] < rk: newfiid.append((fi[pos], fid[pos])) pos += 1 # Skip removed = 0 while pos < nfi and fi[pos] == rk: shape[rfip[k][1]] = fid[pos] pos += 1 removed += 1 # Expecting to find each index from rfi in fi if not removed: raise ValueError(f"Index to be removed ({rk}) not part of indices ({fi}).") # Next to remove k += 1 if k == nrfi: # No more to remove, keep the rest if pos < nfi: newfiid.extend(zip(fi[pos:], fid[pos:])) break assert None not in shape # Unpack into two tuples fi, fid = zip(*newfiid) if newfiid else ((), ()) return fi, fid, tuple(shape)
[docs] def create_slice_indices(component, shape, fi): """Create slice indices.""" all_indices = [] slice_indices = [] repeated_indices = [] free_indices = [] for ind in component: if isinstance(ind, Index): all_indices.append(ind) if ind.count() in fi or ind in free_indices: repeated_indices.append(ind) free_indices.append(ind) elif isinstance(ind, FixedIndex): if int(ind) >= shape[len(all_indices)]: raise ValueError("Index out of bounds.") all_indices.append(ind) elif isinstance(ind, int): if int(ind) >= shape[len(all_indices)]: raise ValueError("Index out of bounds.") all_indices.append(FixedIndex(ind)) elif isinstance(ind, slice): if ind != slice(None): raise ValueError("Only full slices (:) allowed.") i = Index() slice_indices.append(i) all_indices.append(i) elif ind == Ellipsis: er = len(shape) - len(component) + 1 ii = indices(er) slice_indices.extend(ii) all_indices.extend(ii) else: raise ValueError(f"Not expecting {ind}.") if len(all_indices) != len(shape): raise ValueError("Component and shape length don't match.") return tuple(all_indices), tuple(slice_indices), tuple(repeated_indices)
# Outer etc.
[docs] def merge_nonoverlapping_indices(a, b): """Merge non-overlapping free indices into one representation. Example: C[i,j,r,s] = outer(A[i,s], B[j,r]) A, B -> (i,j,r,s), (idim,jdim,rdim,sdim) """ # Extract input properties ai = a.ufl_free_indices bi = b.ufl_free_indices aid = a.ufl_index_dimensions bid = b.ufl_index_dimensions # Merge lists to return s = sorted(zip(ai + bi, aid + bid)) if s: free_indices, index_dimensions = zip(*s) # Consistency checks if len(set(free_indices)) != len(free_indices): raise ValueError("Not expecting repeated indices.") else: free_indices, index_dimensions = (), () return free_indices, index_dimensions
# Product
[docs] def merge_overlapping_indices(afi, afid, bfi, bfid): """Merge overlapping free indices into one free and one repeated representation. Example: C[j,r] := A[i,j,k] * B[i,r,k] A, B -> (j,r), (jdim,rdim), (i,k), (idim,kdim) """ # Extract input properties an = len(afi) bn = len(bfi) # Lists to return free_indices = [] index_dimensions = [] repeated_indices = [] repeated_index_dimensions = [] # Find repeated indices, brute force version for i0 in range(an): for i1 in range(bn): if afi[i0] == bfi[i1]: repeated_indices.append(afi[i0]) repeated_index_dimensions.append(afid[i0]) break # Collect only non-repeated indices, brute force version for i, d in sorted(zip(afi + bfi, afid + bfid)): if i not in repeated_indices: free_indices.append(i) index_dimensions.append(d) # Consistency checks if len(set(free_indices)) != len(free_indices): raise ValueError("Not expecting repeated indices left.") if len(free_indices) + 2 * len(repeated_indices) != an + bn: raise ValueError("Expecting only twice repeated indices.") return ( tuple(free_indices), tuple(index_dimensions), tuple(repeated_indices), tuple(repeated_index_dimensions), )