Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion acceptance/handle_config_file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ username123: username
Expect(err).ToNot(HaveOccurred())

Eventually(session, "5s").Should(gexec.Exit(1))
Expect(session.Err.Contents()).To(ContainSubstring("unknown flag `username123'"))
Expect(session.Err.Contents()).To(ContainSubstring("unknown flag(s) [\"--username123\"]"))
})
})

Expand Down
7 changes: 0 additions & 7 deletions cmd/loadConfigFile_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package cmd

import (
"testing"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
Expand All @@ -22,8 +20,3 @@ var _ = Describe("parseOptions", func() {

})
})

func TestCmds(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Cmds")
}
129 changes: 125 additions & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,11 +659,14 @@ func Main(sout io.Writer, serr io.Writer, version string, applySleepDurationStri
return err
}

parser.Options |= flags.HelpFlag

_, err = parser.ParseArgs(args)

// Strict flag validation block
err = validateCommandFlags(parser, args)
if err != nil {
return err
}

parser.Options |= flags.HelpFlag
if _, err = parser.ParseArgs(args); err != nil {
if e, ok := err.(*flags.Error); ok {
switch e.Type {
case flags.ErrHelp, flags.ErrCommandRequired:
Expand Down Expand Up @@ -787,3 +790,121 @@ func checkForVars(opts *options) error {

return nil
}

// validateCommandFlags checks if the provided command flags are valid for the given command.
func validateCommandFlags(parser *flags.Parser, args []string) error {
// If no args, return nil
if len(args) == 0 {
return nil
}

// Find the command to validate flags for
cmdName := args[0]
var selectedCmd *flags.Command
for _, cmd := range parser.Commands() {
if cmd.Name == cmdName || contains(cmd.Aliases, cmdName) {
selectedCmd = cmd
break
}
}
// Unknown command, let parser handle it
if selectedCmd == nil {
return nil
}

// Find unknown flags
invalidFlags := findUnknownFlags(selectedCmd, args)

// If there are unknown flags, print an error and return
if len(invalidFlags) > 0 {
fmt.Fprintf(os.Stderr, "Error: unknown flag(s) %q for command '%s'\n", invalidFlags, selectedCmd.Name)
fmt.Fprintf(os.Stderr, "See 'om %s --help' for available options.\n", selectedCmd.Name)
return fmt.Errorf("unknown flag(s) %q for command '%s'", invalidFlags, selectedCmd.Name)
}
return nil
}

// findUnknownFlags checks for unknown flags in the provided args for the given command.
func findUnknownFlags(selectedCmd *flags.Command, args []string) []string {
validFlags := make(map[string]bool)
addFlag := func(name string, takesValue bool) {
validFlags[name] = takesValue
}
for _, opt := range selectedCmd.Options() {
val := opt.Value()
_, isBool := val.(*bool)
_, isBoolSlice := val.(*[]bool)
takesValue := !(isBool || isBoolSlice)
if ln := opt.LongNameWithNamespace(); ln != "" {
addFlag("--"+ln, takesValue)
}
if opt.ShortName != 0 {
addFlag("-"+string(opt.ShortName), takesValue)
}
}
addFlag("--help", false)
addFlag("-h", false)

var invalidFlags []string
i := 1
for i < len(args) {
arg := args[i]
if !strings.HasPrefix(arg, "-") {
// Not a flag, just a value
// Example: args = ["upload-product", "file.pivotal"]
// "file.pivotal" is a positional argument or a value for a previous flag
i++
continue
}

// Split flag and value if --flag=value
flagName, hasEquals := arg, false
if eqIdx := strings.Index(arg, "="); eqIdx != -1 {
flagName = arg[:eqIdx]
hasEquals = true
// Example: arg = "--product=foo.pivotal" -> flagName = "--product", value = "foo.pivotal"
}

takesValue, isValid := validFlags[flagName]
if !isValid {
// Unknown flag
// Example: arg = "--notaflag" (not defined in command options)
invalidFlags = append(invalidFlags, flagName)
i++
continue
}

if takesValue {
if hasEquals {
// --flag=value, value is in this arg
// Example: arg = "--product=foo.pivotal"
i++
} else if i+1 < len(args) {
// --flag value, value is next arg (even if it looks like a flag)
// Example: args = ["--product", "--notaflag"]
// "--notaflag" is treated as the value for --product, not as a flag
i += 2
} else {
// --flag with missing value.
// No need to handle this here as this will handled appropriately by the parser.
// Example: args = ["--product"] (no value provided)
i++
}
} else {
// Boolean flag, no value expected
// Example: arg = "--help"
i++
}
}
return invalidFlags
}

// contains checks if a string is present in a list of strings.
func contains(list []string, s string) bool {
for _, v := range list {
if v == s {
return true
}
}
return false
}
13 changes: 13 additions & 0 deletions cmd/suite_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package cmd

import (
"testing"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

func TestCmd(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Cmd Suite")
}
50 changes: 50 additions & 0 deletions cmd/validate_command_flags_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package cmd

import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"

"github.com/jessevdk/go-flags"
)

var _ = Describe("validateCommandFlags", func() {
type uploadProductOptions struct {
Product string `long:"product" short:"p" description:"path to product" required:"true"`
PollingInterval int `long:"polling-interval" short:"i" description:"interval (in seconds) at which to print status" default:"1"`
Shasum string `long:"shasum" description:"shasum of the provided product file to be used for validation"`
Version string `long:"product-version" description:"version of the provided product file to be used for validation"`
Config string `long:"config" short:"c" description:"path to yml file for configuration"`
VarsEnv string `long:"vars-env" description:"load variables from environment variables matching the provided prefix"`
VarsFile []string `long:"vars-file" short:"l" description:"load variables from a YAML file"`
Var []string `long:"var" short:"v" description:"load variable from the command line. Format: VAR=VAL"`
}

var parser *flags.Parser

BeforeEach(func() {
parser = flags.NewParser(nil, flags.Default)
parser.AddCommand("upload-product", "desc", "long desc", &uploadProductOptions{})
})

DescribeTable("flag validation",
func(args []string, wantErr bool, errMsg string) {
err := validateCommandFlags(parser, args)
if wantErr {
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring(errMsg))
} else {
Expect(err).ToNot(HaveOccurred())
}
},
Entry("no args", []string{}, false, ""),
Entry("valid flags", []string{"upload-product", "--product", "file.pivotal", "--polling-interval", "5", "--shasum", "abc123", "--product-version", "2.3.4"}, false, ""),
Entry("valid short flags", []string{"upload-product", "-p", "file.pivotal", "-i", "5"}, false, ""),
Entry("all config interpolation flags", []string{"upload-product", "-c", "config.yml", "--vars-env", "MY", "-l", "vars1.yml", "-l", "vars2.yml", "-v", "FOO=bar", "-v", "BAZ=qux", "-p", "file.pivotal"}, false, ""),
Entry("mix config and main flags", []string{"upload-product", "-p", "file.pivotal", "-c", "config.yml", "--vars-env", "MY", "--shasum", "abc123", "-l", "vars.yml", "-v", "FOO=bar"}, false, ""),
Entry("unknown flag with config flags", []string{"upload-product", "-p", "file.pivotal", "-c", "config.yml", "--notaflag"}, true, "unknown flag(s)"),
Entry("unknown flag", []string{"upload-product", "--notaflag"}, true, "unknown flag(s)"),
Entry("multiple unknown flags", []string{"upload-product", "--foo", "--bar"}, true, "unknown flag(s)"),
Entry("flag value looks like flag", []string{"upload-product", "--product", "--notaflag"}, false, ""),
Entry("unknown short flags", []string{"upload-product", "-p", "file.pivotal", "-x", "18000", "-z", "18000"}, true, "unknown flag(s)"),
)
})
22 changes: 11 additions & 11 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ require (
github.com/golang-jwt/jwt/v4 v4.5.1 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/pprof v0.0.0-20241101162523-b92577c0c142 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
Expand Down Expand Up @@ -143,21 +143,21 @@ require (
go.opentelemetry.io/otel/sdk v1.32.0 // indirect
go.opentelemetry.io/otel/sdk/metric v1.32.0 // indirect
go.opentelemetry.io/otel/trace v1.32.0 // indirect
golang.org/x/crypto v0.29.0 // indirect
golang.org/x/mod v0.22.0 // indirect
golang.org/x/net v0.31.0 // indirect
golang.org/x/sync v0.9.0 // indirect
golang.org/x/sys v0.27.0 // indirect
golang.org/x/term v0.26.0 // indirect
golang.org/x/text v0.20.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/mod v0.24.0 // indirect
golang.org/x/net v0.37.0 // indirect
golang.org/x/sync v0.12.0 // indirect
golang.org/x/sys v0.32.0 // indirect
golang.org/x/term v0.30.0 // indirect
golang.org/x/text v0.23.0 // indirect
golang.org/x/time v0.8.0 // indirect
golang.org/x/tools v0.27.0 // indirect
golang.org/x/tools v0.31.0 // indirect
google.golang.org/genproto v0.0.0-20241104194629-dd2ea8efbc28 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241104194629-dd2ea8efbc28 // indirect
google.golang.org/grpc v1.68.0 // indirect
google.golang.org/grpc/stats/opentelemetry v0.0.0-20241028142157-ada6787961b3 // indirect
google.golang.org/protobuf v1.35.1 // indirect
google.golang.org/protobuf v1.36.5 // indirect
gopkg.in/cheggaaa/pb.v1 v1.0.28 // indirect
gopkg.in/go-playground/assert.v1 v1.2.1 // indirect
)
Loading
Loading