Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 91 additions & 4 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,42 @@ const (
PanicOnError
)

// UnknownFlagsHandling decides how to handle unknown flags
type UnknownFlagsHandling int

const (
// UnknownFlagsHandlingErrorOnUnknown will return an error if an unknown flag is found
UnknownFlagsHandlingErrorOnUnknown UnknownFlagsHandling = iota
// UnknownFlagsHandlingIgnoreUnknown will ignore unknown flags and continue parsing rest of the flags
UnknownFlagsHandlingIgnoreUnknown
// UnknownFlagsHandlingPassUnknownToArgs will treat unknown flags as non-flag arguments.
// Combined shorthand flags mixed with known ones and unknown ones results
// combined flags only with unknown ones.
// E.g. -fghi results -gh if only `f` and `i` are known.
UnknownFlagsHandlingPassUnknownToArgs
Comment on lines +144 to +152
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reeeally long names...
I worry names like ErrorOnUnknown would cause name conflict.

)

// ParseErrorsAllowlist defines the parsing errors that can be ignored
type ParseErrorsAllowlist struct {
// UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags
// Deprecated: Use UnknownFlagsHandling instead
UnknownFlags bool

// UnknownFlagsHandling decides how to handle unknown flags. Defaults to UnknownFlagsHandlingErrorOnUnknown.
UnknownFlagsHandling UnknownFlagsHandling
}

// getUnknownFlagsHandling returns the UnknownFlagsHandling value, considering deprecated UnknownFlags field
func (a *ParseErrorsAllowlist) getUnknownFlagsHandling() UnknownFlagsHandling {
// if UnknownFlagsHandling is set, use it
if a.UnknownFlagsHandling != UnknownFlagsHandlingErrorOnUnknown {
return a.UnknownFlagsHandling
}

if a.UnknownFlags {
return UnknownFlagsHandlingIgnoreUnknown
}
return UnknownFlagsHandlingErrorOnUnknown
}

// NormalizedName is a flag name that has been normalized according to rules
Expand Down Expand Up @@ -967,6 +999,17 @@ func stripUnknownFlagValue(args []string) []string {
return nil
}

// errUnknownFlag is used for internal unknown flag handling.
type unknownFlagError struct {
// UnknownFlags is flags that are unknown and unprocessed.
// It depends on the context whether this has a prefix like '-' or '--'.
UnknownFlags string
}

func (e *unknownFlagError) Error() string {
return fmt.Sprintf("unknown flag: %v", e.UnknownFlags)
}

func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []string, err error) {
a = args
name := s[2:]
Expand All @@ -978,20 +1021,25 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin
split := strings.SplitN(name, "=", 2)
name = split[0]
flag, exists := f.formal[f.normalizeFlagName(name)]
unknownFlagsHandling := f.ParseErrorsAllowlist.getUnknownFlagsHandling()

if !exists {
switch {
case name == "help":
f.usage()
return a, ErrHelp
case f.ParseErrorsAllowlist.UnknownFlags:
case unknownFlagsHandling == UnknownFlagsHandlingIgnoreUnknown:
// --unknown=unknownval arg ...
// we do not want to lose arg in this case
if len(split) >= 2 {
return a, nil
}

return stripUnknownFlagValue(a), nil
case unknownFlagsHandling == UnknownFlagsHandlingPassUnknownToArgs:
return a, &unknownFlagError{
UnknownFlags: s,
}
default:
err = f.fail(&NotExistError{name: name, messageType: flagUnknownFlagMessage})
return
Expand Down Expand Up @@ -1037,12 +1085,14 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse

flag, exists := f.shorthands[c]
if !exists {
unknownFlagsHandling := f.ParseErrorsAllowlist.getUnknownFlagsHandling()

switch {
case c == 'h':
f.usage()
err = ErrHelp
return
case f.ParseErrorsAllowlist.UnknownFlags:
case unknownFlagsHandling == UnknownFlagsHandlingIgnoreUnknown:
// '-f=arg arg ...'
// we do not want to lose arg in this case
if len(shorthands) > 2 && shorthands[1] == '=' {
Expand All @@ -1052,6 +1102,20 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse

outArgs = stripUnknownFlagValue(outArgs)
return
case unknownFlagsHandling == UnknownFlagsHandlingPassUnknownToArgs:
// '-f=arg': pass all the argument
if len(shorthands) > 2 && shorthands[1] == '=' {
outShorts = ""
err = &unknownFlagError{
UnknownFlags: shorthands,
}
return
}
// '-fgh': pass only the first switch
err = &unknownFlagError{
UnknownFlags: shorthands[0:1],
}
return
default:
err = f.fail(&NotExistError{
name: string(c),
Expand Down Expand Up @@ -1102,14 +1166,31 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse
func (f *FlagSet) parseShortArg(s string, args []string, fn parseFunc) (a []string, err error) {
a = args
shorthands := s[1:]
var errUnknownFlagAll *unknownFlagError

// "shorthands" can be a series of shorthand letters of flags (e.g. "-vvv").
for len(shorthands) > 0 {
shorthands, a, err = f.parseSingleShortArg(shorthands, args, fn)
if err != nil {
return
if errUnknownFlag, ok := err.(*unknownFlagError); ok {
// this means f.ParseErrorsAllowlist.UnknownFlagsHandling is set to UnknownFlagsHandlingPassUnknownToArgs
if errUnknownFlagAll == nil {
errUnknownFlagAll = &unknownFlagError{
UnknownFlags: "-",
}
}

errUnknownFlagAll.UnknownFlags = errUnknownFlagAll.UnknownFlags +
errUnknownFlag.UnknownFlags
err = nil
} else {
return
}
}
}
if errUnknownFlagAll != nil {
err = errUnknownFlagAll
}

return
}
Expand Down Expand Up @@ -1139,7 +1220,13 @@ func (f *FlagSet) parseArgs(args []string, fn parseFunc) (err error) {
args, err = f.parseShortArg(s, args, fn)
}
if err != nil {
return
if errUnknownFlag, ok := err.(*unknownFlagError); ok {
// this means f.ParseErrorsAllowlist.UnknownFlagsHandling is set to UnknownFlagsHandlingPassUnknownToArgs
f.args = append(f.args, errUnknownFlag.UnknownFlags)
err = nil
} else {
return
}
}
}
return
Expand Down
109 changes: 109 additions & 0 deletions flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,111 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) {
}
}

func testParseWithUnknownFlagsAndPassToArgs(f *FlagSet, t *testing.T) {
if f.Parsed() {
t.Fatal("f.Parse() = true before Parse")
}
f.ParseErrorsAllowlist.UnknownFlagsHandling = UnknownFlagsHandlingPassUnknownToArgs
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f.SetInterspersed(true)

f.BoolP("boola", "a", false, "bool value")
f.BoolP("boolb", "b", false, "bool2 value")
f.BoolP("boolc", "c", false, "bool3 value")
f.BoolP("boold", "d", false, "bool4 value")
f.BoolP("boole", "e", false, "bool4 value")
f.StringP("stringa", "s", "0", "string value")
f.StringP("stringz", "z", "0", "string value")
f.StringP("stringx", "x", "0", "string value")
f.StringP("stringy", "y", "0", "string value")
f.StringP("stringo", "o", "0", "string value")
f.Lookup("stringx").NoOptDefVal = "1"
args := []string{
"-ab",
// -f and -g is unknown
"-fcgs=xx",
"--stringz=something",
"--unknown1",
"unknown1Value",
"-d=true",
"-x",
"--unknown2=unknown2Value",
"-u=unknown3Value",
"-p",
"unknown4Value",
"-q", //another unknown with bool value
"-y",
"ee",
"--unknown7=unknown7value",
"--stringo=ovalue",
"--unknown8=unknown8value",
"--boole",
"--unknown6",
"",
"-uuuuu",
"",
"--unknown10",
"--unknown11",
"arg0",
"arg1",
}
want := []string{
"boola", "true",
"boolb", "true",
"boolc", "true",
"stringa", "xx",
"stringz", "something",
"boold", "true",
"stringx", "1",
"stringy", "ee",
"stringo", "ovalue",
"boole", "true",
}
wantArgs := []string{
"-fg",
"--unknown1",
"unknown1Value",
"--unknown2=unknown2Value",
"-u=unknown3Value",
"-p",
"unknown4Value",
"-q", //another unknown with bool value
"--unknown7=unknown7value",
"--unknown8=unknown8value",
"--unknown6",
"",
"-uuuuu",
"",
"--unknown10",
"--unknown11",
"arg0",
"arg1",
}
got := []string{}
store := func(flag *Flag, value string) error {
got = append(got, flag.Name)
if len(value) > 0 {
got = append(got, value)
}
return nil
}
if err := f.ParseAll(args, store); err != nil {
t.Errorf("expected no error, got %s", err)
}
if !f.Parsed() {
t.Errorf("f.Parse() = false after Parse")
}
if !reflect.DeepEqual(got, want) {
t.Errorf("f.ParseAll() fail to restore the args")
t.Errorf("Got: %v", got)
t.Errorf("Want: %v", want)
}
if !reflect.DeepEqual(f.Args(), wantArgs) {
t.Errorf("f.ParseAll() fail to restore the args")
t.Errorf("Got: %v", f.Args())
t.Errorf("Want: %v", wantArgs)
}
}

func TestShorthand(t *testing.T) {
f := NewFlagSet("shorthand", ContinueOnError)
if f.Parsed() {
Expand Down Expand Up @@ -652,6 +757,10 @@ func TestIgnoreUnknownFlags(t *testing.T) {
testParseWithUnknownFlags(GetCommandLine(), t)
}

func TestIgnoreUnknownFlagsAndPassToArgs(t *testing.T) {
ResetForTesting(func() { t.Error("bad parse") })
testParseWithUnknownFlagsAndPassToArgs(GetCommandLine(), t)
}
func TestFlagSetParse(t *testing.T) {
testParse(NewFlagSet("test", ContinueOnError), t)
}
Expand Down
Loading