Skip to content

Commit 8727154

Browse files
authored
Add function for stripping source retention options from a descriptor (#250)
This adds a helper function to the `options` sub-package that can strip "source retention" options from a descriptor. These are options that should only be retained in the descriptor in source form -- like when manipulated by a compiler or code generator -- and should not be present at runtime. Stripping these options results in a descriptor that could safely be embedded in generated code.
1 parent f62a9f6 commit 8727154

File tree

2 files changed

+848
-0
lines changed

2 files changed

+848
-0
lines changed
Lines changed: 385 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,385 @@
1+
// Copyright 2020-2024 Buf Technologies, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package options
16+
17+
import (
18+
"fmt"
19+
20+
"google.golang.org/protobuf/proto"
21+
"google.golang.org/protobuf/reflect/protoreflect"
22+
"google.golang.org/protobuf/types/descriptorpb"
23+
)
24+
25+
// StripSourceRetentionOptionsFromFile returns a file descriptor proto that omits any
26+
// options in file that are defined to be retained only in source. If file has no
27+
// such options, then it is returned as is. If it does have such options, a copy is
28+
// made; the given file will not be mutated.
29+
//
30+
// Even when a copy is returned, it is not a deep copy: it may share data with the
31+
// original file. So callers should not mutate the returned file unless mutating the
32+
// input file is also safe.
33+
func StripSourceRetentionOptionsFromFile(file *descriptorpb.FileDescriptorProto) (*descriptorpb.FileDescriptorProto, error) {
34+
var dirty bool
35+
newOpts, err := stripSourceRetentionOptions(file.GetOptions())
36+
if err != nil {
37+
return nil, err
38+
}
39+
if newOpts != file.GetOptions() {
40+
dirty = true
41+
}
42+
newMsgs, changed, err := updateAll(file.GetMessageType(), stripSourceRetentionOptionsFromMessage)
43+
if err != nil {
44+
return nil, err
45+
}
46+
if changed {
47+
dirty = true
48+
}
49+
newEnums, changed, err := updateAll(file.GetEnumType(), stripSourceRetentionOptionsFromEnum)
50+
if err != nil {
51+
return nil, err
52+
}
53+
if changed {
54+
dirty = true
55+
}
56+
newExts, changed, err := updateAll(file.GetExtension(), stripSourceRetentionOptionsFromField)
57+
if err != nil {
58+
return nil, err
59+
}
60+
if changed {
61+
dirty = true
62+
}
63+
newSvcs, changed, err := updateAll(file.GetService(), stripSourceRetentionOptionsFromService)
64+
if err != nil {
65+
return nil, err
66+
}
67+
if changed {
68+
dirty = true
69+
}
70+
71+
if !dirty {
72+
return file, nil
73+
}
74+
75+
newFile, err := shallowCopy(file)
76+
if err != nil {
77+
return nil, err
78+
}
79+
newFile.Options = newOpts
80+
newFile.MessageType = newMsgs
81+
newFile.EnumType = newEnums
82+
newFile.Extension = newExts
83+
newFile.Service = newSvcs
84+
return newFile, nil
85+
}
86+
87+
func stripSourceRetentionOptions[M proto.Message](options M) (M, error) {
88+
optionsRef := options.ProtoReflect()
89+
// See if there are any options to strip.
90+
var found bool
91+
var err error
92+
optionsRef.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool {
93+
fieldOpts, ok := field.Options().(*descriptorpb.FieldOptions)
94+
if !ok {
95+
err = fmt.Errorf("field options is unexpected type: got %T, want %T", field.Options(), fieldOpts)
96+
return false
97+
}
98+
if fieldOpts.GetRetention() == descriptorpb.FieldOptions_RETENTION_SOURCE {
99+
found = true
100+
return false
101+
}
102+
return true
103+
})
104+
var zero M
105+
if err != nil {
106+
return zero, err
107+
}
108+
if !found {
109+
return options, nil
110+
}
111+
112+
// There is at least one. So we need to make a copy that does not have those options.
113+
newOptions := optionsRef.New()
114+
ret, ok := newOptions.Interface().(M)
115+
if !ok {
116+
return zero, fmt.Errorf("creating new message of same type resulted in unexpected type; got %T, want %T", newOptions.Interface(), zero)
117+
}
118+
optionsRef.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool {
119+
fieldOpts, ok := field.Options().(*descriptorpb.FieldOptions)
120+
if !ok {
121+
err = fmt.Errorf("field options is unexpected type: got %T, want %T", field.Options(), fieldOpts)
122+
return false
123+
}
124+
if fieldOpts.GetRetention() != descriptorpb.FieldOptions_RETENTION_SOURCE {
125+
newOptions.Set(field, val)
126+
}
127+
return true
128+
})
129+
if err != nil {
130+
return zero, err
131+
}
132+
return ret, nil
133+
}
134+
135+
func stripSourceRetentionOptionsFromMessage(msg *descriptorpb.DescriptorProto) (*descriptorpb.DescriptorProto, error) {
136+
var dirty bool
137+
newOpts, err := stripSourceRetentionOptions(msg.Options)
138+
if err != nil {
139+
return nil, err
140+
}
141+
if newOpts != msg.Options {
142+
dirty = true
143+
}
144+
newFields, changed, err := updateAll(msg.Field, stripSourceRetentionOptionsFromField)
145+
if err != nil {
146+
return nil, err
147+
}
148+
if changed {
149+
dirty = true
150+
}
151+
newOneofs, changed, err := updateAll(msg.OneofDecl, stripSourceRetentionOptionsFromOneof)
152+
if err != nil {
153+
return nil, err
154+
}
155+
if changed {
156+
dirty = true
157+
}
158+
newExtRanges, changed, err := updateAll(msg.ExtensionRange, stripSourceRetentionOptionsFromExtensionRange)
159+
if err != nil {
160+
return nil, err
161+
}
162+
if changed {
163+
dirty = true
164+
}
165+
newMsgs, changed, err := updateAll(msg.NestedType, stripSourceRetentionOptionsFromMessage)
166+
if err != nil {
167+
return nil, err
168+
}
169+
if changed {
170+
dirty = true
171+
}
172+
newEnums, changed, err := updateAll(msg.EnumType, stripSourceRetentionOptionsFromEnum)
173+
if err != nil {
174+
return nil, err
175+
}
176+
if changed {
177+
dirty = true
178+
}
179+
newExts, changed, err := updateAll(msg.Extension, stripSourceRetentionOptionsFromField)
180+
if err != nil {
181+
return nil, err
182+
}
183+
if changed {
184+
dirty = true
185+
}
186+
187+
if !dirty {
188+
return msg, nil
189+
}
190+
191+
newMsg, err := shallowCopy(msg)
192+
if err != nil {
193+
return nil, err
194+
}
195+
newMsg.Options = newOpts
196+
newMsg.Field = newFields
197+
newMsg.OneofDecl = newOneofs
198+
newMsg.ExtensionRange = newExtRanges
199+
newMsg.NestedType = newMsgs
200+
newMsg.EnumType = newEnums
201+
newMsg.Extension = newExts
202+
return newMsg, nil
203+
}
204+
205+
func stripSourceRetentionOptionsFromField(field *descriptorpb.FieldDescriptorProto) (*descriptorpb.FieldDescriptorProto, error) {
206+
newOpts, err := stripSourceRetentionOptions(field.Options)
207+
if err != nil {
208+
return nil, err
209+
}
210+
if newOpts == field.Options {
211+
return field, nil
212+
}
213+
newField, err := shallowCopy(field)
214+
if err != nil {
215+
return nil, err
216+
}
217+
newField.Options = newOpts
218+
return newField, nil
219+
}
220+
221+
func stripSourceRetentionOptionsFromOneof(oneof *descriptorpb.OneofDescriptorProto) (*descriptorpb.OneofDescriptorProto, error) {
222+
newOpts, err := stripSourceRetentionOptions(oneof.Options)
223+
if err != nil {
224+
return nil, err
225+
}
226+
if newOpts == oneof.Options {
227+
return oneof, nil
228+
}
229+
newOneof, err := shallowCopy(oneof)
230+
if err != nil {
231+
return nil, err
232+
}
233+
newOneof.Options = newOpts
234+
return newOneof, nil
235+
}
236+
237+
func stripSourceRetentionOptionsFromExtensionRange(extRange *descriptorpb.DescriptorProto_ExtensionRange) (*descriptorpb.DescriptorProto_ExtensionRange, error) {
238+
newOpts, err := stripSourceRetentionOptions(extRange.Options)
239+
if err != nil {
240+
return nil, err
241+
}
242+
if newOpts == extRange.Options {
243+
return extRange, nil
244+
}
245+
newExtRange, err := shallowCopy(extRange)
246+
if err != nil {
247+
return nil, err
248+
}
249+
newExtRange.Options = newOpts
250+
return newExtRange, nil
251+
}
252+
253+
func stripSourceRetentionOptionsFromEnum(enum *descriptorpb.EnumDescriptorProto) (*descriptorpb.EnumDescriptorProto, error) {
254+
var dirty bool
255+
newOpts, err := stripSourceRetentionOptions(enum.Options)
256+
if err != nil {
257+
return nil, err
258+
}
259+
if newOpts != enum.Options {
260+
dirty = true
261+
}
262+
newVals, changed, err := updateAll(enum.Value, stripSourceRetentionOptionsFromEnumValue)
263+
if err != nil {
264+
return nil, err
265+
}
266+
if changed {
267+
dirty = true
268+
}
269+
270+
if !dirty {
271+
return enum, nil
272+
}
273+
274+
newEnum, err := shallowCopy(enum)
275+
if err != nil {
276+
return nil, err
277+
}
278+
newEnum.Options = newOpts
279+
newEnum.Value = newVals
280+
return newEnum, nil
281+
}
282+
283+
func stripSourceRetentionOptionsFromEnumValue(enumVal *descriptorpb.EnumValueDescriptorProto) (*descriptorpb.EnumValueDescriptorProto, error) {
284+
newOpts, err := stripSourceRetentionOptions(enumVal.Options)
285+
if err != nil {
286+
return nil, err
287+
}
288+
if newOpts == enumVal.Options {
289+
return enumVal, nil
290+
}
291+
newEnumVal, err := shallowCopy(enumVal)
292+
if err != nil {
293+
return nil, err
294+
}
295+
newEnumVal.Options = newOpts
296+
return newEnumVal, nil
297+
}
298+
299+
func stripSourceRetentionOptionsFromService(svc *descriptorpb.ServiceDescriptorProto) (*descriptorpb.ServiceDescriptorProto, error) {
300+
var dirty bool
301+
newOpts, err := stripSourceRetentionOptions(svc.Options)
302+
if err != nil {
303+
return nil, err
304+
}
305+
if newOpts != svc.Options {
306+
dirty = true
307+
}
308+
newMethods, changed, err := updateAll(svc.Method, stripSourceRetentionOptionsFromMethod)
309+
if err != nil {
310+
return nil, err
311+
}
312+
if changed {
313+
dirty = true
314+
}
315+
316+
if !dirty {
317+
return svc, nil
318+
}
319+
320+
newSvc, err := shallowCopy(svc)
321+
if err != nil {
322+
return nil, err
323+
}
324+
newSvc.Options = newOpts
325+
newSvc.Method = newMethods
326+
return newSvc, nil
327+
}
328+
329+
func stripSourceRetentionOptionsFromMethod(method *descriptorpb.MethodDescriptorProto) (*descriptorpb.MethodDescriptorProto, error) {
330+
newOpts, err := stripSourceRetentionOptions(method.Options)
331+
if err != nil {
332+
return nil, err
333+
}
334+
if newOpts == method.Options {
335+
return method, nil
336+
}
337+
newMethod, err := shallowCopy(method)
338+
if err != nil {
339+
return nil, err
340+
}
341+
newMethod.Options = newOpts
342+
return newMethod, nil
343+
}
344+
345+
func shallowCopy[M proto.Message](msg M) (M, error) {
346+
msgRef := msg.ProtoReflect()
347+
other := msgRef.New()
348+
ret, ok := other.Interface().(M)
349+
if !ok {
350+
return ret, fmt.Errorf("creating new message of same type resulted in unexpected type; got %T, want %T", other.Interface(), ret)
351+
}
352+
msgRef.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool {
353+
other.Set(field, val)
354+
return true
355+
})
356+
return ret, nil
357+
}
358+
359+
// updateAll applies the given function to each element in the given slice. It
360+
// returns the new slice and a bool indicating whether anything was actually
361+
// changed. If the second value is false, then the returned slice is the same
362+
// slice as the input slice. Usually, T is a pointer type, in which case the
363+
// given updateFunc should NOT mutate the input value. Instead, it should return
364+
// the input value if only if there is no update needed. If a mutation is needed,
365+
// it should return a new value.
366+
func updateAll[T comparable](slice []T, updateFunc func(T) (T, error)) ([]T, bool, error) {
367+
var updated []T // initialized lazily, only when/if a copy is needed
368+
for i, item := range slice {
369+
newItem, err := updateFunc(item)
370+
if err != nil {
371+
return nil, false, err
372+
}
373+
if updated != nil {
374+
updated[i] = newItem
375+
} else if newItem != item {
376+
updated = make([]T, len(slice))
377+
copy(updated[:i], slice)
378+
updated[i] = newItem
379+
}
380+
}
381+
if updated != nil {
382+
return updated, true, nil
383+
}
384+
return slice, false, nil
385+
}

0 commit comments

Comments
 (0)