@@ -103,26 +103,70 @@ def rewrite(
103103 flat_json = json .load (f )
104104
105105 # Checks if signature_defs are recorded in tflite
106+ enable_signature_defs = True
106107 if 'signature_defs' not in flat_json or not flat_json ['signature_defs' ]:
108+ print ('' )
107109 print (
108110 f'{ Color .YELLOW } WARNING:{ Color .RESET } ' +
109- f'Processing is aborted because signature_defs is not recorded in tflite.'
111+ f'signature_defs is not recorded in tflite.'
110112 )
111- sys . exit ( 0 )
113+ enable_signature_defs = False
112114
113115 flat_subgraphs = flat_json ['subgraphs' ][0 ]
114116 flat_tensors : List [Dict ] = flat_subgraphs ['tensors' ]
115- flat_signature_def : Dict = flat_json ['signature_defs' ][0 ]
116- flat_signature_def_inputs : List [Dict ] = flat_signature_def ['inputs' ]
117- flat_signature_def_inputs_names = [
118- flat_signature_def_input ['name' ] \
119- for flat_signature_def_input in flat_signature_def_inputs
120- ]
121- flat_signature_def_outputs : List [Dict ] = flat_signature_def ['outputs' ]
122- flat_signature_def_outputs_names = [
123- flat_signature_def_output ['name' ] \
124- for flat_signature_def_output in flat_signature_def_outputs
125- ]
117+ flat_signature_def : Dict = {}
118+ flat_signature_def_inputs_names = []
119+ flat_signature_def_outputs_names = []
120+
121+ if enable_signature_defs :
122+ # Get the output name to be used for the replacement name from signature_defs
123+ # if signature_defs is already defined
124+ flat_signature_def = flat_json ['signature_defs' ][0 ]
125+ flat_signature_def_inputs : List [Dict ] = flat_signature_def ['inputs' ]
126+ flat_signature_def_inputs_names = [
127+ flat_signature_def_input ['name' ] \
128+ for flat_signature_def_input in flat_signature_def_inputs
129+ ]
130+ flat_signature_def_outputs : List [Dict ] = flat_signature_def ['outputs' ]
131+ flat_signature_def_outputs_names = [
132+ flat_signature_def_output ['name' ] \
133+ for flat_signature_def_output in flat_signature_def_outputs
134+ ]
135+ else :
136+ # Generate from tensors if signature_defs is undefined
137+ flat_json ['signature_defs' ] = []
138+ flat_subgraphs_inputs : List [int ] = flat_subgraphs ['inputs' ]
139+ flat_subgraphs_outputs : List [int ] = flat_subgraphs ['outputs' ]
140+ signature_def_inputs = []
141+ signature_def_outputs = []
142+ # inputs
143+ for idx in flat_subgraphs_inputs :
144+ for flat_tensor in flat_tensors :
145+ if int (flat_tensor ['buffer' ]) == idx + 1 :
146+ signature_def_inputs .append (
147+ {'name' : flat_tensor ['name' ], 'tensor_index' : idx }
148+ )
149+ flat_signature_def_inputs_names .append (flat_tensor ['name' ])
150+ break
151+ # outputs
152+ for idx in flat_subgraphs_outputs :
153+ for flat_tensor in flat_tensors :
154+ if int (flat_tensor ['buffer' ]) == idx + 1 :
155+ signature_def_outputs .append (
156+ {'name' : flat_tensor ['name' ], 'tensor_index' : idx }
157+ )
158+ flat_signature_def_outputs_names .append (flat_tensor ['name' ])
159+ break
160+
161+ signature_def = {
162+ "inputs" : signature_def_inputs ,
163+ "outputs" : signature_def_outputs ,
164+ "signature_key" : "serving_default" ,
165+ "subgraph_index" : 0 ,
166+ }
167+ flat_json ['signature_defs' ].append (signature_def )
168+ flat_signature_def_inputs : List [Dict ] = signature_def ['inputs' ]
169+ flat_signature_def_outputs : List [Dict ] = signature_def ['outputs' ]
126170
127171 # If the signature of the input OP and the signature of the output OP overlap,
128172 # rename the signature of the output OP.
0 commit comments