93 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			93 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | """API for traversing the AST nodes. Implemented by the compiler and
 | ||
|  | meta introspection. | ||
|  | """
 | ||
|  | import typing as t | ||
|  | 
 | ||
|  | from .nodes import Node | ||
|  | 
 | ||
|  | if t.TYPE_CHECKING: | ||
|  |     import typing_extensions as te | ||
|  | 
 | ||
|  |     class VisitCallable(te.Protocol): | ||
|  |         def __call__(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any: | ||
|  |             ... | ||
|  | 
 | ||
|  | 
 | ||
|  | class NodeVisitor: | ||
|  |     """Walks the abstract syntax tree and call visitor functions for every
 | ||
|  |     node found.  The visitor functions may return values which will be | ||
|  |     forwarded by the `visit` method. | ||
|  | 
 | ||
|  |     Per default the visitor functions for the nodes are ``'visit_'`` + | ||
|  |     class name of the node.  So a `TryFinally` node visit function would | ||
|  |     be `visit_TryFinally`.  This behavior can be changed by overriding | ||
|  |     the `get_visitor` function.  If no visitor function exists for a node | ||
|  |     (return value `None`) the `generic_visit` visitor is used instead. | ||
|  |     """
 | ||
|  | 
 | ||
|  |     def get_visitor(self, node: Node) -> "t.Optional[VisitCallable]": | ||
|  |         """Return the visitor function for this node or `None` if no visitor
 | ||
|  |         exists for this node.  In that case the generic visit function is | ||
|  |         used instead. | ||
|  |         """
 | ||
|  |         return getattr(self, f"visit_{type(node).__name__}", None) | ||
|  | 
 | ||
|  |     def visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any: | ||
|  |         """Visit a node.""" | ||
|  |         f = self.get_visitor(node) | ||
|  | 
 | ||
|  |         if f is not None: | ||
|  |             return f(node, *args, **kwargs) | ||
|  | 
 | ||
|  |         return self.generic_visit(node, *args, **kwargs) | ||
|  | 
 | ||
|  |     def generic_visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any: | ||
|  |         """Called if no explicit visitor function exists for a node.""" | ||
|  |         for child_node in node.iter_child_nodes(): | ||
|  |             self.visit(child_node, *args, **kwargs) | ||
|  | 
 | ||
|  | 
 | ||
|  | class NodeTransformer(NodeVisitor): | ||
|  |     """Walks the abstract syntax tree and allows modifications of nodes.
 | ||
|  | 
 | ||
|  |     The `NodeTransformer` will walk the AST and use the return value of the | ||
|  |     visitor functions to replace or remove the old node.  If the return | ||
|  |     value of the visitor function is `None` the node will be removed | ||
|  |     from the previous location otherwise it's replaced with the return | ||
|  |     value.  The return value may be the original node in which case no | ||
|  |     replacement takes place. | ||
|  |     """
 | ||
|  | 
 | ||
|  |     def generic_visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> Node: | ||
|  |         for field, old_value in node.iter_fields(): | ||
|  |             if isinstance(old_value, list): | ||
|  |                 new_values = [] | ||
|  |                 for value in old_value: | ||
|  |                     if isinstance(value, Node): | ||
|  |                         value = self.visit(value, *args, **kwargs) | ||
|  |                         if value is None: | ||
|  |                             continue | ||
|  |                         elif not isinstance(value, Node): | ||
|  |                             new_values.extend(value) | ||
|  |                             continue | ||
|  |                     new_values.append(value) | ||
|  |                 old_value[:] = new_values | ||
|  |             elif isinstance(old_value, Node): | ||
|  |                 new_node = self.visit(old_value, *args, **kwargs) | ||
|  |                 if new_node is None: | ||
|  |                     delattr(node, field) | ||
|  |                 else: | ||
|  |                     setattr(node, field, new_node) | ||
|  |         return node | ||
|  | 
 | ||
|  |     def visit_list(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.List[Node]: | ||
|  |         """As transformers may return lists in some places this method
 | ||
|  |         can be used to enforce a list as return value. | ||
|  |         """
 | ||
|  |         rv = self.visit(node, *args, **kwargs) | ||
|  | 
 | ||
|  |         if not isinstance(rv, list): | ||
|  |             return [rv] | ||
|  | 
 | ||
|  |         return rv |