Skip to content
This repository was archived by the owner on Sep 12, 2024. It is now read-only.

Commit b99ee72

Browse files
committed
expression typeinfo implemented
1 parent 4fdfafd commit b99ee72

File tree

4 files changed

+169
-106
lines changed

4 files changed

+169
-106
lines changed

jaclang/compiler/absyntree.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def __init__(self, kid: Sequence[AstNode]) -> None:
5050
self.meta: dict[str, str] = {}
5151
self.loc: CodeLocInfo = CodeLocInfo(*self.resolve_tok_range())
5252

53+
# NOTE: This is only applicable for Expr, However adding it there needs to call the constructor in all the
54+
# subclasses, Adding it here, this needs a review.
55+
self.expr_type: str = ""
56+
5357
@property
5458
def sym_tab(self) -> SymbolTable:
5559
"""Get symbol table."""

jaclang/compiler/passes/main/fuse_typeinfo_pass.py

Lines changed: 72 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from __future__ import annotations
88

9+
from types import MethodType
910
from typing import Callable, TypeVar
1011

1112
import jaclang.compiler.absyntree as ast
@@ -14,7 +15,6 @@
1415
from jaclang.utils.helpers import pascal_to_snake
1516
from jaclang.vendor.mypy.nodes import Node as VNode # bit of a hack
1617

17-
1818
import mypy.nodes as MypyNodes # noqa N812
1919
import mypy.types as MypyTypes # noqa N812
2020
from mypy.checkexpr import Type as MyType
@@ -23,11 +23,82 @@
2323
T = TypeVar("T", bound=ast.AstSymbolNode)
2424

2525

26+
# List of expression nodes which we'll be extracting the type info from.
27+
JAC_EXPR_NODES = (
28+
ast.AwaitExpr,
29+
ast.BinaryExpr,
30+
ast.CompareExpr,
31+
ast.BoolExpr,
32+
ast.LambdaExpr,
33+
ast.UnaryExpr,
34+
ast.IfElseExpr,
35+
ast.AtomTrailer,
36+
ast.AtomUnit,
37+
ast.YieldExpr,
38+
ast.YieldExpr,
39+
ast.FuncCall,
40+
ast.EdgeRefTrailer,
41+
ast.ListVal,
42+
ast.SetVal,
43+
ast.TupleVal,
44+
ast.DictVal,
45+
ast.ListCompr,
46+
ast.DictCompr,
47+
)
48+
49+
2650
class FuseTypeInfoPass(Pass):
2751
"""Python and bytecode file self.__debug_printing pass."""
2852

2953
node_type_hash: dict[MypyNodes.Node | VNode, MyType] = {}
3054

55+
@staticmethod
56+
def enter_expr(self: FuseTypeInfoPass, node: ast.Expr) -> None:
57+
"""
58+
Enter an expression node.
59+
60+
This function is dynamically bound as a method on insntace of this class, since the
61+
group of functions to handle expressions has a the exact same logic.
62+
"""
63+
if len(node.gen.mypy_ast) == 0:
64+
return
65+
66+
# If the corrosponding mypy ast node type has stored here, get the values.
67+
mypy_node = node.gen.mypy_ast[0]
68+
if mypy_node in self.node_type_hash:
69+
mytype: MyType = self.node_type_hash[mypy_node]
70+
node.expr_type = str(mytype)
71+
72+
# TODO: Maybe move this out of the function otherwise it'll construct this dict every time it entered an
73+
# expression. Time and memory wasted here.
74+
collection_types_map = {
75+
ast.ListVal: "builtins.list",
76+
ast.SetVal: "builtins.set",
77+
ast.TupleVal: "builtins.tuple",
78+
ast.DictVal: "builtins.dict",
79+
ast.ListCompr: None,
80+
ast.DictCompr: None,
81+
}
82+
83+
# Set they symbol type for collection expression.
84+
if type(node) in tuple(collection_types_map.keys()):
85+
assert isinstance(node, ast.AtomExpr) # To make mypy happy.
86+
if mypy_node in self.node_type_hash:
87+
node.name_spec.sym_type = str(mytype)
88+
collection_type = collection_types_map[type(node)]
89+
if collection_type is not None:
90+
node.name_spec.sym_type = collection_type
91+
92+
def __init__(self, *args, **kwargs) -> None: # noqa: ANN002, ANN003
93+
"""Initialize the FuseTpeInfoPass instance."""
94+
for expr_node in JAC_EXPR_NODES:
95+
method_name = "enter_" + pascal_to_snake(expr_node.__name__)
96+
method = MethodType(
97+
FuseTypeInfoPass.__handle_node(FuseTypeInfoPass.enter_expr), self
98+
)
99+
setattr(self, method_name, method)
100+
super().__init__(*args, **kwargs)
101+
31102
def __debug_print(self, *argv: object) -> None:
32103
if settings.fuse_type_info_debug:
33104
self.log_info("FuseTypeInfo::", *argv)
@@ -310,54 +381,6 @@ def enter_f_string(self, node: ast.FString) -> None:
310381
"""Pass handler for FString nodes."""
311382
self.__debug_print("Getting type not supported in", type(node))
312383

313-
@__handle_node
314-
def enter_list_val(self, node: ast.ListVal) -> None:
315-
"""Pass handler for ListVal nodes."""
316-
mypy_node = node.gen.mypy_ast[0]
317-
if mypy_node in self.node_type_hash:
318-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
319-
else:
320-
node.name_spec.sym_type = "builtins.list"
321-
322-
@__handle_node
323-
def enter_set_val(self, node: ast.SetVal) -> None:
324-
"""Pass handler for SetVal nodes."""
325-
mypy_node = node.gen.mypy_ast[0]
326-
if mypy_node in self.node_type_hash:
327-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
328-
else:
329-
node.name_spec.sym_type = "builtins.set"
330-
331-
@__handle_node
332-
def enter_tuple_val(self, node: ast.TupleVal) -> None:
333-
"""Pass handler for TupleVal nodes."""
334-
mypy_node = node.gen.mypy_ast[0]
335-
if mypy_node in self.node_type_hash:
336-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
337-
else:
338-
node.name_spec.sym_type = "builtins.tuple"
339-
340-
@__handle_node
341-
def enter_dict_val(self, node: ast.DictVal) -> None:
342-
"""Pass handler for DictVal nodes."""
343-
mypy_node = node.gen.mypy_ast[0]
344-
if mypy_node in self.node_type_hash:
345-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
346-
else:
347-
node.name_spec.sym_type = "builtins.dict"
348-
349-
@__handle_node
350-
def enter_list_compr(self, node: ast.ListCompr) -> None:
351-
"""Pass handler for ListCompr nodes."""
352-
mypy_node = node.gen.mypy_ast[0]
353-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
354-
355-
@__handle_node
356-
def enter_dict_compr(self, node: ast.DictCompr) -> None:
357-
"""Pass handler for DictCompr nodes."""
358-
mypy_node = node.gen.mypy_ast[0]
359-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
360-
361384
@__handle_node
362385
def enter_index_slice(self, node: ast.IndexSlice) -> None:
363386
"""Pass handler for IndexSlice nodes."""

jaclang/compiler/passes/utils/mypy_ast_build.py

Lines changed: 91 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,20 @@
44

55
import ast
66
import os
7+
from types import MethodType
78

89
from jaclang.compiler.absyntree import AstNode
910
from jaclang.compiler.passes import Pass
1011
from jaclang.compiler.passes.main.fuse_typeinfo_pass import (
1112
FuseTypeInfoPass,
1213
)
14+
from jaclang.utils.helpers import pascal_to_snake
1315

1416
import mypy.build as myb
1517
import mypy.checkexpr as mycke
1618
import mypy.errors as mye
1719
import mypy.fastparse as myfp
20+
import mypy.nodes as mypy_nodes
1821
from mypy.build import BuildSource
1922
from mypy.build import BuildSourceSet
2023
from mypy.build import FileSystemCache
@@ -29,6 +32,55 @@
2932
from mypy.semanal_main import semantic_analysis_for_scc
3033

3134

35+
# All the expression nodes of mypy.
36+
EXPRESSION_NODES = (
37+
mypy_nodes.AssertTypeExpr,
38+
mypy_nodes.AssignmentExpr,
39+
mypy_nodes.AwaitExpr,
40+
mypy_nodes.BytesExpr,
41+
mypy_nodes.CallExpr,
42+
mypy_nodes.CastExpr,
43+
mypy_nodes.ComparisonExpr,
44+
mypy_nodes.ComplexExpr,
45+
mypy_nodes.ConditionalExpr,
46+
mypy_nodes.DictionaryComprehension,
47+
mypy_nodes.DictExpr,
48+
mypy_nodes.EllipsisExpr,
49+
mypy_nodes.EnumCallExpr,
50+
mypy_nodes.Expression,
51+
mypy_nodes.FloatExpr,
52+
mypy_nodes.GeneratorExpr,
53+
mypy_nodes.IndexExpr,
54+
mypy_nodes.IntExpr,
55+
mypy_nodes.LambdaExpr,
56+
mypy_nodes.ListComprehension,
57+
mypy_nodes.ListExpr,
58+
mypy_nodes.MemberExpr,
59+
mypy_nodes.NamedTupleExpr,
60+
mypy_nodes.NameExpr,
61+
mypy_nodes.NewTypeExpr,
62+
mypy_nodes.OpExpr,
63+
mypy_nodes.ParamSpecExpr,
64+
mypy_nodes.PromoteExpr,
65+
mypy_nodes.RefExpr,
66+
mypy_nodes.RevealExpr,
67+
mypy_nodes.SetComprehension,
68+
mypy_nodes.SetExpr,
69+
mypy_nodes.SliceExpr,
70+
mypy_nodes.StarExpr,
71+
mypy_nodes.StrExpr,
72+
mypy_nodes.SuperExpr,
73+
mypy_nodes.TupleExpr,
74+
mypy_nodes.TypeAliasExpr,
75+
mypy_nodes.TypedDictExpr,
76+
mypy_nodes.TypeVarExpr,
77+
mypy_nodes.TypeVarTupleExpr,
78+
mypy_nodes.UnaryExpr,
79+
mypy_nodes.YieldExpr,
80+
mypy_nodes.YieldFromExpr,
81+
)
82+
83+
3284
mypy_to_jac_node_map: dict[
3385
tuple[int, int | None, int | None, int | None], list[AstNode]
3486
] = {}
@@ -87,63 +139,45 @@ def __init__(
87139
"""Override to mypy expression checker for direct AST pass through."""
88140
super().__init__(tc, msg, plugin, per_line_checking_time_ns)
89141

90-
def visit_list_expr(self, e: mycke.ListExpr) -> mycke.Type:
91-
"""Type check a list expression [...]."""
92-
out = super().visit_list_expr(e)
93-
FuseTypeInfoPass.node_type_hash[e] = out
94-
return out
95-
96-
def visit_set_expr(self, e: mycke.SetExpr) -> mycke.Type:
97-
"""Type check a set expression {...}."""
98-
out = super().visit_set_expr(e)
99-
FuseTypeInfoPass.node_type_hash[e] = out
100-
return out
101-
102-
def visit_tuple_expr(self, e: myfp.TupleExpr) -> myb.Type:
103-
"""Type check a tuple expression (...)."""
104-
out = super().visit_tuple_expr(e)
105-
FuseTypeInfoPass.node_type_hash[e] = out
106-
return out
107-
108-
def visit_dict_expr(self, e: myfp.DictExpr) -> myb.Type:
109-
"""Type check a dictionary expression {...}."""
110-
out = super().visit_dict_expr(e)
111-
FuseTypeInfoPass.node_type_hash[e] = out
112-
return out
113-
114-
def visit_list_comprehension(self, e: myfp.ListComprehension) -> myb.Type:
115-
"""Type check a list comprehension."""
116-
out = super().visit_list_comprehension(e)
117-
FuseTypeInfoPass.node_type_hash[e] = out
118-
return out
119-
120-
def visit_set_comprehension(self, e: myfp.SetComprehension) -> myb.Type:
121-
"""Type check a set comprehension."""
122-
out = super().visit_set_comprehension(e)
123-
FuseTypeInfoPass.node_type_hash[e] = out
124-
return out
125-
126-
def visit_generator_expr(self, e: myfp.GeneratorExpr) -> myb.Type:
127-
"""Type check a generator expression."""
128-
out = super().visit_generator_expr(e)
129-
FuseTypeInfoPass.node_type_hash[e] = out
130-
return out
131-
132-
def visit_dictionary_comprehension(
133-
self, e: myfp.DictionaryComprehension
134-
) -> myb.Type:
135-
"""Type check a dict comprehension."""
136-
out = super().visit_dictionary_comprehension(e)
137-
FuseTypeInfoPass.node_type_hash[e] = out
138-
return out
139-
140-
def visit_member_expr(
141-
self, e: myfp.MemberExpr, is_lvalue: bool = False
142-
) -> myb.Type:
143-
"""Type check a member expr."""
144-
out = super().visit_member_expr(e, is_lvalue)
145-
FuseTypeInfoPass.node_type_hash[e] = out
146-
return out
142+
# For every expression there, create attach a method on this instance (self) named "enter_expr()"
143+
for expr_node in EXPRESSION_NODES:
144+
method_name = "visit_" + pascal_to_snake(expr_node.__name__)
145+
146+
# We call the super() version of the method so ensure the parent class has the method or else continue.
147+
if not hasattr(mycke.ExpressionChecker, method_name):
148+
continue
149+
150+
# If the method already overriden then don't override it again here. Continue. Note that the method exists
151+
# on the parent class and if it's also exists on this class and it's a different object that means it's
152+
# overrident method.
153+
if getattr(mycke.ExpressionChecker, method_name) != getattr(
154+
ExpressionChecker, method_name
155+
):
156+
continue
157+
158+
# Since the "closure" function bellow captures the method name inside it, we cannot use it directly as the
159+
# "method_name" variable is used inside a loop and by the time the closure close the "method_name" value,
160+
# it'll be changed by the loop, so we need another method ("make_closure") to persist the value.
161+
def make_closure(method_name: str): # noqa: ANN201
162+
def closure(
163+
self: ExpressionChecker,
164+
e: mycke.Expression,
165+
*args, # noqa: ANN002
166+
**kwargs, # noqa: ANN003
167+
) -> mycke.Type:
168+
# Ignore B023 here since we bind loop variable properly but flake8 raise a false alarm
169+
# (in some version of it), a bug in flake8 (https://github.com/PyCQA/flake8-bugbear/issues/269).
170+
out = getattr(mycke.ExpressionChecker, method_name)( # noqa: B023
171+
self, e, *args, **kwargs
172+
)
173+
FuseTypeInfoPass.node_type_hash[e] = out
174+
return out
175+
176+
return closure
177+
178+
# Attach the new "visit_expr()" method to this instance.
179+
method = make_closure(method_name)
180+
setattr(self, method_name, MethodType(method, self))
147181

148182

149183
class State(myb.State):

jaclang/utils/treeprinter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def __node_repr_in_tree(node: AstNode) -> str:
114114
)
115115
out += f" SymbolPath: {symbol}"
116116
return out
117+
elif isinstance(node, ast.Expr):
118+
return f"{node.__class__.__name__} - Type: {node.expr_type}"
117119
elif isinstance(node, Token):
118120
return f"{node.__class__.__name__} - {node.value}, {access}"
119121
elif (

0 commit comments

Comments
 (0)