Package sfc :: Package representation :: Module swiginac_eval
[hide private]
[frames] | no frames]

Source Code for Module sfc.representation.swiginac_eval

  1  """This module defines evaluation algorithms for converting 
  2  converting UFL expressions to swiginac representation.""" 
  3   
  4  # Copyright (C) 2008-2009 Martin Sandve Alnes and Simula Resarch Laboratory 
  5  # 
  6  # This file is part of SyFi. 
  7  # 
  8  # SyFi is free software: you can redistribute it and/or modify 
  9  # it under the terms of the GNU General Public License as published by 
 10  # the Free Software Foundation, either version 2 of the License, or 
 11  # (at your option) any later version. 
 12  # 
 13  # SyFi is distributed in the hope that it will be useful, 
 14  # but WITHOUT ANY WARRANTY; without even the implied warranty of 
 15  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 
 16  # GNU General Public License for more details. 
 17  # 
 18  # You should have received a copy of the GNU General Public License 
 19  # along with SyFi. If not, see <http://www.gnu.org/licenses/>. 
 20  # 
 21  # Modified by Kent-Andre Mardal, 2010. 
 22  # 
 23  # First added:  2008-08-22 
 24  # Last changed: 2009-03-19 
 25   
 26  from collections import defaultdict 
 27  from itertools import izip, chain 
 28   
 29  import swiginac 
 30   
 31  from ufl import * 
 32  from ufl.classes import * 
 33  from ufl.common import some_key, product, Stack, StackDict 
 34  from ufl.algorithms.transformations import Transformer, MultiFunction 
 35  from ufl.permutation import compute_indices 
 36   
 37  from sfc.common import sfc_assert, sfc_error, sfc_warning 
 38  from sfc.symbolic_utils import symbol, symbols 
 39   
40 -class EvaluateAsSwiginac(MultiFunction):
41 - def __init__(self, formrep, itgrep, data, on_facet): # TODO: Remove unused arguments after implementing:
42 MultiFunction.__init__(self) 43 self.formrep = formrep 44 self.itgrep = itgrep 45 self.data = data 46 self.on_facet = on_facet 47 self.current_basis_function = (None,)*formrep.rank
48 49 ### Fallback handlers: 50
51 - def expr(self, o, *ops):
52 sfc_error("Evaluation not implemented for expr %s." % type(o).__name__)
53
54 - def terminal(self, o, *ops):
55 sfc_error("Evaluation not implemented for terminal %s." % type(o).__name__)
56 57 ### Terminals: 58
59 - def zero(self, o):
60 return swiginac.numeric(0)
61
62 - def scalar_value(self, o):
63 return swiginac.numeric(o.value())
64
65 - def spatial_coordinate(self, o, component=(), derivatives=()):
66 # Called by indexed 67 68 if component: 69 # 2D, 3D 70 c, = component 71 else: 72 # 1D 73 c = 0 74 75 if derivatives: 76 if len(derivatives) > 1: 77 return swiginac.numeric(0) 78 d, = derivatives 79 if d == c: 80 return swiginac.numeric(1) 81 return swiginac.numeric(0) 82 83 return self.formrep.x_sym[c]
84
85 - def facet_normal(self, o, component=(), derivatives=()):
86 # Called by indexed 87 sfc_assert(self.on_facet, "Expecting to be on a facet in facet_normal.") 88 89 if derivatives: 90 return swiginac.numeric(0) 91 92 if component: 93 # 2D, 3D 94 c, = component 95 else: 96 # 1D 97 c = 0 98 99 return self.formrep.n_sym[c]
100
101 - def cell_volume(self, o, component=(), derivatives=()):
102 if derivatives: 103 return swiginac.numeric(0) 104 gr = self.formrep.geomrep 105 return swiginac.abs(gr.detG_sym * gr.sfc_cell.reference_volume)
106
107 - def argument(self, o, component=(), derivatives=()):
108 109 # Assuming renumbered arguments! 110 iarg = o.count() 111 sfc_assert(iarg >= 0, "Argument count shouldn't be negative.") 112 sfc_assert(isinstance(component, tuple), "Expecting tuple for component.") 113 114 j = self.current_basis_function[iarg] 115 116 if derivatives: 117 s = self.formrep.Dv_sym(iarg, j, component, derivatives, self.on_facet) 118 e = self.formrep.Dv_expr(iarg, j, component, derivatives, False, self.on_facet) # FIXME: use_symbols = False -> can do better 119 else: 120 s = self.formrep.v_sym(iarg, j, component, self.on_facet) 121 e = self.formrep.v_expr(iarg, j, component) 122 123 if e.nops() == 0: 124 return e # FIXME: Avoid generating code for s when not using it 125 return s
126
127 - def coefficient(self, o, component=(), derivatives=()):
128 # Assuming renumbered arguments! 129 iarg = o.count() 130 sfc_assert(iarg >= 0, "Coefficient count shouldn't be negative.") 131 sfc_assert(isinstance(component, tuple), "Expecting tuple for component.") 132 133 if derivatives: 134 s = self.formrep.Dw_sym(iarg, component, derivatives) 135 e = self.formrep.Dw_expr(iarg, component, derivatives, False, self.on_facet) # FIXME: use_symbols = False -> can do better 136 else: 137 # w^i_h(x) = \sum_j w[i][j] * phi^i_j(x) 138 s = self.formrep.w_sym(iarg, component) 139 e = self.formrep.w_expr(iarg, component, False, self.on_facet) # FIXME: use_symbols = False -> can do better 140 141 if e.nops() == 0: 142 return e # FIXME: Avoid generating code for s when not using it 143 return s
144 145 ### Indexing and derivatives: 146
147 - def multi_index(self, o):
148 return tuple(map(int, o))
149
150 - def spatial_derivative(self, o, f, i, component=(), derivatives=()):
151 derivatives = sorted(derivatives + self.multi_index(i)) 152 if isinstance(f, Indexed): 153 # Since expand_indices moves Indexed in to the terminals, 154 # SpatialDerivative can be outside an Indexed 155 sfc_assert(component == (), "Expecting no component outside of Indexed!") 156 A, ii = f.operands() 157 component = self.multi_index(ii) 158 return self(A, component, derivatives) 159 return self(f, component, derivatives)
160
161 - def indexed(self, o, A, ii):
162 # Passes on control to one of: 163 #def argument(self, o, component): 164 #def coefficient(self, o, component): 165 #def facet_normal(self, o, component): 166 #def spatial_coordinate(self, o, component): 167 #def spatial_derivative(self, o, component): 168 component = self.multi_index(ii) 169 if isinstance(A, SpatialDerivative): 170 f, i = A.operands() 171 return self.spatial_derivative(A, f, i, component) 172 return self(A, component)
173 174 ### Algebraic operators: 175
176 - def power(self, o, a, b):
177 return a**b
178
179 - def sum(self, o, *ops):
180 return sum(ops)
181
182 - def product(self, o, *ops):
183 return product(ops)
184
185 - def division(self, o, a, b):
186 return a / b
187
188 - def abs(self, o, a):
189 return swiginac.abs(a)
190 191 ### Basic math functions: 192
193 - def sqrt(self, o, a):
194 return swiginac.sqrt(a)
195
196 - def exp(self, o, a):
197 return swiginac.exp(a)
198
199 - def ln(self, o, a):
200 return swiginac.log(a)
201
202 - def cos(self, o, a):
203 return swiginac.cos(a)
204
205 - def sin(self, o, a):
206 return swiginac.sin(a)
207
208 - def tan(self, o, a):
209 return swiginac.tan(a)
210
211 - def acos(self, o, a):
212 return swiginac.acos(a)
213
214 - def asin(self, o, a):
215 return swiginac.asin(a)
216
217 - def atan(self, o, a):
218 return swiginac.atan(a)
219 220 221 222
223 - def variable(self, o):
224 sfc_error("Should strip away variables before building graph.")
225 226 # FIXME: Implement all missing operators here 227
228 -class SwiginacEvaluator(Transformer):
229 "Algorithm for evaluation of an UFL expression as a swiginac expression."
230 - def __init__(self, formrep, use_symbols, on_facet):
231 Transformer.__init__(self)#, variable_cache) 232 233 # input 234 self._formrep = formrep 235 self._use_symbols = use_symbols 236 self._on_facet = on_facet 237 238 # current basis function configuration 239 self._current_basis_function = tuple(0 for i in range(formrep.rank)) 240 241 # current indexing status 242 self._components = Stack() 243 self._index2value = StackDict() 244 245 # code and cache structures 246 # FIXME: Need pre-initialized self._variable2symbol and self._tokens 247 self._variable2symbol = {} 248 self._tokens = [] 249 250 # convenience variables 251 self.nsd = self._formrep.cell.nsd
252
253 - def pop_tokens(self):
254 # TODO: Make a generator approach to this? Allow handlers to trigger a "token yield"? 255 t = self._tokens 256 self._tokens = [] 257 return t
258
259 - def update(self, iota):
260 self._current_basis_function = tuple(iota)
261
262 - def component(self):
263 "Return current component tuple." 264 if len(self._components): 265 return self._components.peek() 266 return ()
267 268 ### Fallback handlers: 269
270 - def expr(self, x):
271 sfc_error("Missing ufl to swiginac handler for type %s" % str(type(x)))
272
273 - def terminal(self, x):
274 sfc_error("Missing ufl to swiginac handler for terminal type %s" % str(type(x)))
275 276 ### Handlers for basic terminal objects: 277
278 - def zero(self, x):
279 #sfc_assert(len(self.component()) == len(x.shape()), "Index component length mismatch in zero tensor!") 280 return swiginac.numeric(0)
281
282 - def scalar_value(self, x):
283 #sfc_assert(self.component() == (), "Shouldn't have any component at this point.") 284 return swiginac.numeric(x._value)
285
286 - def identity(self, x):
287 c = self.component() 288 v = 1 if c[0] == c[1] else 0 289 return swiginac.numeric(v)
290
291 - def argument(self, x):
292 iarg = x.count() 293 sfc_assert(iarg >= 0, "Argument count shouldn't be negative.") 294 j = self._current_basis_function[iarg] 295 c = self.component() 296 if self._use_symbols: 297 return self._formrep.v_sym(iarg, j, c, self._on_facet) 298 else: 299 return self._formrep.v_expr(iarg, j, c)
300
301 - def coefficient(self, x):
302 iarg = x.count() 303 c = self.component() 304 if self._use_symbols: 305 return self._formrep.w_sym(iarg, c) 306 else: 307 # w^i_h(x) = \sum_j w[i][j] * phi^i_j(x) 308 return self._formrep.w_expr(iarg, c, False, self._on_facet)
309
310 - def facet_normal(self, x):
311 sfc_assert(self._on_facet, "Expecting to be on a facet in facet_normal.") 312 c, = self.component() 313 return self._formrep.n_sym[c]
314
315 - def spatial_coordinate(self, x):
316 c, = self.component() 317 return self._formrep.x_sym[c]
318 319 ### Handler for variables: 320
321 - def variable(self, x):
322 return self.visit(x._expression)
323
324 - def garbage(self, x): # TODO: Maybe some of the ideas here can be used in code generation
325 c = self.component() 326 index_values = tuple(self._index2value[k] for k in x._expression.free_indices()) 327 328 # TODO: Doesn't always depend on _current_basis_function, this is crap: 329 key = (x.count(), c, index_values, self._current_basis_function) 330 vsym = self._variable2symbol.get(key) 331 332 if vsym is None: 333 expr = self.visit(x._expression) 334 # TODO: Doesn't always depend on _current_basis_function, this is crap: 335 compstr = "_".join("%d" % k for k in chain(c, index_values, self._current_basis_function)) 336 vname = "_".join(("t_%d" % x.count(), compstr)) 337 vsym = symbol(vname) 338 self._variable2symbol[key] = vsym 339 self._tokens.append((vsym, expr))
340 341 ### Handlers for basic algebra: 342
343 - def sum(self, x, *ops):
344 return sum(ops)
345
346 - def index_sum(self, x):
347 ops = [] 348 summand, multiindex = x.operands() 349 index, = multiindex 350 for i in range(x.dimension()): 351 self._index2value.push(index, i) 352 ops.append(self.visit(summand)) 353 self._index2value.pop() 354 return sum(ops)
355
356 - def product(self, x):
357 sfc_assert(not self.component(), "Non-empty indexing component in product!") 358 ops = [self.visit(o) for o in x.operands()] 359 return product(ops)
360 # ... 361
362 - def division(self, x, a, b):
363 return a / b
364
365 - def power(self, x, a, b):
366 return a ** b
367
368 - def abs(self, x, a):
369 return swiginac.abs(a)
370 371 ### Basic math functions:
372 - def sqrt(self, x, y):
373 return swiginac.sqrt(y)
374
375 - def exp(self, x, y):
376 return swiginac.exp(y)
377
378 - def ln(self, x, y):
379 return swiginac.log(y)
380
381 - def cos(self, x, y):
382 return swiginac.cos(y)
383
384 - def sin(self, x, y):
385 return swiginac.sin(y)
386 387 ### Index handling:
388 - def multi_index(self, x):
389 subcomp = [] 390 for i in x: 391 if isinstance(i, FixedIndex): 392 subcomp.append(i._value) 393 elif isinstance(i, Index): 394 subcomp.append(self._index2value[i]) 395 return tuple(subcomp)
396
397 - def indexed(self, x):
398 A, ii = x.operands() 399 self._components.push(self.visit(ii)) 400 result = self.visit(A) 401 self._components.pop() 402 return result
403 404 ### Container handling: 405
406 - def old_list_tensor(self, x): # doesn't support e.g. building a matrix from vector rows
407 component = self.component() 408 sfc_assert(len(component) > 0 and \ 409 all(isinstance(i, int) for i in component), 410 "Can't index tensor with %s." % repr(component)) 411 412 # Hide indexing when evaluating subexpression 413 self._components.push(()) 414 415 # Get scalar UFL subexpression from tensor 416 e = x 417 for i in component: 418 e = e._expressions[i] 419 sfc_assert(e.shape() == (), "Expecting scalar expression "\ 420 "after extracting component from tensor.") 421 422 # Apply conversion to scalar subexpression 423 r = self.visit(e) 424 425 # Return to previous component state 426 self._components.pop() 427 return r 428
429 - def list_tensor(self, x):
430 # Pick the right subtensor and subcomponent 431 c = self.component() 432 c0, c1 = c[0], c[1:] 433 op = x.operands()[c0] 434 # Evaluate subtensor with this subcomponent 435 self._components.push(c1) 436 r = self.visit(op) 437 self._components.pop() 438 return r
439
440 - def component_tensor(self, x):
441 # this function evaluates the tensor expression 442 # with indices equal to the current component tuple 443 expression, indices = x.operands() 444 sfc_assert(expression.shape() == (), "Expecting scalar base expression.") 445 446 # update index map with component tuple values 447 comp = self.component() 448 sfc_assert(len(indices) == len(comp), "Index/component mismatch.") 449 for i, v in izip(indices._indices, comp): 450 self._index2value.push(i, v) 451 self._components.push(()) 452 453 # evaluate with these indices 454 result = self.visit(expression) 455 456 # revert index map 457 for i in range(len(comp)): 458 self._index2value.pop() 459 self._components.pop() 460 return result
461 462 ### Differentiation: 463
464 - def _ddx(self, f, i):
465 """Differentiate swiginac expression f w.r.t. x_i, using 466 df/dx_i = df/dxi_j dxi_j/dx_i.""" 467 Ginv = self._formrep.Ginv_sym 468 xi = self._formrep.xi_sym 469 return sum(Ginv[j, i] * swiginac.diff(f, xi[j]) for j in range(self.nsd))
470
471 - def spatial_derivative(self, x):
472 # Assuming that AD has been applied, so 473 # the expression to differentiate is always a Terminal. 474 475 f, ii = x.operands() 476 477 sfc_assert(isinstance(f, Terminal), \ 478 "Expecting to differentiate a Terminal object, you must apply AD first!") # The exception is higher order derivatives, ignoring for now 479 480 # Get component and derivative directions 481 c = self.component() 482 der = self.visit(ii) 483 484 # --- Handle derivatives of basis functions 485 if isinstance(f, Argument): 486 iarg = f.count() 487 i = self._current_basis_function[iarg] 488 if self._use_symbols: 489 return self._formrep.Dv_sym(iarg, i, c, der, self._on_facet) 490 else: 491 return self._formrep.Dv_expr(iarg, i, c, der, False, self._on_facet) 492 493 # --- Handle derivatives of coefficient functions 494 if isinstance(f, Coefficient): 495 iarg = f.count() 496 if self._use_symbols: 497 return self._formrep.Dw_sym(iarg, c, der) 498 else: 499 return self._formrep.Dw_expr(iarg, c, der, False, self._on_facet) 500 501 # --- Handle derivatives of geometry objects 502 if isinstance(f, FacetNormal): 503 return swiginac.numeric(0.0) 504 505 if isinstance(f, SpatialCoordinate): 506 c, = c 507 if der[0] == c: 508 return swiginac.numeric(1.0) 509 else: 510 return swiginac.numeric(0.0) 511 512 sfc_error("Eh?")
513
514 - def derivative(self, x):
515 sfc_error("Derivative shouldn't occur here, you must apply AD first!")
516 517 ### Interior facet stuff: 518
519 - def positive_restricted(self, x, y):
520 sfc_error("TODO: Restrictions not implemented!") 521 return y
522
523 - def negative_restricted(self, x, y):
524 sfc_error("TODO: Restrictions not implemented!") 525 return y
526 527 528 ### These require code structure and thus shouldn't occur in SwiginacEvaluator 529 # (i.e. any conditionals should be handled externally) 530 #d[EQ] = 531 #d[NE] = 532 #d[LE] = 533 #d[GE] = 534 #d[LT] = 535 #d[GT] = 536 537 ### These are replaced by expand_compounds, so we skip them here: 538 #d[Identity] = 539 #d[Transposed] = 540 #d[Outer] = 541 #d[Inner] = 542 #d[Dot] = 543 #d[Cross] = 544 #d[Trace] = 545 #d[Determinant] = 546 #d[Inverse] = 547 #d[Deviatoric] = 548 #d[Cofactor] = 549 #d[Grad] = 550 #d[Div] = 551 #d[Curl] = 552 #d[Rot] = 553