|
| 1 | +// Copyright 2020-2024 Buf Technologies, Inc. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +package options |
| 16 | + |
| 17 | +import ( |
| 18 | + "fmt" |
| 19 | + |
| 20 | + "google.golang.org/protobuf/proto" |
| 21 | + "google.golang.org/protobuf/reflect/protoreflect" |
| 22 | + "google.golang.org/protobuf/types/descriptorpb" |
| 23 | +) |
| 24 | + |
| 25 | +// StripSourceRetentionOptionsFromFile returns a file descriptor proto that omits any |
| 26 | +// options in file that are defined to be retained only in source. If file has no |
| 27 | +// such options, then it is returned as is. If it does have such options, a copy is |
| 28 | +// made; the given file will not be mutated. |
| 29 | +// |
| 30 | +// Even when a copy is returned, it is not a deep copy: it may share data with the |
| 31 | +// original file. So callers should not mutate the returned file unless mutating the |
| 32 | +// input file is also safe. |
| 33 | +func StripSourceRetentionOptionsFromFile(file *descriptorpb.FileDescriptorProto) (*descriptorpb.FileDescriptorProto, error) { |
| 34 | + var dirty bool |
| 35 | + newOpts, err := stripSourceRetentionOptions(file.GetOptions()) |
| 36 | + if err != nil { |
| 37 | + return nil, err |
| 38 | + } |
| 39 | + if newOpts != file.GetOptions() { |
| 40 | + dirty = true |
| 41 | + } |
| 42 | + newMsgs, changed, err := updateAll(file.GetMessageType(), stripSourceRetentionOptionsFromMessage) |
| 43 | + if err != nil { |
| 44 | + return nil, err |
| 45 | + } |
| 46 | + if changed { |
| 47 | + dirty = true |
| 48 | + } |
| 49 | + newEnums, changed, err := updateAll(file.GetEnumType(), stripSourceRetentionOptionsFromEnum) |
| 50 | + if err != nil { |
| 51 | + return nil, err |
| 52 | + } |
| 53 | + if changed { |
| 54 | + dirty = true |
| 55 | + } |
| 56 | + newExts, changed, err := updateAll(file.GetExtension(), stripSourceRetentionOptionsFromField) |
| 57 | + if err != nil { |
| 58 | + return nil, err |
| 59 | + } |
| 60 | + if changed { |
| 61 | + dirty = true |
| 62 | + } |
| 63 | + newSvcs, changed, err := updateAll(file.GetService(), stripSourceRetentionOptionsFromService) |
| 64 | + if err != nil { |
| 65 | + return nil, err |
| 66 | + } |
| 67 | + if changed { |
| 68 | + dirty = true |
| 69 | + } |
| 70 | + |
| 71 | + if !dirty { |
| 72 | + return file, nil |
| 73 | + } |
| 74 | + |
| 75 | + newFile, err := shallowCopy(file) |
| 76 | + if err != nil { |
| 77 | + return nil, err |
| 78 | + } |
| 79 | + newFile.Options = newOpts |
| 80 | + newFile.MessageType = newMsgs |
| 81 | + newFile.EnumType = newEnums |
| 82 | + newFile.Extension = newExts |
| 83 | + newFile.Service = newSvcs |
| 84 | + return newFile, nil |
| 85 | +} |
| 86 | + |
| 87 | +func stripSourceRetentionOptions[M proto.Message](options M) (M, error) { |
| 88 | + optionsRef := options.ProtoReflect() |
| 89 | + // See if there are any options to strip. |
| 90 | + var found bool |
| 91 | + var err error |
| 92 | + optionsRef.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool { |
| 93 | + fieldOpts, ok := field.Options().(*descriptorpb.FieldOptions) |
| 94 | + if !ok { |
| 95 | + err = fmt.Errorf("field options is unexpected type: got %T, want %T", field.Options(), fieldOpts) |
| 96 | + return false |
| 97 | + } |
| 98 | + if fieldOpts.GetRetention() == descriptorpb.FieldOptions_RETENTION_SOURCE { |
| 99 | + found = true |
| 100 | + return false |
| 101 | + } |
| 102 | + return true |
| 103 | + }) |
| 104 | + var zero M |
| 105 | + if err != nil { |
| 106 | + return zero, err |
| 107 | + } |
| 108 | + if !found { |
| 109 | + return options, nil |
| 110 | + } |
| 111 | + |
| 112 | + // There is at least one. So we need to make a copy that does not have those options. |
| 113 | + newOptions := optionsRef.New() |
| 114 | + ret, ok := newOptions.Interface().(M) |
| 115 | + if !ok { |
| 116 | + return zero, fmt.Errorf("creating new message of same type resulted in unexpected type; got %T, want %T", newOptions.Interface(), zero) |
| 117 | + } |
| 118 | + optionsRef.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool { |
| 119 | + fieldOpts, ok := field.Options().(*descriptorpb.FieldOptions) |
| 120 | + if !ok { |
| 121 | + err = fmt.Errorf("field options is unexpected type: got %T, want %T", field.Options(), fieldOpts) |
| 122 | + return false |
| 123 | + } |
| 124 | + if fieldOpts.GetRetention() != descriptorpb.FieldOptions_RETENTION_SOURCE { |
| 125 | + newOptions.Set(field, val) |
| 126 | + } |
| 127 | + return true |
| 128 | + }) |
| 129 | + if err != nil { |
| 130 | + return zero, err |
| 131 | + } |
| 132 | + return ret, nil |
| 133 | +} |
| 134 | + |
| 135 | +func stripSourceRetentionOptionsFromMessage(msg *descriptorpb.DescriptorProto) (*descriptorpb.DescriptorProto, error) { |
| 136 | + var dirty bool |
| 137 | + newOpts, err := stripSourceRetentionOptions(msg.Options) |
| 138 | + if err != nil { |
| 139 | + return nil, err |
| 140 | + } |
| 141 | + if newOpts != msg.Options { |
| 142 | + dirty = true |
| 143 | + } |
| 144 | + newFields, changed, err := updateAll(msg.Field, stripSourceRetentionOptionsFromField) |
| 145 | + if err != nil { |
| 146 | + return nil, err |
| 147 | + } |
| 148 | + if changed { |
| 149 | + dirty = true |
| 150 | + } |
| 151 | + newOneofs, changed, err := updateAll(msg.OneofDecl, stripSourceRetentionOptionsFromOneof) |
| 152 | + if err != nil { |
| 153 | + return nil, err |
| 154 | + } |
| 155 | + if changed { |
| 156 | + dirty = true |
| 157 | + } |
| 158 | + newExtRanges, changed, err := updateAll(msg.ExtensionRange, stripSourceRetentionOptionsFromExtensionRange) |
| 159 | + if err != nil { |
| 160 | + return nil, err |
| 161 | + } |
| 162 | + if changed { |
| 163 | + dirty = true |
| 164 | + } |
| 165 | + newMsgs, changed, err := updateAll(msg.NestedType, stripSourceRetentionOptionsFromMessage) |
| 166 | + if err != nil { |
| 167 | + return nil, err |
| 168 | + } |
| 169 | + if changed { |
| 170 | + dirty = true |
| 171 | + } |
| 172 | + newEnums, changed, err := updateAll(msg.EnumType, stripSourceRetentionOptionsFromEnum) |
| 173 | + if err != nil { |
| 174 | + return nil, err |
| 175 | + } |
| 176 | + if changed { |
| 177 | + dirty = true |
| 178 | + } |
| 179 | + newExts, changed, err := updateAll(msg.Extension, stripSourceRetentionOptionsFromField) |
| 180 | + if err != nil { |
| 181 | + return nil, err |
| 182 | + } |
| 183 | + if changed { |
| 184 | + dirty = true |
| 185 | + } |
| 186 | + |
| 187 | + if !dirty { |
| 188 | + return msg, nil |
| 189 | + } |
| 190 | + |
| 191 | + newMsg, err := shallowCopy(msg) |
| 192 | + if err != nil { |
| 193 | + return nil, err |
| 194 | + } |
| 195 | + newMsg.Options = newOpts |
| 196 | + newMsg.Field = newFields |
| 197 | + newMsg.OneofDecl = newOneofs |
| 198 | + newMsg.ExtensionRange = newExtRanges |
| 199 | + newMsg.NestedType = newMsgs |
| 200 | + newMsg.EnumType = newEnums |
| 201 | + newMsg.Extension = newExts |
| 202 | + return newMsg, nil |
| 203 | +} |
| 204 | + |
| 205 | +func stripSourceRetentionOptionsFromField(field *descriptorpb.FieldDescriptorProto) (*descriptorpb.FieldDescriptorProto, error) { |
| 206 | + newOpts, err := stripSourceRetentionOptions(field.Options) |
| 207 | + if err != nil { |
| 208 | + return nil, err |
| 209 | + } |
| 210 | + if newOpts == field.Options { |
| 211 | + return field, nil |
| 212 | + } |
| 213 | + newField, err := shallowCopy(field) |
| 214 | + if err != nil { |
| 215 | + return nil, err |
| 216 | + } |
| 217 | + newField.Options = newOpts |
| 218 | + return newField, nil |
| 219 | +} |
| 220 | + |
| 221 | +func stripSourceRetentionOptionsFromOneof(oneof *descriptorpb.OneofDescriptorProto) (*descriptorpb.OneofDescriptorProto, error) { |
| 222 | + newOpts, err := stripSourceRetentionOptions(oneof.Options) |
| 223 | + if err != nil { |
| 224 | + return nil, err |
| 225 | + } |
| 226 | + if newOpts == oneof.Options { |
| 227 | + return oneof, nil |
| 228 | + } |
| 229 | + newOneof, err := shallowCopy(oneof) |
| 230 | + if err != nil { |
| 231 | + return nil, err |
| 232 | + } |
| 233 | + newOneof.Options = newOpts |
| 234 | + return newOneof, nil |
| 235 | +} |
| 236 | + |
| 237 | +func stripSourceRetentionOptionsFromExtensionRange(extRange *descriptorpb.DescriptorProto_ExtensionRange) (*descriptorpb.DescriptorProto_ExtensionRange, error) { |
| 238 | + newOpts, err := stripSourceRetentionOptions(extRange.Options) |
| 239 | + if err != nil { |
| 240 | + return nil, err |
| 241 | + } |
| 242 | + if newOpts == extRange.Options { |
| 243 | + return extRange, nil |
| 244 | + } |
| 245 | + newExtRange, err := shallowCopy(extRange) |
| 246 | + if err != nil { |
| 247 | + return nil, err |
| 248 | + } |
| 249 | + newExtRange.Options = newOpts |
| 250 | + return newExtRange, nil |
| 251 | +} |
| 252 | + |
| 253 | +func stripSourceRetentionOptionsFromEnum(enum *descriptorpb.EnumDescriptorProto) (*descriptorpb.EnumDescriptorProto, error) { |
| 254 | + var dirty bool |
| 255 | + newOpts, err := stripSourceRetentionOptions(enum.Options) |
| 256 | + if err != nil { |
| 257 | + return nil, err |
| 258 | + } |
| 259 | + if newOpts != enum.Options { |
| 260 | + dirty = true |
| 261 | + } |
| 262 | + newVals, changed, err := updateAll(enum.Value, stripSourceRetentionOptionsFromEnumValue) |
| 263 | + if err != nil { |
| 264 | + return nil, err |
| 265 | + } |
| 266 | + if changed { |
| 267 | + dirty = true |
| 268 | + } |
| 269 | + |
| 270 | + if !dirty { |
| 271 | + return enum, nil |
| 272 | + } |
| 273 | + |
| 274 | + newEnum, err := shallowCopy(enum) |
| 275 | + if err != nil { |
| 276 | + return nil, err |
| 277 | + } |
| 278 | + newEnum.Options = newOpts |
| 279 | + newEnum.Value = newVals |
| 280 | + return newEnum, nil |
| 281 | +} |
| 282 | + |
| 283 | +func stripSourceRetentionOptionsFromEnumValue(enumVal *descriptorpb.EnumValueDescriptorProto) (*descriptorpb.EnumValueDescriptorProto, error) { |
| 284 | + newOpts, err := stripSourceRetentionOptions(enumVal.Options) |
| 285 | + if err != nil { |
| 286 | + return nil, err |
| 287 | + } |
| 288 | + if newOpts == enumVal.Options { |
| 289 | + return enumVal, nil |
| 290 | + } |
| 291 | + newEnumVal, err := shallowCopy(enumVal) |
| 292 | + if err != nil { |
| 293 | + return nil, err |
| 294 | + } |
| 295 | + newEnumVal.Options = newOpts |
| 296 | + return newEnumVal, nil |
| 297 | +} |
| 298 | + |
| 299 | +func stripSourceRetentionOptionsFromService(svc *descriptorpb.ServiceDescriptorProto) (*descriptorpb.ServiceDescriptorProto, error) { |
| 300 | + var dirty bool |
| 301 | + newOpts, err := stripSourceRetentionOptions(svc.Options) |
| 302 | + if err != nil { |
| 303 | + return nil, err |
| 304 | + } |
| 305 | + if newOpts != svc.Options { |
| 306 | + dirty = true |
| 307 | + } |
| 308 | + newMethods, changed, err := updateAll(svc.Method, stripSourceRetentionOptionsFromMethod) |
| 309 | + if err != nil { |
| 310 | + return nil, err |
| 311 | + } |
| 312 | + if changed { |
| 313 | + dirty = true |
| 314 | + } |
| 315 | + |
| 316 | + if !dirty { |
| 317 | + return svc, nil |
| 318 | + } |
| 319 | + |
| 320 | + newSvc, err := shallowCopy(svc) |
| 321 | + if err != nil { |
| 322 | + return nil, err |
| 323 | + } |
| 324 | + newSvc.Options = newOpts |
| 325 | + newSvc.Method = newMethods |
| 326 | + return newSvc, nil |
| 327 | +} |
| 328 | + |
| 329 | +func stripSourceRetentionOptionsFromMethod(method *descriptorpb.MethodDescriptorProto) (*descriptorpb.MethodDescriptorProto, error) { |
| 330 | + newOpts, err := stripSourceRetentionOptions(method.Options) |
| 331 | + if err != nil { |
| 332 | + return nil, err |
| 333 | + } |
| 334 | + if newOpts == method.Options { |
| 335 | + return method, nil |
| 336 | + } |
| 337 | + newMethod, err := shallowCopy(method) |
| 338 | + if err != nil { |
| 339 | + return nil, err |
| 340 | + } |
| 341 | + newMethod.Options = newOpts |
| 342 | + return newMethod, nil |
| 343 | +} |
| 344 | + |
| 345 | +func shallowCopy[M proto.Message](msg M) (M, error) { |
| 346 | + msgRef := msg.ProtoReflect() |
| 347 | + other := msgRef.New() |
| 348 | + ret, ok := other.Interface().(M) |
| 349 | + if !ok { |
| 350 | + return ret, fmt.Errorf("creating new message of same type resulted in unexpected type; got %T, want %T", other.Interface(), ret) |
| 351 | + } |
| 352 | + msgRef.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool { |
| 353 | + other.Set(field, val) |
| 354 | + return true |
| 355 | + }) |
| 356 | + return ret, nil |
| 357 | +} |
| 358 | + |
| 359 | +// updateAll applies the given function to each element in the given slice. It |
| 360 | +// returns the new slice and a bool indicating whether anything was actually |
| 361 | +// changed. If the second value is false, then the returned slice is the same |
| 362 | +// slice as the input slice. Usually, T is a pointer type, in which case the |
| 363 | +// given updateFunc should NOT mutate the input value. Instead, it should return |
| 364 | +// the input value if only if there is no update needed. If a mutation is needed, |
| 365 | +// it should return a new value. |
| 366 | +func updateAll[T comparable](slice []T, updateFunc func(T) (T, error)) ([]T, bool, error) { |
| 367 | + var updated []T // initialized lazily, only when/if a copy is needed |
| 368 | + for i, item := range slice { |
| 369 | + newItem, err := updateFunc(item) |
| 370 | + if err != nil { |
| 371 | + return nil, false, err |
| 372 | + } |
| 373 | + if updated != nil { |
| 374 | + updated[i] = newItem |
| 375 | + } else if newItem != item { |
| 376 | + updated = make([]T, len(slice)) |
| 377 | + copy(updated[:i], slice) |
| 378 | + updated[i] = newItem |
| 379 | + } |
| 380 | + } |
| 381 | + if updated != nil { |
| 382 | + return updated, true, nil |
| 383 | + } |
| 384 | + return slice, false, nil |
| 385 | +} |
0 commit comments