1 """This module defines evaluation algorithms for converting
2 converting UFL expressions to swiginac representation."""
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26 from collections import defaultdict
27 from itertools import izip, chain
28
29 import swiginac
30
31 from ufl import *
32 from ufl.classes import *
33 from ufl.common import some_key, product, Stack, StackDict
34 from ufl.algorithms.transformations import Transformer, MultiFunction
35 from ufl.permutation import compute_indices
36
37 from sfc.common import sfc_assert, sfc_error, sfc_warning
38 from sfc.symbolic_utils import symbol, symbols
39
41 - def __init__(self, formrep, itgrep, data, on_facet):
42 MultiFunction.__init__(self)
43 self.formrep = formrep
44 self.itgrep = itgrep
45 self.data = data
46 self.on_facet = on_facet
47 self.current_basis_function = (None,)*formrep.rank
48
49
50
51 - def expr(self, o, *ops):
52 sfc_error("Evaluation not implemented for expr %s." % type(o).__name__)
53
55 sfc_error("Evaluation not implemented for terminal %s." % type(o).__name__)
56
57
58
60 return swiginac.numeric(0)
61
63 return swiginac.numeric(o.value())
64
66
67
68 if component:
69
70 c, = component
71 else:
72
73 c = 0
74
75 if derivatives:
76 if len(derivatives) > 1:
77 return swiginac.numeric(0)
78 d, = derivatives
79 if d == c:
80 return swiginac.numeric(1)
81 return swiginac.numeric(0)
82
83 return self.formrep.x_sym[c]
84
86
87 sfc_assert(self.on_facet, "Expecting to be on a facet in facet_normal.")
88
89 if derivatives:
90 return swiginac.numeric(0)
91
92 if component:
93
94 c, = component
95 else:
96
97 c = 0
98
99 return self.formrep.n_sym[c]
100
101 - def cell_volume(self, o, component=(), derivatives=()):
102 if derivatives:
103 return swiginac.numeric(0)
104 gr = self.formrep.geomrep
105 return swiginac.abs(gr.detG_sym * gr.sfc_cell.reference_volume)
106
107 - def argument(self, o, component=(), derivatives=()):
108
109
110 iarg = o.count()
111 sfc_assert(iarg >= 0, "Argument count shouldn't be negative.")
112 sfc_assert(isinstance(component, tuple), "Expecting tuple for component.")
113
114 j = self.current_basis_function[iarg]
115
116 if derivatives:
117 s = self.formrep.Dv_sym(iarg, j, component, derivatives, self.on_facet)
118 e = self.formrep.Dv_expr(iarg, j, component, derivatives, False, self.on_facet)
119 else:
120 s = self.formrep.v_sym(iarg, j, component, self.on_facet)
121 e = self.formrep.v_expr(iarg, j, component)
122
123 if e.nops() == 0:
124 return e
125 return s
126
127 - def coefficient(self, o, component=(), derivatives=()):
128
129 iarg = o.count()
130 sfc_assert(iarg >= 0, "Coefficient count shouldn't be negative.")
131 sfc_assert(isinstance(component, tuple), "Expecting tuple for component.")
132
133 if derivatives:
134 s = self.formrep.Dw_sym(iarg, component, derivatives)
135 e = self.formrep.Dw_expr(iarg, component, derivatives, False, self.on_facet)
136 else:
137
138 s = self.formrep.w_sym(iarg, component)
139 e = self.formrep.w_expr(iarg, component, False, self.on_facet)
140
141 if e.nops() == 0:
142 return e
143 return s
144
145
146
148 return tuple(map(int, o))
149
160
173
174
175
176 - def power(self, o, a, b):
178
179 - def sum(self, o, *ops):
181
184
187
188 - def abs(self, o, a):
189 return swiginac.abs(a)
190
191
192
193 - def sqrt(self, o, a):
194 return swiginac.sqrt(a)
195
196 - def exp(self, o, a):
197 return swiginac.exp(a)
198
199 - def ln(self, o, a):
200 return swiginac.log(a)
201
202 - def cos(self, o, a):
203 return swiginac.cos(a)
204
205 - def sin(self, o, a):
206 return swiginac.sin(a)
207
208 - def tan(self, o, a):
209 return swiginac.tan(a)
210
211 - def acos(self, o, a):
212 return swiginac.acos(a)
213
214 - def asin(self, o, a):
215 return swiginac.asin(a)
216
217 - def atan(self, o, a):
218 return swiginac.atan(a)
219
220
221
222
224 sfc_error("Should strip away variables before building graph.")
225
226
227
229 "Algorithm for evaluation of an UFL expression as a swiginac expression."
230 - def __init__(self, formrep, use_symbols, on_facet):
231 Transformer.__init__(self)
232
233
234 self._formrep = formrep
235 self._use_symbols = use_symbols
236 self._on_facet = on_facet
237
238
239 self._current_basis_function = tuple(0 for i in range(formrep.rank))
240
241
242 self._components = Stack()
243 self._index2value = StackDict()
244
245
246
247 self._variable2symbol = {}
248 self._tokens = []
249
250
251 self.nsd = self._formrep.cell.nsd
252
254
255 t = self._tokens
256 self._tokens = []
257 return t
258
260 self._current_basis_function = tuple(iota)
261
263 "Return current component tuple."
264 if len(self._components):
265 return self._components.peek()
266 return ()
267
268
269
271 sfc_error("Missing ufl to swiginac handler for type %s" % str(type(x)))
272
274 sfc_error("Missing ufl to swiginac handler for terminal type %s" % str(type(x)))
275
276
277
279
280 return swiginac.numeric(0)
281
283
284 return swiginac.numeric(x._value)
285
287 c = self.component()
288 v = 1 if c[0] == c[1] else 0
289 return swiginac.numeric(v)
290
292 iarg = x.count()
293 sfc_assert(iarg >= 0, "Argument count shouldn't be negative.")
294 j = self._current_basis_function[iarg]
295 c = self.component()
296 if self._use_symbols:
297 return self._formrep.v_sym(iarg, j, c, self._on_facet)
298 else:
299 return self._formrep.v_expr(iarg, j, c)
300
302 iarg = x.count()
303 c = self.component()
304 if self._use_symbols:
305 return self._formrep.w_sym(iarg, c)
306 else:
307
308 return self._formrep.w_expr(iarg, c, False, self._on_facet)
309
311 sfc_assert(self._on_facet, "Expecting to be on a facet in facet_normal.")
312 c, = self.component()
313 return self._formrep.n_sym[c]
314
316 c, = self.component()
317 return self._formrep.x_sym[c]
318
319
320
322 return self.visit(x._expression)
323
325 c = self.component()
326 index_values = tuple(self._index2value[k] for k in x._expression.free_indices())
327
328
329 key = (x.count(), c, index_values, self._current_basis_function)
330 vsym = self._variable2symbol.get(key)
331
332 if vsym is None:
333 expr = self.visit(x._expression)
334
335 compstr = "_".join("%d" % k for k in chain(c, index_values, self._current_basis_function))
336 vname = "_".join(("t_%d" % x.count(), compstr))
337 vsym = symbol(vname)
338 self._variable2symbol[key] = vsym
339 self._tokens.append((vsym, expr))
340
341
342
343 - def sum(self, x, *ops):
345
347 ops = []
348 summand, multiindex = x.operands()
349 index, = multiindex
350 for i in range(x.dimension()):
351 self._index2value.push(index, i)
352 ops.append(self.visit(summand))
353 self._index2value.pop()
354 return sum(ops)
355
357 sfc_assert(not self.component(), "Non-empty indexing component in product!")
358 ops = [self.visit(o) for o in x.operands()]
359 return product(ops)
360
361
364
365 - def power(self, x, a, b):
367
368 - def abs(self, x, a):
369 return swiginac.abs(a)
370
371
372 - def sqrt(self, x, y):
373 return swiginac.sqrt(y)
374
375 - def exp(self, x, y):
376 return swiginac.exp(y)
377
378 - def ln(self, x, y):
379 return swiginac.log(y)
380
381 - def cos(self, x, y):
382 return swiginac.cos(y)
383
384 - def sin(self, x, y):
385 return swiginac.sin(y)
386
387
389 subcomp = []
390 for i in x:
391 if isinstance(i, FixedIndex):
392 subcomp.append(i._value)
393 elif isinstance(i, Index):
394 subcomp.append(self._index2value[i])
395 return tuple(subcomp)
396
398 A, ii = x.operands()
399 self._components.push(self.visit(ii))
400 result = self.visit(A)
401 self._components.pop()
402 return result
403
404
405
407 component = self.component()
408 sfc_assert(len(component) > 0 and \
409 all(isinstance(i, int) for i in component),
410 "Can't index tensor with %s." % repr(component))
411
412
413 self._components.push(())
414
415
416 e = x
417 for i in component:
418 e = e._expressions[i]
419 sfc_assert(e.shape() == (), "Expecting scalar expression "\
420 "after extracting component from tensor.")
421
422
423 r = self.visit(e)
424
425
426 self._components.pop()
427 return r
428
430
431 c = self.component()
432 c0, c1 = c[0], c[1:]
433 op = x.operands()[c0]
434
435 self._components.push(c1)
436 r = self.visit(op)
437 self._components.pop()
438 return r
439
441
442
443 expression, indices = x.operands()
444 sfc_assert(expression.shape() == (), "Expecting scalar base expression.")
445
446
447 comp = self.component()
448 sfc_assert(len(indices) == len(comp), "Index/component mismatch.")
449 for i, v in izip(indices._indices, comp):
450 self._index2value.push(i, v)
451 self._components.push(())
452
453
454 result = self.visit(expression)
455
456
457 for i in range(len(comp)):
458 self._index2value.pop()
459 self._components.pop()
460 return result
461
462
463
464 - def _ddx(self, f, i):
465 """Differentiate swiginac expression f w.r.t. x_i, using
466 df/dx_i = df/dxi_j dxi_j/dx_i."""
467 Ginv = self._formrep.Ginv_sym
468 xi = self._formrep.xi_sym
469 return sum(Ginv[j, i] * swiginac.diff(f, xi[j]) for j in range(self.nsd))
470
472
473
474
475 f, ii = x.operands()
476
477 sfc_assert(isinstance(f, Terminal), \
478 "Expecting to differentiate a Terminal object, you must apply AD first!")
479
480
481 c = self.component()
482 der = self.visit(ii)
483
484
485 if isinstance(f, Argument):
486 iarg = f.count()
487 i = self._current_basis_function[iarg]
488 if self._use_symbols:
489 return self._formrep.Dv_sym(iarg, i, c, der, self._on_facet)
490 else:
491 return self._formrep.Dv_expr(iarg, i, c, der, False, self._on_facet)
492
493
494 if isinstance(f, Coefficient):
495 iarg = f.count()
496 if self._use_symbols:
497 return self._formrep.Dw_sym(iarg, c, der)
498 else:
499 return self._formrep.Dw_expr(iarg, c, der, False, self._on_facet)
500
501
502 if isinstance(f, FacetNormal):
503 return swiginac.numeric(0.0)
504
505 if isinstance(f, SpatialCoordinate):
506 c, = c
507 if der[0] == c:
508 return swiginac.numeric(1.0)
509 else:
510 return swiginac.numeric(0.0)
511
512 sfc_error("Eh?")
513
515 sfc_error("Derivative shouldn't occur here, you must apply AD first!")
516
517
518
520 sfc_error("TODO: Restrictions not implemented!")
521 return y
522
524 sfc_error("TODO: Restrictions not implemented!")
525 return y
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553