@@ -241,7 +241,11 @@ def _rslv_case_expr(switch_attr, switch_type, expr):
241241 )
242242 if isinstance (expr , list ):
243243 assert switch_type [0 ] == "custom" , "Unknown enum type ! %s" % repr (expr )
244- identifier = switch_type [1 ].name
244+ if isinstance (switch_type [1 ], str ):
245+ # implicit
246+ identifier = switch_type [1 ]
247+ else :
248+ identifier = switch_type [1 ].name
245249 vals = [
246250 "%s.%s"
247251 % (
@@ -690,11 +694,14 @@ def __init__(self, name, ptr_lvl, fields, idl_attributes, struct_name):
690694
691695 def match_struct_attributes (self , fields ):
692696 """
693- List all `length_is` and `size_is` in fields and mirror into `length_of` and `count_of`
697+ List all `length_is` and `size_is` in fields and mirror into `length_of` and `count_of`.
698+ Also handles implicit switch_type of union fields
694699 """
695700 if not isinstance (fields , list ): # ScapyEnum
696701 return fields
697702 mapped_fields = {x .name : x for x in fields }
703+
704+ # Handle length_is/size_is
698705 for fld in (
699706 x
700707 for x in fields
@@ -725,6 +732,26 @@ def _procfld(mainfld, length_or_size_is):
725732 # if length and size point to the same, prioritize length
726733 if sizefld and (not lengthfld or lengthfld != sizefld ):
727734 _procfld (sizefld , size_is )
735+
736+ # handle switch_is
737+ for fld in (
738+ x for x in fields if any (y [0 ] == "switch_is" for y in x .idl_attributes )
739+ ):
740+ # Field might have an implicit switch_type, in which case we shall add it.
741+ switch_is = _lkp (fld .idl_attributes , "switch_is" )[0 ]
742+ switch_type = _lkp (fld .idl_attributes , "switch_type" )
743+
744+ if not switch_type and isinstance (switch_is , str ):
745+ # inline: add implicit switch_type for processing
746+ reffld = mapped_fields [switch_is ]
747+ # This only makes sense if the subtype is an enum
748+ if isinstance (reffld , ScapyStructField ) and isinstance (
749+ reffld .subtype , ScapyEnum
750+ ):
751+ fld .idl_attributes .append (
752+ ("switch_type" , ("custom" , reffld .field_type ))
753+ )
754+
728755 return fields
729756
730757 def __repr__ (self ):
0 commit comments