1
2
3 """
4 This module contains code generation tools for the ufc::dofmap class.
5 """
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27 import ufl
28 from sfc.codegeneration.codeformatting import indent, CodeFormatter, gen_token_assignments
29 from sfc.geometry import gen_geometry_code
30 from sfc.symbolic_utils import symbol, symbols
31 from sfc.common import sfc_assert, sfc_error, sfc_warning, sfc_info
32
35 self.rep = elementrep
36 self.classname = elementrep.dof_map_classname
37 self.signature = repr(self.rep.ufl_element)
38 self.options = self.rep.options.code.dof_map
39
40 if self.options.enable_dof_ptv:
41
42 vars = ["global_component_stride", "loc2glob_size"]
43 self.constructor_vars = vars
44 self.constructor_arg_string = ", ".join(["unsigned int %s_" % v for v in vars])
45 self.constructor_arg_string2 = ", ".join(vars)
46
48 l = []
49 if self.options.enable_dof_ptv:
50 l.extend(["Ptv.h", "DofT.h", "Dof_Ptv.h"])
51 return l
52
56
58 vars = {
59 'classname' : self.classname,
60 'constructor' : indent(self.gen_constructor()),
61 "constructor_arguments" : indent(self.gen_constructor_arguments()),
62 "initializer_list" : indent(self.gen_initializer_list()),
63 'destructor' : indent(self.gen_destructor()),
64 "create" : indent(self.gen_create()),
65 'signature' : indent(self.gen_signature()),
66 'needs_mesh_entities' : indent(self.gen_needs_mesh_entities()),
67 'init_mesh' : indent(self.gen_init_mesh()),
68 'init_cell' : indent(self.gen_init_cell()),
69 'init_cell_finalize' : indent(self.gen_init_cell_finalize()),
70 'global_dimension' : indent(self.gen_global_dimension()),
71 'local_dimension' : indent(self.gen_local_dimension()),
72 'max_local_dimension' : indent(self.gen_max_local_dimension()),
73 'geometric_dimension' : indent(self.gen_geometric_dimension()),
74 "topological_dimension" : indent(self.gen_topological_dimension()),
75 'num_facet_dofs' : indent(self.gen_num_facet_dofs()),
76 'num_entity_dofs' : indent(self.gen_num_entity_dofs()),
77 'tabulate_dofs' : indent(self.gen_tabulate_dofs()),
78 'tabulate_facet_dofs' : indent(self.gen_tabulate_facet_dofs()),
79 'tabulate_entity_dofs' : indent(self.gen_tabulate_entity_dofs()),
80 'tabulate_coordinates' : indent(self.gen_tabulate_coordinates()),
81 'num_sub_dofmaps' : indent(self.gen_num_sub_dofmaps()),
82 'create_sub_dofmap' : indent(self.gen_create_sub_dofmap()),
83 'members' : indent(self.gen_members()),
84 }
85 return vars
86
89
92
95
98
100 code = "return new %s();" % self.classname
101 return code
102
104 """const char* signature() const"""
105 return 'return "%s";' % self.signature
106
108 """bool needs_mesh_entities(unsigned int d) const"""
109 if self.rep.ufl_element.family() == "Real":
110 return 'return false;'
111
112
113 needs = tuple( [ ('true' if n else 'false') for n in self.rep.num_entity_dofs] )
114
115 code = CodeFormatter()
116 code.begin_switch("d")
117 for i, n in enumerate(needs):
118 code += "case %d: return %s;" % (i, n)
119 code.end_switch()
120 code += 'throw std::runtime_error("Invalid dimension in needs_mesh_entities.");'
121 return str(code)
122
124 """bool init_mesh(const mesh& m)"""
125 nsd = self.rep.cell.nsd
126 if self.rep.ufl_element.family() == "Real":
127 return 'return false;'
128
129 if not self.options.enable_dof_ptv:
130
131 num_entities = symbols(["m.num_entities[%d]" % i for i in range(nsd+1)])
132 global_dimension = sum(self.rep.num_entity_dofs[i]*num_entities[i] for i in range(nsd+1))
133 code = '_global_dimension = %s;\n' % global_dimension.printc()
134 code += "return false;"
135 return code
136
137
138 if isinstance(self.rep.ufl_element, ufl.MixedElement):
139 assert isinstance(self.rep.ufl_element, (ufl.VectorElement, ufl.TensorElement))
140 local_component_stride = (self.rep.local_dimension // len(self.rep.sub_elements))
141 code= CodeFormatter()
142 code += "// allocating space for loc2glob map"
143 code += "dof.init(m.num_entities[%d], %d);" % (nsd, local_component_stride)
144 code += "loc2glob_size = m.num_entities[%d] * %d;\n" % (nsd, local_component_stride)
145 code += "return true;"
146 return str(code)
147
149 """void init_cell(const mesh& m, const cell& c)"""
150
151 if not self.options.enable_dof_ptv:
152 return ""
153 if self.rep.ufl_element.family() == "Real":
154 return ""
155
156
157
158
159 nsd = self.rep.cell.nsd
160 nbf = self.rep.local_dimension
161
162 code = CodeFormatter()
163 code.new_text( gen_geometry_code(nsd, detG=False) )
164 code += "unsigned int element = c.entity_indices[%d][0];" % (nsd)
165
166 nsc = len(self.rep.sub_elements)
167
168 if nsc > 1:
169 code += "// ASSUMING HERE THAT THE DOFS FOR EACH SUB COMPONENT ARE GROUPED"
170 code += "// Only counting and numbering dofs for a single sub components"
171
172 for i in range(nbf // nsc):
173 x_strings = [self.rep.dof_x[i][d].printc() for d in range(nsd)]
174 dof_vals = ", ".join( x_strings )
175 num_dof_vals = nsd
176
177 if nsc > 1:
178 assert nsc == self.rep.value_size
179
180
181
182
183
184
185
186
187 code += ""
188 code += "double dof%d[%d] = { %s };" % (i, num_dof_vals, dof_vals)
189 code += "Ptv pdof%d(%d, dof%d);" % (i, num_dof_vals, i)
190 code += "dof.insert_dof(element, %d, pdof%d);" % (i, i)
191
192 return str(code)
193
195 """void init_cell_finalize()"""
196 if self.rep.ufl_element.family() == "Real":
197 return ""
198
199 code = ""
200 if self.options.enable_dof_ptv:
201
202 code += "loc2glob = dof.get_loc2glob_array();\n"
203
204 code += "global_component_stride = dof.global_dimension();\n"
205
206 code += '_global_dimension = global_component_stride * %d;\n' % len(self.rep.sub_elements)
207
208 code += 'dof.clear();\n'
209 return code
210
212 """unsigned int global_dimension() const"""
213 if self.rep.ufl_element.family() == "Real":
214 return 'return %d;' % self.rep.value_size
215 return 'return _global_dimension;'
216
218 """unsigned int local_dimension(const cell& c) const"""
219 return 'return %d;' % self.rep.local_dimension
220
222 """unsigned int max_local_dimension() const"""
223 return 'return %d;' % self.rep.local_dimension
224
226 """unsigned int geometric_dimension() const"""
227 return 'return %d;' % self.rep.cell.nsd
228
230 return "return %d;" % self.rep.cell.nsd
231
233 """unsigned int num_facet_dofs() const"""
234 return "return %d;" % self.rep.num_facet_dofs
235
237 """unsigned int num_entity_dofs(unsigned int d) const"""
238 if self.rep.ufl_element.family() == "Real":
239 return 'return 0;'
240 code = CodeFormatter()
241 code.begin_switch("d")
242 for i in range(self.rep.cell.nsd+1):
243 code.begin_case(i)
244 code += "return %d;" % self.rep.num_entity_dofs[i]
245 code.end_case()
246 code.end_switch()
247 code += 'throw std::runtime_error("Invalid entity dimension.");'
248 return str(code)
249
258
260 """void tabulate_dofs(unsigned int* dofs,
261 const mesh& m,
262 const cell& c) const"""
263 code = CodeFormatter()
264
265 cell = self.rep.cell
266 nsd = cell.nsd
267
268
269 mesh_num_entities = symbols("m.num_entities[%d]" % d for d in range(nsd+1))
270 cell_entity_indices = []
271 for d in range(nsd+1):
272 cell_entity_indices += [symbols( "c.entity_indices[%d][%d]" % (d, i) for i in range(cell.num_entities[d]) )]
273
274 def iter_sub_elements(rep):
275 "Flatten the sub element hierarchy into a list."
276 if rep.sub_elements:
277 for r in rep.sub_elements:
278 for s in iter_sub_elements(r):
279 yield s
280 else:
281 yield rep
282
283
284 local_subelement_offset = 0
285 global_subelement_offset = symbol("global_subelement_offset")
286 code += "int %s = 0;" % global_subelement_offset
287 for rep in iter_sub_elements(self.rep):
288
289 local_entity_offset = 0
290 global_entity_offset = 0
291 tokens = []
292 for d in range(nsd+1):
293
294
295 for i in range(cell.num_entities[d]):
296
297
298 entity_index = cell_entity_indices[d][i]
299
300
301 entity_dofs = rep.entity_dofs[d][i]
302 sfc_assert(len(entity_dofs) == rep.num_entity_dofs[d], "Inconsistency in entity dofs.")
303
304 for (j,dof) in enumerate(entity_dofs):
305 local_value = entity_index * rep.num_entity_dofs[d] + j
306 value = global_subelement_offset + global_entity_offset + local_value
307 name = symbol("dofs[%d]" % (local_subelement_offset + dof))
308 tokens.append((name, value))
309
310
311 local_entity_offset += cell.num_entities[d] * rep.num_entity_dofs[d]
312 global_entity_offset += mesh_num_entities[d] * rep.num_entity_dofs[d]
313
314
315 sfc_assert(rep.local_dimension == len(tokens), "Collected too few dof tokens!")
316 local_subelement_offset += rep.local_dimension
317 global_subelement_size = global_entity_offset
318
319 code.begin_block()
320 code += "// Subelement with signature: %s" % rep.signature
321 code += gen_token_assignments(tokens)
322 code += "%s += %s;" % (global_subelement_offset.printc(), global_subelement_size.printc())
323 code.end_block()
324
325 sfc_assert(local_subelement_offset == self.rep.local_dimension,
326 "Dof computation didn't accumulate correctly!")
327 return str(code)
328
330 """void tabulate_dofs(unsigned int* dofs,
331 const mesh& m,
332 const cell& c) const"""
333 if isinstance(self.rep.ufl_element, ufl.MixedElement):
334 assert isinstance(self.rep.ufl_element, (ufl.VectorElement, ufl.TensorElement))
335 local_component_stride = (self.rep.local_dimension // len(self.rep.sub_elements))
336
337 code = CodeFormatter()
338 code += "const unsigned int global_element_offset = %d * c.entity_indices[%d][0];" % (local_component_stride, self.rep.cell.nsd)
339 code += "const unsigned int *scalar_dofs = loc2glob.get() + global_element_offset;"
340
341 code += "for(unsigned int iloc=0; iloc<%d; iloc++)" % local_component_stride
342 code.begin_block()
343
344 code += "const unsigned int global_scalar_dof = scalar_dofs[iloc];"
345 for i in range(len(self.rep.sub_elements)):
346 code += "dofs[iloc + %d * %d] = global_scalar_dof + global_component_stride * %d;" % (local_component_stride, i, i)
347
348 code.end_block()
349
350 return str(code)
351
353 """void tabulate_facet_dofs(unsigned int* dofs,
354 unsigned int facet) const
355 This implementation should be general for elements with point evaluation dofs on simplices.
356 """
357 if self.rep.ufl_element.family() == "Real":
358 return 'throw std::runtime_error("tabulate_facet_dofs not implemented for Real elements.");'
359
360 code = CodeFormatter()
361 code.begin_switch("facet")
362 for i, fd in enumerate(self.rep.facet_dofs):
363 code.begin_case(i)
364 for j, d in enumerate(fd):
365 code += "dofs[%d] = %d;" % (j, d)
366 code.end_case()
367 code += "default:"
368 code.indent()
369 code += 'throw std::runtime_error("Invalid facet number.");'
370 code.dedent()
371 code.end_switch()
372
373 return str(code)
374
376 """void tabulate_entity_dofs(unsigned int* dofs,
377 unsigned int d, unsigned int i) const
378 """
379 if self.rep.ufl_element.family() == "Real":
380 return 'throw std::runtime_error("tabulate_entity_dofs not implemented for Real elements.");'
381 code = CodeFormatter()
382
383 code.begin_switch("d")
384 for d in range(self.rep.cell.nsd+1):
385 if any(self.rep.entity_dofs[d]):
386 code.begin_case(d)
387 code.begin_switch("i")
388 n = self.rep.cell.num_entities[d]
389 for i in range(n):
390
391 dofs_on_entity = self.rep.entity_dofs[d][i]
392 sfc_assert(len(dofs_on_entity) == self.rep.num_entity_dofs[d], "Inconsistency in entity dofs.")
393 code.begin_case(i)
394 for k, ed in enumerate(dofs_on_entity):
395 code += "dofs[%d] = %d;" % (k, ed)
396 code.end_case()
397 code.end_switch()
398 code.end_case()
399 code.end_switch()
400 return str(code)
401
403 """void tabulate_coordinates(double** coordinates,
404 const cell& c) const"""
405 if self.rep.ufl_element.family() == "Real":
406 return 'throw std::runtime_error("tabulate_coordinates not implemented for Real elements.");'
407 code = CodeFormatter()
408 code += gen_geometry_code(self.rep.cell.nsd, detG=False)
409 for i in range(self.rep.local_dimension):
410 for k in range(self.rep.cell.nsd):
411
412 code += "coordinates[%d][%d] = %s;" % (i, k, self.rep.dof_x[i][k].printc())
413 return str(code)
414
416 """unsigned int num_sub_dofmaps() const"""
417 return "return %d;" % len(self.rep.sub_elements)
418
420 """dofmap* create_sub_dofmap(unsigned int i) const"""
421 if self.options.enable_dof_ptv:
422 if len(self.rep.sub_elements) > 1:
423 code = CodeFormatter()
424 code.begin_switch("i")
425 for i, fe in enumerate(self.rep.sub_elements):
426 code += "case %d: return new %s(loc2glob, %s);" % (i, fe.dof_map_classname, self.constructor_arg_string2)
427 code.end_switch()
428 code += 'throw std::runtime_error("Invalid index in create_sub_dofmap.");'
429 else:
430 code = "return new %s(loc2glob, %s);" % (self.classname, self.constructor_arg_string2)
431 else:
432 if len(self.rep.sub_elements) > 1:
433 code = CodeFormatter()
434 code.begin_switch("i")
435 for i, fe in enumerate(self.rep.sub_elements):
436 code += "case %d: return new %s();" % (i, fe.dof_map_classname)
437 code.end_switch()
438 code += 'throw std::runtime_error("Invalid index in create_sub_dofmap.");'
439 else:
440 code = "return new %s();" % self.classname
441 return str(code)
442
444 cell = self.rep.cell
445 nsd = cell.nsd
446 code = CodeFormatter()
447
448
449
450 code += "public:"
451 code.indent()
452 if self.rep.ufl_element.family() != "Real":
453 code += "unsigned int _global_dimension;"
454
455 if self.options.enable_dof_ptv:
456 code += "Dof_Ptv dof;"
457 code += "std::tr1::shared_ptr<unsigned int> loc2glob;"
458 code += "unsigned int global_component_stride;"
459 code += 'unsigned int loc2glob_size;'
460 code.dedent()
461
462 if self.options.enable_dof_ptv:
463
464 code += "public:"
465 code.indent()
466 args = self.constructor_arg_string
467 code += "%s(std::tr1::shared_ptr<unsigned int> loc2glob, %s);" % (self.classname, args)
468 code.dedent()
469
470 return str(code)
471
473 """Generate local utility functions."""
474 nsd = self.rep.cell.nsd
475
476 code = CodeFormatter()
477
478
479
480
481
482 if self.options.enable_dof_ptv:
483
484 code += "%s::%s(std::tr1::shared_ptr<unsigned int> loc2glob_, %s):" % (self.classname, self.classname, self.constructor_arg_string)
485 code.indent()
486 code += "loc2glob(loc2glob_),"
487 for v in self.constructor_vars[:-1]:
488 code += "%s(%s_)," % (v, v)
489 v = self.constructor_vars[-1]
490 code += "%s(%s_)" % (v, v)
491 code.dedent()
492
493 return str(code)
494