diff --git a/cli/cli.go b/cli/cli.go index 3afb9fd..2dd5f9b 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -28,6 +28,9 @@ type CLI struct { // Functions to run before the server starts up. prestart []func() + + // Functions to run after parsing args + argsParsed []func() } // NewRouter creates a new router, new CLI, sets the default middlware, and @@ -56,6 +59,11 @@ func New(router *huma.Router) *CLI { app.root = &cobra.Command{ Use: filepath.Base(os.Args[0]), Version: app.GetVersion(), + PersistentPreRun: func(cmd *cobra.Command, args []string) { + for _, f := range app.argsParsed { + f() + } + }, Run: func(cmd *cobra.Command, args []string) { fmt.Printf("Starting %s %s on %s:%v\n", app.GetTitle(), app.GetVersion(), viper.Get("host"), viper.Get("port")) @@ -160,6 +168,14 @@ func (c *CLI) PreStart(f func()) { c.prestart = append(c.prestart, f) } +// ArgsParsed registers a function to run after arguments have been parsed +// but before any command handler has been run. It is similar to a PreStart +// function but runs *before* PreStart functions and can be used for more +// than server startup, i.e. custom commands as well. +func (c *CLI) ArgsParsed(f func()) { + c.argsParsed = append(c.argsParsed, f) +} + // Run runs the CLI. func (c *CLI) Run() { if err := c.root.Execute(); err != nil { diff --git a/cli/cli_test.go b/cli/cli_test.go index abecad4..5959d1d 100644 --- a/cli/cli_test.go +++ b/cli/cli_test.go @@ -3,9 +3,12 @@ package cli import ( "context" "os" + "sync" "testing" "time" + "github.com/spf13/cobra" + "github.com/spf13/viper" "github.com/stretchr/testify/assert" ) @@ -32,3 +35,35 @@ func TestCLI(t *testing.T) { assert.Equal(t, true, started) } + +func TestParsedArgs(t *testing.T) { + app := NewRouter("Test API", "1.0.0") + + foo := "" + app.Flag("foo", "f", "desc", "") + + wg := sync.WaitGroup{} + wg.Add(1) + + app.Root().AddCommand(&cobra.Command{ + Use: "foo-test", + Run: func(cmd *cobra.Command, args []string) { + // Command does nothing... + }, + }) + + app.ArgsParsed(func() { + foo = viper.GetString("foo") + wg.Done() + }) + + app.Root().SetArgs([]string{"foo-test", "--foo=bar"}) + + go func() { + app.Root().Execute() + }() + + wg.Wait() + + assert.Equal(t, "bar", foo) +}