66
77from __future__ import annotations
88
9- from types import MethodType
109from typing import Callable , Optional , TypeVar
1110
1211import jaclang .compiler .absyntree as ast
1312from jaclang .compiler .passes import Pass
14- from jaclang .compiler .passes .transform import Transform
1513from jaclang .settings import settings
1614from jaclang .utils .helpers import pascal_to_snake
1715from jaclang .vendor .mypy .nodes import Node as VNode # bit of a hack
2422T = TypeVar ("T" , bound = ast .AstSymbolNode )
2523
2624
27- # List of expression nodes which we'll be extracting the type info from.
28- JAC_EXPR_NODES = (
29- ast .AwaitExpr ,
30- ast .BinaryExpr ,
31- ast .CompareExpr ,
32- ast .BoolExpr ,
33- ast .LambdaExpr ,
34- ast .UnaryExpr ,
35- ast .IfElseExpr ,
36- ast .AtomTrailer ,
37- ast .AtomUnit ,
38- ast .YieldExpr ,
39- ast .YieldExpr ,
40- ast .FuncCall ,
41- ast .EdgeRefTrailer ,
42- ast .ListVal ,
43- ast .SetVal ,
44- ast .TupleVal ,
45- ast .DictVal ,
46- ast .ListCompr ,
47- ast .DictCompr ,
48- )
49-
50-
5125class FuseTypeInfoPass (Pass ):
5226 """Python and bytecode file self.__debug_printing pass."""
5327
5428 node_type_hash : dict [MypyNodes .Node | VNode , MyType ] = {}
5529
56- @staticmethod
57- def enter_expr (self : FuseTypeInfoPass , node : ast .Expr ) -> None :
58- """
59- Enter an expression node.
60-
61- This function is dynamically bound as a method on insntace of this class, since the
62- group of functions to handle expressions has a the exact same logic.
63- """
64- if len (node .gen .mypy_ast ) == 0 :
65- return
66-
67- # If the corrosponding mypy ast node type has stored here, get the values.
68- mypy_node = node .gen .mypy_ast [0 ]
69- if mypy_node in self .node_type_hash :
70- mytype : MyType = self .node_type_hash [mypy_node ]
71- node .expr_type = str (mytype )
72-
73- # TODO: Maybe move this out of the function otherwise it'll construct this dict every time it entered an
74- # expression. Time and memory wasted here.
75- collection_types_map = {
76- ast .ListVal : "builtins.list" ,
77- ast .SetVal : "builtins.set" ,
78- ast .TupleVal : "builtins.tuple" ,
79- ast .DictVal : "builtins.dict" ,
80- ast .ListCompr : None ,
81- ast .DictCompr : None ,
82- }
83-
84- # Set they symbol type for collection expression.
85- if type (node ) in tuple (collection_types_map .keys ()):
86- assert isinstance (node , ast .AtomExpr ) # To make mypy happy.
87- if mypy_node in self .node_type_hash :
88- node .name_spec .sym_type = str (mytype )
89- collection_type = collection_types_map [type (node )]
90- if collection_type is not None :
91- node .name_spec .sym_type = collection_type
92-
93- def __init__ (self , input_ir : T , prior : Optional [Transform ]) -> None :
94- """Initialize the FuseTpeInfoPass instance."""
95- for expr_node in JAC_EXPR_NODES :
96- method_name = "enter_" + pascal_to_snake (expr_node .__name__ )
97- method = MethodType (
98- FuseTypeInfoPass .__handle_node (FuseTypeInfoPass .enter_expr ), self
99- )
100- setattr (self , method_name , method )
101- super ().__init__ (input_ir , prior )
30+ # Override this to support enter expression.
31+ def enter_node (self , node : ast .AstNode ) -> None :
32+ """Run on entering node."""
33+ if hasattr (self , f"enter_{ pascal_to_snake (type (node ).__name__ )} " ):
34+ getattr (self , f"enter_{ pascal_to_snake (type (node ).__name__ )} " )(node )
35+ elif isinstance (node , ast .Expr ):
36+ self .enter_expr (node )
10237
10338 def __debug_print (self , * argv : object ) -> None :
10439 if settings .fuse_type_info_debug :
10540 self .log_info ("FuseTypeInfo::" , * argv )
10641
107- def __call_type_handler (
108- self , node : ast .AstSymbolNode , mypy_type : MypyTypes .ProperType
109- ) -> None :
42+ def __call_type_handler (self , mypy_type : MypyTypes .Type ) -> Optional [str ]:
11043 mypy_type_name = pascal_to_snake (mypy_type .__class__ .__name__ )
11144 type_handler_name = f"get_type_from_{ mypy_type_name } "
11245 if hasattr (self , type_handler_name ):
113- getattr (self , type_handler_name )(node , mypy_type )
114- else :
115- self . __debug_print (
116- f' { node . loc } "MypyTypes:: { mypy_type . __class__ . __name__ } " isn \' t supported yet'
117- )
46+ return getattr (self , type_handler_name )(mypy_type )
47+ self . __debug_print (
48+ f'"MypyTypes:: { mypy_type . __class__ . __name__ } " isn \' t supported yet'
49+ )
50+ return None
11851
11952 def __set_sym_table_link (self , node : ast .AstSymbolNode ) -> None :
12053 typ = node .sym_type .split ("." )
@@ -244,7 +177,9 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
244177 mypy_node = mypy_node .node
245178
246179 if isinstance (mypy_node , (MypyNodes .Var , MypyNodes .FuncDef )):
247- self .__call_type_handler (node , mypy_node .type )
180+ node .name_spec .sym_type = (
181+ self .__call_type_handler (mypy_node .type ) or node .name_spec .sym_type
182+ )
248183
249184 elif isinstance (mypy_node , MypyNodes .MypyFile ):
250185 node .name_spec .sym_type = "types.ModuleType"
@@ -253,7 +188,10 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
253188 node .name_spec .sym_type = mypy_node .fullname
254189
255190 elif isinstance (mypy_node , MypyNodes .OverloadedFuncDef ):
256- self .__call_type_handler (node , mypy_node .items [0 ].func .type )
191+ node .name_spec .sym_type = (
192+ self .__call_type_handler (mypy_node .items [0 ].func .type )
193+ or node .name_spec .sym_type
194+ )
257195
258196 elif mypy_node is None :
259197 node .name_spec .sym_type = "None"
@@ -269,17 +207,67 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
269207 node .name_spec .sym_type = mypy_node .fullname
270208 self .__set_sym_table_link (node )
271209 elif isinstance (mypy_node , MypyNodes .FuncDef ):
272- self .__call_type_handler (node , mypy_node .type )
210+ node .name_spec .sym_type = (
211+ self .__call_type_handler (mypy_node .type ) or node .name_spec .sym_type
212+ )
273213 elif isinstance (mypy_node , MypyNodes .Argument ):
274- self .__call_type_handler (node , mypy_node .variable .type )
214+ node .name_spec .sym_type = (
215+ self .__call_type_handler (mypy_node .variable .type )
216+ or node .name_spec .sym_type
217+ )
275218 elif isinstance (mypy_node , MypyNodes .Decorator ):
276- self .__call_type_handler (node , mypy_node .func .type .ret_type )
219+ node .name_spec .sym_type = (
220+ self .__call_type_handler (mypy_node .func .type .ret_type )
221+ or node .name_spec .sym_type
222+ )
277223 else :
278224 self .__debug_print (
279225 f'"{ node .loc } ::{ node .__class__ .__name__ } " mypy node isn\' t supported' ,
280226 type (mypy_node ),
281227 )
282228
229+ # NOTE: Since expression nodes are not AstSymbolNodes, I'm not decorating this with __handle_node
230+ # and IMO instead of checking if it's a symbol node or an expression, we somehow mark expressions as
231+ # valid nodes that can have symbols. At this point I'm leaving this like this and lemme know
232+ # otherwise.
233+ def enter_expr (self : FuseTypeInfoPass , node : ast .Expr ) -> None :
234+ """
235+ Enter an expression node.
236+
237+ This function is dynamically bound as a method on insntace of this class, since the
238+ group of functions to handle expressions has a the exact same logic.
239+ """
240+ if len (node .gen .mypy_ast ) == 0 :
241+ return
242+
243+ # If the corrosponding mypy ast node type has stored here, get the values.
244+ mypy_node = node .gen .mypy_ast [0 ]
245+ if mypy_node in self .node_type_hash :
246+ mytype : MyType = self .node_type_hash [mypy_node ]
247+ node .expr_type = self .__call_type_handler (mytype ) or ""
248+
249+ # TODO: Maybe move this out of the function otherwise it'll construct this dict every time it entered an
250+ # expression. Time and memory wasted here.
251+ collection_types_map = {
252+ ast .ListVal : "builtins.list" ,
253+ ast .SetVal : "builtins.set" ,
254+ ast .TupleVal : "builtins.tuple" ,
255+ ast .DictVal : "builtins.dict" ,
256+ ast .ListCompr : None ,
257+ ast .DictCompr : None ,
258+ }
259+
260+ # Set they symbol type for collection expression.
261+ if type (node ) in tuple (collection_types_map .keys ()):
262+ assert isinstance (node , ast .AtomExpr ) # To make mypy happy.
263+ collection_type = collection_types_map [type (node )]
264+ if collection_type is not None :
265+ node .name_spec .sym_type = collection_type
266+ if mypy_node in self .node_type_hash :
267+ node .name_spec .sym_type = (
268+ self .__call_type_handler (mytype ) or node .name_spec .sym_type
269+ )
270+
283271 @__handle_node
284272 def enter_name (self , node : ast .NameAtom ) -> None :
285273 """Pass handler for name nodes."""
@@ -319,7 +307,10 @@ def enter_enum_def(self, node: ast.EnumDef) -> None:
319307 def enter_ability (self , node : ast .Ability ) -> None :
320308 """Pass handler for Ability nodes."""
321309 if isinstance (node .gen .mypy_ast [0 ], MypyNodes .FuncDef ):
322- self .__call_type_handler (node , node .gen .mypy_ast [0 ].type .ret_type )
310+ node .name_spec .sym_type = (
311+ self .__call_type_handler (node .gen .mypy_ast [0 ].type .ret_type )
312+ or node .name_spec .sym_type
313+ )
323314 else :
324315 self .__debug_print (
325316 f"{ node .loc } : Can't get type of an ability from mypy node other than Ability." ,
@@ -330,7 +321,10 @@ def enter_ability(self, node: ast.Ability) -> None:
330321 def enter_ability_def (self , node : ast .AbilityDef ) -> None :
331322 """Pass handler for AbilityDef nodes."""
332323 if isinstance (node .gen .mypy_ast [0 ], MypyNodes .FuncDef ):
333- self .__call_type_handler (node , node .gen .mypy_ast [0 ].type .ret_type )
324+ node .name_spec .sym_type = (
325+ self .__call_type_handler (node .gen .mypy_ast [0 ].type .ret_type )
326+ or node .name_spec .sym_type
327+ )
334328 else :
335329 self .__debug_print (
336330 f"{ node .loc } : Can't get type of an AbilityDef from mypy node other than FuncDef." ,
@@ -343,7 +337,10 @@ def enter_param_var(self, node: ast.ParamVar) -> None:
343337 if isinstance (node .gen .mypy_ast [0 ], MypyNodes .Argument ):
344338 mypy_node : MypyNodes .Argument = node .gen .mypy_ast [0 ]
345339 if mypy_node .variable .type :
346- self .__call_type_handler (node , mypy_node .variable .type )
340+ node .name_spec .sym_type = (
341+ self .__call_type_handler (mypy_node .variable .type )
342+ or node .name_spec .sym_type
343+ )
347344 else :
348345 self .__debug_print (
349346 f"{ node .loc } : Can't get parameter value from mypyNode other than Argument"
@@ -357,7 +354,9 @@ def enter_has_var(self, node: ast.HasVar) -> None:
357354 if isinstance (mypy_node , MypyNodes .AssignmentStmt ):
358355 n = mypy_node .lvalues [0 ].node
359356 if isinstance (n , (MypyNodes .Var , MypyNodes .FuncDef )):
360- self .__call_type_handler (node , n .type )
357+ node .name_spec .sym_type = (
358+ self .__call_type_handler (n .type ) or node .name_spec .sym_type
359+ )
361360 else :
362361 self .__debug_print (
363362 "Getting type of 'AssignmentStmt' is only supported with Var and FuncDef"
@@ -396,7 +395,9 @@ def enter_arch_ref(self, node: ast.ArchRef) -> None:
396395 self .__set_sym_table_link (node )
397396 elif isinstance (node .gen .mypy_ast [0 ], MypyNodes .FuncDef ):
398397 mypy_node2 : MypyNodes .FuncDef = node .gen .mypy_ast [0 ]
399- self .__call_type_handler (node , mypy_node2 .type )
398+ node .name_spec .sym_type = (
399+ self .__call_type_handler (mypy_node2 .type ) or node .name_spec .sym_type
400+ )
400401 else :
401402 self .__debug_print (
402403 f"{ node .loc } : Can't get ArchRef value from mypyNode other than ClassDef" ,
@@ -448,42 +449,34 @@ def enter_builtin_type(self, node: ast.BuiltinType) -> None:
448449 """Pass handler for BuiltinType nodes."""
449450 self .__collect_type_from_symbol (node )
450451
451- def get_type_from_instance (
452- self , node : ast .AstSymbolNode , mypy_type : MypyTypes .Instance
453- ) -> None :
452+ def get_type_from_instance (self , mypy_type : MypyTypes .Instance ) -> Optional [str ]:
454453 """Get type info from mypy type Instance."""
455- node . name_spec . sym_type = str (mypy_type )
454+ return str (mypy_type )
456455
457456 def get_type_from_callable_type (
458- self , node : ast . AstSymbolNode , mypy_type : MypyTypes .CallableType
459- ) -> None :
457+ self , mypy_type : MypyTypes .CallableType
458+ ) -> Optional [ str ] :
460459 """Get type info from mypy type CallableType."""
461- node . name_spec . sym_type = str (mypy_type .ret_type )
460+ return str (mypy_type .ret_type )
462461
463462 # TODO: Which overloaded function to get the return value from?
464463 def get_type_from_overloaded (
465- self , node : ast . AstSymbolNode , mypy_type : MypyTypes .Overloaded
466- ) -> None :
464+ self , mypy_type : MypyTypes .Overloaded
465+ ) -> Optional [ str ] :
467466 """Get type info from mypy type Overloaded."""
468- self .__call_type_handler (node , mypy_type .items [0 ])
467+ return self .__call_type_handler (mypy_type .items [0 ])
469468
470- def get_type_from_none_type (
471- self , node : ast .AstSymbolNode , mypy_type : MypyTypes .NoneType
472- ) -> None :
469+ def get_type_from_none_type (self , mypy_type : MypyTypes .NoneType ) -> Optional [str ]:
473470 """Get type info from mypy type NoneType."""
474- node . name_spec . sym_type = "None"
471+ return "None"
475472
476- def get_type_from_any_type (
477- self , node : ast .AstSymbolNode , mypy_type : MypyTypes .AnyType
478- ) -> None :
473+ def get_type_from_any_type (self , mypy_type : MypyTypes .AnyType ) -> Optional [str ]:
479474 """Get type info from mypy type NoneType."""
480- node . name_spec . sym_type = "Any"
475+ return "Any"
481476
482- def get_type_from_tuple_type (
483- self , node : ast .AstSymbolNode , mypy_type : MypyTypes .TupleType
484- ) -> None :
477+ def get_type_from_tuple_type (self , mypy_type : MypyTypes .TupleType ) -> Optional [str ]:
485478 """Get type info from mypy type TupleType."""
486- node . name_spec . sym_type = "builtins.tuple"
479+ return "builtins.tuple"
487480
488481 def exit_assignment (self , node : ast .Assignment ) -> None :
489482 """Add new symbols in the symbol table in case of self."""
0 commit comments