Files
clawd/venv/lib/python3.12/site-packages/pycparser/c_generator.py

574 lines
20 KiB
Python

# ------------------------------------------------------------------------------
# pycparser: c_generator.py
#
# C code generator from pycparser AST nodes.
#
# Eli Bendersky [https://eli.thegreenplace.net/]
# License: BSD
# ------------------------------------------------------------------------------
from typing import Callable, List, Optional
from . import c_ast
class CGenerator:
"""Uses the same visitor pattern as c_ast.NodeVisitor, but modified to
return a value from each visit method, using string accumulation in
generic_visit.
"""
indent_level: int
reduce_parentheses: bool
def __init__(self, reduce_parentheses: bool = False) -> None:
"""Constructs C-code generator
reduce_parentheses:
if True, eliminates needless parentheses on binary operators
"""
# Statements start with indentation of self.indent_level spaces, using
# the _make_indent method.
self.indent_level = 0
self.reduce_parentheses = reduce_parentheses
def _make_indent(self) -> str:
return " " * self.indent_level
def visit(self, node: c_ast.Node) -> str:
method = "visit_" + node.__class__.__name__
return getattr(self, method, self.generic_visit)(node)
def generic_visit(self, node: Optional[c_ast.Node]) -> str:
if node is None:
return ""
else:
return "".join(self.visit(c) for c_name, c in node.children())
def visit_Constant(self, n: c_ast.Constant) -> str:
return n.value
def visit_ID(self, n: c_ast.ID) -> str:
return n.name
def visit_Pragma(self, n: c_ast.Pragma) -> str:
ret = "#pragma"
if n.string:
ret += " " + n.string
return ret
def visit_ArrayRef(self, n: c_ast.ArrayRef) -> str:
arrref = self._parenthesize_unless_simple(n.name)
return arrref + "[" + self.visit(n.subscript) + "]"
def visit_StructRef(self, n: c_ast.StructRef) -> str:
sref = self._parenthesize_unless_simple(n.name)
return sref + n.type + self.visit(n.field)
def visit_FuncCall(self, n: c_ast.FuncCall) -> str:
fref = self._parenthesize_unless_simple(n.name)
args = self.visit(n.args) if n.args is not None else ""
return fref + "(" + args + ")"
def visit_UnaryOp(self, n: c_ast.UnaryOp) -> str:
match n.op:
case "sizeof":
# Always parenthesize the argument of sizeof since it can be
# a name.
return f"sizeof({self.visit(n.expr)})"
case "p++":
operand = self._parenthesize_unless_simple(n.expr)
return f"{operand}++"
case "p--":
operand = self._parenthesize_unless_simple(n.expr)
return f"{operand}--"
case _:
operand = self._parenthesize_unless_simple(n.expr)
return f"{n.op}{operand}"
# Precedence map of binary operators:
precedence_map = {
# Should be in sync with c_parser.CParser.precedence
# Higher numbers are stronger binding
"||": 0, # weakest binding
"&&": 1,
"|": 2,
"^": 3,
"&": 4,
"==": 5,
"!=": 5,
">": 6,
">=": 6,
"<": 6,
"<=": 6,
">>": 7,
"<<": 7,
"+": 8,
"-": 8,
"*": 9,
"/": 9,
"%": 9, # strongest binding
}
def visit_BinaryOp(self, n: c_ast.BinaryOp) -> str:
# Note: all binary operators are left-to-right associative
#
# If `n.left.op` has a stronger or equally binding precedence in
# comparison to `n.op`, no parenthesis are needed for the left:
# e.g., `(a*b) + c` is equivalent to `a*b + c`, as well as
# `(a+b) - c` is equivalent to `a+b - c` (same precedence).
# If the left operator is weaker binding than the current, then
# parentheses are necessary:
# e.g., `(a+b) * c` is NOT equivalent to `a+b * c`.
lval_str = self._parenthesize_if(
n.left,
lambda d: not (
self._is_simple_node(d)
or self.reduce_parentheses
and isinstance(d, c_ast.BinaryOp)
and self.precedence_map[d.op] >= self.precedence_map[n.op]
),
)
# If `n.right.op` has a stronger -but not equal- binding precedence,
# parenthesis can be omitted on the right:
# e.g., `a + (b*c)` is equivalent to `a + b*c`.
# If the right operator is weaker or equally binding, then parentheses
# are necessary:
# e.g., `a * (b+c)` is NOT equivalent to `a * b+c` and
# `a - (b+c)` is NOT equivalent to `a - b+c` (same precedence).
rval_str = self._parenthesize_if(
n.right,
lambda d: not (
self._is_simple_node(d)
or self.reduce_parentheses
and isinstance(d, c_ast.BinaryOp)
and self.precedence_map[d.op] > self.precedence_map[n.op]
),
)
return f"{lval_str} {n.op} {rval_str}"
def visit_Assignment(self, n: c_ast.Assignment) -> str:
rval_str = self._parenthesize_if(
n.rvalue, lambda n: isinstance(n, c_ast.Assignment)
)
return f"{self.visit(n.lvalue)} {n.op} {rval_str}"
def visit_IdentifierType(self, n: c_ast.IdentifierType) -> str:
return " ".join(n.names)
def _visit_expr(self, n: c_ast.Node) -> str:
match n:
case c_ast.InitList():
return "{" + self.visit(n) + "}"
case c_ast.ExprList() | c_ast.Compound():
return "(" + self.visit(n) + ")"
case _:
return self.visit(n)
def visit_Decl(self, n: c_ast.Decl, no_type: bool = False) -> str:
# no_type is used when a Decl is part of a DeclList, where the type is
# explicitly only for the first declaration in a list.
#
s = n.name if no_type else self._generate_decl(n)
if n.bitsize:
s += " : " + self.visit(n.bitsize)
if n.init:
s += " = " + self._visit_expr(n.init)
return s
def visit_DeclList(self, n: c_ast.DeclList) -> str:
s = self.visit(n.decls[0])
if len(n.decls) > 1:
s += ", " + ", ".join(
self.visit_Decl(decl, no_type=True) for decl in n.decls[1:]
)
return s
def visit_Typedef(self, n: c_ast.Typedef) -> str:
s = ""
if n.storage:
s += " ".join(n.storage) + " "
s += self._generate_type(n.type)
return s
def visit_Cast(self, n: c_ast.Cast) -> str:
s = "(" + self._generate_type(n.to_type, emit_declname=False) + ")"
return s + " " + self._parenthesize_unless_simple(n.expr)
def visit_ExprList(self, n: c_ast.ExprList) -> str:
visited_subexprs = []
for expr in n.exprs:
visited_subexprs.append(self._visit_expr(expr))
return ", ".join(visited_subexprs)
def visit_InitList(self, n: c_ast.InitList) -> str:
visited_subexprs = []
for expr in n.exprs:
visited_subexprs.append(self._visit_expr(expr))
return ", ".join(visited_subexprs)
def visit_Enum(self, n: c_ast.Enum) -> str:
return self._generate_struct_union_enum(n, name="enum")
def visit_Alignas(self, n: c_ast.Alignas) -> str:
return "_Alignas({})".format(self.visit(n.alignment))
def visit_Enumerator(self, n: c_ast.Enumerator) -> str:
if not n.value:
return "{indent}{name},\n".format(
indent=self._make_indent(),
name=n.name,
)
else:
return "{indent}{name} = {value},\n".format(
indent=self._make_indent(),
name=n.name,
value=self.visit(n.value),
)
def visit_FuncDef(self, n: c_ast.FuncDef) -> str:
decl = self.visit(n.decl)
self.indent_level = 0
body = self.visit(n.body)
if n.param_decls:
knrdecls = ";\n".join(self.visit(p) for p in n.param_decls)
return decl + "\n" + knrdecls + ";\n" + body + "\n"
else:
return decl + "\n" + body + "\n"
def visit_FileAST(self, n: c_ast.FileAST) -> str:
s = ""
for ext in n.ext:
match ext:
case c_ast.FuncDef():
s += self.visit(ext)
case c_ast.Pragma():
s += self.visit(ext) + "\n"
case _:
s += self.visit(ext) + ";\n"
return s
def visit_Compound(self, n: c_ast.Compound) -> str:
s = self._make_indent() + "{\n"
self.indent_level += 2
if n.block_items:
s += "".join(self._generate_stmt(stmt) for stmt in n.block_items)
self.indent_level -= 2
s += self._make_indent() + "}\n"
return s
def visit_CompoundLiteral(self, n: c_ast.CompoundLiteral) -> str:
return "(" + self.visit(n.type) + "){" + self.visit(n.init) + "}"
def visit_EmptyStatement(self, n: c_ast.EmptyStatement) -> str:
return ";"
def visit_ParamList(self, n: c_ast.ParamList) -> str:
return ", ".join(self.visit(param) for param in n.params)
def visit_Return(self, n: c_ast.Return) -> str:
s = "return"
if n.expr:
s += " " + self.visit(n.expr)
return s + ";"
def visit_Break(self, n: c_ast.Break) -> str:
return "break;"
def visit_Continue(self, n: c_ast.Continue) -> str:
return "continue;"
def visit_TernaryOp(self, n: c_ast.TernaryOp) -> str:
s = "(" + self._visit_expr(n.cond) + ") ? "
s += "(" + self._visit_expr(n.iftrue) + ") : "
s += "(" + self._visit_expr(n.iffalse) + ")"
return s
def visit_If(self, n: c_ast.If) -> str:
s = "if ("
if n.cond:
s += self.visit(n.cond)
s += ")\n"
s += self._generate_stmt(n.iftrue, add_indent=True)
if n.iffalse:
s += self._make_indent() + "else\n"
s += self._generate_stmt(n.iffalse, add_indent=True)
return s
def visit_For(self, n: c_ast.For) -> str:
s = "for ("
if n.init:
s += self.visit(n.init)
s += ";"
if n.cond:
s += " " + self.visit(n.cond)
s += ";"
if n.next:
s += " " + self.visit(n.next)
s += ")\n"
s += self._generate_stmt(n.stmt, add_indent=True)
return s
def visit_While(self, n: c_ast.While) -> str:
s = "while ("
if n.cond:
s += self.visit(n.cond)
s += ")\n"
s += self._generate_stmt(n.stmt, add_indent=True)
return s
def visit_DoWhile(self, n: c_ast.DoWhile) -> str:
s = "do\n"
s += self._generate_stmt(n.stmt, add_indent=True)
s += self._make_indent() + "while ("
if n.cond:
s += self.visit(n.cond)
s += ");"
return s
def visit_StaticAssert(self, n: c_ast.StaticAssert) -> str:
s = "_Static_assert("
s += self.visit(n.cond)
if n.message:
s += ","
s += self.visit(n.message)
s += ")"
return s
def visit_Switch(self, n: c_ast.Switch) -> str:
s = "switch (" + self.visit(n.cond) + ")\n"
s += self._generate_stmt(n.stmt, add_indent=True)
return s
def visit_Case(self, n: c_ast.Case) -> str:
s = "case " + self.visit(n.expr) + ":\n"
for stmt in n.stmts:
s += self._generate_stmt(stmt, add_indent=True)
return s
def visit_Default(self, n: c_ast.Default) -> str:
s = "default:\n"
for stmt in n.stmts:
s += self._generate_stmt(stmt, add_indent=True)
return s
def visit_Label(self, n: c_ast.Label) -> str:
return n.name + ":\n" + self._generate_stmt(n.stmt)
def visit_Goto(self, n: c_ast.Goto) -> str:
return "goto " + n.name + ";"
def visit_EllipsisParam(self, n: c_ast.EllipsisParam) -> str:
return "..."
def visit_Struct(self, n: c_ast.Struct) -> str:
return self._generate_struct_union_enum(n, "struct")
def visit_Typename(self, n: c_ast.Typename) -> str:
return self._generate_type(n.type)
def visit_Union(self, n: c_ast.Union) -> str:
return self._generate_struct_union_enum(n, "union")
def visit_NamedInitializer(self, n: c_ast.NamedInitializer) -> str:
s = ""
for name in n.name:
if isinstance(name, c_ast.ID):
s += "." + name.name
else:
s += "[" + self.visit(name) + "]"
s += " = " + self._visit_expr(n.expr)
return s
def visit_FuncDecl(self, n: c_ast.FuncDecl) -> str:
return self._generate_type(n)
def visit_ArrayDecl(self, n: c_ast.ArrayDecl) -> str:
return self._generate_type(n, emit_declname=False)
def visit_TypeDecl(self, n: c_ast.TypeDecl) -> str:
return self._generate_type(n, emit_declname=False)
def visit_PtrDecl(self, n: c_ast.PtrDecl) -> str:
return self._generate_type(n, emit_declname=False)
def _generate_struct_union_enum(
self, n: c_ast.Struct | c_ast.Union | c_ast.Enum, name: str
) -> str:
"""Generates code for structs, unions, and enums. name should be
'struct', 'union', or 'enum'.
"""
if name in ("struct", "union"):
assert isinstance(n, (c_ast.Struct, c_ast.Union))
members = n.decls
body_function = self._generate_struct_union_body
else:
assert name == "enum"
assert isinstance(n, c_ast.Enum)
members = None if n.values is None else n.values.enumerators
body_function = self._generate_enum_body
s = name + " " + (n.name or "")
if members is not None:
# None means no members
# Empty sequence means an empty list of members
s += "\n"
s += self._make_indent()
self.indent_level += 2
s += "{\n"
s += body_function(members)
self.indent_level -= 2
s += self._make_indent() + "}"
return s
def _generate_struct_union_body(self, members: List[c_ast.Node]) -> str:
return "".join(self._generate_stmt(decl) for decl in members)
def _generate_enum_body(self, members: List[c_ast.Enumerator]) -> str:
# `[:-2] + '\n'` removes the final `,` from the enumerator list
return "".join(self.visit(value) for value in members)[:-2] + "\n"
def _generate_stmt(self, n: c_ast.Node, add_indent: bool = False) -> str:
"""Generation from a statement node. This method exists as a wrapper
for individual visit_* methods to handle different treatment of
some statements in this context.
"""
if add_indent:
self.indent_level += 2
indent = self._make_indent()
if add_indent:
self.indent_level -= 2
match n:
case (
c_ast.Decl()
| c_ast.Assignment()
| c_ast.Cast()
| c_ast.UnaryOp()
| c_ast.BinaryOp()
| c_ast.TernaryOp()
| c_ast.FuncCall()
| c_ast.ArrayRef()
| c_ast.StructRef()
| c_ast.Constant()
| c_ast.ID()
| c_ast.Typedef()
| c_ast.ExprList()
):
# These can also appear in an expression context so no semicolon
# is added to them automatically
#
return indent + self.visit(n) + ";\n"
case c_ast.Compound():
# No extra indentation required before the opening brace of a
# compound - because it consists of multiple lines it has to
# compute its own indentation.
#
return self.visit(n)
case c_ast.If():
return indent + self.visit(n)
case _:
return indent + self.visit(n) + "\n"
def _generate_decl(self, n: c_ast.Decl) -> str:
"""Generation from a Decl node."""
s = ""
if n.funcspec:
s = " ".join(n.funcspec) + " "
if n.storage:
s += " ".join(n.storage) + " "
if n.align:
s += self.visit(n.align[0]) + " "
s += self._generate_type(n.type)
return s
def _generate_type(
self,
n: c_ast.Node,
modifiers: List[c_ast.Node] = [],
emit_declname: bool = True,
) -> str:
"""Recursive generation from a type node. n is the type node.
modifiers collects the PtrDecl, ArrayDecl and FuncDecl modifiers
encountered on the way down to a TypeDecl, to allow proper
generation from it.
"""
# ~ print(n, modifiers)
match n:
case c_ast.TypeDecl():
s = ""
if n.quals:
s += " ".join(n.quals) + " "
s += self.visit(n.type)
nstr = n.declname if n.declname and emit_declname else ""
# Resolve modifiers.
# Wrap in parens to distinguish pointer to array and pointer to
# function syntax.
#
for i, modifier in enumerate(modifiers):
match modifier:
case c_ast.ArrayDecl():
if i != 0 and isinstance(modifiers[i - 1], c_ast.PtrDecl):
nstr = "(" + nstr + ")"
nstr += "["
if modifier.dim_quals:
nstr += " ".join(modifier.dim_quals) + " "
if modifier.dim is not None:
nstr += self.visit(modifier.dim)
nstr += "]"
case c_ast.FuncDecl():
if i != 0 and isinstance(modifiers[i - 1], c_ast.PtrDecl):
nstr = "(" + nstr + ")"
args = (
self.visit(modifier.args)
if modifier.args is not None
else ""
)
nstr += "(" + args + ")"
case c_ast.PtrDecl():
if modifier.quals:
quals = " ".join(modifier.quals)
suffix = f" {nstr}" if nstr else ""
nstr = f"* {quals}{suffix}"
else:
nstr = "*" + nstr
if nstr:
s += " " + nstr
return s
case c_ast.Decl():
return self._generate_decl(n.type)
case c_ast.Typename():
return self._generate_type(n.type, emit_declname=emit_declname)
case c_ast.IdentifierType():
return " ".join(n.names) + " "
case c_ast.ArrayDecl() | c_ast.PtrDecl() | c_ast.FuncDecl():
return self._generate_type(
n.type, modifiers + [n], emit_declname=emit_declname
)
case _:
return self.visit(n)
def _parenthesize_if(
self, n: c_ast.Node, condition: Callable[[c_ast.Node], bool]
) -> str:
"""Visits 'n' and returns its string representation, parenthesized
if the condition function applied to the node returns True.
"""
s = self._visit_expr(n)
if condition(n):
return "(" + s + ")"
else:
return s
def _parenthesize_unless_simple(self, n: c_ast.Node) -> str:
"""Common use case for _parenthesize_if"""
return self._parenthesize_if(n, lambda d: not self._is_simple_node(d))
def _is_simple_node(self, n: c_ast.Node) -> bool:
"""Returns True for nodes that are "simple" - i.e. nodes that always
have higher precedence than operators.
"""
return isinstance(
n,
(c_ast.Constant, c_ast.ID, c_ast.ArrayRef, c_ast.StructRef, c_ast.FuncCall),
)