package main import ( "fmt" "log" "os" "os/exec" "path/filepath" "strings" "github.com/spf13/cobra" ) func logFail(cond bool, s string, args ...interface{}) { if !cond { log.Fatalf(s, args...) } } func ok(cmd *exec.Cmd) bool { err := cmd.Run(); return err == nil } func portString() string { return fmt.Sprintf("%[1]d:localhost:%[1]d", port) } func checkControl() bool { return ok(exec.Command("ssh", "-O", "check", host)) } func activateSshControl() bool { return ok(exec.Command("ssh", "-M", "-f", "-N", host)) } func connectPort() bool { return ok(exec.Command("ssh", "-fNL", portString(), host)) } func deactivateSshControl() bool { return ok(exec.Command("ssh", "-O", "exit", host)) } func tunnelUp() { if !checkControl() { logFail(activateSshControl(), "failed to activate ssh for host: %s", host) } logFail(connectPort(), "failed to connect to host: %s with port: %d", host, port) } func tunnelDown() { logFail(deactivateSshControl(), "failed to disable ssh control %s", host) } func getControlSockets() (ret []string) { files, err := os.ReadDir(sshDir) if err != nil { panic(err) } for _, f := range files { if strings.HasPrefix(f.Name(), "control") { ret = append(ret, f.Name()) } } return } func tunnelShow() { controls := getControlSockets() fmt.Printf("%d active connections\n", len(controls)) if len(controls) > 0 { fmt.Println("hosts:") for _, c := range controls { fmt.Printf(" %s\n", strings.Split(c, "-")[1]) } } } func genSubCmd(use string, short string, run func()) *cobra.Command { return &cobra.Command{ Use: use, Short: short, Args: cobra.MinimumNArgs(1), Run: func(cmd *cobra.Command, args []string) { host = args[0]; run() }, } } var ( sshDir = filepath.Join(os.Getenv("HOME"), ".ssh") rootCmd = &cobra.Command{Use: "tunnel", Short: "control ssh tunnels"} upCmd = genSubCmd("up hostname [flags]", "activate ssh tunnel", tunnelUp) downCmd = genSubCmd("down hostname [flags]", "deactivate ssh tunnel", tunnelDown) showCmd = &cobra.Command{Use: "show", Short: "show activate tunnels", Run: func(cmd *cobra.Command, args []string) { tunnelShow() }} port uint64 host string ) func init() { rootCmd.CompletionOptions.HiddenDefaultCmd = true rootCmd.AddCommand(upCmd, downCmd, showCmd) upCmd.Flags().Uint64VarP(&port, "port", "p", 0, "port number") upCmd.MarkFlagRequired("port") } func main() { err := rootCmd.Execute() if err != nil { os.Exit(1) } }