Package sfc :: Package codegeneration :: Module dofmapcg
[hide private]
[frames] | no frames]

Source Code for Module sfc.codegeneration.dofmapcg

  1  #!/usr/bin/env python 
  2  # -*- coding: utf-8 -*- 
  3  """ 
  4  This module contains code generation tools for the ufc::dofmap class. 
  5  """ 
  6   
  7  # Copyright (C) 2008-2009 Martin Sandve Alnes and Simula Resarch Laboratory 
  8  # 
  9  # This file is part of SyFi. 
 10  # 
 11  # SyFi is free software: you can redistribute it and/or modify 
 12  # it under the terms of the GNU General Public License as published by 
 13  # the Free Software Foundation, either version 2 of the License, or 
 14  # (at your option) any later version. 
 15  # 
 16  # SyFi is distributed in the hope that it will be useful, 
 17  # but WITHOUT ANY WARRANTY; without even the implied warranty of 
 18  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 
 19  # GNU General Public License for more details. 
 20  # 
 21  # You should have received a copy of the GNU General Public License 
 22  # along with SyFi. If not, see <http://www.gnu.org/licenses/>. 
 23  # 
 24  # First added:  2008-08-13 
 25  # Last changed: 2009-05-14 
 26   
 27  import ufl 
 28  from sfc.codegeneration.codeformatting import indent, CodeFormatter, gen_token_assignments 
 29  from sfc.geometry import gen_geometry_code 
 30  from sfc.symbolic_utils import symbol, symbols 
 31  from sfc.common import sfc_assert, sfc_error, sfc_warning, sfc_info 
 32   
33 -class DofMapCG:
34 - def __init__(self, elementrep):
35 self.rep = elementrep 36 self.classname = elementrep.dof_map_classname 37 self.signature = repr(self.rep.ufl_element) 38 self.options = self.rep.options.code.dof_map 39 40 if self.options.enable_dof_ptv: 41 # variables for full initialization by construction, reused a few places: 42 vars = ["global_component_stride", "loc2glob_size"] 43 self.constructor_vars = vars 44 self.constructor_arg_string = ", ".join(["unsigned int %s_" % v for v in vars]) 45 self.constructor_arg_string2 = ", ".join(vars)
46
47 - def hincludes(self):
48 l = [] 49 if self.options.enable_dof_ptv: 50 l.extend(["Ptv.h", "DofT.h", "Dof_Ptv.h"]) 51 return l
52
53 - def cincludes(self):
54 l = [] 55 return l
56
57 - def generate_code_dict(self):
58 vars = { 59 'classname' : self.classname, 60 'constructor' : indent(self.gen_constructor()), 61 "constructor_arguments" : indent(self.gen_constructor_arguments()), 62 "initializer_list" : indent(self.gen_initializer_list()), 63 'destructor' : indent(self.gen_destructor()), 64 "create" : indent(self.gen_create()), 65 'signature' : indent(self.gen_signature()), 66 'needs_mesh_entities' : indent(self.gen_needs_mesh_entities()), 67 'init_mesh' : indent(self.gen_init_mesh()), 68 'init_cell' : indent(self.gen_init_cell()), 69 'init_cell_finalize' : indent(self.gen_init_cell_finalize()), 70 'global_dimension' : indent(self.gen_global_dimension()), 71 'local_dimension' : indent(self.gen_local_dimension()), 72 'max_local_dimension' : indent(self.gen_max_local_dimension()), 73 'geometric_dimension' : indent(self.gen_geometric_dimension()), 74 "topological_dimension" : indent(self.gen_topological_dimension()), 75 'num_facet_dofs' : indent(self.gen_num_facet_dofs()), 76 'num_entity_dofs' : indent(self.gen_num_entity_dofs()), 77 'tabulate_dofs' : indent(self.gen_tabulate_dofs()), 78 'tabulate_facet_dofs' : indent(self.gen_tabulate_facet_dofs()), 79 'tabulate_entity_dofs' : indent(self.gen_tabulate_entity_dofs()), 80 'tabulate_coordinates' : indent(self.gen_tabulate_coordinates()), 81 'num_sub_dofmaps' : indent(self.gen_num_sub_dofmaps()), 82 'create_sub_dofmap' : indent(self.gen_create_sub_dofmap()), 83 'members' : indent(self.gen_members()), 84 } 85 return vars
86
87 - def gen_constructor(self):
88 return ""
89
91 return ""
92
93 - def gen_initializer_list(self):
94 return ""
95
96 - def gen_destructor(self):
97 return ""
98
99 - def gen_create(self):
100 code = "return new %s();" % self.classname 101 return code
102
103 - def gen_signature(self):
104 """const char* signature() const""" 105 return 'return "%s";' % self.signature
106
107 - def gen_needs_mesh_entities(self):
108 """bool needs_mesh_entities(unsigned int d) const""" 109 if self.rep.ufl_element.family() == "Real": 110 return 'return false;' 111 112 # pick the mesh entities we need 113 needs = tuple( [ ('true' if n else 'false') for n in self.rep.num_entity_dofs] ) 114 # return false when a type of mesh entities are not needed 115 code = CodeFormatter() 116 code.begin_switch("d") 117 for i, n in enumerate(needs): 118 code += "case %d: return %s;" % (i, n) 119 code.end_switch() 120 code += 'throw std::runtime_error("Invalid dimension in needs_mesh_entities.");' 121 return str(code)
122
123 - def gen_init_mesh(self):
124 """bool init_mesh(const mesh& m)""" 125 nsd = self.rep.cell.nsd 126 if self.rep.ufl_element.family() == "Real": 127 return 'return false;' 128 129 if not self.options.enable_dof_ptv: 130 # compute and store global dimension 131 num_entities = symbols(["m.num_entities[%d]" % i for i in range(nsd+1)]) 132 global_dimension = sum(self.rep.num_entity_dofs[i]*num_entities[i] for i in range(nsd+1)) 133 code = '_global_dimension = %s;\n' % global_dimension.printc() 134 code += "return false;" 135 return code 136 137 # This code doesn't work for general mixed elements! 138 if isinstance(self.rep.ufl_element, ufl.MixedElement): 139 assert isinstance(self.rep.ufl_element, (ufl.VectorElement, ufl.TensorElement)) 140 local_component_stride = (self.rep.local_dimension // len(self.rep.sub_elements)) 141 code= CodeFormatter() 142 code += "// allocating space for loc2glob map" 143 code += "dof.init(m.num_entities[%d], %d);" % (nsd, local_component_stride) 144 code += "loc2glob_size = m.num_entities[%d] * %d;\n" % (nsd, local_component_stride) 145 code += "return true;" 146 return str(code)
147
148 - def gen_init_cell(self):
149 """void init_cell(const mesh& m, const cell& c)""" 150 151 if not self.options.enable_dof_ptv: 152 return "" 153 if self.rep.ufl_element.family() == "Real": 154 return "" 155 156 # FIXME: This code needs updating, e.g. it doesn't handle mixed elements. 157 158 #fe = self.rep.syfi_element 159 nsd = self.rep.cell.nsd 160 nbf = self.rep.local_dimension 161 162 code = CodeFormatter() 163 code.new_text( gen_geometry_code(nsd, detG=False) ) 164 code += "unsigned int element = c.entity_indices[%d][0];" % (nsd) 165 166 nsc = len(self.rep.sub_elements) 167 168 if nsc > 1: 169 code += "// ASSUMING HERE THAT THE DOFS FOR EACH SUB COMPONENT ARE GROUPED" 170 code += "// Only counting and numbering dofs for a single sub components" 171 172 for i in range(nbf // nsc): 173 x_strings = [self.rep.dof_x[i][d].printc() for d in range(nsd)] 174 dof_vals = ", ".join( x_strings ) 175 num_dof_vals = nsd 176 177 if nsc > 1: 178 assert nsc == self.rep.value_size 179 # This code could be useful later if for some reason we need to revert 180 # to the old numbering scheme where sub component dofs are intertwined: 181 #dofi = fe.dof(i) 182 #assert isinstance(dofi[0], list) 183 #for d in dofi[1:]: 184 # dof_vals += ", " + d.printc() 185 # num_dof_vals += 1 186 187 code += "" 188 code += "double dof%d[%d] = { %s };" % (i, num_dof_vals, dof_vals) 189 code += "Ptv pdof%d(%d, dof%d);" % (i, num_dof_vals, i) 190 code += "dof.insert_dof(element, %d, pdof%d);" % (i, i) 191 192 return str(code)
193
194 - def gen_init_cell_finalize(self):
195 """void init_cell_finalize()""" 196 if self.rep.ufl_element.family() == "Real": 197 return "" 198 199 code = "" 200 if self.options.enable_dof_ptv: 201 #code += "dof.build_loc2glob();\n" 202 code += "loc2glob = dof.get_loc2glob_array();\n" 203 # constant for tabulating vector element dofs from a scalar element numbering 204 code += "global_component_stride = dof.global_dimension();\n" 205 # store global dimension 206 code += '_global_dimension = global_component_stride * %d;\n' % len(self.rep.sub_elements) 207 # clear temporary datastructures 208 code += 'dof.clear();\n' 209 return code
210
211 - def gen_global_dimension(self):
212 """unsigned int global_dimension() const""" 213 if self.rep.ufl_element.family() == "Real": 214 return 'return %d;' % self.rep.value_size 215 return 'return _global_dimension;'
216
217 - def gen_local_dimension(self):
218 """unsigned int local_dimension(const cell& c) const""" 219 return 'return %d;' % self.rep.local_dimension
220
221 - def gen_max_local_dimension(self):
222 """unsigned int max_local_dimension() const""" 223 return 'return %d;' % self.rep.local_dimension
224
225 - def gen_geometric_dimension(self):
226 """unsigned int geometric_dimension() const""" 227 return 'return %d;' % self.rep.cell.nsd
228
230 return "return %d;" % self.rep.cell.nsd
231
232 - def gen_num_facet_dofs(self):
233 """unsigned int num_facet_dofs() const""" 234 return "return %d;" % self.rep.num_facet_dofs
235
236 - def gen_num_entity_dofs(self):
237 """unsigned int num_entity_dofs(unsigned int d) const""" 238 if self.rep.ufl_element.family() == "Real": 239 return 'return 0;' 240 code = CodeFormatter() 241 code.begin_switch("d") 242 for i in range(self.rep.cell.nsd+1): 243 code.begin_case(i) 244 code += "return %d;" % self.rep.num_entity_dofs[i] 245 code.end_case() 246 code.end_switch() 247 code += 'throw std::runtime_error("Invalid entity dimension.");' 248 return str(code)
249
250 - def gen_tabulate_dofs(self):
251 if self.rep.ufl_element.family() == "Real": 252 return '\n'.join('dofs[%d] = %d;' % (i, i) for i in range(self.rep.value_size)) 253 254 if self.options.enable_dof_ptv: 255 return self.gen_tabulate_dofs__dof_ptv() 256 else: 257 return self.gen_tabulate_dofs__implicit()
258
260 """void tabulate_dofs(unsigned int* dofs, 261 const mesh& m, 262 const cell& c) const""" 263 code = CodeFormatter() 264 265 cell = self.rep.cell 266 nsd = cell.nsd 267 268 # symbols referencing entity index arrays 269 mesh_num_entities = symbols("m.num_entities[%d]" % d for d in range(nsd+1)) 270 cell_entity_indices = [] 271 for d in range(nsd+1): 272 cell_entity_indices += [symbols( "c.entity_indices[%d][%d]" % (d, i) for i in range(cell.num_entities[d]) )] 273 274 def iter_sub_elements(rep): 275 "Flatten the sub element hierarchy into a list." 276 if rep.sub_elements: 277 for r in rep.sub_elements: 278 for s in iter_sub_elements(r): 279 yield s 280 else: 281 yield rep
282 283 # (A) Iterate over all basic elements in order 284 local_subelement_offset = 0 285 global_subelement_offset = symbol("global_subelement_offset") 286 code += "int %s = 0;" % global_subelement_offset 287 for rep in iter_sub_elements(self.rep): 288 # (B) Loop over entity dimensions d in order 289 local_entity_offset = 0 290 global_entity_offset = 0 291 tokens = [] 292 for d in range(nsd+1): 293 # The offset for dofs in this loop is (global_subelement_offset + global_entity_offset) 294 # (C) Loop over entities (d,i) in order 295 for i in range(cell.num_entities[d]): 296 297 # this is the global mesh index of cell entity (d,i) 298 entity_index = cell_entity_indices[d][i] 299 300 # For each entity (d,i) we have a list of dofs 301 entity_dofs = rep.entity_dofs[d][i] 302 sfc_assert(len(entity_dofs) == rep.num_entity_dofs[d], "Inconsistency in entity dofs.") 303 304 for (j,dof) in enumerate(entity_dofs): 305 local_value = entity_index * rep.num_entity_dofs[d] + j 306 value = global_subelement_offset + global_entity_offset + local_value 307 name = symbol("dofs[%d]" % (local_subelement_offset + dof)) 308 tokens.append((name, value)) 309 310 # (B) Accumulate offsets to dofs on entities of dimension d 311 local_entity_offset += cell.num_entities[d] * rep.num_entity_dofs[d] 312 global_entity_offset += mesh_num_entities[d] * rep.num_entity_dofs[d] 313 314 # (A) Accumulate subelement offsets 315 sfc_assert(rep.local_dimension == len(tokens), "Collected too few dof tokens!") 316 local_subelement_offset += rep.local_dimension 317 global_subelement_size = global_entity_offset 318 319 code.begin_block() 320 code += "// Subelement with signature: %s" % rep.signature 321 code += gen_token_assignments(tokens) 322 code += "%s += %s;" % (global_subelement_offset.printc(), global_subelement_size.printc()) 323 code.end_block() 324 325 sfc_assert(local_subelement_offset == self.rep.local_dimension, 326 "Dof computation didn't accumulate correctly!") 327 return str(code)
328
329 - def gen_tabulate_dofs__dof_ptv(self):
330 """void tabulate_dofs(unsigned int* dofs, 331 const mesh& m, 332 const cell& c) const""" 333 if isinstance(self.rep.ufl_element, ufl.MixedElement): 334 assert isinstance(self.rep.ufl_element, (ufl.VectorElement, ufl.TensorElement)) 335 local_component_stride = (self.rep.local_dimension // len(self.rep.sub_elements)) 336 337 code = CodeFormatter() 338 code += "const unsigned int global_element_offset = %d * c.entity_indices[%d][0];" % (local_component_stride, self.rep.cell.nsd) 339 code += "const unsigned int *scalar_dofs = loc2glob.get() + global_element_offset;" 340 341 code += "for(unsigned int iloc=0; iloc<%d; iloc++)" % local_component_stride 342 code.begin_block() 343 344 code += "const unsigned int global_scalar_dof = scalar_dofs[iloc];" 345 for i in range(len(self.rep.sub_elements)): 346 code += "dofs[iloc + %d * %d] = global_scalar_dof + global_component_stride * %d;" % (local_component_stride, i, i) 347 348 code.end_block() 349 350 return str(code)
351
352 - def gen_tabulate_facet_dofs(self):
353 """void tabulate_facet_dofs(unsigned int* dofs, 354 unsigned int facet) const 355 This implementation should be general for elements with point evaluation dofs on simplices. 356 """ 357 if self.rep.ufl_element.family() == "Real": 358 return 'throw std::runtime_error("tabulate_facet_dofs not implemented for Real elements.");' 359 # generate code for each facet: for each facet i, tabulate local dofs[j] 360 code = CodeFormatter() 361 code.begin_switch("facet") 362 for i, fd in enumerate(self.rep.facet_dofs): 363 code.begin_case(i) 364 for j, d in enumerate(fd): 365 code += "dofs[%d] = %d;" % (j, d) 366 code.end_case() 367 code += "default:" 368 code.indent() 369 code += 'throw std::runtime_error("Invalid facet number.");' 370 code.dedent() 371 code.end_switch() 372 373 return str(code)
374
375 - def gen_tabulate_entity_dofs(self):
376 """void tabulate_entity_dofs(unsigned int* dofs, 377 unsigned int d, unsigned int i) const 378 """ 379 if self.rep.ufl_element.family() == "Real": 380 return 'throw std::runtime_error("tabulate_entity_dofs not implemented for Real elements.");' 381 code = CodeFormatter() 382 # define one case for each cell entity (d, i) 383 code.begin_switch("d") 384 for d in range(self.rep.cell.nsd+1): 385 if any(self.rep.entity_dofs[d]): 386 code.begin_case(d) 387 code.begin_switch("i") 388 n = self.rep.cell.num_entities[d] 389 for i in range(n): 390 # get list of local dofs associated with cell entity (d, i) 391 dofs_on_entity = self.rep.entity_dofs[d][i] 392 sfc_assert(len(dofs_on_entity) == self.rep.num_entity_dofs[d], "Inconsistency in entity dofs.") 393 code.begin_case(i) 394 for k, ed in enumerate(dofs_on_entity): 395 code += "dofs[%d] = %d;" % (k, ed) 396 code.end_case() 397 code.end_switch() 398 code.end_case() 399 code.end_switch() 400 return str(code)
401
402 - def gen_tabulate_coordinates(self):
403 """void tabulate_coordinates(double** coordinates, 404 const cell& c) const""" 405 if self.rep.ufl_element.family() == "Real": 406 return 'throw std::runtime_error("tabulate_coordinates not implemented for Real elements.");' 407 code = CodeFormatter() 408 code += gen_geometry_code(self.rep.cell.nsd, detG=False) 409 for i in range(self.rep.local_dimension): 410 for k in range(self.rep.cell.nsd): 411 # generate code to compute component k of the coordinate for dof i 412 code += "coordinates[%d][%d] = %s;" % (i, k, self.rep.dof_x[i][k].printc()) 413 return str(code)
414
415 - def gen_num_sub_dofmaps(self):
416 """unsigned int num_sub_dofmaps() const""" 417 return "return %d;" % len(self.rep.sub_elements)
418
419 - def gen_create_sub_dofmap(self):
420 """dofmap* create_sub_dofmap(unsigned int i) const""" 421 if self.options.enable_dof_ptv: 422 if len(self.rep.sub_elements) > 1: 423 code = CodeFormatter() 424 code.begin_switch("i") 425 for i, fe in enumerate(self.rep.sub_elements): 426 code += "case %d: return new %s(loc2glob, %s);" % (i, fe.dof_map_classname, self.constructor_arg_string2) 427 code.end_switch() 428 code += 'throw std::runtime_error("Invalid index in create_sub_dofmap.");' 429 else: 430 code = "return new %s(loc2glob, %s);" % (self.classname, self.constructor_arg_string2) 431 else: 432 if len(self.rep.sub_elements) > 1: 433 code = CodeFormatter() 434 code.begin_switch("i") 435 for i, fe in enumerate(self.rep.sub_elements): 436 code += "case %d: return new %s();" % (i, fe.dof_map_classname) 437 code.end_switch() 438 code += 'throw std::runtime_error("Invalid index in create_sub_dofmap.");' 439 else: 440 code = "return new %s();" % self.classname # FIXME: Should we throw error here instead now? 441 return str(code)
442
443 - def gen_members(self):
444 cell = self.rep.cell 445 nsd = cell.nsd 446 code = CodeFormatter() 447 448 # dof data structures 449 #code += "protected:" 450 code += "public:" 451 code.indent() 452 if self.rep.ufl_element.family() != "Real": 453 code += "unsigned int _global_dimension;" 454 455 if self.options.enable_dof_ptv: 456 code += "Dof_Ptv dof;" 457 code += "std::tr1::shared_ptr<unsigned int> loc2glob;" 458 code += "unsigned int global_component_stride;" # for tabulating vector element dofs from a scalar element numbering 459 code += 'unsigned int loc2glob_size;' 460 code.dedent() 461 462 if self.options.enable_dof_ptv: 463 # add additional constructor to pass initialization info to share initialized dofmap memory 464 code += "public:" 465 code.indent() 466 args = self.constructor_arg_string 467 code += "%s(std::tr1::shared_ptr<unsigned int> loc2glob, %s);" % (self.classname, args) 468 code.dedent() 469 470 return str(code)
471
472 - def generate_support_code(self):
473 """Generate local utility functions.""" 474 nsd = self.rep.cell.nsd 475 476 code = CodeFormatter() 477 #code += "namespace { // local namespace" 478 #code += " // code private to this compilation unit (.cpp file) goes here" 479 #code += "} // end local namespace" 480 #code += "" 481 482 if self.options.enable_dof_ptv: 483 # Implement additional constructor for shared memory between initializated dofmaps 484 code += "%s::%s(std::tr1::shared_ptr<unsigned int> loc2glob_, %s):" % (self.classname, self.classname, self.constructor_arg_string) 485 code.indent() 486 code += "loc2glob(loc2glob_)," 487 for v in self.constructor_vars[:-1]: 488 code += "%s(%s_)," % (v, v) 489 v = self.constructor_vars[-1] 490 code += "%s(%s_)" % (v, v) 491 code.dedent() 492 493 return str(code)
494