574 lines
20 KiB
Python
574 lines
20 KiB
Python
# ------------------------------------------------------------------------------
|
|
# pycparser: c_generator.py
|
|
#
|
|
# C code generator from pycparser AST nodes.
|
|
#
|
|
# Eli Bendersky [https://eli.thegreenplace.net/]
|
|
# License: BSD
|
|
# ------------------------------------------------------------------------------
|
|
from typing import Callable, List, Optional
|
|
|
|
from . import c_ast
|
|
|
|
|
|
class CGenerator:
|
|
"""Uses the same visitor pattern as c_ast.NodeVisitor, but modified to
|
|
return a value from each visit method, using string accumulation in
|
|
generic_visit.
|
|
"""
|
|
|
|
indent_level: int
|
|
reduce_parentheses: bool
|
|
|
|
def __init__(self, reduce_parentheses: bool = False) -> None:
|
|
"""Constructs C-code generator
|
|
|
|
reduce_parentheses:
|
|
if True, eliminates needless parentheses on binary operators
|
|
"""
|
|
# Statements start with indentation of self.indent_level spaces, using
|
|
# the _make_indent method.
|
|
self.indent_level = 0
|
|
self.reduce_parentheses = reduce_parentheses
|
|
|
|
def _make_indent(self) -> str:
|
|
return " " * self.indent_level
|
|
|
|
def visit(self, node: c_ast.Node) -> str:
|
|
method = "visit_" + node.__class__.__name__
|
|
return getattr(self, method, self.generic_visit)(node)
|
|
|
|
def generic_visit(self, node: Optional[c_ast.Node]) -> str:
|
|
if node is None:
|
|
return ""
|
|
else:
|
|
return "".join(self.visit(c) for c_name, c in node.children())
|
|
|
|
def visit_Constant(self, n: c_ast.Constant) -> str:
|
|
return n.value
|
|
|
|
def visit_ID(self, n: c_ast.ID) -> str:
|
|
return n.name
|
|
|
|
def visit_Pragma(self, n: c_ast.Pragma) -> str:
|
|
ret = "#pragma"
|
|
if n.string:
|
|
ret += " " + n.string
|
|
return ret
|
|
|
|
def visit_ArrayRef(self, n: c_ast.ArrayRef) -> str:
|
|
arrref = self._parenthesize_unless_simple(n.name)
|
|
return arrref + "[" + self.visit(n.subscript) + "]"
|
|
|
|
def visit_StructRef(self, n: c_ast.StructRef) -> str:
|
|
sref = self._parenthesize_unless_simple(n.name)
|
|
return sref + n.type + self.visit(n.field)
|
|
|
|
def visit_FuncCall(self, n: c_ast.FuncCall) -> str:
|
|
fref = self._parenthesize_unless_simple(n.name)
|
|
args = self.visit(n.args) if n.args is not None else ""
|
|
return fref + "(" + args + ")"
|
|
|
|
def visit_UnaryOp(self, n: c_ast.UnaryOp) -> str:
|
|
match n.op:
|
|
case "sizeof":
|
|
# Always parenthesize the argument of sizeof since it can be
|
|
# a name.
|
|
return f"sizeof({self.visit(n.expr)})"
|
|
case "p++":
|
|
operand = self._parenthesize_unless_simple(n.expr)
|
|
return f"{operand}++"
|
|
case "p--":
|
|
operand = self._parenthesize_unless_simple(n.expr)
|
|
return f"{operand}--"
|
|
case _:
|
|
operand = self._parenthesize_unless_simple(n.expr)
|
|
return f"{n.op}{operand}"
|
|
|
|
# Precedence map of binary operators:
|
|
precedence_map = {
|
|
# Should be in sync with c_parser.CParser.precedence
|
|
# Higher numbers are stronger binding
|
|
"||": 0, # weakest binding
|
|
"&&": 1,
|
|
"|": 2,
|
|
"^": 3,
|
|
"&": 4,
|
|
"==": 5,
|
|
"!=": 5,
|
|
">": 6,
|
|
">=": 6,
|
|
"<": 6,
|
|
"<=": 6,
|
|
">>": 7,
|
|
"<<": 7,
|
|
"+": 8,
|
|
"-": 8,
|
|
"*": 9,
|
|
"/": 9,
|
|
"%": 9, # strongest binding
|
|
}
|
|
|
|
def visit_BinaryOp(self, n: c_ast.BinaryOp) -> str:
|
|
# Note: all binary operators are left-to-right associative
|
|
#
|
|
# If `n.left.op` has a stronger or equally binding precedence in
|
|
# comparison to `n.op`, no parenthesis are needed for the left:
|
|
# e.g., `(a*b) + c` is equivalent to `a*b + c`, as well as
|
|
# `(a+b) - c` is equivalent to `a+b - c` (same precedence).
|
|
# If the left operator is weaker binding than the current, then
|
|
# parentheses are necessary:
|
|
# e.g., `(a+b) * c` is NOT equivalent to `a+b * c`.
|
|
lval_str = self._parenthesize_if(
|
|
n.left,
|
|
lambda d: not (
|
|
self._is_simple_node(d)
|
|
or self.reduce_parentheses
|
|
and isinstance(d, c_ast.BinaryOp)
|
|
and self.precedence_map[d.op] >= self.precedence_map[n.op]
|
|
),
|
|
)
|
|
# If `n.right.op` has a stronger -but not equal- binding precedence,
|
|
# parenthesis can be omitted on the right:
|
|
# e.g., `a + (b*c)` is equivalent to `a + b*c`.
|
|
# If the right operator is weaker or equally binding, then parentheses
|
|
# are necessary:
|
|
# e.g., `a * (b+c)` is NOT equivalent to `a * b+c` and
|
|
# `a - (b+c)` is NOT equivalent to `a - b+c` (same precedence).
|
|
rval_str = self._parenthesize_if(
|
|
n.right,
|
|
lambda d: not (
|
|
self._is_simple_node(d)
|
|
or self.reduce_parentheses
|
|
and isinstance(d, c_ast.BinaryOp)
|
|
and self.precedence_map[d.op] > self.precedence_map[n.op]
|
|
),
|
|
)
|
|
return f"{lval_str} {n.op} {rval_str}"
|
|
|
|
def visit_Assignment(self, n: c_ast.Assignment) -> str:
|
|
rval_str = self._parenthesize_if(
|
|
n.rvalue, lambda n: isinstance(n, c_ast.Assignment)
|
|
)
|
|
return f"{self.visit(n.lvalue)} {n.op} {rval_str}"
|
|
|
|
def visit_IdentifierType(self, n: c_ast.IdentifierType) -> str:
|
|
return " ".join(n.names)
|
|
|
|
def _visit_expr(self, n: c_ast.Node) -> str:
|
|
match n:
|
|
case c_ast.InitList():
|
|
return "{" + self.visit(n) + "}"
|
|
case c_ast.ExprList() | c_ast.Compound():
|
|
return "(" + self.visit(n) + ")"
|
|
case _:
|
|
return self.visit(n)
|
|
|
|
def visit_Decl(self, n: c_ast.Decl, no_type: bool = False) -> str:
|
|
# no_type is used when a Decl is part of a DeclList, where the type is
|
|
# explicitly only for the first declaration in a list.
|
|
#
|
|
s = n.name if no_type else self._generate_decl(n)
|
|
if n.bitsize:
|
|
s += " : " + self.visit(n.bitsize)
|
|
if n.init:
|
|
s += " = " + self._visit_expr(n.init)
|
|
return s
|
|
|
|
def visit_DeclList(self, n: c_ast.DeclList) -> str:
|
|
s = self.visit(n.decls[0])
|
|
if len(n.decls) > 1:
|
|
s += ", " + ", ".join(
|
|
self.visit_Decl(decl, no_type=True) for decl in n.decls[1:]
|
|
)
|
|
return s
|
|
|
|
def visit_Typedef(self, n: c_ast.Typedef) -> str:
|
|
s = ""
|
|
if n.storage:
|
|
s += " ".join(n.storage) + " "
|
|
s += self._generate_type(n.type)
|
|
return s
|
|
|
|
def visit_Cast(self, n: c_ast.Cast) -> str:
|
|
s = "(" + self._generate_type(n.to_type, emit_declname=False) + ")"
|
|
return s + " " + self._parenthesize_unless_simple(n.expr)
|
|
|
|
def visit_ExprList(self, n: c_ast.ExprList) -> str:
|
|
visited_subexprs = []
|
|
for expr in n.exprs:
|
|
visited_subexprs.append(self._visit_expr(expr))
|
|
return ", ".join(visited_subexprs)
|
|
|
|
def visit_InitList(self, n: c_ast.InitList) -> str:
|
|
visited_subexprs = []
|
|
for expr in n.exprs:
|
|
visited_subexprs.append(self._visit_expr(expr))
|
|
return ", ".join(visited_subexprs)
|
|
|
|
def visit_Enum(self, n: c_ast.Enum) -> str:
|
|
return self._generate_struct_union_enum(n, name="enum")
|
|
|
|
def visit_Alignas(self, n: c_ast.Alignas) -> str:
|
|
return "_Alignas({})".format(self.visit(n.alignment))
|
|
|
|
def visit_Enumerator(self, n: c_ast.Enumerator) -> str:
|
|
if not n.value:
|
|
return "{indent}{name},\n".format(
|
|
indent=self._make_indent(),
|
|
name=n.name,
|
|
)
|
|
else:
|
|
return "{indent}{name} = {value},\n".format(
|
|
indent=self._make_indent(),
|
|
name=n.name,
|
|
value=self.visit(n.value),
|
|
)
|
|
|
|
def visit_FuncDef(self, n: c_ast.FuncDef) -> str:
|
|
decl = self.visit(n.decl)
|
|
self.indent_level = 0
|
|
body = self.visit(n.body)
|
|
if n.param_decls:
|
|
knrdecls = ";\n".join(self.visit(p) for p in n.param_decls)
|
|
return decl + "\n" + knrdecls + ";\n" + body + "\n"
|
|
else:
|
|
return decl + "\n" + body + "\n"
|
|
|
|
def visit_FileAST(self, n: c_ast.FileAST) -> str:
|
|
s = ""
|
|
for ext in n.ext:
|
|
match ext:
|
|
case c_ast.FuncDef():
|
|
s += self.visit(ext)
|
|
case c_ast.Pragma():
|
|
s += self.visit(ext) + "\n"
|
|
case _:
|
|
s += self.visit(ext) + ";\n"
|
|
return s
|
|
|
|
def visit_Compound(self, n: c_ast.Compound) -> str:
|
|
s = self._make_indent() + "{\n"
|
|
self.indent_level += 2
|
|
if n.block_items:
|
|
s += "".join(self._generate_stmt(stmt) for stmt in n.block_items)
|
|
self.indent_level -= 2
|
|
s += self._make_indent() + "}\n"
|
|
return s
|
|
|
|
def visit_CompoundLiteral(self, n: c_ast.CompoundLiteral) -> str:
|
|
return "(" + self.visit(n.type) + "){" + self.visit(n.init) + "}"
|
|
|
|
def visit_EmptyStatement(self, n: c_ast.EmptyStatement) -> str:
|
|
return ";"
|
|
|
|
def visit_ParamList(self, n: c_ast.ParamList) -> str:
|
|
return ", ".join(self.visit(param) for param in n.params)
|
|
|
|
def visit_Return(self, n: c_ast.Return) -> str:
|
|
s = "return"
|
|
if n.expr:
|
|
s += " " + self.visit(n.expr)
|
|
return s + ";"
|
|
|
|
def visit_Break(self, n: c_ast.Break) -> str:
|
|
return "break;"
|
|
|
|
def visit_Continue(self, n: c_ast.Continue) -> str:
|
|
return "continue;"
|
|
|
|
def visit_TernaryOp(self, n: c_ast.TernaryOp) -> str:
|
|
s = "(" + self._visit_expr(n.cond) + ") ? "
|
|
s += "(" + self._visit_expr(n.iftrue) + ") : "
|
|
s += "(" + self._visit_expr(n.iffalse) + ")"
|
|
return s
|
|
|
|
def visit_If(self, n: c_ast.If) -> str:
|
|
s = "if ("
|
|
if n.cond:
|
|
s += self.visit(n.cond)
|
|
s += ")\n"
|
|
s += self._generate_stmt(n.iftrue, add_indent=True)
|
|
if n.iffalse:
|
|
s += self._make_indent() + "else\n"
|
|
s += self._generate_stmt(n.iffalse, add_indent=True)
|
|
return s
|
|
|
|
def visit_For(self, n: c_ast.For) -> str:
|
|
s = "for ("
|
|
if n.init:
|
|
s += self.visit(n.init)
|
|
s += ";"
|
|
if n.cond:
|
|
s += " " + self.visit(n.cond)
|
|
s += ";"
|
|
if n.next:
|
|
s += " " + self.visit(n.next)
|
|
s += ")\n"
|
|
s += self._generate_stmt(n.stmt, add_indent=True)
|
|
return s
|
|
|
|
def visit_While(self, n: c_ast.While) -> str:
|
|
s = "while ("
|
|
if n.cond:
|
|
s += self.visit(n.cond)
|
|
s += ")\n"
|
|
s += self._generate_stmt(n.stmt, add_indent=True)
|
|
return s
|
|
|
|
def visit_DoWhile(self, n: c_ast.DoWhile) -> str:
|
|
s = "do\n"
|
|
s += self._generate_stmt(n.stmt, add_indent=True)
|
|
s += self._make_indent() + "while ("
|
|
if n.cond:
|
|
s += self.visit(n.cond)
|
|
s += ");"
|
|
return s
|
|
|
|
def visit_StaticAssert(self, n: c_ast.StaticAssert) -> str:
|
|
s = "_Static_assert("
|
|
s += self.visit(n.cond)
|
|
if n.message:
|
|
s += ","
|
|
s += self.visit(n.message)
|
|
s += ")"
|
|
return s
|
|
|
|
def visit_Switch(self, n: c_ast.Switch) -> str:
|
|
s = "switch (" + self.visit(n.cond) + ")\n"
|
|
s += self._generate_stmt(n.stmt, add_indent=True)
|
|
return s
|
|
|
|
def visit_Case(self, n: c_ast.Case) -> str:
|
|
s = "case " + self.visit(n.expr) + ":\n"
|
|
for stmt in n.stmts:
|
|
s += self._generate_stmt(stmt, add_indent=True)
|
|
return s
|
|
|
|
def visit_Default(self, n: c_ast.Default) -> str:
|
|
s = "default:\n"
|
|
for stmt in n.stmts:
|
|
s += self._generate_stmt(stmt, add_indent=True)
|
|
return s
|
|
|
|
def visit_Label(self, n: c_ast.Label) -> str:
|
|
return n.name + ":\n" + self._generate_stmt(n.stmt)
|
|
|
|
def visit_Goto(self, n: c_ast.Goto) -> str:
|
|
return "goto " + n.name + ";"
|
|
|
|
def visit_EllipsisParam(self, n: c_ast.EllipsisParam) -> str:
|
|
return "..."
|
|
|
|
def visit_Struct(self, n: c_ast.Struct) -> str:
|
|
return self._generate_struct_union_enum(n, "struct")
|
|
|
|
def visit_Typename(self, n: c_ast.Typename) -> str:
|
|
return self._generate_type(n.type)
|
|
|
|
def visit_Union(self, n: c_ast.Union) -> str:
|
|
return self._generate_struct_union_enum(n, "union")
|
|
|
|
def visit_NamedInitializer(self, n: c_ast.NamedInitializer) -> str:
|
|
s = ""
|
|
for name in n.name:
|
|
if isinstance(name, c_ast.ID):
|
|
s += "." + name.name
|
|
else:
|
|
s += "[" + self.visit(name) + "]"
|
|
s += " = " + self._visit_expr(n.expr)
|
|
return s
|
|
|
|
def visit_FuncDecl(self, n: c_ast.FuncDecl) -> str:
|
|
return self._generate_type(n)
|
|
|
|
def visit_ArrayDecl(self, n: c_ast.ArrayDecl) -> str:
|
|
return self._generate_type(n, emit_declname=False)
|
|
|
|
def visit_TypeDecl(self, n: c_ast.TypeDecl) -> str:
|
|
return self._generate_type(n, emit_declname=False)
|
|
|
|
def visit_PtrDecl(self, n: c_ast.PtrDecl) -> str:
|
|
return self._generate_type(n, emit_declname=False)
|
|
|
|
def _generate_struct_union_enum(
|
|
self, n: c_ast.Struct | c_ast.Union | c_ast.Enum, name: str
|
|
) -> str:
|
|
"""Generates code for structs, unions, and enums. name should be
|
|
'struct', 'union', or 'enum'.
|
|
"""
|
|
if name in ("struct", "union"):
|
|
assert isinstance(n, (c_ast.Struct, c_ast.Union))
|
|
members = n.decls
|
|
body_function = self._generate_struct_union_body
|
|
else:
|
|
assert name == "enum"
|
|
assert isinstance(n, c_ast.Enum)
|
|
members = None if n.values is None else n.values.enumerators
|
|
body_function = self._generate_enum_body
|
|
s = name + " " + (n.name or "")
|
|
if members is not None:
|
|
# None means no members
|
|
# Empty sequence means an empty list of members
|
|
s += "\n"
|
|
s += self._make_indent()
|
|
self.indent_level += 2
|
|
s += "{\n"
|
|
s += body_function(members)
|
|
self.indent_level -= 2
|
|
s += self._make_indent() + "}"
|
|
return s
|
|
|
|
def _generate_struct_union_body(self, members: List[c_ast.Node]) -> str:
|
|
return "".join(self._generate_stmt(decl) for decl in members)
|
|
|
|
def _generate_enum_body(self, members: List[c_ast.Enumerator]) -> str:
|
|
# `[:-2] + '\n'` removes the final `,` from the enumerator list
|
|
return "".join(self.visit(value) for value in members)[:-2] + "\n"
|
|
|
|
def _generate_stmt(self, n: c_ast.Node, add_indent: bool = False) -> str:
|
|
"""Generation from a statement node. This method exists as a wrapper
|
|
for individual visit_* methods to handle different treatment of
|
|
some statements in this context.
|
|
"""
|
|
if add_indent:
|
|
self.indent_level += 2
|
|
indent = self._make_indent()
|
|
if add_indent:
|
|
self.indent_level -= 2
|
|
|
|
match n:
|
|
case (
|
|
c_ast.Decl()
|
|
| c_ast.Assignment()
|
|
| c_ast.Cast()
|
|
| c_ast.UnaryOp()
|
|
| c_ast.BinaryOp()
|
|
| c_ast.TernaryOp()
|
|
| c_ast.FuncCall()
|
|
| c_ast.ArrayRef()
|
|
| c_ast.StructRef()
|
|
| c_ast.Constant()
|
|
| c_ast.ID()
|
|
| c_ast.Typedef()
|
|
| c_ast.ExprList()
|
|
):
|
|
# These can also appear in an expression context so no semicolon
|
|
# is added to them automatically
|
|
#
|
|
return indent + self.visit(n) + ";\n"
|
|
case c_ast.Compound():
|
|
# No extra indentation required before the opening brace of a
|
|
# compound - because it consists of multiple lines it has to
|
|
# compute its own indentation.
|
|
#
|
|
return self.visit(n)
|
|
case c_ast.If():
|
|
return indent + self.visit(n)
|
|
case _:
|
|
return indent + self.visit(n) + "\n"
|
|
|
|
def _generate_decl(self, n: c_ast.Decl) -> str:
|
|
"""Generation from a Decl node."""
|
|
s = ""
|
|
if n.funcspec:
|
|
s = " ".join(n.funcspec) + " "
|
|
if n.storage:
|
|
s += " ".join(n.storage) + " "
|
|
if n.align:
|
|
s += self.visit(n.align[0]) + " "
|
|
s += self._generate_type(n.type)
|
|
return s
|
|
|
|
def _generate_type(
|
|
self,
|
|
n: c_ast.Node,
|
|
modifiers: List[c_ast.Node] = [],
|
|
emit_declname: bool = True,
|
|
) -> str:
|
|
"""Recursive generation from a type node. n is the type node.
|
|
modifiers collects the PtrDecl, ArrayDecl and FuncDecl modifiers
|
|
encountered on the way down to a TypeDecl, to allow proper
|
|
generation from it.
|
|
"""
|
|
# ~ print(n, modifiers)
|
|
match n:
|
|
case c_ast.TypeDecl():
|
|
s = ""
|
|
if n.quals:
|
|
s += " ".join(n.quals) + " "
|
|
s += self.visit(n.type)
|
|
|
|
nstr = n.declname if n.declname and emit_declname else ""
|
|
# Resolve modifiers.
|
|
# Wrap in parens to distinguish pointer to array and pointer to
|
|
# function syntax.
|
|
#
|
|
for i, modifier in enumerate(modifiers):
|
|
match modifier:
|
|
case c_ast.ArrayDecl():
|
|
if i != 0 and isinstance(modifiers[i - 1], c_ast.PtrDecl):
|
|
nstr = "(" + nstr + ")"
|
|
nstr += "["
|
|
if modifier.dim_quals:
|
|
nstr += " ".join(modifier.dim_quals) + " "
|
|
if modifier.dim is not None:
|
|
nstr += self.visit(modifier.dim)
|
|
nstr += "]"
|
|
case c_ast.FuncDecl():
|
|
if i != 0 and isinstance(modifiers[i - 1], c_ast.PtrDecl):
|
|
nstr = "(" + nstr + ")"
|
|
args = (
|
|
self.visit(modifier.args)
|
|
if modifier.args is not None
|
|
else ""
|
|
)
|
|
nstr += "(" + args + ")"
|
|
case c_ast.PtrDecl():
|
|
if modifier.quals:
|
|
quals = " ".join(modifier.quals)
|
|
suffix = f" {nstr}" if nstr else ""
|
|
nstr = f"* {quals}{suffix}"
|
|
else:
|
|
nstr = "*" + nstr
|
|
if nstr:
|
|
s += " " + nstr
|
|
return s
|
|
case c_ast.Decl():
|
|
return self._generate_decl(n.type)
|
|
case c_ast.Typename():
|
|
return self._generate_type(n.type, emit_declname=emit_declname)
|
|
case c_ast.IdentifierType():
|
|
return " ".join(n.names) + " "
|
|
case c_ast.ArrayDecl() | c_ast.PtrDecl() | c_ast.FuncDecl():
|
|
return self._generate_type(
|
|
n.type, modifiers + [n], emit_declname=emit_declname
|
|
)
|
|
case _:
|
|
return self.visit(n)
|
|
|
|
def _parenthesize_if(
|
|
self, n: c_ast.Node, condition: Callable[[c_ast.Node], bool]
|
|
) -> str:
|
|
"""Visits 'n' and returns its string representation, parenthesized
|
|
if the condition function applied to the node returns True.
|
|
"""
|
|
s = self._visit_expr(n)
|
|
if condition(n):
|
|
return "(" + s + ")"
|
|
else:
|
|
return s
|
|
|
|
def _parenthesize_unless_simple(self, n: c_ast.Node) -> str:
|
|
"""Common use case for _parenthesize_if"""
|
|
return self._parenthesize_if(n, lambda d: not self._is_simple_node(d))
|
|
|
|
def _is_simple_node(self, n: c_ast.Node) -> bool:
|
|
"""Returns True for nodes that are "simple" - i.e. nodes that always
|
|
have higher precedence than operators.
|
|
"""
|
|
return isinstance(
|
|
n,
|
|
(c_ast.Constant, c_ast.ID, c_ast.ArrayRef, c_ast.StructRef, c_ast.FuncCall),
|
|
)
|