Using syntax trees in Python

July 2, 2024

I've recently been working on a project where I needed to parse and update Django model files by adding new fields. Initially, I used plain string manipulation for parsing, but it quickly became too complex and error-prone. To address this, I switched to using syntax trees, specifically concrete syntax trees (CSTs), for parsing. In this post, I'll discuss how I leveraged syntax trees in Python to handle Django model files more effectively.

For those unfamiliar with syntax trees, I'll provide a brief introduction to abstract syntax trees (ASTs), what they are and how they differ from concrete syntax trees (CSTs). I'll then demonstrate how to parse Python code into ASTs and CSTs using the built-in ast module and the LibCST library, respectively. To keep things more relevant to the work I've been doing, after the introduction to both ASTs and CSTs, I'll stick to using CSTs with LibCST for the remainder of the post. Bring your snorkel and let's dive in! 🤿

What are ASTs?

ASTs are tree representations of the abstract syntactic structure of source code written in a programming language. They capture the structure of the code without including every detail of the original code. ASTs are commonly used in compilers and interpreters to analyze, transform, and generate code.

In Python, the ast module provides tools for working with ASTs. The module includes functions for parsing Python source code into an AST and for traversing, analyzing, and modifying the AST.

Parsing Python code into an AST

Let's start by creating an example Python file that we'll parse into an AST. Here's a simple Python script that we'll use for the remainder of this post in the file example.py:

  a = 1 # Assigning a value to variable a
  b = 2 # Assigning a value to variable b
  c = a + b # Assigning the sum of a and b to variable c

Reading and parsing the Python file

To parse this code into an AST, we can use the ast.parse() function from the ast module:

  import ast

  with open("example.py", "r") as file: # Read the Python file
      source = file.read()

  tree = ast.parse(source)

The ast.parse() function takes a string of Python source code and returns the root of the AST. We read the contents of the example.py file into the source variable and then parse it using ast.parse(). We now have the AST representation of the code in the tree variable, and we can print it to see the structure of the AST using the ast.dump() function:

  print(ast.dump(tree, indent=4))

The ast.dump() function prints the AST in a readable format with optional indentation. Running this code will output the following:

  Module(
      body=[
          Assign(
              targets=[
                  Name(id='a', ctx=Store())],
              value=Constant(value=1)),
          Assign(
              targets=[
                  Name(id='b', ctx=Store())],
              value=Constant(value=2)),
          Assign(
              targets=[
                  Name(id='c', ctx=Store())],
              value=BinOp(
                  left=Name(id='a', ctx=Load()),
                  op=Add(),
                  right=Name(id='b', ctx=Load())))],
      type_ignores=[])

With this output, you can see the structure of the AST, including the nodes representing the assignments and the binary operation. The AST abstracts away details like whitespace and comments, focusing on the syntactic structure of the code.

However, for our particular use case this is no good, as we need to retain the exact syntax of the code, including whitespace and comments. Thankfully, this is where CSTs come in!

What are CSTs?

CSTs are a type of syntax tree that preserves the concrete syntax of the source code. Unlike traditional ASTs, which abstract away details like whitespace and comments, CSTs retain these details. This can be useful when you need to work with the exact syntax of the code, such as when you're modifying existing code.

In Python, the Instagram-developed LibCST library provides support for working with CSTs. LibCST is a concrete syntax tree parser and serializer for Python source code. It allows you to parse Python code into a CST, manipulate the CST, and serialize it back into Python code.

Parsing Python code into a CST

To parse Python code into a CST using LibCST, you first need to install the library. You can install it using pip:

  pip install libcst

Reading and parsing the Python file

Once you have LibCST installed, you can parse our example.py file into a CST using the parse_module() function from libcst:

  import libcst as cst

  with open("example.py", "r") as file: # Read the Python file
      source = file.read()

  tree = cst.parse_module(source)

The parse_module() function takes a string of Python source code and returns the root of the CST. We read the contents of the example.py file into the source variable and then parse it using parse_module(). We now have the CST representation of the code in the tree variable, and we can print it to see the structure of the CST by simply printing the tree variable:

  print(tree)

Running this code will output the following:

  Module(
      body=[
          SimpleStatementLine(
              body=[
                  Assign(
                      targets=[
                          AssignTarget(
                              target=Name(
                                  value='a',
                                  lpar=[],
                                  rpar=[],
                              ),
                              whitespace_before_equal=SimpleWhitespace(
                                  value=' ',
                              ),
                              whitespace_after_equal=SimpleWhitespace(
                                  value=' ',
                              ),
                          ),
                      ],
                      value=Integer(
                          value='1',
                          lpar=[],
                          rpar=[],
                      ),
                      semicolon=MaybeSentinel.DEFAULT,
                  ),
              ],
              leading_lines=[],
              trailing_whitespace=TrailingWhitespace(
                  whitespace=SimpleWhitespace(
                      value=' ',
                  ),
                  comment=Comment(
                      value='# Assigning a value to variable a',
                  ),
                  newline=Newline(
                      value=None,
                  ),
              ),
          ),
          SimpleStatementLine(
              body=[
                  Assign(
                      targets=[
                          AssignTarget(
                              target=Name(
                                  value='b',
                                  lpar=[],
                                  rpar=[],
                              ),
                              whitespace_before_equal=SimpleWhitespace(
                                  value=' ',
                              ),
                              whitespace_after_equal=SimpleWhitespace(
                                  value=' ',
                              ),
                          ),
                      ],
                      value=Integer(
                          value='2',
                          lpar=[],
                          rpar=[],
                      ),
                      semicolon=MaybeSentinel.DEFAULT,
                  ),
              ],
              leading_lines=[],
              trailing_whitespace=TrailingWhitespace(
                  whitespace=SimpleWhitespace(
                      value=' ',
                  ),
                  comment=Comment(
                      value='# Assigning a value to variable b',
                  ),
                  newline=Newline(
                      value=None,
                  ),
              ),
          ),
          SimpleStatementLine(
              body=[
                  Assign(
                      targets=[
                          AssignTarget(
                              target=Name(
                                  value='c',
                                  lpar=[],
                                  rpar=[],
                              ),
                              whitespace_before_equal=SimpleWhitespace(
                                  value=' ',
                              ),
                              whitespace_after_equal=SimpleWhitespace(
                                  value=' ',
                              ),
                          ),
                      ],
                      value=BinaryOperation(
                          left=Name(
                              value='a',
                              lpar=[],
                              rpar=[],
                          ),
                          operator=Add(
                              whitespace_before=SimpleWhitespace(
                                  value=' ',
                              ),
                              whitespace_after=SimpleWhitespace(
                                  value=' ',
                              ),
                          ),
                          right=Name(
                              value='b',
                              lpar=[],
                              rpar=[],
                          ),
                          lpar=[],
                          rpar=[],
                      ),
                      semicolon=MaybeSentinel.DEFAULT,
                  ),
              ],
              leading_lines=[],
              trailing_whitespace=TrailingWhitespace(
                  whitespace=SimpleWhitespace(
                      value=' ',
                  ),
                  comment=Comment(
                      value='# Assigning the sum of a and b to variable c',
                  ),
                  newline=Newline(
                      value=None,
                  ),
              ),
          ),
      ],
      header=[
          EmptyLine(
              indent=True,
              whitespace=SimpleWhitespace(
                  value='',
              ),
              comment=None,
              newline=Newline(
                  value=None,
              ),
          ),
      ],
      footer=[],
      encoding='utf-8',
      default_indent='    ',
      default_newline='\n',
      has_trailing_newline=True,
  )

As you can see, the CST retains the exact syntax of the code, including whitespace and comments; however, it can also make the CST more verbose and complex compared to a traditional AST.

Modifying a CST

Let's say we want to replace the assignment c = a + b with d = b - a in our example.py file. We can achieve this by manipulating the CST; we can walk the tree, locate the node representing the assignment we want to replace, and then modify the node in place! Here's how we can do it using LibCST.

Modifying a tree using a CSTTransformer

In LibCST, to modify a CST, you first need to traverse the tree and locate the node you want to modify. There are two main ways to traverse a CST: using a CSTVisitor or a CSTTransformer. A CSTVisitor allows you to visit nodes in the tree without modifying them, while a CSTTransformer allows you to modify nodes as you traverse the tree. Since we want to modify the tree, we'll use a CSTTransformer.

  import libcst as cst

  class Transformer(cst.CSTTransformer):
    ...

Filtering based on node type

With a CSTTransformer, you define methods that handle specific node types in the tree. When the transformer traverses the tree, it calls the appropriate method for each node type. You can then modify the node in the method and return the modified node. There are two main methods you can define in a CSTTransformer: visit_<NodeType> and leave_<NodeType>. The visit_<NodeType> method is called when the transformer enters a node of type NodeType, and the leave_<NodeType> method is called when the transformer exits a node of type NodeType. Examining the output of our CST from the previous section, we can see that each line in the code is represented by a SimpleStatementLine node. We can define a transformer that replaces the assignment c = a + b with d = b - a by first targeting the SimpleStatementLine nodes.

  import libcst as cst

  class Transformer(cst.CSTTransformer):
    def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine):
      ...

Some of the other node types you might encounter in the CST include:

  • Assign: Represents an assignment statement,
  • AssignTarget: Represents the target of an assignment,
  • Name: Represents a variable name,
  • BinaryOperation: Represents a binary operation,
  • Add: Represents an addition operation,
  • Subtract: Represents a subtraction operation,
  • ...

You can find a complete list of node types here, but for the most part you can simply print out the entire tree and look at the structure to determine the node types you need to target.

Locating the node to modify

You may be wondering... how can we determine the line we want to modify? Well, we can use the matchers module from LibCST to match specific patterns in the CST. The matchers module provides a way to define patterns that match specific nodes in the CST. We can use these patterns to locate the nodes we want to modify. In our case, we want to iterate over the SimpleStatementLine nodes and match the Assign node that assigns the sum of a and b to c; or in other words, we want to look for the Assign node where the target is c.

  import libcst as cst
  import libcst.matchers as m

  class Transformer(cst.CSTTransformer):
    def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine):
      if m.matches(
            original_node,
            m.SimpleStatementLine(
                body=[
                    m.Assign(
                        targets=[
                            m.AssignTarget(
                                target=m.Name(
                                    value='c'
                                )
                            )
                        ]
                    )
                ]
            )
        ):
          ...

Matching is very verbose, as you can see from the example above, but it allows you to target specific nodes in the CST based on their structure. The matches() function takes the node you want to match against and a pattern that describes the structure you're looking for.

We have to use the m. prefix before each node type, rather than cst., because the m module provides a set of matchers that allow you to define patterns for matching nodes in the CST. If you were to use cst. instead of m., you would be comparing the node directly, which would not work as expected, i.e., it would return an error indicating the node type's __init__ method is missing it's value argument, since the cst. prefix is used to create new nodes.

Modifying the node

Once we've located the node we want to modify, we can replace it with the new assignment d = b - a. We can do this by creating a new Assign node with the updated target and value and returning it from the method. We'll also add a comment to the new assignment to explain what it does.

  import libcst as cst
  import libcst.matchers as m

  class Transformer(cst.CSTTransformer):
    def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine):
      if m.matches(
            original_node,
            m.SimpleStatementLine(
                body=[
                    m.Assign(
                        targets=[
                            m.AssignTarget(
                                target=m.Name(
                                    value='c'
                                )
                            )
                        ]
                    )
                ]
            )
        ):
          return original_node.with_changes(
            body=[
                cst.Assign(
                    targets=[
                        cst.AssignTarget(
                            target=cst.Name(
                                value='d'
                            )
                        )
                    ],
                    value=cst.BinaryOperation(
                        left=cst.Name(
                            value='b'
                        ),
                        operator=cst.Subtract(
                            whitespace_before=cst.SimpleWhitespace(value=' '),
                            whitespace_after=cst.SimpleWhitespace(value=' ')
                        ),
                        right=cst.Name(
                            value='a'
                        )
                    )
                )
            ],
            trailing_whitespace=cst.TrailingWhitespace(
              whitespace=cst.SimpleWhitespace(
                value=' ',
              ),
              comment=cst.Comment(
                value='# Assigning d to the difference of b and a'
              )
            )
          )

      return original_node

In this method, we use the with_changes() method to create a new Assign node with the updated target and value. We then return the new Assign node with the comment explaining the assignment. If the node doesn't match the pattern we're looking for, we return the original node unchanged. This way, we only modify the nodes we're interested in.

Applying the transformer and serializing the CST

Finally, we can apply the transformer to the CST and serialize the modified CST back into Python code using the code attribute of the tree:

  import libcst as cst

  with open("example.py", "r") as file: # Read the Python file
      source = file.read()

  tree = cst.parse_module(source)

  new_tree = tree.visit(Transformer()) # Using the Transformer oulined above

  with open("example.py", "w") as file: # Write the modified Python file
      file.write(new_tree.code)

Running this code will modify the example.py file to replace the assignment c = a + b with d = b - a. You can verify the changes by opening the example.py file and checking that the contents match the following expected output:

  a = 1 # Assigning a value to variable a
  b = 2 # Assigning a value to variable b
  d = b - a # Assigning d to the difference of b and a

Conclusion

In this post, we explored the usage of ASTs and CSTs in Python for parsing and modifying source code. We started by parsing Python code into ASTs using the built-in ast module and into CSTs using the LibCST library. We then demonstrated how to modify a CST by replacing an assignment in the code. Using CSTs with LibCST can be a powerful tool when you need to work with the exact syntax of the code, such as when modifying existing code. I hope this post has given you a better understanding of how to leverage ASTs and CSTs in Python for parsing and manipulating source code!

References