Home | Trees | Indices | Help |
|
---|
|
1 #!/usr/bin/env python 2 # -*- coding: utf-8 -*- 3 """ 4 General utilities for code generation. 5 """ 6 7 # Copyright (C) 2008-2009 Martin Sandve Alnes and Simula Resarch Laboratory 8 # 9 # This file is part of SyFi. 10 # 11 # SyFi is free software: you can redistribute it and/or modify 12 # it under the terms of the GNU General Public License as published by 13 # the Free Software Foundation, either version 2 of the License, or 14 # (at your option) any later version. 15 # 16 # SyFi is distributed in the hope that it will be useful, 17 # but WITHOUT ANY WARRANTY; without even the implied warranty of 18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 19 # GNU General Public License for more details. 20 # 21 # You should have received a copy of the GNU General Public License 22 # along with SyFi. If not, see <http://www.gnu.org/licenses/>. 23 # 24 # First added: 2008-02-27 25 # Last changed: 2008-10-13 26 27 import re 28 import swiginac 29 30 from pprint import * 31 32 33 # global variable for consistent indentation 34 indent_size = 4 3537 """Indent a text n times. Text can also be a list of 38 strings, or recursively a list of lists of strings.""" 39 # must have something to indent 40 if not(text and n): 41 return text 42 43 # indent components of a list without merging 44 if isinstance(text, list): 45 return [indent(c, n) for c in text] 46 47 # fix to avoid extra spaces at end 48 append_lineend = text.endswith("\n") 49 50 # apply spaces to all lines 51 lines = text.split("\n") 52 if append_lineend: 53 lines.pop() 54 space = " "*(indent_size*n) 55 text = "\n".join((space + s) for s in lines) 56 57 # fix to avoid extra spaces at end 58 if append_lineend: 59 text += "\n" 60 61 return text62 6365 """Utility class for assembling code strings into a multiline string. 66 67 Supports checking for matching parenteses and applying indentation 68 to generate code that is more robust with respect to correctness 69 of program flow and readability of code. 70 71 Support for the following constructs: 72 {} (basic block), if, else if, else, switch, case, while, do, class 73 74 Typical usage: 75 >>> c = CodeFormatter() 76 >>> c.begin_switch("i") 77 >>> c.begin_case(0) 78 >>> c += "foo();" 79 >>> c.end_case() 80 >>> c.begin_case(1) 81 >>> c.begin_if("a > b") 82 >>> c += "bar();" 83 >>> c.begin_else_if("c > b") 84 >>> c += "bar2();" 85 >>> c.end_if() 86 >>> c.end_case() 87 >>> c.end_switch() 88 >>> print str(c) 89 switch(i) 90 { 91 case 0: 92 foo(); 93 break; 94 case 1: 95 if( a > b ) 96 { 97 bar(); 98 } 99 else if( c > b ) 100 { 101 bar2(); 102 } 103 break; 104 } 105 """297 298 299 # ... Utility functions for code generation 300107 self.name = name 108 #self.text = "" 109 self.text = [] 110 self.indentation = 0 111 self.context = ["ROOT"]112 115 118120 self.assert_context(context) 121 c = self.context.pop(-1) 122 if not c == context: 123 raise RuntimeError("Expected context '%s' to be '%s'." % (c, context))124126 if not (self.context[-1] == context): 127 raise RuntimeError("Expected to be in context '%s', state is %s." % (context, str(self.context)))128130 for i in range(len(contexts)): 131 if not self.context[-(i+1)] == contexts[i]: 132 raise RuntimeError("Expected to be in contexts %s, state is %s." % (str(contexts), str(self.context)))133135 self.assert_context("ROOT") 136 if self.indentation != 0: 137 raise RuntimeError("Code is not closed, indentation state is %d != 0." % self.indentation)138140 #try: 141 self.assert_closed_code() 142 #except: 143 # raise RuntimeError("Converting a CodeFormatter to str, but context is not closed: %s" % str(self.context)) 144 #return self.text 145 return "".join(self.text)146148 "Add a block of text directly with no modifications." 149 #self.text += text 150 self.text.append(text)151153 "Add a line with auto-indentation." 154 if self.text: 155 #self.text += "\n" 156 self.text.append("\n") 157 #self.text += indent(text, self.indentation) 158 self.text.append(indent(text, self.indentation))159161 if self.text: 162 #self.text += "\n" 163 self.text.append("\n") 164 #self.text += indent(text, self.indentation) 165 self.text.append(indent(text, self.indentation)) 166 return self167 170172 if self.indentation <= 0: 173 print "WARNING: Dedented non-indented code, something"\ 174 " is wrong in code generation program flow." 175 self.indentation = 0 176 else: 177 self.indentation -= 1178180 if "\n" in text: 181 self.new_line( "/*" ) 182 self.new_line( indent(text, self.indentation) ) # TODO: indent text to proper level! 183 self.new_line( "*/" ) 184 else: 185 self.new_line( "// " + text )186 190 194 199 204 209211 self.assert_context("switch") 212 self.add_context("case") 213 self.new_line("case %s:" % str(arg)) 214 self.indent() 215 if braces: self.begin_block()216218 if self.get_context() == "block": # braces = True in begin_case 219 self.end_block() 220 self.new_line("break;") 221 self.dedent() 222 self.remove_context("case")223 227 232 236 242 247 252254 self.end_block() 255 self.assert_context("if") 256 self.new_line("else if( %s )" % str(arg)) 257 self.begin_block()258 264 268270 self.add_context("class") 271 code = "class %s" % str(classname) 272 if bases: 273 code += ": " + ", ".join(bases) 274 self.new_line( code ) 275 self.begin_block()276 281282 - def declare_function(self, name, return_type="void", args=[], const=False, virtual=False, inline=False, classname=None):283 code = "%s;" % function_signature(name, return_type, args, virtual, inline, const, classname) 284 self.new_line(code)285286 - def define_function(self, name, return_type="void", args=[], const=False, virtual=False, inline=False, classname=None, body="// Empty body"):287 signature = function_signature(name, return_type, args, virtual, inline, const, classname) 288 self.new_line(signature) 289 self.begin_block() 290 self.new_line(body) 291 self.end_block()292302 """Constructs an index string for a row major (C type) 303 indexing of a flattened tensor of rank 0, 1, or 2.""" 304 if len(shape) == 0: 305 return "0" 306 if len(shape) == 1: 307 return "%d" % i[0] 308 if len(shape) == 2: 309 return "%d*%d + %d" % (shape[1], i[0], i[1]) 310 raise ValueError("Rank 3 or higher not supported in row_major_index_string()")311 312314 """Optimize storage size of floating point numbers by removing unneeded trailing zeros.""" 315 regexp = re.compile('0+e') 316 return regexp.sub('e', code)317318 -def function_signature(name, return_type="void", args=[], virtual=False, inline=False, const=False, classname=None):319 "Render function signature from arguments." 320 code = [] 321 if virtual: 322 code.append("virtual ") 323 if inline: 324 code.append("inline ") 325 326 code.extend((return_type, " ")) 327 328 if classname: 329 code.extend((classname, "::")) 330 331 def joinarg(arg): 332 if isinstance(arg, str): 333 return arg 334 if len(arg) == 2: 335 return "%s %s" % arg 336 elif len(arg) == 3: 337 return "%s %s%s" % arg 338 raise RuntimeError("Invalid arg: %s" % repr(arg))339 args = ", ".join(joinarg(arg) for arg in args) 340 341 code.extend((name, "(", args, ")")) 342 343 if const: 344 code.append(" const") 345 return "".join(code) 346 347 # TODO: These token code generation functions can be optimized by iterating differently 348 # TODO: A function that can apply two rules to the same stream and return two joined codes 349 350 default_code_rule = r"double %(symbol)s = %(value)s;"352 """token[0] is a symbol or string or matrix with symbols, token[1] is a scalar, expression or matrix of the same shape with the corresponding values. 353 Generates code based on rule, default %s, for all elements of token[0] and token[1].""" % default_code_rule 354 355 def code_rule(s, v): 356 return rule % { "symbol": s, "value": v }357 358 # make symbol and value swiginac objects if they're not 359 sym, val = token 360 361 if isinstance(val, (int, float)): 362 val = swiginac.numeric(val) 363 364 # check if we have one or more tokens here 365 if isinstance(sym, (swiginac.matrix, swiginac.lst)): 366 if len(sym) != len(val): 367 raise RuntimeError("sym and val must have same size.") 368 code = "\n".join( code_rule(sym[i].printc(), val[i].printc()) for i in xrange(len(sym)) ) 369 else: 370 code = code_rule(str(sym), val.printc()) 371 372 return code 373 374 #=============================================================================== 375 # def gen_symbol_declaration(symbol, prefix="double ", postfix=";\n"): 376 # """symbol is a swiginac.symbol or matrix or lst with symbols. 377 # Generates code for declaration of all symbols in symbol.""" 378 # 379 # if not isinstance(symbol, (swiginac.matrix, swiginac.lst) ): 380 # symbol = swiginac.matrix(1,1, [symbol]) 381 # 382 # # cat all symbols in comma separated string 383 # symbol_list = ", ".join( str(symbol[i]) for i in xrange(len(symbol)) ) 384 # 385 # code = prefix + symbol_list + postfix 386 # return code 387 #=============================================================================== 388 389391 return "\n".join(gen_token_code(token, rule='std::cout << "'+indent+'%(symbol)s = " << %(symbol)s << std::endl;') for token in tokens)392 393395 return "\n".join("double %s;" % str(s) for s in symbols)396 397399 return "\n".join(gen_token_code(token, rule="double %(symbol)s;") for token in tokens)400 401403 return "\n".join(gen_token_code(token, rule="double %(symbol)s = %(value)s;") for token in tokens)404 405407 return "\n".join(gen_token_code(token, rule="const double %(symbol)s = %(value)s;") for token in tokens)408 409411 return "\n".join(gen_token_code(token, rule="%(symbol)s = %(value)s;") for token in tokens)412 413415 return "\n".join(gen_token_code(token, rule="%(symbol)s += %(value)s;") for token in tokens)416 417419 if braces: 420 start_brace = "{\n" 421 end_brace = "\n}" 422 else: 423 start_brace = "" 424 end_brace = "" 425 case_code = "\n".join([ "case %s:\n%s%s\n break;%s" % (str(c[0]), start_brace, indent(c[1]), end_brace) for c in cases]) 426 if default_case: 427 case_code += "\ndefault:\n%s%s %s" % (start_brace, indent(default_case), end_brace) 428 return "switch(%s)\n{\n%s\n}" % (argument, indent(case_code))429 430 431 #class Switch: 432 # def __init__(self, argument, cases=[], default_case=None, braces=False): 433 # self.argument = argument 434 # self.cases = cases 435 # self.default_case = default_case 436 # self.braces = braces 437 # 438 # def __str__(self): 439 # return gen_switch(self.argument, self.cases, self.default_case, self.braces) 440 # 441 # 442 #class IfElse: 443 # def __init__(self, cases=[]): 444 # self.cases = cases 445 # 446 # def __str__(self): 447 # c = self.cases[0] 448 # code = "if(%s)\n{\n%s\n}" % (c[0], indent(c[1])) 449 # for c in self.cases[1:]: 450 # code += "\nelse if(%s)\n{\n%s\n}" % (c[0], indent(c[1])) 451 # return code 452 # 453 # 454 #class Struct: 455 # def __init__(self, name, variables=[]): 456 # self.name = name 457 # self.variables = variables 458 # 459 # def __str__(self): 460 # inner_code = "" 461 # for v in self.variables: 462 # if isinstance(v, tuple): 463 # inner_code += "%s %s;\n" % (v[0], v[1]) 464 # else: 465 # inner_code += "double %s;\n" % v 466 # return "struct %s\n{\n%s}" % (self.name, indent( inner_code )) 467 468 469471 472 inputargs = deps 473 474 # figure out which variables to output and 475 # which to declare on the local stack 476 localtokens = [] 477 outputargs = [] 478 for t in tokens: 479 s = t[0] 480 if s in targets: 481 outputargs.append(s) 482 else: 483 localtokens.append(s) 484 485 # generate code for argument list in the function call 486 allargs = inputargs + outputargs 487 callargumentlist = ", ".join(str(a) for a in allargs) 488 489 # generate code for the argument list in the function definition, 490 # input args and output args separately 491 argumentlist1 = ", ".join("double %s" % str(a) for a in inputargs) 492 argumentlist2 = ", ".join("double & %s" % str(a) for a in outputargs) 493 494 # join input and output arguments to a single argument list 495 argumentlist = argumentlist1 496 if argumentlist1 and argumentlist2: 497 argumentlist += ", " 498 argumentlist += argumentlist2 499 500 # generate function body to compute targets 501 body = "" 502 body += "\n".join(" double %s;" % s for s in localtokens) 503 body += "\n" 504 body += "\n".join(" %s = %s;" % (t[0], t[1]) for t in tokens) 505 506 # stich together code pieces to return 507 fundef = "void %s(%s)\n{\n%s\n}" % (name, argumentlist, body) 508 funcall = "%s(%s);" % (name, callargumentlist) 509 510 return fundef, funcall511 512514 tokens = [ 515 ("a", "u * v"), 516 ("b", "u + w"), 517 ] 518 deps = ["u", "v", "w"] 519 targets = ["a"] 520 521 fundef, funcall = outline("foo", tokens, targets, deps) 522 print fundef 523 print funcall524 525 526 527 #year, month, day, hour, minute = time.localtime()[:5] 528 #date_string = "%d:%d, %d/%d %d" % (hour, minute, day, month, year) 529 530 531 if __name__ == "__main__": 532 c = CodeFormatter() 533 534 c.begin_class( "Foo" ) 535 c += "dings" 536 c.end_class() 537 538 c += "" 539 540 c.begin_class( "Bar", ("fee", "foe") ) 541 c.comment( "something funny" ) 542 c.declare_function("blatti", "int", const=False, virtual=False) 543 c.declare_function("foo", "double", const=True) 544 c.declare_function("bar", virtual=True) 545 c.end_class() 546 547 c.define_function("blatti", "int", const=False, virtual=False, classname="Bar", body='cout << "Hello world!" << endl;') 548 c.call_function("blatti") 549 550 c += "" 551 552 # a basic if 553 c.begin_if( "a < b" ) 554 c += "foo.bar();" 555 c.end_if() 556 557 c += "" 558 559 # a compound if 560 c.begin_if( "a < b" ) 561 c += "foo.bar();" 562 c.begin_else_if( "c < b" ) 563 c += "foo.bar();" 564 c.begin_else() 565 c += "foo.bar();" 566 c.end_if() 567 568 c += "" 569 570 # a simple do loop 571 c.begin_do() 572 c += "foo();" 573 c.end_do("a > 0") 574 575 c += "" 576 577 # a simple while loop 578 c.begin_while("a > 0") 579 c += "foo();" 580 c.end_while() 581 582 c += "" 583 584 # an empty switch 585 c.begin_switch("i") 586 c.end_switch() 587 588 c += "" 589 590 # a compound switch 591 c.begin_switch("k") 592 c.begin_case(0) 593 c += "foo.bar();" 594 c.end_case() 595 c.begin_case("1") 596 c += "bar.foo();" 597 c += "bar.foo();" 598 c.end_case() 599 c.end_switch() 600 601 c += "" 602 603 # verify that the code is closed 604 c.assert_closed_code() 605 606 # print the result 607 print c 608 609 610 611 if __name__ == '__main__': 612 from swiginac import symbol 613 614 x = symbol("x") 615 y = symbol("y") 616 z = symbol("z") 617 pi = swiginac.Pi 618 cos = swiginac.cos 619 tokens = [ (x, 1), (y, x**2-1), (z, cos(2*pi*x*y)) ] 620 print gen_token_declarations(tokens) 621 print gen_token_definitions(tokens) 622 print gen_const_token_definitions(tokens) 623 print gen_token_assignments(tokens) 624 print gen_token_additions(tokens) 625 626 # print Switch("i", [(1, "foo();"), (2, "bar();")], "foe();", False) 627 # print Switch("i", [(1, "foo();"), (2, "bar();")], None, True) 628 # print Switch("i", [(1, "foo();"), (2, "bar();")], "foe();", True) 629 # print Switch("i", [(1, "foo();"), (2, "bar();")], None, False) 630 # 631 # print IfElse([("i==0", "foo();"), ("i!=1", "bar();")]) 632 # 633 # print Struct("foostruct", ["a", "b", "c"]) 634 # print Struct("barstruct", [("int", "a"), ("bool", "b"), "c"]) 635 636 640 641 if __name__ == "__main__": 642 _test() 643
Home | Trees | Indices | Help |
|
---|
Generated by Epydoc 3.0.1 on Mon Jun 11 11:34:33 2012 | http://epydoc.sourceforge.net |