SyFi  0.3
test_codeformatter.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 __authors__ = "Martin Sandve Alnes"
00004 __date__ = "2008-09-04 -- 2012-16-05"
00005 
00006 import unittest
00007 
00008 from sfc.codegeneration.codeformatting import (CodeFormatter,
00009                                                gen_token_declarations,
00010                                                gen_token_definitions)
00011 
00012 test_switch_result = """switch(facet)
00013 {
00014 case 0:
00015     dofs[0] = 2;
00016     dofs[1] = 0;
00017     dofs[2] = 1;
00018     break;
00019 case 1:
00020     dofs[0] = 5;
00021     dofs[1] = 3;
00022     dofs[2] = 4;
00023     break;
00024 case 2:
00025     dofs[0] = 6;
00026     dofs[1] = 7;
00027     dofs[2] = 8;
00028     break;
00029 default:
00030     throw std::runtime_error("Invalid facet number.");
00031 }"""
00032 
00033 gen_tokens_result = """{
00034     double s1;
00035     double s2;
00036 }
00037         double s1 = e1;
00038         double s2 = e2;
00039     double s1 = e1;
00040     double s2 = e2;"""
00041 
00042 functions_result = """inline void c(double a[3], double b[3], double c[3]) const;
00043 
00044 inline void c(double a[3], double b[3], double c[3]) const
00045 {
00046     // Empty body!
00047 }
00048 
00049 c(a, b, c);
00050 """
00051 
00052 class CodeFormattingTest(unittest.TestCase):
00053 
00054     def setUp(self):
00055         pass
00056 
00057     def _compare_codes(self, code, correct):
00058         "Compare codes and print codes if this fails."
00059         if code != correct:
00060             print "Failure, got code:"
00061             print '"""%s"""' % code
00062             print "but expecting:"
00063             print '"""%s"""' % correct
00064         self.assertTrue(code == correct)
00065 
00066     def test_switch(self):
00067         code = CodeFormatter()
00068         code.begin_switch("facet")
00069         facet_dofs = [(2, 0, 1), (5, 3, 4), (6, 7, 8)]
00070         for i, dofs in enumerate(facet_dofs):
00071             code.begin_case(i)
00072             for j, d in enumerate(dofs):
00073                 code += "dofs[%d] = %d;" % (j, d)
00074             code.end_case()
00075         code += "default:"
00076         code.indent()
00077         code += 'throw std::runtime_error("Invalid facet number.");'
00078         code.dedent()
00079         code.end_switch()
00080         code = str(code)
00081 
00082         self._compare_codes(code, test_switch_result)
00083 
00084     def test_gen_tokens(self):
00085         code = CodeFormatter()
00086         class MockObject:
00087             def __init__(self, text):
00088                 self._text = text
00089             def printc(self):
00090                 return self._text
00091             def __str__(self):
00092                 return self._text
00093         s1 = MockObject("s1")
00094         e1 = MockObject("e1")
00095         s2 = MockObject("s2")
00096         e2 = MockObject("e2")
00097         tokens = [(s1, e1), (s2, e2)]
00098         code.begin_block()
00099         code += gen_token_declarations(tokens)
00100         code.end_block()
00101         code.indent()
00102         code.indent()
00103         code += gen_token_definitions(tokens)
00104         code.dedent()
00105         code += gen_token_definitions(tokens)
00106         code.dedent()
00107         code = str(code)
00108 
00109         self._compare_codes(code, gen_tokens_result)
00110 
00111     def test_functions(self):
00112         code = CodeFormatter()
00113 
00114         name = "myfunction"
00115 
00116         argnames = ["a", "b", "c"]
00117         args = [("double", name, "[3]") for name in argnames]
00118 
00119         code.declare_function(name, args=args, const=True, inline=True)
00120         code.new_line("")
00121 
00122         body = "// Empty body!"
00123         code.define_function(name, args=args, const=True, inline=True, body=body)
00124         code.new_line("")
00125 
00126         code.call_function(name, args=argnames)
00127         code.new_line("")
00128 
00129         code = str(code)
00130 
00131         self._compare_codes(code, functions_result)
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Defines