Skip to content

Commit d34bda4

Browse files
committed
if signature_defs is empty, it does not generate an error, but automatically generates signature_defs based on the input/output information or the input name of the --rename
1 parent 9dbf608 commit d34bda4

File tree

1 file changed

+57
-13
lines changed

1 file changed

+57
-13
lines changed

tfliteiorewriter/main.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -103,26 +103,70 @@ def rewrite(
103103
flat_json = json.load(f)
104104

105105
# Checks if signature_defs are recorded in tflite
106+
enable_signature_defs = True
106107
if 'signature_defs' not in flat_json or not flat_json['signature_defs']:
108+
print('')
107109
print(
108110
f'{Color.YELLOW}WARNING:{Color.RESET} ' +
109-
f'Processing is aborted because signature_defs is not recorded in tflite.'
111+
f'signature_defs is not recorded in tflite.'
110112
)
111-
sys.exit(0)
113+
enable_signature_defs = False
112114

113115
flat_subgraphs = flat_json['subgraphs'][0]
114116
flat_tensors: List[Dict] = flat_subgraphs['tensors']
115-
flat_signature_def: Dict = flat_json['signature_defs'][0]
116-
flat_signature_def_inputs: List[Dict] = flat_signature_def['inputs']
117-
flat_signature_def_inputs_names = [
118-
flat_signature_def_input['name'] \
119-
for flat_signature_def_input in flat_signature_def_inputs
120-
]
121-
flat_signature_def_outputs: List[Dict] = flat_signature_def['outputs']
122-
flat_signature_def_outputs_names = [
123-
flat_signature_def_output['name'] \
124-
for flat_signature_def_output in flat_signature_def_outputs
125-
]
117+
flat_signature_def: Dict = {}
118+
flat_signature_def_inputs_names = []
119+
flat_signature_def_outputs_names = []
120+
121+
if enable_signature_defs:
122+
# Get the output name to be used for the replacement name from signature_defs
123+
# if signature_defs is already defined
124+
flat_signature_def = flat_json['signature_defs'][0]
125+
flat_signature_def_inputs: List[Dict] = flat_signature_def['inputs']
126+
flat_signature_def_inputs_names = [
127+
flat_signature_def_input['name'] \
128+
for flat_signature_def_input in flat_signature_def_inputs
129+
]
130+
flat_signature_def_outputs: List[Dict] = flat_signature_def['outputs']
131+
flat_signature_def_outputs_names = [
132+
flat_signature_def_output['name'] \
133+
for flat_signature_def_output in flat_signature_def_outputs
134+
]
135+
else:
136+
# Generate from tensors if signature_defs is undefined
137+
flat_json['signature_defs'] = []
138+
flat_subgraphs_inputs: List[int] = flat_subgraphs['inputs']
139+
flat_subgraphs_outputs: List[int] = flat_subgraphs['outputs']
140+
signature_def_inputs = []
141+
signature_def_outputs = []
142+
# inputs
143+
for idx in flat_subgraphs_inputs:
144+
for flat_tensor in flat_tensors:
145+
if int(flat_tensor['buffer']) == idx + 1:
146+
signature_def_inputs.append(
147+
{'name': flat_tensor['name'], 'tensor_index': idx}
148+
)
149+
flat_signature_def_inputs_names.append(flat_tensor['name'])
150+
break
151+
# outputs
152+
for idx in flat_subgraphs_outputs:
153+
for flat_tensor in flat_tensors:
154+
if int(flat_tensor['buffer']) == idx + 1:
155+
signature_def_outputs.append(
156+
{'name': flat_tensor['name'], 'tensor_index': idx}
157+
)
158+
flat_signature_def_outputs_names.append(flat_tensor['name'])
159+
break
160+
161+
signature_def = {
162+
"inputs": signature_def_inputs,
163+
"outputs": signature_def_outputs,
164+
"signature_key": "serving_default",
165+
"subgraph_index": 0,
166+
}
167+
flat_json['signature_defs'].append(signature_def)
168+
flat_signature_def_inputs: List[Dict] = signature_def['inputs']
169+
flat_signature_def_outputs: List[Dict] = signature_def['outputs']
126170

127171
# If the signature of the input OP and the signature of the output OP overlap,
128172
# rename the signature of the output OP.

0 commit comments

Comments
 (0)