Skip to content
This repository was archived by the owner on Sep 12, 2024. It is now read-only.

Commit dae5363

Browse files
committed
enter_expr refactored
1 parent a9c6a2b commit dae5363

File tree

3 files changed

+108
-113
lines changed

3 files changed

+108
-113
lines changed

jaclang/compiler/passes/main/fuse_typeinfo_pass.py

Lines changed: 105 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@
66

77
from __future__ import annotations
88

9-
from types import MethodType
109
from typing import Callable, Optional, TypeVar
1110

1211
import jaclang.compiler.absyntree as ast
1312
from jaclang.compiler.passes import Pass
14-
from jaclang.compiler.passes.transform import Transform
1513
from jaclang.settings import settings
1614
from jaclang.utils.helpers import pascal_to_snake
1715
from jaclang.vendor.mypy.nodes import Node as VNode # bit of a hack
@@ -24,97 +22,32 @@
2422
T = TypeVar("T", bound=ast.AstSymbolNode)
2523

2624

27-
# List of expression nodes which we'll be extracting the type info from.
28-
JAC_EXPR_NODES = (
29-
ast.AwaitExpr,
30-
ast.BinaryExpr,
31-
ast.CompareExpr,
32-
ast.BoolExpr,
33-
ast.LambdaExpr,
34-
ast.UnaryExpr,
35-
ast.IfElseExpr,
36-
ast.AtomTrailer,
37-
ast.AtomUnit,
38-
ast.YieldExpr,
39-
ast.YieldExpr,
40-
ast.FuncCall,
41-
ast.EdgeRefTrailer,
42-
ast.ListVal,
43-
ast.SetVal,
44-
ast.TupleVal,
45-
ast.DictVal,
46-
ast.ListCompr,
47-
ast.DictCompr,
48-
)
49-
50-
5125
class FuseTypeInfoPass(Pass):
5226
"""Python and bytecode file self.__debug_printing pass."""
5327

5428
node_type_hash: dict[MypyNodes.Node | VNode, MyType] = {}
5529

56-
@staticmethod
57-
def enter_expr(self: FuseTypeInfoPass, node: ast.Expr) -> None:
58-
"""
59-
Enter an expression node.
60-
61-
This function is dynamically bound as a method on insntace of this class, since the
62-
group of functions to handle expressions has a the exact same logic.
63-
"""
64-
if len(node.gen.mypy_ast) == 0:
65-
return
66-
67-
# If the corrosponding mypy ast node type has stored here, get the values.
68-
mypy_node = node.gen.mypy_ast[0]
69-
if mypy_node in self.node_type_hash:
70-
mytype: MyType = self.node_type_hash[mypy_node]
71-
node.expr_type = str(mytype)
72-
73-
# TODO: Maybe move this out of the function otherwise it'll construct this dict every time it entered an
74-
# expression. Time and memory wasted here.
75-
collection_types_map = {
76-
ast.ListVal: "builtins.list",
77-
ast.SetVal: "builtins.set",
78-
ast.TupleVal: "builtins.tuple",
79-
ast.DictVal: "builtins.dict",
80-
ast.ListCompr: None,
81-
ast.DictCompr: None,
82-
}
83-
84-
# Set they symbol type for collection expression.
85-
if type(node) in tuple(collection_types_map.keys()):
86-
assert isinstance(node, ast.AtomExpr) # To make mypy happy.
87-
if mypy_node in self.node_type_hash:
88-
node.name_spec.sym_type = str(mytype)
89-
collection_type = collection_types_map[type(node)]
90-
if collection_type is not None:
91-
node.name_spec.sym_type = collection_type
92-
93-
def __init__(self, input_ir: T, prior: Optional[Transform]) -> None:
94-
"""Initialize the FuseTpeInfoPass instance."""
95-
for expr_node in JAC_EXPR_NODES:
96-
method_name = "enter_" + pascal_to_snake(expr_node.__name__)
97-
method = MethodType(
98-
FuseTypeInfoPass.__handle_node(FuseTypeInfoPass.enter_expr), self
99-
)
100-
setattr(self, method_name, method)
101-
super().__init__(input_ir, prior)
30+
# Override this to support enter expression.
31+
def enter_node(self, node: ast.AstNode) -> None:
32+
"""Run on entering node."""
33+
if hasattr(self, f"enter_{pascal_to_snake(type(node).__name__)}"):
34+
getattr(self, f"enter_{pascal_to_snake(type(node).__name__)}")(node)
35+
elif isinstance(node, ast.Expr):
36+
self.enter_expr(node)
10237

10338
def __debug_print(self, *argv: object) -> None:
10439
if settings.fuse_type_info_debug:
10540
self.log_info("FuseTypeInfo::", *argv)
10641

107-
def __call_type_handler(
108-
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.ProperType
109-
) -> None:
42+
def __call_type_handler(self, mypy_type: MypyTypes.Type) -> Optional[str]:
11043
mypy_type_name = pascal_to_snake(mypy_type.__class__.__name__)
11144
type_handler_name = f"get_type_from_{mypy_type_name}"
11245
if hasattr(self, type_handler_name):
113-
getattr(self, type_handler_name)(node, mypy_type)
114-
else:
115-
self.__debug_print(
116-
f'{node.loc}"MypyTypes::{mypy_type.__class__.__name__}" isn\'t supported yet'
117-
)
46+
return getattr(self, type_handler_name)(mypy_type)
47+
self.__debug_print(
48+
f'"MypyTypes::{mypy_type.__class__.__name__}" isn\'t supported yet'
49+
)
50+
return None
11851

11952
def __set_sym_table_link(self, node: ast.AstSymbolNode) -> None:
12053
typ = node.sym_type.split(".")
@@ -244,7 +177,9 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
244177
mypy_node = mypy_node.node
245178

246179
if isinstance(mypy_node, (MypyNodes.Var, MypyNodes.FuncDef)):
247-
self.__call_type_handler(node, mypy_node.type)
180+
node.name_spec.sym_type = (
181+
self.__call_type_handler(mypy_node.type) or node.name_spec.sym_type
182+
)
248183

249184
elif isinstance(mypy_node, MypyNodes.MypyFile):
250185
node.name_spec.sym_type = "types.ModuleType"
@@ -253,7 +188,10 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
253188
node.name_spec.sym_type = mypy_node.fullname
254189

255190
elif isinstance(mypy_node, MypyNodes.OverloadedFuncDef):
256-
self.__call_type_handler(node, mypy_node.items[0].func.type)
191+
node.name_spec.sym_type = (
192+
self.__call_type_handler(mypy_node.items[0].func.type)
193+
or node.name_spec.sym_type
194+
)
257195

258196
elif mypy_node is None:
259197
node.name_spec.sym_type = "None"
@@ -269,17 +207,67 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
269207
node.name_spec.sym_type = mypy_node.fullname
270208
self.__set_sym_table_link(node)
271209
elif isinstance(mypy_node, MypyNodes.FuncDef):
272-
self.__call_type_handler(node, mypy_node.type)
210+
node.name_spec.sym_type = (
211+
self.__call_type_handler(mypy_node.type) or node.name_spec.sym_type
212+
)
273213
elif isinstance(mypy_node, MypyNodes.Argument):
274-
self.__call_type_handler(node, mypy_node.variable.type)
214+
node.name_spec.sym_type = (
215+
self.__call_type_handler(mypy_node.variable.type)
216+
or node.name_spec.sym_type
217+
)
275218
elif isinstance(mypy_node, MypyNodes.Decorator):
276-
self.__call_type_handler(node, mypy_node.func.type.ret_type)
219+
node.name_spec.sym_type = (
220+
self.__call_type_handler(mypy_node.func.type.ret_type)
221+
or node.name_spec.sym_type
222+
)
277223
else:
278224
self.__debug_print(
279225
f'"{node.loc}::{node.__class__.__name__}" mypy node isn\'t supported',
280226
type(mypy_node),
281227
)
282228

229+
# NOTE: Since expression nodes are not AstSymbolNodes, I'm not decorating this with __handle_node
230+
# and IMO instead of checking if it's a symbol node or an expression, we somehow mark expressions as
231+
# valid nodes that can have symbols. At this point I'm leaving this like this and lemme know
232+
# otherwise.
233+
def enter_expr(self: FuseTypeInfoPass, node: ast.Expr) -> None:
234+
"""
235+
Enter an expression node.
236+
237+
This function is dynamically bound as a method on insntace of this class, since the
238+
group of functions to handle expressions has a the exact same logic.
239+
"""
240+
if len(node.gen.mypy_ast) == 0:
241+
return
242+
243+
# If the corrosponding mypy ast node type has stored here, get the values.
244+
mypy_node = node.gen.mypy_ast[0]
245+
if mypy_node in self.node_type_hash:
246+
mytype: MyType = self.node_type_hash[mypy_node]
247+
node.expr_type = self.__call_type_handler(mytype) or ""
248+
249+
# TODO: Maybe move this out of the function otherwise it'll construct this dict every time it entered an
250+
# expression. Time and memory wasted here.
251+
collection_types_map = {
252+
ast.ListVal: "builtins.list",
253+
ast.SetVal: "builtins.set",
254+
ast.TupleVal: "builtins.tuple",
255+
ast.DictVal: "builtins.dict",
256+
ast.ListCompr: None,
257+
ast.DictCompr: None,
258+
}
259+
260+
# Set they symbol type for collection expression.
261+
if type(node) in tuple(collection_types_map.keys()):
262+
assert isinstance(node, ast.AtomExpr) # To make mypy happy.
263+
collection_type = collection_types_map[type(node)]
264+
if collection_type is not None:
265+
node.name_spec.sym_type = collection_type
266+
if mypy_node in self.node_type_hash:
267+
node.name_spec.sym_type = (
268+
self.__call_type_handler(mytype) or node.name_spec.sym_type
269+
)
270+
283271
@__handle_node
284272
def enter_name(self, node: ast.NameAtom) -> None:
285273
"""Pass handler for name nodes."""
@@ -319,7 +307,10 @@ def enter_enum_def(self, node: ast.EnumDef) -> None:
319307
def enter_ability(self, node: ast.Ability) -> None:
320308
"""Pass handler for Ability nodes."""
321309
if isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef):
322-
self.__call_type_handler(node, node.gen.mypy_ast[0].type.ret_type)
310+
node.name_spec.sym_type = (
311+
self.__call_type_handler(node.gen.mypy_ast[0].type.ret_type)
312+
or node.name_spec.sym_type
313+
)
323314
else:
324315
self.__debug_print(
325316
f"{node.loc}: Can't get type of an ability from mypy node other than Ability.",
@@ -330,7 +321,10 @@ def enter_ability(self, node: ast.Ability) -> None:
330321
def enter_ability_def(self, node: ast.AbilityDef) -> None:
331322
"""Pass handler for AbilityDef nodes."""
332323
if isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef):
333-
self.__call_type_handler(node, node.gen.mypy_ast[0].type.ret_type)
324+
node.name_spec.sym_type = (
325+
self.__call_type_handler(node.gen.mypy_ast[0].type.ret_type)
326+
or node.name_spec.sym_type
327+
)
334328
else:
335329
self.__debug_print(
336330
f"{node.loc}: Can't get type of an AbilityDef from mypy node other than FuncDef.",
@@ -343,7 +337,10 @@ def enter_param_var(self, node: ast.ParamVar) -> None:
343337
if isinstance(node.gen.mypy_ast[0], MypyNodes.Argument):
344338
mypy_node: MypyNodes.Argument = node.gen.mypy_ast[0]
345339
if mypy_node.variable.type:
346-
self.__call_type_handler(node, mypy_node.variable.type)
340+
node.name_spec.sym_type = (
341+
self.__call_type_handler(mypy_node.variable.type)
342+
or node.name_spec.sym_type
343+
)
347344
else:
348345
self.__debug_print(
349346
f"{node.loc}: Can't get parameter value from mypyNode other than Argument"
@@ -357,7 +354,9 @@ def enter_has_var(self, node: ast.HasVar) -> None:
357354
if isinstance(mypy_node, MypyNodes.AssignmentStmt):
358355
n = mypy_node.lvalues[0].node
359356
if isinstance(n, (MypyNodes.Var, MypyNodes.FuncDef)):
360-
self.__call_type_handler(node, n.type)
357+
node.name_spec.sym_type = (
358+
self.__call_type_handler(n.type) or node.name_spec.sym_type
359+
)
361360
else:
362361
self.__debug_print(
363362
"Getting type of 'AssignmentStmt' is only supported with Var and FuncDef"
@@ -396,7 +395,9 @@ def enter_arch_ref(self, node: ast.ArchRef) -> None:
396395
self.__set_sym_table_link(node)
397396
elif isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef):
398397
mypy_node2: MypyNodes.FuncDef = node.gen.mypy_ast[0]
399-
self.__call_type_handler(node, mypy_node2.type)
398+
node.name_spec.sym_type = (
399+
self.__call_type_handler(mypy_node2.type) or node.name_spec.sym_type
400+
)
400401
else:
401402
self.__debug_print(
402403
f"{node.loc}: Can't get ArchRef value from mypyNode other than ClassDef",
@@ -448,42 +449,34 @@ def enter_builtin_type(self, node: ast.BuiltinType) -> None:
448449
"""Pass handler for BuiltinType nodes."""
449450
self.__collect_type_from_symbol(node)
450451

451-
def get_type_from_instance(
452-
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.Instance
453-
) -> None:
452+
def get_type_from_instance(self, mypy_type: MypyTypes.Instance) -> Optional[str]:
454453
"""Get type info from mypy type Instance."""
455-
node.name_spec.sym_type = str(mypy_type)
454+
return str(mypy_type)
456455

457456
def get_type_from_callable_type(
458-
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.CallableType
459-
) -> None:
457+
self, mypy_type: MypyTypes.CallableType
458+
) -> Optional[str]:
460459
"""Get type info from mypy type CallableType."""
461-
node.name_spec.sym_type = str(mypy_type.ret_type)
460+
return str(mypy_type.ret_type)
462461

463462
# TODO: Which overloaded function to get the return value from?
464463
def get_type_from_overloaded(
465-
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.Overloaded
466-
) -> None:
464+
self, mypy_type: MypyTypes.Overloaded
465+
) -> Optional[str]:
467466
"""Get type info from mypy type Overloaded."""
468-
self.__call_type_handler(node, mypy_type.items[0])
467+
return self.__call_type_handler(mypy_type.items[0])
469468

470-
def get_type_from_none_type(
471-
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.NoneType
472-
) -> None:
469+
def get_type_from_none_type(self, mypy_type: MypyTypes.NoneType) -> Optional[str]:
473470
"""Get type info from mypy type NoneType."""
474-
node.name_spec.sym_type = "None"
471+
return "None"
475472

476-
def get_type_from_any_type(
477-
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.AnyType
478-
) -> None:
473+
def get_type_from_any_type(self, mypy_type: MypyTypes.AnyType) -> Optional[str]:
479474
"""Get type info from mypy type NoneType."""
480-
node.name_spec.sym_type = "Any"
475+
return "Any"
481476

482-
def get_type_from_tuple_type(
483-
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.TupleType
484-
) -> None:
477+
def get_type_from_tuple_type(self, mypy_type: MypyTypes.TupleType) -> Optional[str]:
485478
"""Get type info from mypy type TupleType."""
486-
node.name_spec.sym_type = "builtins.tuple"
479+
return "builtins.tuple"
487480

488481
def exit_assignment(self, node: ast.Assignment) -> None:
489482
"""Add new symbols in the symbol table in case of self."""

jaclang/compiler/passes/main/tests/test_type_check_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,6 @@ def test_type_coverage(self) -> None:
5959
self.assertIn("HasVar - species - Type: builtins.str", out)
6060
self.assertIn("myDog - Type: type_info.Dog", out)
6161
self.assertIn("Body - Type: type_info.Dog.Body", out)
62-
self.assertEqual(out.count("Type: builtins.str"), 28)
62+
self.assertEqual(out.count("Type: builtins.str"), 29)
6363
for i in lis:
6464
self.assertNotIn(i, out)

jaclang/utils/treeprinter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def __node_repr_in_tree(node: AstNode) -> str:
135135
)
136136
out += f" SymbolPath: {symbol}"
137137
return out
138+
elif isinstance(node, ast.Expr):
139+
return f"{node.__class__.__name__} - Type: {node.expr_type}"
138140
else:
139141
return f"{node.__class__.__name__}, {access}"
140142

0 commit comments

Comments
 (0)