[client] Open browser for ssh automatically (#4838)

This commit is contained in:
Viktor Liu
2025-11-26 16:06:47 +01:00
committed by GitHub
parent f31bba87b4
commit 02200d790b
7 changed files with 107 additions and 39 deletions

View File

@@ -4,14 +4,12 @@ import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
@@ -373,21 +371,13 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
cmd.Println("")
if !noBrowser {
if err := openBrowser(verificationURIComplete); err != nil {
if err := util.OpenBrowser(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
}
}
}
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
func openBrowser(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
return open.Run(url)
}
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {

View File

@@ -51,6 +51,7 @@ var (
identityFile string
skipCachedToken bool
requestPTY bool
sshNoBrowser bool
)
var (
@@ -81,6 +82,7 @@ func init() {
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
sshCmd.PersistentFlags().BoolVar(&sshNoBrowser, noBrowserFlag, false, noBrowserDesc)
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
@@ -185,6 +187,21 @@ func getEnvOrDefault(flagName, defaultValue string) string {
return defaultValue
}
// getBoolEnvOrDefault checks for boolean environment variables with WT_ and NB_ prefixes
func getBoolEnvOrDefault(flagName string, defaultValue bool) bool {
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
if parsed, err := strconv.ParseBool(envValue); err == nil {
return parsed
}
}
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
if parsed, err := strconv.ParseBool(envValue); err == nil {
return parsed
}
}
return defaultValue
}
// resetSSHGlobals sets SSH globals to their default values
func resetSSHGlobals() {
port = sshserver.DefaultSSHPort
@@ -196,6 +213,7 @@ func resetSSHGlobals() {
strictHostKeyChecking = true
knownHostsFile = ""
identityFile = ""
sshNoBrowser = false
}
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
@@ -370,6 +388,7 @@ type sshFlags struct {
KnownHostsFile string
IdentityFile string
SkipCachedToken bool
NoBrowser bool
ConfigPath string
LogLevel string
LocalForwards []string
@@ -381,6 +400,7 @@ type sshFlags struct {
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
defaultNoBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
fs.SetOutput(nil)
@@ -401,6 +421,7 @@ func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
fs.BoolVar(&flags.NoBrowser, "no-browser", defaultNoBrowser, noBrowserDesc)
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
@@ -449,6 +470,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
knownHostsFile = flags.KnownHostsFile
identityFile = flags.IdentityFile
skipCachedToken = flags.SkipCachedToken
sshNoBrowser = flags.NoBrowser
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
configPath = flags.ConfigPath
@@ -508,6 +530,7 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
DaemonAddr: daemonAddr,
SkipCachedToken: skipCachedToken,
InsecureSkipVerify: !strictHostKeyChecking,
NoBrowser: sshNoBrowser,
})
if err != nil {
@@ -763,7 +786,15 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
return fmt.Errorf("invalid port: %s", portStr)
}
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
// Check env var for browser setting since this command is invoked via SSH ProxyCommand
// where command-line flags cannot be passed. Default is to open browser.
noBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
var browserOpener func(string) error
if !noBrowser {
browserOpener = util.OpenBrowser
}
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr(), browserOpener)
if err != nil {
return fmt.Errorf("create SSH proxy: %w", err)
}

View File

@@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/util"
)
const (
@@ -278,6 +279,7 @@ type DialOptions struct {
DaemonAddr string
SkipCachedToken bool
InsecureSkipVerify bool
NoBrowser bool
}
// Dial connects to the given ssh server with specified options
@@ -307,7 +309,7 @@ func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, er
config.Auth = append(config.Auth, authMethod)
}
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken, opts.NoBrowser)
}
// dialSSH establishes an SSH connection without JWT authentication
@@ -333,7 +335,7 @@ func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig
}
// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) {
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache, noBrowser bool) (*Client, error) {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("parse address %s: %w", addr, err)
@@ -359,7 +361,7 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo
jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
defer cancel()
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache)
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache, noBrowser)
if err != nil {
return nil, fmt.Errorf("request JWT token: %w", err)
}
@@ -369,7 +371,7 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo
}
// requestJWTToken requests a JWT token from the NetBird daemon
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache, noBrowser bool) (string, error) {
hint := profilemanager.GetLoginHint()
conn, err := connectToDaemon(daemonAddr)
@@ -379,7 +381,13 @@ func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (st
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint)
var browserOpener func(string) error
if !noBrowser {
browserOpener = util.OpenBrowser
}
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint, browserOpener)
}
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon

View File

@@ -67,8 +67,31 @@ func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKe
return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
}
// printAuthInstructions prints authentication instructions to stderr
func printAuthInstructions(stderr io.Writer, authResponse *proto.RequestJWTAuthResponse, browserWillOpen bool) {
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
if browserWillOpen {
_, _ = fmt.Fprintln(stderr, "Please do the SSO login in your browser.")
_, _ = fmt.Fprintln(stderr, "If your browser didn't open automatically, use this URL to log in:")
_, _ = fmt.Fprintln(stderr)
}
_, _ = fmt.Fprintf(stderr, "%s\n", authResponse.VerificationURIComplete)
if authResponse.UserCode != "" {
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
}
if browserWillOpen {
_, _ = fmt.Fprintln(stderr)
}
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
}
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string) (string, error) {
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string, openBrowser func(string) error) (string, error) {
req := &proto.RequestJWTAuthRequest{}
if hint != "" {
req.Hint = &hint
@@ -84,12 +107,13 @@ func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdo
}
if stderr != nil {
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
_, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete)
if authResponse.UserCode != "" {
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
printAuthInstructions(stderr, authResponse, openBrowser != nil)
}
if openBrowser != nil {
if err := openBrowser(authResponse.VerificationURIComplete); err != nil {
log.Debugf("open browser: %v", err)
}
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
}
tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{

View File

@@ -35,15 +35,16 @@ const (
)
type SSHProxy struct {
daemonAddr string
targetHost string
targetPort int
stderr io.Writer
conn *grpc.ClientConn
daemonClient proto.DaemonServiceClient
daemonAddr string
targetHost string
targetPort int
stderr io.Writer
conn *grpc.ClientConn
daemonClient proto.DaemonServiceClient
browserOpener func(string) error
}
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) {
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) {
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
@@ -51,12 +52,13 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHP
}
return &SSHProxy{
daemonAddr: daemonAddr,
targetHost: targetHost,
targetPort: targetPort,
stderr: stderr,
conn: grpcConn,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
daemonAddr: daemonAddr,
targetHost: targetHost,
targetPort: targetPort,
stderr: stderr,
conn: grpcConn,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
browserOpener: browserOpener,
}, nil
}
@@ -70,7 +72,7 @@ func (p *SSHProxy) Close() error {
func (p *SSHProxy) Connect(ctx context.Context) error {
hint := profilemanager.GetLoginHint()
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint)
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint, p.browserOpener)
if err != nil {
return fmt.Errorf(jwtAuthErrorMsg, err)
}

View File

@@ -153,7 +153,7 @@ func TestSSHProxy_Connect(t *testing.T) {
validToken := generateValidJWT(t, privateKey, issuer, audience)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, nil)
proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()

View File

@@ -1,6 +1,19 @@
package util
import "os"
import (
"os"
"os/exec"
"github.com/skratchdot/open-golang/open"
)
// OpenBrowser opens the URL in a browser, respecting the BROWSER environment variable.
func OpenBrowser(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
return open.Run(url)
}
// SliceDiff returns the elements in slice `x` that are not in slice `y`
func SliceDiff(x, y []string) []string {