Package sfc :: Package codegeneration :: Module codeformatting
[hide private]
[frames] | no frames]

Source Code for Module sfc.codegeneration.codeformatting

  1  #!/usr/bin/env python 
  2  # -*- coding: utf-8 -*- 
  3  """ 
  4  General utilities for code generation. 
  5  """ 
  6   
  7  # Copyright (C) 2008-2009 Martin Sandve Alnes and Simula Resarch Laboratory 
  8  # 
  9  # This file is part of SyFi. 
 10  # 
 11  # SyFi is free software: you can redistribute it and/or modify 
 12  # it under the terms of the GNU General Public License as published by 
 13  # the Free Software Foundation, either version 2 of the License, or 
 14  # (at your option) any later version. 
 15  # 
 16  # SyFi is distributed in the hope that it will be useful, 
 17  # but WITHOUT ANY WARRANTY; without even the implied warranty of 
 18  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 
 19  # GNU General Public License for more details. 
 20  # 
 21  # You should have received a copy of the GNU General Public License 
 22  # along with SyFi. If not, see <http://www.gnu.org/licenses/>. 
 23  # 
 24  # First added:  2008-02-27 
 25  # Last changed: 2008-10-13 
 26   
 27  import re 
 28  import swiginac 
 29   
 30  from pprint import * 
 31   
 32   
 33  # global variable for consistent indentation  
 34  indent_size = 4 
 35   
36 -def indent(text, n=1):
37 """Indent a text n times. Text can also be a list of 38 strings, or recursively a list of lists of strings.""" 39 # must have something to indent 40 if not(text and n): 41 return text 42 43 # indent components of a list without merging 44 if isinstance(text, list): 45 return [indent(c, n) for c in text] 46 47 # fix to avoid extra spaces at end 48 append_lineend = text.endswith("\n") 49 50 # apply spaces to all lines 51 lines = text.split("\n") 52 if append_lineend: 53 lines.pop() 54 space = " "*(indent_size*n) 55 text = "\n".join((space + s) for s in lines) 56 57 # fix to avoid extra spaces at end 58 if append_lineend: 59 text += "\n" 60 61 return text
62 63
64 -class CodeFormatter:
65 """Utility class for assembling code strings into a multiline string. 66 67 Supports checking for matching parenteses and applying indentation 68 to generate code that is more robust with respect to correctness 69 of program flow and readability of code. 70 71 Support for the following constructs: 72 {} (basic block), if, else if, else, switch, case, while, do, class 73 74 Typical usage: 75 >>> c = CodeFormatter() 76 >>> c.begin_switch("i") 77 >>> c.begin_case(0) 78 >>> c += "foo();" 79 >>> c.end_case() 80 >>> c.begin_case(1) 81 >>> c.begin_if("a > b") 82 >>> c += "bar();" 83 >>> c.begin_else_if("c > b") 84 >>> c += "bar2();" 85 >>> c.end_if() 86 >>> c.end_case() 87 >>> c.end_switch() 88 >>> print str(c) 89 switch(i) 90 { 91 case 0: 92 foo(); 93 break; 94 case 1: 95 if( a > b ) 96 { 97 bar(); 98 } 99 else if( c > b ) 100 { 101 bar2(); 102 } 103 break; 104 } 105 """
106 - def __init__(self, name=""):
107 self.name = name 108 #self.text = "" 109 self.text = [] 110 self.indentation = 0 111 self.context = ["ROOT"]
112
113 - def get_context(self):
114 return self.context[-1]
115
116 - def add_context(self, context):
117 self.context.append(context)
118
119 - def remove_context(self, context):
120 self.assert_context(context) 121 c = self.context.pop(-1) 122 if not c == context: 123 raise RuntimeError("Expected context '%s' to be '%s'." % (c, context))
124
125 - def assert_context(self, context):
126 if not (self.context[-1] == context): 127 raise RuntimeError("Expected to be in context '%s', state is %s." % (context, str(self.context)))
128
129 - def assert_contexts(self, contexts):
130 for i in range(len(contexts)): 131 if not self.context[-(i+1)] == contexts[i]: 132 raise RuntimeError("Expected to be in contexts %s, state is %s." % (str(contexts), str(self.context)))
133
134 - def assert_closed_code(self):
135 self.assert_context("ROOT") 136 if self.indentation != 0: 137 raise RuntimeError("Code is not closed, indentation state is %d != 0." % self.indentation)
138
139 - def __str__(self):
140 #try: 141 self.assert_closed_code() 142 #except: 143 # raise RuntimeError("Converting a CodeFormatter to str, but context is not closed: %s" % str(self.context)) 144 #return self.text 145 return "".join(self.text)
146
147 - def new_text(self, text):
148 "Add a block of text directly with no modifications." 149 #self.text += text 150 self.text.append(text)
151
152 - def new_line(self, text=""):
153 "Add a line with auto-indentation." 154 if self.text: 155 #self.text += "\n" 156 self.text.append("\n") 157 #self.text += indent(text, self.indentation) 158 self.text.append(indent(text, self.indentation))
159
160 - def __iadd__(self, text):
161 if self.text: 162 #self.text += "\n" 163 self.text.append("\n") 164 #self.text += indent(text, self.indentation) 165 self.text.append(indent(text, self.indentation)) 166 return self
167
168 - def indent(self):
169 self.indentation += 1
170
171 - def dedent(self):
172 if self.indentation <= 0: 173 print "WARNING: Dedented non-indented code, something"\ 174 " is wrong in code generation program flow." 175 self.indentation = 0 176 else: 177 self.indentation -= 1
178
179 - def comment(self, text):
180 if "\n" in text: 181 self.new_line( "/*" ) 182 self.new_line( indent(text, self.indentation) ) # TODO: indent text to proper level! 183 self.new_line( "*/" ) 184 else: 185 self.new_line( "// " + text )
186
187 - def begin_debug(self):
188 self.add_context("debug") 189 self.new_line("#ifdef SFCDEBUG")
190
191 - def end_debug(self):
192 self.new_line("#endif // SFCDEBUG") 193 self.remove_context("debug")
194
195 - def begin_block(self):
196 self.add_context("block") 197 self.new_line("{") 198 self.indent()
199
200 - def end_block(self):
201 self.dedent() 202 self.new_line("}") 203 self.remove_context("block")
204
205 - def begin_switch(self, arg):
206 self.new_line("switch(%s)" % str(arg)) 207 self.new_line("{") 208 self.add_context("switch")
209
210 - def begin_case(self, arg, braces=False):
211 self.assert_context("switch") 212 self.add_context("case") 213 self.new_line("case %s:" % str(arg)) 214 self.indent() 215 if braces: self.begin_block()
216
217 - def end_case(self):
218 if self.get_context() == "block": # braces = True in begin_case 219 self.end_block() 220 self.new_line("break;") 221 self.dedent() 222 self.remove_context("case")
223
224 - def end_switch(self):
225 self.new_line("}") 226 self.remove_context("switch")
227
228 - def begin_while(self, arg):
229 self.add_context("while") 230 self.new_line("while( %s )" % str(arg)) 231 self.begin_block()
232
233 - def end_while(self):
234 self.end_block() 235 self.remove_context("while")
236
237 - def begin_do(self):
238 self.add_context("do") 239 self.new_line("do") 240 self.new_line("{") 241 self.indent()
242
243 - def end_do(self, arg):
244 self.dedent() 245 self.new_line( "} while(%s);" % str(arg) ) 246 self.remove_context("do")
247
248 - def begin_if(self, arg):
249 self.add_context("if") 250 self.new_line("if( %s )" % str(arg)) 251 self.begin_block()
252
253 - def begin_else_if(self, arg):
254 self.end_block() 255 self.assert_context("if") 256 self.new_line("else if( %s )" % str(arg)) 257 self.begin_block()
258
259 - def begin_else(self):
260 self.end_block() 261 self.assert_context("if") 262 self.new_line("else") 263 self.begin_block()
264
265 - def end_if(self):
266 self.end_block() 267 self.remove_context("if")
268
269 - def begin_class(self, classname, bases=[]):
270 self.add_context("class") 271 code = "class %s" % str(classname) 272 if bases: 273 code += ": " + ", ".join(bases) 274 self.new_line( code ) 275 self.begin_block()
276
277 - def end_class(self):
278 self.end_block() 279 self.new_text(";") 280 self.remove_context("class")
281
282 - def declare_function(self, name, return_type="void", args=[], const=False, virtual=False, inline=False, classname=None):
283 code = "%s;" % function_signature(name, return_type, args, virtual, inline, const, classname) 284 self.new_line(code)
285
286 - def define_function(self, name, return_type="void", args=[], const=False, virtual=False, inline=False, classname=None, body="// Empty body"):
287 signature = function_signature(name, return_type, args, virtual, inline, const, classname) 288 self.new_line(signature) 289 self.begin_block() 290 self.new_line(body) 291 self.end_block()
292
293 - def call_function(self, name, args=[]):
294 args = ", ".join(args) 295 code = "".join((name, "(", args, ");")) 296 self.new_line(code)
297 298 299 # ... Utility functions for code generation 300
301 -def row_major_index_string(i, shape):
302 """Constructs an index string for a row major (C type) 303 indexing of a flattened tensor of rank 0, 1, or 2.""" 304 if len(shape) == 0: 305 return "0" 306 if len(shape) == 1: 307 return "%d" % i[0] 308 if len(shape) == 2: 309 return "%d*%d + %d" % (shape[1], i[0], i[1]) 310 raise ValueError("Rank 3 or higher not supported in row_major_index_string()")
311 312
313 -def optimize_floats(code):
314 """Optimize storage size of floating point numbers by removing unneeded trailing zeros.""" 315 regexp = re.compile('0+e') 316 return regexp.sub('e', code)
317
318 -def function_signature(name, return_type="void", args=[], virtual=False, inline=False, const=False, classname=None):
319 "Render function signature from arguments." 320 code = [] 321 if virtual: 322 code.append("virtual ") 323 if inline: 324 code.append("inline ") 325 326 code.extend((return_type, " ")) 327 328 if classname: 329 code.extend((classname, "::")) 330 331 def joinarg(arg): 332 if isinstance(arg, str): 333 return arg 334 if len(arg) == 2: 335 return "%s %s" % arg 336 elif len(arg) == 3: 337 return "%s %s%s" % arg 338 raise RuntimeError("Invalid arg: %s" % repr(arg))
339 args = ", ".join(joinarg(arg) for arg in args) 340 341 code.extend((name, "(", args, ")")) 342 343 if const: 344 code.append(" const") 345 return "".join(code) 346 347 # TODO: These token code generation functions can be optimized by iterating differently 348 # TODO: A function that can apply two rules to the same stream and return two joined codes 349 350 default_code_rule = r"double %(symbol)s = %(value)s;"
351 -def gen_token_code(token, rule=default_code_rule):
352 """token[0] is a symbol or string or matrix with symbols, token[1] is a scalar, expression or matrix of the same shape with the corresponding values. 353 Generates code based on rule, default %s, for all elements of token[0] and token[1].""" % default_code_rule 354 355 def code_rule(s, v): 356 return rule % { "symbol": s, "value": v }
357 358 # make symbol and value swiginac objects if they're not 359 sym, val = token 360 361 if isinstance(val, (int, float)): 362 val = swiginac.numeric(val) 363 364 # check if we have one or more tokens here 365 if isinstance(sym, (swiginac.matrix, swiginac.lst)): 366 if len(sym) != len(val): 367 raise RuntimeError("sym and val must have same size.") 368 code = "\n".join( code_rule(sym[i].printc(), val[i].printc()) for i in xrange(len(sym)) ) 369 else: 370 code = code_rule(str(sym), val.printc()) 371 372 return code 373 374 #=============================================================================== 375 # def gen_symbol_declaration(symbol, prefix="double ", postfix=";\n"): 376 # """symbol is a swiginac.symbol or matrix or lst with symbols. 377 # Generates code for declaration of all symbols in symbol.""" 378 # 379 # if not isinstance(symbol, (swiginac.matrix, swiginac.lst) ): 380 # symbol = swiginac.matrix(1,1, [symbol]) 381 # 382 # # cat all symbols in comma separated string 383 # symbol_list = ", ".join( str(symbol[i]) for i in xrange(len(symbol)) ) 384 # 385 # code = prefix + symbol_list + postfix 386 # return code 387 #=============================================================================== 388 389
390 -def gen_token_prints(tokens, indent=" "):
391 return "\n".join(gen_token_code(token, rule='std::cout << "'+indent+'%(symbol)s = " << %(symbol)s << std::endl;') for token in tokens)
392 393
394 -def gen_symbol_declarations(symbols):
395 return "\n".join("double %s;" % str(s) for s in symbols)
396 397
398 -def gen_token_declarations(tokens):
399 return "\n".join(gen_token_code(token, rule="double %(symbol)s;") for token in tokens)
400 401
402 -def gen_token_definitions(tokens):
403 return "\n".join(gen_token_code(token, rule="double %(symbol)s = %(value)s;") for token in tokens)
404 405
406 -def gen_const_token_definitions(tokens):
407 return "\n".join(gen_token_code(token, rule="const double %(symbol)s = %(value)s;") for token in tokens)
408 409
410 -def gen_token_assignments(tokens):
411 return "\n".join(gen_token_code(token, rule="%(symbol)s = %(value)s;") for token in tokens)
412 413
414 -def gen_token_additions(tokens):
415 return "\n".join(gen_token_code(token, rule="%(symbol)s += %(value)s;") for token in tokens)
416 417
418 -def gen_switch(argument, cases, default_case=None, braces=False):
419 if braces: 420 start_brace = "{\n" 421 end_brace = "\n}" 422 else: 423 start_brace = "" 424 end_brace = "" 425 case_code = "\n".join([ "case %s:\n%s%s\n break;%s" % (str(c[0]), start_brace, indent(c[1]), end_brace) for c in cases]) 426 if default_case: 427 case_code += "\ndefault:\n%s%s %s" % (start_brace, indent(default_case), end_brace) 428 return "switch(%s)\n{\n%s\n}" % (argument, indent(case_code))
429 430 431 #class Switch: 432 # def __init__(self, argument, cases=[], default_case=None, braces=False): 433 # self.argument = argument 434 # self.cases = cases 435 # self.default_case = default_case 436 # self.braces = braces 437 # 438 # def __str__(self): 439 # return gen_switch(self.argument, self.cases, self.default_case, self.braces) 440 # 441 # 442 #class IfElse: 443 # def __init__(self, cases=[]): 444 # self.cases = cases 445 # 446 # def __str__(self): 447 # c = self.cases[0] 448 # code = "if(%s)\n{\n%s\n}" % (c[0], indent(c[1])) 449 # for c in self.cases[1:]: 450 # code += "\nelse if(%s)\n{\n%s\n}" % (c[0], indent(c[1])) 451 # return code 452 # 453 # 454 #class Struct: 455 # def __init__(self, name, variables=[]): 456 # self.name = name 457 # self.variables = variables 458 # 459 # def __str__(self): 460 # inner_code = "" 461 # for v in self.variables: 462 # if isinstance(v, tuple): 463 # inner_code += "%s %s;\n" % (v[0], v[1]) 464 # else: 465 # inner_code += "double %s;\n" % v 466 # return "struct %s\n{\n%s}" % (self.name, indent( inner_code )) 467 468 469
470 -def outline(name, tokens, targets, deps):
471 472 inputargs = deps 473 474 # figure out which variables to output and 475 # which to declare on the local stack 476 localtokens = [] 477 outputargs = [] 478 for t in tokens: 479 s = t[0] 480 if s in targets: 481 outputargs.append(s) 482 else: 483 localtokens.append(s) 484 485 # generate code for argument list in the function call 486 allargs = inputargs + outputargs 487 callargumentlist = ", ".join(str(a) for a in allargs) 488 489 # generate code for the argument list in the function definition, 490 # input args and output args separately 491 argumentlist1 = ", ".join("double %s" % str(a) for a in inputargs) 492 argumentlist2 = ", ".join("double & %s" % str(a) for a in outputargs) 493 494 # join input and output arguments to a single argument list 495 argumentlist = argumentlist1 496 if argumentlist1 and argumentlist2: 497 argumentlist += ", " 498 argumentlist += argumentlist2 499 500 # generate function body to compute targets 501 body = "" 502 body += "\n".join(" double %s;" % s for s in localtokens) 503 body += "\n" 504 body += "\n".join(" %s = %s;" % (t[0], t[1]) for t in tokens) 505 506 # stich together code pieces to return 507 fundef = "void %s(%s)\n{\n%s\n}" % (name, argumentlist, body) 508 funcall = "%s(%s);" % (name, callargumentlist) 509 510 return fundef, funcall
511 512
513 -def test_outliner():
514 tokens = [ 515 ("a", "u * v"), 516 ("b", "u + w"), 517 ] 518 deps = ["u", "v", "w"] 519 targets = ["a"] 520 521 fundef, funcall = outline("foo", tokens, targets, deps) 522 print fundef 523 print funcall
524 525 526 527 #year, month, day, hour, minute = time.localtime()[:5] 528 #date_string = "%d:%d, %d/%d %d" % (hour, minute, day, month, year) 529 530 531 if __name__ == "__main__": 532 c = CodeFormatter() 533 534 c.begin_class( "Foo" ) 535 c += "dings" 536 c.end_class() 537 538 c += "" 539 540 c.begin_class( "Bar", ("fee", "foe") ) 541 c.comment( "something funny" ) 542 c.declare_function("blatti", "int", const=False, virtual=False) 543 c.declare_function("foo", "double", const=True) 544 c.declare_function("bar", virtual=True) 545 c.end_class() 546 547 c.define_function("blatti", "int", const=False, virtual=False, classname="Bar", body='cout << "Hello world!" << endl;') 548 c.call_function("blatti") 549 550 c += "" 551 552 # a basic if 553 c.begin_if( "a < b" ) 554 c += "foo.bar();" 555 c.end_if() 556 557 c += "" 558 559 # a compound if 560 c.begin_if( "a < b" ) 561 c += "foo.bar();" 562 c.begin_else_if( "c < b" ) 563 c += "foo.bar();" 564 c.begin_else() 565 c += "foo.bar();" 566 c.end_if() 567 568 c += "" 569 570 # a simple do loop 571 c.begin_do() 572 c += "foo();" 573 c.end_do("a > 0") 574 575 c += "" 576 577 # a simple while loop 578 c.begin_while("a > 0") 579 c += "foo();" 580 c.end_while() 581 582 c += "" 583 584 # an empty switch 585 c.begin_switch("i") 586 c.end_switch() 587 588 c += "" 589 590 # a compound switch 591 c.begin_switch("k") 592 c.begin_case(0) 593 c += "foo.bar();" 594 c.end_case() 595 c.begin_case("1") 596 c += "bar.foo();" 597 c += "bar.foo();" 598 c.end_case() 599 c.end_switch() 600 601 c += "" 602 603 # verify that the code is closed 604 c.assert_closed_code() 605 606 # print the result 607 print c 608 609 610 611 if __name__ == '__main__': 612 from swiginac import symbol 613 614 x = symbol("x") 615 y = symbol("y") 616 z = symbol("z") 617 pi = swiginac.Pi 618 cos = swiginac.cos 619 tokens = [ (x, 1), (y, x**2-1), (z, cos(2*pi*x*y)) ] 620 print gen_token_declarations(tokens) 621 print gen_token_definitions(tokens) 622 print gen_const_token_definitions(tokens) 623 print gen_token_assignments(tokens) 624 print gen_token_additions(tokens) 625 626 # print Switch("i", [(1, "foo();"), (2, "bar();")], "foe();", False) 627 # print Switch("i", [(1, "foo();"), (2, "bar();")], None, True) 628 # print Switch("i", [(1, "foo();"), (2, "bar();")], "foe();", True) 629 # print Switch("i", [(1, "foo();"), (2, "bar();")], None, False) 630 # 631 # print IfElse([("i==0", "foo();"), ("i!=1", "bar();")]) 632 # 633 # print Struct("foostruct", ["a", "b", "c"]) 634 # print Struct("barstruct", [("int", "a"), ("bool", "b"), "c"]) 635 636
637 -def _test():
638 import doctest 639 return doctest.testmod()
640 641 if __name__ == "__main__": 642 _test() 643