1
2 """
3 This module contains representation classes for integrals.
4 """
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28 from itertools import izip
29
30
31 import swiginac
32
33
34 from ufl.permutation import compute_indices
35 from ufl.algorithms import Graph, partition, expand_indices2, expand_indices, strip_variables, tree_format
36 from ufl.classes import Expr, Terminal, UtilityType, SpatialDerivative, Indexed
37
38 from sfc.common import sfc_assert, sfc_warning, sfc_debug, sfc_error
39 from sfc.common.utilities import indices_subset
40 from sfc.codegeneration.codeformatting import indent, row_major_index_string
41 from sfc.symbolic_utils import symbols, symbol
42 from sfc.representation.swiginac_eval import SwiginacEvaluator, EvaluateAsSwiginac
43
45
46 Vcount = [len(vi) for vi in Vin]
47
48 Pset = set(P)
49 for i in P:
50 for j in Vout[i]:
51 if j in Pset:
52 Vcount[j] -= 1
53
54 return dict((i, (Vcount[i] == 0)) for i in P)
55
57 return (e.nops() == 0)
58
59
60
61
62
65 self.itgrep = itgrep
66 self.integral = integral
67
68
69 metadata = integral.measure().metadata()
70 self.integration_method = itgrep.options.integration_method
71 self.integration_order = itgrep.options.integration_order
72
73 self.quad_rule = None
74 self.facet_quad_rules = None
75
76 self.G = None
77 self.Vdeps = None
78 self.Vin_count = None
79 self.partitions = {}
80 self.evaluator = None
81 self.evaluate = None
82
83 self.V_symbols = {}
84 self.vertex_data_set = {}
85
86 - def store(self, value, i, basis_functions):
87
88 vd = self.vertex_data_set.get(basis_functions)
89 if vd is None:
90 vd = {}
91 self.vertex_data_set[basis_functions] = vd
92 vd[i] = value
93
94 - def get_vd(self, j, basis_functions):
95 jkeep = [(("v%d" % k) in self.Vdeps[j]) for k in range(self.itgrep.formrep.rank)]
96 iota, = indices_subset([basis_functions], jkeep)
97 if all(i is None for i in iota):
98 iota = ()
99 vd = self.vertex_data_set[iota]
100 return vd
101
103 vd = self.get_vd(j, basis_functions)
104 return vd[j]
105
107 "Delete stored data for j."
108 vd = self.get_vd(j, basis_functions)
109
110
112 - def __init__(self, integrals, formrep, on_facet):
113 sfc_debug("Entering IntegralRepresentation.__init__")
114
115 self.integrals = integrals
116 self.formrep = formrep
117 self._on_facet = on_facet
118 self.classname = formrep.itg_names[integrals[0]]
119
120 self.options = self.formrep.options.code.integral
121
122
123 A_shape = []
124 for i in range(self.formrep.rank):
125 element = self.formrep.formdata.elements[i]
126 rep = self.formrep.element_reps[element]
127 A_shape.append(rep.local_dimension)
128 self.A_shape = tuple(A_shape)
129
130
131 self.indices = compute_indices(self.A_shape)
132
133
134 Asym = symbols("A[%s]" % row_major_index_string(i, self.A_shape) for i in self.indices)
135 self.A_sym = dict(izip(self.indices, Asym))
136
137
138 self._free_symbols = []
139 self._symbol_counter = 0
140
141
142 self.symbolic_integral = None
143 self.quadrature_integrals = []
144 self.integral_data = {}
145
146
147 fd = self.formrep.formdata
148 for integral in self.integrals:
149 data = IntegralData(self, integral)
150
151 if data.integration_method == "symbolic":
152 self.symbolic_integral = integral
153 data.evaluator = SwiginacEvaluator(self.formrep, use_symbols=False, on_facet=self._on_facet)
154
155 elif data.integration_method == "quadrature":
156
157
158 data.quad_rule = self.formrep.quad_rule
159 data.facet_quad_rules = self.formrep.facet_quad_rules
160
161 self.quadrature_integrals.append(integral)
162 if self.options.safemode:
163
164 data.evaluator = SwiginacEvaluator(self.formrep, use_symbols=True, on_facet=self._on_facet)
165
166 else:
167
168 data.evaluate = EvaluateAsSwiginac(self.formrep, self, data, on_facet=self._on_facet)
169
170
171 integrand = integral.integrand()
172 integrand = strip_variables(integrand)
173 if self.options.use_expand_indices2:
174 integrand = expand_indices2(integrand)
175 else:
176 integrand = expand_indices(integrand)
177
178
179 data.G = Graph(integrand)
180
181
182 V = data.G.V()
183 n = len(V)
184 data.Vin_count = [0]*n
185 for i, vs in enumerate(data.G.Vin()):
186 for j in vs:
187 vj = V[j]
188
189 data.Vin_count[i] += 1
190
191
192
193
194
195
196
197
198
199 data.partitions, data.Vdeps = partition(data.G)
200
201 self.integral_data[integral.measure()] = data
202
203 sfc_debug("Leaving IntegralRepresentation.__init__")
204
205
206
208 s = ""
209 s += "IntegralRepresentation:\n"
210 s += " classname: %s\n" % self.classname
211 s += " A_shape: %s\n" % self.A_shape
212 s += " indices: %s\n" % self.indices
213 s += " A_sym: %s\n" % self.A_sym
214
215
216
217
218
219
220 s += " UFL Integrals:\n"
221 for integral in self.integrals:
222 s += indent(str(integral)) + "\n\n"
223 return s.strip("\n")
224
225
226
228 """Allocate symbol(s) for vertex j. Either gets symbol from the free
229 symbol set or creates a new symbol and increases the symbol counter."""
230 if self._free_symbols:
231 return self._free_symbols.pop()
232 s = symbol("s[%d]" % self._symbol_counter)
233 self._symbol_counter += 1
234 data.V_symbols[j] = s
235 return s
236
238 "Delete stored data for j and make its symbols available again."
239 s = data.V_symbols.get(j)
240 if s is None:
241 sfc_debug("Trying to deallocate symbols that are not allocated!")
242 return
243 self._free_symbols.append(s)
244 del data.V_symbols[j]
245
246
247
249 sfc_debug("Entering IntegralRepresentation.iter_partition")
250
251 deps = frozenset(deps)
252
253 P = data.partitions.get(deps)
254 if not P:
255 sfc_debug("Leaving IntegralRepresentation.iter_partition, empty")
256 return
257
258 data.evaluate.current_basis_function = basis_functions
259
260
261 G = data.G
262 V = G.V()
263 E = G.E()
264 Vin = G.Vin()
265 Vout = G.Vout()
266
267
268 is_local = find_locals(Vin, Vout, P)
269
270
271 for i in P:
272 v = V[i]
273
274 if isinstance(v, UtilityType):
275
276 continue
277
278 if v.shape():
279
280 if not isinstance(v, (Terminal, SpatialDerivative)):
281 print "="*30
282 print "type:", type(v)
283 print "str:", str(v)
284 print "child types:", [type(V[j]) for j in Vout[i]]
285 print "child str:"
286 print "\n".join( " vertex %d: %s" % (j, str(V[j])) for j in Vout[i] )
287 print "number of parents:", len(Vin[i])
288 if len(Vin[i]) < 5:
289 print "parent types:", [type(V[j]) for j in Vin[i]]
290
291
292 sfc_error("Expecting all indexing to have been propagated to terminals?")
293 continue
294
295 if Vin[i] and all(isinstance(V[j], SpatialDerivative) for j in Vin[i]):
296
297
298
299 if not isinstance(v, (Terminal, SpatialDerivative, Indexed)):
300 print "="*30
301 print type(v)
302 print str(v)
303 sfc_error("Expecting all indexing to have been propagated to terminals?")
304 continue
305
306 if isinstance(v, (Indexed, SpatialDerivative)):
307 ops = v.operands()
308
309
310 if not all(isinstance(o, (Expr, swiginac.basic)) for o in ops):
311 print ";"*80
312 print tree_format(v)
313 print str(v)
314 print type(ops)
315 print str(ops)
316 print repr(ops)
317 print "types:"
318 print "\n".join(str(type(o)) for o in ops)
319 print ";"*80
320
321 e = data.evaluate(v, *ops)
322
323 else:
324
325
326
327 ops = []
328 for j in Vout[i]:
329 try:
330
331 e = data.fetch_storage(j, basis_functions)
332 except:
333 print "Failed to fetch expression for vertex %d," % j
334 print " V[%d] = %s" % (j, repr(V[j]))
335 print " parent V[%d] = %s" % (i, repr(V[i]))
336 raise RuntimeError
337 ops.append(e)
338 ops = tuple(ops)
339 e = data.evaluate(v, *ops)
340
341
342
343
344
345
346 for j in Vout[i]:
347 data.Vin_count[j] -= 1
348 if False:
349
350 self.free_symbols(data, j)
351 data.free_storage(j, basis_functions)
352
353
354 if is_simple(e):
355
356
357 data.store(e, i, basis_functions)
358 else:
359 if is_local[i]:
360 pass
361
362
363 s = self.allocate(data, i)
364 data.store(s, i, basis_functions)
365
366 yield (s, e)
367
368 sfc_debug("Leaving IntegralRepresentation.iter_partition")
369
370
371
372
373
374
375
376
377
378
379
381 "Return an iterator over member tokens dependent of spatial variables. Overload in subclasses!"
382
383 assert data.integration_method == "quadrature"
384
385
386
387
388
389
390 fr = self.formrep
391 fd = fr.formdata
392 generated = set()
393 for iarg in range(fr.rank + fr.num_coefficients):
394 elm = fd.elements[iarg]
395 rep = fr.element_reps[elm]
396 for i in range(rep.local_dimension):
397 for component in rep.value_components:
398
399 s = fr.v_sym(iarg, i, component, self._on_facet)
400 if not (s == 0 or s in generated):
401 e = fr.v_expr(iarg, i, component)
402 t = (s, e)
403 yield t
404 generated.add(s)
405
406 for d in range(fr.cell.nsd):
407 der = (d,)
408 s = fr.dv_sym(iarg, i, component, der, self._on_facet)
409 if not (s == 0 or s in generated):
410 e = fr.dv_expr(iarg, i, component, der)
411 t = (s, e)
412 yield t
413 generated.add(s)
414
416 "Return an iterator over geometry tokens independent of spatial variables. Overload in subclasses!"
417 fr = self.formrep
418
419
420
421
422 for (ss,ee) in zip(fr.vx_sym, fr.vx_expr):
423 for i in range(ss.nops()):
424 yield (ss.op(i), ee.op(i))
425
426
427 (ss,ee) = (fr.G_sym, fr.G_expr)
428 for i in range(ss.nops()):
429 yield (ss.op(i), ee.op(i))
430
431
432 yield (fr.detGtmp_sym, fr.detGtmp_expr)
433 yield (fr.detG_sym, fr.detG_expr)
434
435
436 (ss,ee) = (fr.Ginv_sym, fr.Ginv_expr)
437 for i in range(ss.nops()):
438 yield (ss.op(i), ee.op(i))
439
440 if self._on_facet:
441
442 yield (fr.detG_sign_sym, fr.detG_sign_expr)
443 else:
444 if self.symbolic_integral is not None:
445
446 yield (fr.D_sym, fr.detG_sym)
447
449 "Return an iterator over runtime tokens dependent of spatial variables. Overload in subclasses!"
450 assert data.integration_method == "quadrature"
451
452
453
454
455
456
457 fr = self.formrep
458 fd = fr.formdata
459 generated = set()
460 for iarg in range(fr.rank + fr.num_coefficients):
461 elm = fd.elements[iarg]
462 rep = fr.element_reps[elm]
463 for i in range(rep.local_dimension):
464 for component in rep.value_components:
465
466 for d in range(fr.cell.nsd):
467 der = (d,)
468 s = fr.Dv_sym(iarg, i, component, der, self._on_facet)
469 if not (s == 0 or s in generated):
470 e = fr.Dv_expr(iarg, i, component, der, True, self._on_facet)
471 t = (s, e)
472 yield t
473 generated.add(s)
474
475
476 generated = set()
477 for iarg in range(fr.num_coefficients):
478 elm = fd.elements[fr.rank+iarg]
479 rep = fr.element_reps[elm]
480 for component in rep.value_components:
481
482 s = fr.w_sym(iarg, component)
483 if not s in generated:
484 e = fr.w_expr(iarg, component, True, self._on_facet)
485 t = (s, e)
486 yield t
487 generated.add(s)
488
489 for d in range(fr.cell.nsd):
490 der = (d,)
491 s = fr.Dw_sym(iarg, component, der)
492 if not (s == 0 or s in generated):
493 e = fr.Dw_expr(iarg, component, der, True, self._on_facet)
494 t = (s, e)
495 yield t
496 generated.add(s)
497
498
499 if self._on_facet:
500
501
502
503 D_expr = fr.quad_weight_sym*fr.facet_D_sym
504 else:
505
506 D_expr = fr.quad_weight_sym*fr.detG_sym
507 yield (fr.D_sym, D_expr)
508
509
510
512 "Iterate over all A[iota] tokens."
513 for iota in self.indices:
514 A_sym = self.A_sym[iota]
515 A_expr = self.compute_A(data, iota, facet)
516 yield (A_sym, A_expr)
517
518 - def compute_A(self, data, iota, facet=None):
519 "Compute expression for A[iota]. Overload in subclasses!"
520 raise NotImplementedError
521