Package sfc :: Package codegeneration :: Module dofmapcg
[hide private]
[frames] | no frames]

Source Code for Module sfc.codegeneration.dofmapcg

  1  #!/usr/bin/env python 
  2  # -*- coding: utf-8 -*- 
  3  """ 
  4  This module contains code generation tools for the ufc::dofmap class. 
  5  """ 
  6   
  7  # Copyright (C) 2008-2009 Martin Sandve Alnes and Simula Resarch Laboratory 
  8  # 
  9  # This file is part of SyFi. 
 10  # 
 11  # SyFi is free software: you can redistribute it and/or modify 
 12  # it under the terms of the GNU General Public License as published by 
 13  # the Free Software Foundation, either version 2 of the License, or 
 14  # (at your option) any later version. 
 15  # 
 16  # SyFi is distributed in the hope that it will be useful, 
 17  # but WITHOUT ANY WARRANTY; without even the implied warranty of 
 18  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 
 19  # GNU General Public License for more details. 
 20  # 
 21  # You should have received a copy of the GNU General Public License 
 22  # along with SyFi. If not, see <http://www.gnu.org/licenses/>. 
 23  # 
 24  # First added:  2008-08-13 
 25  # Last changed: 2009-05-14 
 26   
 27  import ufl 
 28  from sfc.codegeneration.codeformatting import indent, CodeFormatter, gen_token_assignments 
 29  from sfc.geometry import gen_geometry_code 
 30  from sfc.symbolic_utils import symbol, symbols 
 31  from sfc.common import sfc_assert, sfc_error, sfc_warning, sfc_info 
 32  from sfc.common.utilities import unique 
 33  import ufc_utils 
 34   
35 -class DofMapCG:
36 - def __init__(self, elementrep):
37 self.rep = elementrep 38 self.classname = elementrep.dof_map_classname 39 self.signature = repr(self.rep.ufl_element) 40 self.options = self.rep.options.code.dof_map 41 42 if self.options.enable_dof_ptv: 43 # variables for full initialization by construction, reused a few places: 44 vars = ["global_component_stride", "loc2glob_size"] 45 self.constructor_vars = vars 46 self.constructor_arg_string = ", ".join(["unsigned int %s_" % v for v in vars]) 47 self.constructor_arg_string2 = ", ".join(vars)
48
49 - def hincludes(self):
50 if self.options.enable_dof_ptv: 51 return ["Ptv.h", "DofT.h", "Dof_Ptv.h"] 52 else: 53 return []
54
55 - def cincludes(self):
56 return []
57
58 - def local_includes(self):
59 return unique("%s.h" % e.dof_map_classname 60 for e in self.rep.sub_elements)
61
62 - def header_template(self):
63 return ufc_utils.dofmap_header
64
65 - def impl_template(self):
66 return ufc_utils.dofmap_implementation
67
68 - def generate_code_dict(self):
69 vars = { 70 'classname' : self.classname, 71 'constructor' : indent(self.gen_constructor()), 72 "constructor_arguments" : indent(self.gen_constructor_arguments()), 73 "initializer_list" : indent(self.gen_initializer_list()), 74 'destructor' : indent(self.gen_destructor()), 75 "create" : indent(self.gen_create()), 76 'signature' : indent(self.gen_signature()), 77 'needs_mesh_entities' : indent(self.gen_needs_mesh_entities()), 78 'init_mesh' : indent(self.gen_init_mesh()), 79 'init_cell' : indent(self.gen_init_cell()), 80 'init_cell_finalize' : indent(self.gen_init_cell_finalize()), 81 'global_dimension' : indent(self.gen_global_dimension()), 82 'local_dimension' : indent(self.gen_local_dimension()), 83 'max_local_dimension' : indent(self.gen_max_local_dimension()), 84 'geometric_dimension' : indent(self.gen_geometric_dimension()), 85 "topological_dimension" : indent(self.gen_topological_dimension()), 86 'num_facet_dofs' : indent(self.gen_num_facet_dofs()), 87 'num_entity_dofs' : indent(self.gen_num_entity_dofs()), 88 'tabulate_dofs' : indent(self.gen_tabulate_dofs()), 89 'tabulate_facet_dofs' : indent(self.gen_tabulate_facet_dofs()), 90 'tabulate_entity_dofs' : indent(self.gen_tabulate_entity_dofs()), 91 'tabulate_coordinates' : indent(self.gen_tabulate_coordinates()), 92 'num_sub_dofmaps' : indent(self.gen_num_sub_dofmaps()), 93 'create_sub_dofmap' : indent(self.gen_create_sub_dofmap()), 94 'members' : indent(self.gen_members()), 95 } 96 return vars
97
98 - def gen_constructor(self):
99 return ""
100
102 return ""
103
104 - def gen_initializer_list(self):
105 return ""
106
107 - def gen_destructor(self):
108 return ""
109
110 - def gen_create(self):
111 code = "return new %s();" % self.classname 112 return code
113
114 - def gen_signature(self):
115 """const char* signature() const""" 116 return 'return "%s";' % self.signature
117
118 - def gen_needs_mesh_entities(self):
119 """bool needs_mesh_entities(unsigned int d) const""" 120 if self.rep.ufl_element.family() == "Real": 121 return 'return false;' 122 123 # pick the mesh entities we need 124 needs = tuple( [ ('true' if n else 'false') for n in self.rep.num_entity_dofs] ) 125 # return false when a type of mesh entities are not needed 126 code = CodeFormatter() 127 code.begin_switch("d") 128 for i, n in enumerate(needs): 129 code += "case %d: return %s;" % (i, n) 130 code.end_switch() 131 code += 'throw std::runtime_error("Invalid dimension in needs_mesh_entities.");' 132 return str(code)
133
134 - def gen_init_mesh(self):
135 """bool init_mesh(const mesh& m)""" 136 nsd = self.rep.cell.nsd 137 if self.rep.ufl_element.family() == "Real": 138 return 'return false;' 139 140 if not self.options.enable_dof_ptv: 141 # compute and store global dimension 142 num_entities = symbols(["m.num_entities[%d]" % i for i in range(nsd+1)]) 143 global_dimension = sum(self.rep.num_entity_dofs[i]*num_entities[i] for i in range(nsd+1)) 144 code = '_global_dimension = %s;\n' % global_dimension.printc() 145 code += "return false;" 146 return code 147 148 # This code doesn't work for general mixed elements! 149 if isinstance(self.rep.ufl_element, ufl.MixedElement): 150 assert isinstance(self.rep.ufl_element, (ufl.VectorElement, ufl.TensorElement)) 151 local_component_stride = (self.rep.local_dimension // len(self.rep.sub_elements)) 152 code= CodeFormatter() 153 code += "// allocating space for loc2glob map" 154 code += "dof.init(m.num_entities[%d], %d);" % (nsd, local_component_stride) 155 code += "loc2glob_size = m.num_entities[%d] * %d;\n" % (nsd, local_component_stride) 156 code += "return true;" 157 return str(code)
158
159 - def gen_init_cell(self):
160 """void init_cell(const mesh& m, const cell& c)""" 161 162 if not self.options.enable_dof_ptv: 163 return "" 164 if self.rep.ufl_element.family() == "Real": 165 return "" 166 167 # FIXME: This code needs updating, e.g. it doesn't handle mixed elements. 168 169 #fe = self.rep.syfi_element 170 nsd = self.rep.cell.nsd 171 nbf = self.rep.local_dimension 172 173 code = CodeFormatter() 174 code.new_text( gen_geometry_code(nsd, detG=False) ) 175 code += "unsigned int element = c.entity_indices[%d][0];" % (nsd) 176 177 nsc = len(self.rep.sub_elements) 178 179 if nsc > 1: 180 code += "// ASSUMING HERE THAT THE DOFS FOR EACH SUB COMPONENT ARE GROUPED" 181 code += "// Only counting and numbering dofs for a single sub components" 182 183 for i in range(nbf // nsc): 184 x_strings = [self.rep.dof_x[i][d].printc() for d in range(nsd)] 185 dof_vals = ", ".join( x_strings ) 186 num_dof_vals = nsd 187 188 if nsc > 1: 189 assert nsc == self.rep.value_size 190 # This code could be useful later if for some reason we need to revert 191 # to the old numbering scheme where sub component dofs are intertwined: 192 #dofi = fe.dof(i) 193 #assert isinstance(dofi[0], list) 194 #for d in dofi[1:]: 195 # dof_vals += ", " + d.printc() 196 # num_dof_vals += 1 197 198 code += "" 199 code += "double dof%d[%d] = { %s };" % (i, num_dof_vals, dof_vals) 200 code += "Ptv pdof%d(%d, dof%d);" % (i, num_dof_vals, i) 201 code += "dof.insert_dof(element, %d, pdof%d);" % (i, i) 202 203 return str(code)
204
205 - def gen_init_cell_finalize(self):
206 """void init_cell_finalize()""" 207 if self.rep.ufl_element.family() == "Real": 208 return "" 209 210 code = "" 211 if self.options.enable_dof_ptv: 212 #code += "dof.build_loc2glob();\n" 213 code += "loc2glob = dof.get_loc2glob_array();\n" 214 # constant for tabulating vector element dofs from a scalar element numbering 215 code += "global_component_stride = dof.global_dimension();\n" 216 # store global dimension 217 code += '_global_dimension = global_component_stride * %d;\n' % len(self.rep.sub_elements) 218 # clear temporary datastructures 219 code += 'dof.clear();\n' 220 return code
221
222 - def gen_global_dimension(self):
223 """unsigned int global_dimension() const""" 224 if self.rep.ufl_element.family() == "Real": 225 return 'return %d;' % self.rep.value_size 226 return 'return _global_dimension;'
227
228 - def gen_local_dimension(self):
229 """unsigned int local_dimension(const cell& c) const""" 230 return 'return %d;' % self.rep.local_dimension
231
232 - def gen_max_local_dimension(self):
233 """unsigned int max_local_dimension() const""" 234 return 'return %d;' % self.rep.local_dimension
235
236 - def gen_geometric_dimension(self):
237 """unsigned int geometric_dimension() const""" 238 return 'return %d;' % self.rep.cell.nsd
239
241 return "return %d;" % self.rep.cell.nsd
242
243 - def gen_num_facet_dofs(self):
244 """unsigned int num_facet_dofs() const""" 245 return "return %d;" % self.rep.num_facet_dofs
246
247 - def gen_num_entity_dofs(self):
248 """unsigned int num_entity_dofs(unsigned int d) const""" 249 if self.rep.ufl_element.family() == "Real": 250 return 'return 0;' 251 code = CodeFormatter() 252 code.begin_switch("d") 253 for i in range(self.rep.cell.nsd+1): 254 code.begin_case(i) 255 code += "return %d;" % self.rep.num_entity_dofs[i] 256 code.end_case() 257 code.end_switch() 258 code += 'throw std::runtime_error("Invalid entity dimension.");' 259 return str(code)
260
261 - def gen_tabulate_dofs(self):
262 if self.rep.ufl_element.family() == "Real": 263 return '\n'.join('dofs[%d] = %d;' % (i, i) for i in range(self.rep.value_size)) 264 265 if self.options.enable_dof_ptv: 266 return self.gen_tabulate_dofs__dof_ptv() 267 else: 268 return self.gen_tabulate_dofs__implicit()
269
271 """void tabulate_dofs(unsigned int* dofs, 272 const mesh& m, 273 const cell& c) const""" 274 code = CodeFormatter() 275 276 cell = self.rep.cell 277 nsd = cell.nsd 278 279 # symbols referencing entity index arrays 280 mesh_num_entities = symbols("m.num_entities[%d]" % d for d in range(nsd+1)) 281 cell_entity_indices = [] 282 for d in range(nsd+1): 283 cell_entity_indices += [symbols( "c.entity_indices[%d][%d]" % (d, i) for i in range(cell.num_entities[d]) )] 284 285 def iter_sub_elements(rep): 286 "Flatten the sub element hierarchy into a list." 287 if rep.sub_elements: 288 for r in rep.sub_elements: 289 for s in iter_sub_elements(r): 290 yield s 291 else: 292 yield rep
293 294 # (A) Iterate over all basic elements in order 295 local_subelement_offset = 0 296 global_subelement_offset = symbol("global_subelement_offset") 297 code += "int %s = 0;" % global_subelement_offset 298 for rep in iter_sub_elements(self.rep): 299 # (B) Loop over entity dimensions d in order 300 local_entity_offset = 0 301 global_entity_offset = 0 302 tokens = [] 303 for d in range(nsd+1): 304 # The offset for dofs in this loop is (global_subelement_offset + global_entity_offset) 305 # (C) Loop over entities (d,i) in order 306 for i in range(cell.num_entities[d]): 307 308 # this is the global mesh index of cell entity (d,i) 309 entity_index = cell_entity_indices[d][i] 310 311 # For each entity (d,i) we have a list of dofs 312 entity_dofs = rep.entity_dofs[d][i] 313 sfc_assert(len(entity_dofs) == rep.num_entity_dofs[d], "Inconsistency in entity dofs.") 314 315 for (j,dof) in enumerate(entity_dofs): 316 local_value = entity_index * rep.num_entity_dofs[d] + j 317 value = global_subelement_offset + global_entity_offset + local_value 318 name = symbol("dofs[%d]" % (local_subelement_offset + dof)) 319 tokens.append((name, value)) 320 321 # (B) Accumulate offsets to dofs on entities of dimension d 322 local_entity_offset += cell.num_entities[d] * rep.num_entity_dofs[d] 323 global_entity_offset += mesh_num_entities[d] * rep.num_entity_dofs[d] 324 325 # (A) Accumulate subelement offsets 326 sfc_assert(rep.local_dimension == len(tokens), "Collected too few dof tokens!") 327 local_subelement_offset += rep.local_dimension 328 global_subelement_size = global_entity_offset 329 330 code.begin_block() 331 code += "// Subelement with signature: %s" % rep.signature 332 code += gen_token_assignments(tokens) 333 code += "%s += %s;" % (global_subelement_offset.printc(), global_subelement_size.printc()) 334 code.end_block() 335 336 sfc_assert(local_subelement_offset == self.rep.local_dimension, 337 "Dof computation didn't accumulate correctly!") 338 return str(code)
339
340 - def gen_tabulate_dofs__dof_ptv(self):
341 """void tabulate_dofs(unsigned int* dofs, 342 const mesh& m, 343 const cell& c) const""" 344 if isinstance(self.rep.ufl_element, ufl.MixedElement): 345 assert isinstance(self.rep.ufl_element, (ufl.VectorElement, ufl.TensorElement)) 346 local_component_stride = (self.rep.local_dimension // len(self.rep.sub_elements)) 347 348 code = CodeFormatter() 349 code += "const unsigned int global_element_offset = %d * c.entity_indices[%d][0];" % (local_component_stride, self.rep.cell.nsd) 350 code += "const unsigned int *scalar_dofs = loc2glob.get() + global_element_offset;" 351 352 code += "for(unsigned int iloc=0; iloc<%d; iloc++)" % local_component_stride 353 code.begin_block() 354 355 code += "const unsigned int global_scalar_dof = scalar_dofs[iloc];" 356 for i in range(len(self.rep.sub_elements)): 357 code += "dofs[iloc + %d * %d] = global_scalar_dof + global_component_stride * %d;" % (local_component_stride, i, i) 358 359 code.end_block() 360 361 return str(code)
362
363 - def gen_tabulate_facet_dofs(self):
364 """void tabulate_facet_dofs(unsigned int* dofs, 365 unsigned int facet) const 366 This implementation should be general for elements with point evaluation dofs on simplices. 367 """ 368 if self.rep.ufl_element.family() == "Real": 369 return 'throw std::runtime_error("tabulate_facet_dofs not implemented for Real elements.");' 370 # generate code for each facet: for each facet i, tabulate local dofs[j] 371 code = CodeFormatter() 372 code.begin_switch("facet") 373 for i, fd in enumerate(self.rep.facet_dofs): 374 code.begin_case(i) 375 for j, d in enumerate(fd): 376 code += "dofs[%d] = %d;" % (j, d) 377 code.end_case() 378 code += "default:" 379 code.indent() 380 code += 'throw std::runtime_error("Invalid facet number.");' 381 code.dedent() 382 code.end_switch() 383 384 return str(code)
385
386 - def gen_tabulate_entity_dofs(self):
387 """void tabulate_entity_dofs(unsigned int* dofs, 388 unsigned int d, unsigned int i) const 389 """ 390 if self.rep.ufl_element.family() == "Real": 391 return 'throw std::runtime_error("tabulate_entity_dofs not implemented for Real elements.");' 392 code = CodeFormatter() 393 # define one case for each cell entity (d, i) 394 code.begin_switch("d") 395 for d in range(self.rep.cell.nsd+1): 396 if any(self.rep.entity_dofs[d]): 397 code.begin_case(d) 398 code.begin_switch("i") 399 n = self.rep.cell.num_entities[d] 400 for i in range(n): 401 # get list of local dofs associated with cell entity (d, i) 402 dofs_on_entity = self.rep.entity_dofs[d][i] 403 sfc_assert(len(dofs_on_entity) == self.rep.num_entity_dofs[d], "Inconsistency in entity dofs.") 404 code.begin_case(i) 405 for k, ed in enumerate(dofs_on_entity): 406 code += "dofs[%d] = %d;" % (k, ed) 407 code.end_case() 408 code.end_switch() 409 code.end_case() 410 code.end_switch() 411 return str(code)
412
413 - def gen_tabulate_coordinates(self):
414 """void tabulate_coordinates(double** coordinates, 415 const cell& c) const""" 416 if self.rep.ufl_element.family() == "Real": 417 return 'throw std::runtime_error("tabulate_coordinates not implemented for Real elements.");' 418 code = CodeFormatter() 419 code += gen_geometry_code(self.rep.cell.nsd, detG=False) 420 for i in range(self.rep.local_dimension): 421 for k in range(self.rep.cell.nsd): 422 # generate code to compute component k of the coordinate for dof i 423 code += "coordinates[%d][%d] = %s;" % (i, k, self.rep.dof_x[i][k].printc()) 424 return str(code)
425
426 - def gen_num_sub_dofmaps(self):
427 """unsigned int num_sub_dofmaps() const""" 428 return "return %d;" % len(self.rep.sub_elements)
429
430 - def gen_create_sub_dofmap(self):
431 """dofmap* create_sub_dofmap(unsigned int i) const""" 432 if self.options.enable_dof_ptv: 433 if len(self.rep.sub_elements) > 1: 434 code = CodeFormatter() 435 code.begin_switch("i") 436 for i, fe in enumerate(self.rep.sub_elements): 437 code += "case %d: return new %s(loc2glob, %s);" % (i, fe.dof_map_classname, self.constructor_arg_string2) 438 code.end_switch() 439 code += 'throw std::runtime_error("Invalid index in create_sub_dofmap.");' 440 else: 441 code = "return new %s(loc2glob, %s);" % (self.classname, self.constructor_arg_string2) 442 else: 443 if len(self.rep.sub_elements) > 1: 444 code = CodeFormatter() 445 code.begin_switch("i") 446 for i, fe in enumerate(self.rep.sub_elements): 447 code += "case %d: return new %s();" % (i, fe.dof_map_classname) 448 code.end_switch() 449 code += 'throw std::runtime_error("Invalid index in create_sub_dofmap.");' 450 else: 451 code = "return new %s();" % self.classname # FIXME: Should we throw error here instead now? 452 return str(code)
453
454 - def gen_members(self):
455 cell = self.rep.cell 456 nsd = cell.nsd 457 code = CodeFormatter() 458 459 # dof data structures 460 #code += "protected:" 461 code += "public:" 462 code.indent() 463 if self.rep.ufl_element.family() != "Real": 464 code += "unsigned int _global_dimension;" 465 466 if self.options.enable_dof_ptv: 467 code += "Dof_Ptv dof;" 468 code += "std::tr1::shared_ptr<unsigned int> loc2glob;" 469 code += "unsigned int global_component_stride;" # for tabulating vector element dofs from a scalar element numbering 470 code += 'unsigned int loc2glob_size;' 471 code.dedent() 472 473 if self.options.enable_dof_ptv: 474 # add additional constructor to pass initialization info to share initialized dofmap memory 475 code += "public:" 476 code.indent() 477 args = self.constructor_arg_string 478 code += "%s(std::tr1::shared_ptr<unsigned int> loc2glob, %s);" % (self.classname, args) 479 code.dedent() 480 481 return str(code)
482
483 - def generate_support_code(self):
484 """Generate local utility functions.""" 485 nsd = self.rep.cell.nsd 486 487 code = CodeFormatter() 488 #code += "namespace { // local namespace" 489 #code += " // code private to this compilation unit (.cpp file) goes here" 490 #code += "} // end local namespace" 491 #code += "" 492 493 if self.options.enable_dof_ptv: 494 # Implement additional constructor for shared memory between initializated dofmaps 495 code += "%s::%s(std::tr1::shared_ptr<unsigned int> loc2glob_, %s):" % (self.classname, self.classname, self.constructor_arg_string) 496 code.indent() 497 code += "loc2glob(loc2glob_)," 498 for v in self.constructor_vars[:-1]: 499 code += "%s(%s_)," % (v, v) 500 v = self.constructor_vars[-1] 501 code += "%s(%s_)" % (v, v) 502 code.dedent() 503 504 return str(code)
505