Compare commits

...

1 Commits

Author SHA1 Message Date
braginini
a369357a85 Add Zitadel IdP 2025-12-19 07:08:08 -05:00
13 changed files with 2965 additions and 0 deletions

35
idp/cmd/env.go Normal file
View File

@@ -0,0 +1,35 @@
package cmd
import (
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
// setFlagsFromEnvVars reads and updates flag values from environment variables with prefix NB_IDP_
func setFlagsFromEnvVars(cmd *cobra.Command) {
flags := cmd.PersistentFlags()
flags.VisitAll(func(f *pflag.Flag) {
newEnvVar := flagNameToEnvVar(f.Name, "NB_IDP_")
value, present := os.LookupEnv(newEnvVar)
if !present {
return
}
err := flags.Set(f.Name, value)
if err != nil {
log.Infof("unable to configure flag %s using variable %s, err: %v", f.Name, newEnvVar, err)
}
})
}
// flagNameToEnvVar converts flag name to environment var name adding a prefix,
// replacing dashes and making all uppercase (e.g. data-dir is converted to NB_IDP_DATA_DIR)
func flagNameToEnvVar(cmdFlag string, prefix string) string {
parsed := strings.ReplaceAll(cmdFlag, "-", "_")
upper := strings.ToUpper(parsed)
return prefix + upper
}

148
idp/cmd/root.go Normal file
View File

@@ -0,0 +1,148 @@
package cmd
import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/idp/oidcprovider"
"github.com/netbirdio/netbird/util"
)
// Config holds the IdP server configuration
type Config struct {
ListenPort int
Issuer string
DataDir string
LogLevel string
LogFile string
DevMode bool
DashboardRedirectURIs []string
CLIRedirectURIs []string
DashboardClientID string
CLIClientID string
}
var (
config *Config
rootCmd = &cobra.Command{
Use: "idp",
Short: "NetBird Identity Provider",
Long: "Embedded OIDC Identity Provider for NetBird",
SilenceUsage: true,
SilenceErrors: true,
RunE: execute,
}
)
func init() {
_ = util.InitLog("trace", util.LogConsole)
config = &Config{}
rootCmd.PersistentFlags().IntVarP(&config.ListenPort, "port", "p", 33081, "port to listen on")
rootCmd.PersistentFlags().StringVarP(&config.Issuer, "issuer", "i", "", "OIDC issuer URL (default: http://localhost:<port>)")
rootCmd.PersistentFlags().StringVarP(&config.DataDir, "data-dir", "d", "/var/lib/netbird", "directory to store IdP data")
rootCmd.PersistentFlags().StringVar(&config.LogLevel, "log-level", "info", "log level (trace, debug, info, warn, error)")
rootCmd.PersistentFlags().StringVar(&config.LogFile, "log-file", "console", "log file path or 'console'")
rootCmd.PersistentFlags().BoolVar(&config.DevMode, "dev-mode", false, "enable development mode (allows HTTP)")
rootCmd.PersistentFlags().StringSliceVar(&config.DashboardRedirectURIs, "dashboard-redirect-uris", []string{
"http://localhost:3000/callback",
"http://localhost:3000/silent-callback",
}, "allowed redirect URIs for dashboard client")
rootCmd.PersistentFlags().StringSliceVar(&config.CLIRedirectURIs, "cli-redirect-uris", []string{
"http://localhost:53000",
"http://localhost:54000",
}, "allowed redirect URIs for CLI client")
rootCmd.PersistentFlags().StringVar(&config.DashboardClientID, "dashboard-client-id", "netbird-dashboard", "client ID for dashboard")
rootCmd.PersistentFlags().StringVar(&config.CLIClientID, "cli-client-id", "netbird-client", "client ID for CLI")
// Add subcommands
rootCmd.AddCommand(userCmd)
setFlagsFromEnvVars(rootCmd)
}
// Execute runs the root command
func Execute() error {
return rootCmd.Execute()
}
func execute(cmd *cobra.Command, args []string) error {
err := util.InitLog(config.LogLevel, config.LogFile)
if err != nil {
return fmt.Errorf("failed to initialize log: %s", err)
}
// Set default issuer if not provided
issuer := config.Issuer
if issuer == "" {
issuer = fmt.Sprintf("http://localhost:%d", config.ListenPort)
}
log.Infof("Starting NetBird Identity Provider")
log.Infof(" Port: %d", config.ListenPort)
log.Infof(" Issuer: %s", issuer)
log.Infof(" Data directory: %s", config.DataDir)
log.Infof(" Dev mode: %v", config.DevMode)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create provider config
providerConfig := &oidcprovider.Config{
Issuer: issuer,
Port: config.ListenPort,
DataDir: config.DataDir,
DevMode: config.DevMode,
}
// Create the provider
provider, err := oidcprovider.NewProvider(ctx, providerConfig)
if err != nil {
return fmt.Errorf("failed to create IdP: %w", err)
}
// Ensure default clients exist
if err := provider.EnsureDefaultClients(ctx, config.DashboardRedirectURIs, config.CLIRedirectURIs); err != nil {
return fmt.Errorf("failed to create default clients: %w", err)
}
// Start the provider
if err := provider.Start(ctx); err != nil {
return fmt.Errorf("failed to start IdP: %w", err)
}
log.Infof("IdP is running")
log.Infof(" Discovery: %s/.well-known/openid-configuration", issuer)
log.Infof(" Authorization: %s/authorize", issuer)
log.Infof(" Token: %s/oauth/token", issuer)
log.Infof(" Device authorization: %s/device_authorization", issuer)
log.Infof(" JWKS: %s/keys", issuer)
log.Infof(" Login: %s/login", issuer)
log.Infof(" Device flow: %s/device", issuer)
// Wait for exit signal
waitForExitSignal()
log.Infof("Shutting down IdP...")
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10)
defer shutdownCancel()
if err := provider.Stop(shutdownCtx); err != nil {
return fmt.Errorf("failed to stop IdP: %w", err)
}
log.Infof("IdP stopped")
return nil
}
func waitForExitSignal() {
osSigs := make(chan os.Signal, 1)
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
<-osSigs
}

249
idp/cmd/user.go Normal file
View File

@@ -0,0 +1,249 @@
package cmd
import (
"context"
"fmt"
"os"
"syscall"
"text/tabwriter"
"github.com/spf13/cobra"
"golang.org/x/term"
"github.com/netbirdio/netbird/idp/oidcprovider"
)
var userCmd = &cobra.Command{
Use: "user",
Short: "Manage IdP users",
Long: "Commands for managing users in the embedded IdP",
}
var userAddCmd = &cobra.Command{
Use: "add",
Short: "Add a new user",
Long: "Add a new user to the embedded IdP",
RunE: userAdd,
}
var userListCmd = &cobra.Command{
Use: "list",
Short: "List all users",
Long: "List all users in the embedded IdP",
RunE: userList,
}
var userDeleteCmd = &cobra.Command{
Use: "delete <username>",
Short: "Delete a user",
Long: "Delete a user from the embedded IdP",
Args: cobra.ExactArgs(1),
RunE: userDelete,
}
var userPasswordCmd = &cobra.Command{
Use: "password <username>",
Short: "Change user password",
Long: "Change password for a user in the embedded IdP",
Args: cobra.ExactArgs(1),
RunE: userChangePassword,
}
// User add flags
var (
userUsername string
userEmail string
userFirstName string
userLastName string
userPassword string
)
func init() {
userAddCmd.Flags().StringVarP(&userUsername, "username", "u", "", "username (required)")
userAddCmd.Flags().StringVarP(&userEmail, "email", "e", "", "email address (required)")
userAddCmd.Flags().StringVarP(&userFirstName, "first-name", "f", "", "first name")
userAddCmd.Flags().StringVarP(&userLastName, "last-name", "l", "", "last name")
userAddCmd.Flags().StringVarP(&userPassword, "password", "p", "", "password (will prompt if not provided)")
_ = userAddCmd.MarkFlagRequired("username")
_ = userAddCmd.MarkFlagRequired("email")
userCmd.AddCommand(userAddCmd)
userCmd.AddCommand(userListCmd)
userCmd.AddCommand(userDeleteCmd)
userCmd.AddCommand(userPasswordCmd)
}
func getStore() (*oidcprovider.Store, error) {
ctx := context.Background()
store, err := oidcprovider.NewStore(ctx, config.DataDir)
if err != nil {
return nil, fmt.Errorf("failed to open store: %w", err)
}
return store, nil
}
func userAdd(cmd *cobra.Command, args []string) error {
store, err := getStore()
if err != nil {
return err
}
defer store.Close()
password := userPassword
if password == "" {
// Prompt for password
fmt.Print("Enter password: ")
bytePassword, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
return fmt.Errorf("failed to read password: %w", err)
}
fmt.Println()
fmt.Print("Confirm password: ")
byteConfirm, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
return fmt.Errorf("failed to read password confirmation: %w", err)
}
fmt.Println()
if string(bytePassword) != string(byteConfirm) {
return fmt.Errorf("passwords do not match")
}
password = string(bytePassword)
}
if password == "" {
return fmt.Errorf("password cannot be empty")
}
user := &oidcprovider.User{
Username: userUsername,
Email: userEmail,
FirstName: userFirstName,
LastName: userLastName,
Password: password,
EmailVerified: true, // Mark as verified since admin is creating the user
}
ctx := context.Background()
if err := store.CreateUser(ctx, user); err != nil {
return fmt.Errorf("failed to create user: %w", err)
}
fmt.Printf("User '%s' created successfully (ID: %s)\n", userUsername, user.ID)
return nil
}
func userList(cmd *cobra.Command, args []string) error {
store, err := getStore()
if err != nil {
return err
}
defer store.Close()
ctx := context.Background()
users, err := store.ListUsers(ctx)
if err != nil {
return fmt.Errorf("failed to list users: %w", err)
}
if len(users) == 0 {
fmt.Println("No users found")
return nil
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "ID\tUSERNAME\tEMAIL\tNAME\tVERIFIED\tCREATED")
for _, user := range users {
name := fmt.Sprintf("%s %s", user.FirstName, user.LastName)
verified := "No"
if user.EmailVerified {
verified = "Yes"
}
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
user.ID,
user.Username,
user.Email,
name,
verified,
user.CreatedAt.Format("2006-01-02 15:04"),
)
}
w.Flush()
return nil
}
func userDelete(cmd *cobra.Command, args []string) error {
username := args[0]
store, err := getStore()
if err != nil {
return err
}
defer store.Close()
ctx := context.Background()
// Find user by username
user, err := store.GetUserByUsername(ctx, username)
if err != nil {
return fmt.Errorf("user '%s' not found", username)
}
if err := store.DeleteUser(ctx, user.ID); err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
fmt.Printf("User '%s' deleted successfully\n", username)
return nil
}
func userChangePassword(cmd *cobra.Command, args []string) error {
username := args[0]
store, err := getStore()
if err != nil {
return err
}
defer store.Close()
ctx := context.Background()
// Find user by username
user, err := store.GetUserByUsername(ctx, username)
if err != nil {
return fmt.Errorf("user '%s' not found", username)
}
// Prompt for new password
fmt.Print("Enter new password: ")
bytePassword, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
return fmt.Errorf("failed to read password: %w", err)
}
fmt.Println()
fmt.Print("Confirm new password: ")
byteConfirm, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
return fmt.Errorf("failed to read password confirmation: %w", err)
}
fmt.Println()
if string(bytePassword) != string(byteConfirm) {
return fmt.Errorf("passwords do not match")
}
password := string(bytePassword)
if password == "" {
return fmt.Errorf("password cannot be empty")
}
if err := store.UpdateUserPassword(ctx, user.ID, password); err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
fmt.Printf("Password updated for user '%s'\n", username)
return nil
}

13
idp/main.go Normal file
View File

@@ -0,0 +1,13 @@
package main
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/idp/cmd"
)
func main() {
if err := cmd.Execute(); err != nil {
log.Fatalf("failed to execute command: %v", err)
}
}

249
idp/oidcprovider/client.go Normal file
View File

@@ -0,0 +1,249 @@
package oidcprovider
import (
"time"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
)
// OIDCClient wraps the database Client model and implements op.Client interface
type OIDCClient struct {
client *Client
loginURL func(string) string
redirectURIs []string
grantTypes []oidc.GrantType
responseTypes []oidc.ResponseType
}
// NewOIDCClient creates an OIDCClient from a database Client
func NewOIDCClient(client *Client, loginURL func(string) string) *OIDCClient {
return &OIDCClient{
client: client,
loginURL: loginURL,
redirectURIs: ParseJSONArray(client.RedirectURIs),
grantTypes: parseGrantTypes(client.GrantTypes),
responseTypes: parseResponseTypes(client.ResponseTypes),
}
}
// GetID returns the client ID
func (c *OIDCClient) GetID() string {
return c.client.ID
}
// RedirectURIs returns the registered redirect URIs
func (c *OIDCClient) RedirectURIs() []string {
return c.redirectURIs
}
// PostLogoutRedirectURIs returns the registered post-logout redirect URIs
func (c *OIDCClient) PostLogoutRedirectURIs() []string {
return ParseJSONArray(c.client.PostLogoutURIs)
}
// ApplicationType returns the application type (native, web, user_agent)
func (c *OIDCClient) ApplicationType() op.ApplicationType {
switch c.client.ApplicationType {
case "native":
return op.ApplicationTypeNative
case "web":
return op.ApplicationTypeWeb
case "user_agent":
return op.ApplicationTypeUserAgent
default:
return op.ApplicationTypeWeb
}
}
// AuthMethod returns the authentication method
func (c *OIDCClient) AuthMethod() oidc.AuthMethod {
switch c.client.AuthMethod {
case "none":
return oidc.AuthMethodNone
case "client_secret_basic":
return oidc.AuthMethodBasic
case "client_secret_post":
return oidc.AuthMethodPost
case "private_key_jwt":
return oidc.AuthMethodPrivateKeyJWT
default:
return oidc.AuthMethodNone
}
}
// ResponseTypes returns the allowed response types
func (c *OIDCClient) ResponseTypes() []oidc.ResponseType {
return c.responseTypes
}
// GrantTypes returns the allowed grant types
func (c *OIDCClient) GrantTypes() []oidc.GrantType {
return c.grantTypes
}
// LoginURL returns the login URL for this client
func (c *OIDCClient) LoginURL(authRequestID string) string {
if c.loginURL != nil {
return c.loginURL(authRequestID)
}
return "/login?authRequestID=" + authRequestID
}
// AccessTokenType returns the access token type
func (c *OIDCClient) AccessTokenType() op.AccessTokenType {
switch c.client.AccessTokenType {
case "jwt":
return op.AccessTokenTypeJWT
default:
return op.AccessTokenTypeBearer
}
}
// IDTokenLifetime returns the ID token lifetime
func (c *OIDCClient) IDTokenLifetime() time.Duration {
if c.client.IDTokenLifetime > 0 {
return time.Duration(c.client.IDTokenLifetime) * time.Second
}
return time.Hour // default 1 hour
}
// DevMode returns whether the client is in development mode
func (c *OIDCClient) DevMode() bool {
return c.client.DevMode
}
// RestrictAdditionalIdTokenScopes returns any restricted scopes for ID tokens
func (c *OIDCClient) RestrictAdditionalIdTokenScopes() func(scopes []string) []string {
return func(scopes []string) []string {
return scopes
}
}
// RestrictAdditionalAccessTokenScopes returns any restricted scopes for access tokens
func (c *OIDCClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string {
return func(scopes []string) []string {
return scopes
}
}
// IsScopeAllowed checks if a scope is allowed for this client
func (c *OIDCClient) IsScopeAllowed(scope string) bool {
// Allow all standard OIDC scopes
switch scope {
case oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone, oidc.ScopeAddress, oidc.ScopeOfflineAccess:
return true
}
return true // Allow custom scopes as well
}
// IDTokenUserinfoClaimsAssertion returns whether userinfo claims should be included in ID token
func (c *OIDCClient) IDTokenUserinfoClaimsAssertion() bool {
return false
}
// ClockSkew returns the allowed clock skew for this client
func (c *OIDCClient) ClockSkew() time.Duration {
if c.client.ClockSkew > 0 {
return time.Duration(c.client.ClockSkew) * time.Second
}
return 0
}
// Helper functions for parsing grant types and response types
func parseGrantTypes(jsonStr string) []oidc.GrantType {
types := ParseJSONArray(jsonStr)
if len(types) == 0 {
// Default grant types
return []oidc.GrantType{
oidc.GrantTypeCode,
oidc.GrantTypeRefreshToken,
}
}
result := make([]oidc.GrantType, 0, len(types))
for _, t := range types {
switch t {
case "authorization_code":
result = append(result, oidc.GrantTypeCode)
case "refresh_token":
result = append(result, oidc.GrantTypeRefreshToken)
case "client_credentials":
result = append(result, oidc.GrantTypeClientCredentials)
case "urn:ietf:params:oauth:grant-type:device_code":
result = append(result, oidc.GrantTypeDeviceCode)
case "urn:ietf:params:oauth:grant-type:token-exchange":
result = append(result, oidc.GrantTypeTokenExchange)
}
}
return result
}
func parseResponseTypes(jsonStr string) []oidc.ResponseType {
types := ParseJSONArray(jsonStr)
if len(types) == 0 {
// Default response types
return []oidc.ResponseType{oidc.ResponseTypeCode}
}
result := make([]oidc.ResponseType, 0, len(types))
for _, t := range types {
switch t {
case "code":
result = append(result, oidc.ResponseTypeCode)
case "id_token":
result = append(result, oidc.ResponseTypeIDToken)
}
}
return result
}
// CreateNativeClient creates a native client configuration (for CLI/mobile apps with PKCE)
func CreateNativeClient(id, name string, redirectURIs []string) *Client {
return &Client{
ID: id,
Name: name,
RedirectURIs: ToJSONArray(redirectURIs),
ApplicationType: "native",
AuthMethod: "none", // Public client
ResponseTypes: ToJSONArray([]string{"code"}),
GrantTypes: ToJSONArray([]string{"authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:device_code"}),
AccessTokenType: "bearer",
DevMode: true,
IDTokenLifetime: 3600,
}
}
// CreateWebClient creates a web client configuration (for SPAs/web apps)
func CreateWebClient(id, secret, name string, redirectURIs []string) *Client {
return &Client{
ID: id,
Secret: secret,
Name: name,
RedirectURIs: ToJSONArray(redirectURIs),
ApplicationType: "web",
AuthMethod: "client_secret_basic",
ResponseTypes: ToJSONArray([]string{"code"}),
GrantTypes: ToJSONArray([]string{"authorization_code", "refresh_token"}),
AccessTokenType: "bearer",
DevMode: false,
IDTokenLifetime: 3600,
}
}
// CreateSPAClient creates a Single Page Application client configuration (public client for SPAs)
func CreateSPAClient(id, name string, redirectURIs []string) *Client {
return &Client{
ID: id,
Name: name,
RedirectURIs: ToJSONArray(redirectURIs),
ApplicationType: "user_agent",
AuthMethod: "none", // Public client for SPA
ResponseTypes: ToJSONArray([]string{"code"}),
GrantTypes: ToJSONArray([]string{"authorization_code", "refresh_token"}),
AccessTokenType: "bearer",
DevMode: true,
IDTokenLifetime: 3600,
}
}

220
idp/oidcprovider/device.go Normal file
View File

@@ -0,0 +1,220 @@
package oidcprovider
import (
"encoding/base64"
"html/template"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/gorilla/securecookie"
log "github.com/sirupsen/logrus"
)
// DeviceHandler handles the device authorization flow
type DeviceHandler struct {
storage *OIDCStorage
tmpl *template.Template
secureCookie *securecookie.SecureCookie
}
// NewDeviceHandler creates a new device handler
func NewDeviceHandler(storage *OIDCStorage) (*DeviceHandler, error) {
tmpl, err := template.ParseFS(templateFS, "templates/*.html")
if err != nil {
return nil, err
}
// Generate secure cookie keys
hashKey := securecookie.GenerateRandomKey(32)
blockKey := securecookie.GenerateRandomKey(32)
return &DeviceHandler{
storage: storage,
tmpl: tmpl,
secureCookie: securecookie.New(hashKey, blockKey),
}, nil
}
// Router returns the device flow router
func (h *DeviceHandler) Router() chi.Router {
r := chi.NewRouter()
r.Get("/", h.userCodePage)
r.Post("/login", h.handleLogin)
r.Post("/confirm", h.handleConfirm)
return r
}
// userCodePage displays the user code entry form
func (h *DeviceHandler) userCodePage(w http.ResponseWriter, r *http.Request) {
userCode := r.URL.Query().Get("user_code")
data := map[string]interface{}{
"UserCode": userCode,
"Error": "",
"Step": "code", // code, login, or confirm
}
if userCode != "" {
// Verify the user code exists
_, err := h.storage.GetDeviceAuthorizationByUserCode(r.Context(), userCode)
if err != nil {
data["Error"] = "Invalid or expired user code"
data["UserCode"] = ""
} else {
data["Step"] = "login"
}
}
if err := h.tmpl.ExecuteTemplate(w, "device.html", data); err != nil {
log.Errorf("failed to render device template: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
}
}
// handleLogin processes the login form on the device flow
func (h *DeviceHandler) handleLogin(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "invalid form", http.StatusBadRequest)
return
}
userCode := r.FormValue("user_code")
username := r.FormValue("username")
password := r.FormValue("password")
data := map[string]interface{}{
"UserCode": userCode,
"Error": "",
"Step": "login",
}
if userCode == "" || username == "" || password == "" {
data["Error"] = "Please fill in all fields"
h.tmpl.ExecuteTemplate(w, "device.html", data)
return
}
// Validate credentials
userID, err := h.storage.CheckUsernamePasswordSimple(username, password)
if err != nil {
log.Warnf("device login failed for user %s: %v", username, err)
data["Error"] = "Invalid username or password"
h.tmpl.ExecuteTemplate(w, "device.html", data)
return
}
// Get device authorization info
authState, err := h.storage.GetDeviceAuthorizationByUserCode(r.Context(), userCode)
if err != nil {
data["Error"] = "Invalid or expired user code"
data["Step"] = "code"
data["UserCode"] = ""
h.tmpl.ExecuteTemplate(w, "device.html", data)
return
}
// Set secure cookie with user info for confirmation step
cookieValue := map[string]string{
"user_code": userCode,
"user_id": userID,
}
encoded, err := h.secureCookie.Encode("device_auth", cookieValue)
if err != nil {
log.Errorf("failed to encode cookie: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
http.SetCookie(w, &http.Cookie{
Name: "device_auth",
Value: encoded,
Path: "/device",
HttpOnly: true,
Secure: r.TLS != nil,
SameSite: http.SameSiteStrictMode,
})
// Show confirmation page
data["Step"] = "confirm"
data["ClientID"] = authState.ClientID
data["Scopes"] = authState.Scopes
data["UserID"] = userID
h.tmpl.ExecuteTemplate(w, "device.html", data)
}
// handleConfirm processes the authorization decision
func (h *DeviceHandler) handleConfirm(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "invalid form", http.StatusBadRequest)
return
}
// Get values from cookie
cookie, err := r.Cookie("device_auth")
if err != nil {
http.Redirect(w, r, "/device", http.StatusFound)
return
}
var cookieValue map[string]string
if err := h.secureCookie.Decode("device_auth", cookie.Value, &cookieValue); err != nil {
http.Redirect(w, r, "/device", http.StatusFound)
return
}
userCode := cookieValue["user_code"]
userID := cookieValue["user_id"]
action := r.FormValue("action")
data := map[string]interface{}{
"Step": "result",
}
// Clear the cookie
http.SetCookie(w, &http.Cookie{
Name: "device_auth",
Value: "",
Path: "/device",
MaxAge: -1,
HttpOnly: true,
})
if action == "allow" {
if err := h.storage.CompleteDeviceAuthorization(r.Context(), userCode, userID); err != nil {
log.Errorf("failed to complete device authorization: %v", err)
data["Error"] = "Failed to authorize device"
} else {
data["Success"] = true
data["Message"] = "Device authorized successfully! You can now close this window."
}
} else {
if err := h.storage.DenyDeviceAuthorization(r.Context(), userCode); err != nil {
log.Errorf("failed to deny device authorization: %v", err)
}
data["Success"] = false
data["Message"] = "Authorization denied. You can close this window."
}
h.tmpl.ExecuteTemplate(w, "device.html", data)
}
// GenerateUserCode generates a user-friendly code for device flow
func GenerateUserCode() string {
// Generate a base20 code (BCDFGHJKLMNPQRSTVWXZ - no vowels to avoid words)
chars := "BCDFGHJKLMNPQRSTVWXZ"
b := securecookie.GenerateRandomKey(8)
result := make([]byte, 8)
for i := range result {
result[i] = chars[int(b[i])%len(chars)]
}
// Format as XXXX-XXXX
return string(result[:4]) + "-" + string(result[4:])
}
// GenerateDeviceCode generates a secure device code
func GenerateDeviceCode() string {
b := securecookie.GenerateRandomKey(32)
return base64.RawURLEncoding.EncodeToString(b)
}

105
idp/oidcprovider/login.go Normal file
View File

@@ -0,0 +1,105 @@
package oidcprovider
import (
"embed"
"html/template"
"net/http"
"github.com/go-chi/chi/v5"
log "github.com/sirupsen/logrus"
)
//go:embed templates/*.html
var templateFS embed.FS
// LoginHandler handles the login flow
type LoginHandler struct {
storage *OIDCStorage
callback func(string) string
tmpl *template.Template
}
// NewLoginHandler creates a new login handler
func NewLoginHandler(storage *OIDCStorage, callback func(string) string) (*LoginHandler, error) {
tmpl, err := template.ParseFS(templateFS, "templates/*.html")
if err != nil {
return nil, err
}
return &LoginHandler{
storage: storage,
callback: callback,
tmpl: tmpl,
}, nil
}
// Router returns the login router
func (h *LoginHandler) Router() chi.Router {
r := chi.NewRouter()
r.Get("/", h.loginPage)
r.Post("/", h.handleLogin)
return r
}
// loginPage displays the login form
func (h *LoginHandler) loginPage(w http.ResponseWriter, r *http.Request) {
authRequestID := r.URL.Query().Get("authRequestID")
if authRequestID == "" {
http.Error(w, "missing auth request ID", http.StatusBadRequest)
return
}
data := map[string]interface{}{
"AuthRequestID": authRequestID,
"Error": "",
}
if err := h.tmpl.ExecuteTemplate(w, "login.html", data); err != nil {
log.Errorf("failed to render login template: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
}
}
// handleLogin processes the login form submission
func (h *LoginHandler) handleLogin(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "invalid form", http.StatusBadRequest)
return
}
authRequestID := r.FormValue("authRequestID")
username := r.FormValue("username")
password := r.FormValue("password")
if authRequestID == "" || username == "" || password == "" {
data := map[string]interface{}{
"AuthRequestID": authRequestID,
"Error": "Please fill in all fields",
}
h.tmpl.ExecuteTemplate(w, "login.html", data)
return
}
// Validate credentials and get user ID
userID, err := h.storage.CheckUsernamePasswordSimple(username, password)
if err != nil {
log.Warnf("login failed for user %s: %v", username, err)
data := map[string]interface{}{
"AuthRequestID": authRequestID,
"Error": "Invalid username or password",
}
h.tmpl.ExecuteTemplate(w, "login.html", data)
return
}
// Complete the auth request
if err := h.storage.CompleteAuthRequest(r.Context(), authRequestID, userID); err != nil {
log.Errorf("failed to complete auth request: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
// Redirect to callback
callbackURL := h.callback(authRequestID)
http.Redirect(w, r, callbackURL, http.StatusFound)
}

136
idp/oidcprovider/models.go Normal file
View File

@@ -0,0 +1,136 @@
package oidcprovider
import (
"time"
"golang.org/x/text/language"
)
// User represents an OIDC user stored in the database
type User struct {
ID string `gorm:"primaryKey"`
Username string `gorm:"uniqueIndex;not null"`
Password string `gorm:"not null"` // bcrypt hashed
Email string
EmailVerified bool
FirstName string
LastName string
Phone string
PhoneVerified bool
PreferredLanguage string // language tag string
IsAdmin bool
CreatedAt time.Time
UpdatedAt time.Time
}
// GetPreferredLanguage returns the user's preferred language as a language.Tag
func (u *User) GetPreferredLanguage() language.Tag {
if u.PreferredLanguage == "" {
return language.English
}
tag, err := language.Parse(u.PreferredLanguage)
if err != nil {
return language.English
}
return tag
}
// Client represents an OIDC client (application) stored in the database
type Client struct {
ID string `gorm:"primaryKey"`
Secret string // bcrypt hashed, empty for public clients
Name string
RedirectURIs string // JSON array of redirect URIs
PostLogoutURIs string // JSON array of post-logout redirect URIs
ApplicationType string // native, web, user_agent
AuthMethod string // none, client_secret_basic, client_secret_post, private_key_jwt
ResponseTypes string // JSON array: code, id_token, token
GrantTypes string // JSON array: authorization_code, refresh_token, client_credentials, urn:ietf:params:oauth:grant-type:device_code
AccessTokenType string // bearer or jwt
DevMode bool // allows non-HTTPS redirect URIs
IDTokenLifetime int64 // in seconds, default 3600 (1 hour)
ClockSkew int64 // in seconds, allowed clock skew
CreatedAt time.Time
UpdatedAt time.Time
}
// AuthRequest represents an ongoing authorization request
type AuthRequest struct {
ID string `gorm:"primaryKey"`
ClientID string `gorm:"index"`
Scopes string // JSON array of scopes
RedirectURI string
State string
Nonce string
ResponseType string
ResponseMode string
CodeChallenge string
CodeMethod string // S256 or plain
UserID string // set after user authentication
Done bool // true when user has authenticated
AuthTime time.Time
CreatedAt time.Time
MaxAge int64 // max authentication age in seconds
Prompt string // none, login, consent, select_account
UILocales string // space-separated list of locales
LoginHint string
ACRValues string // space-separated list of ACR values
}
// AuthCode represents an authorization code
type AuthCode struct {
Code string `gorm:"primaryKey"`
AuthRequestID string `gorm:"index"`
CreatedAt time.Time
ExpiresAt time.Time
}
// AccessToken represents an access token
type AccessToken struct {
ID string `gorm:"primaryKey"`
ApplicationID string `gorm:"index"`
Subject string `gorm:"index"`
Audience string // JSON array
Scopes string // JSON array
Expiration time.Time
CreatedAt time.Time
}
// RefreshToken represents a refresh token
type RefreshToken struct {
ID string `gorm:"primaryKey"`
Token string `gorm:"uniqueIndex"`
AuthRequestID string
ApplicationID string `gorm:"index"`
Subject string `gorm:"index"`
Audience string // JSON array
Scopes string // JSON array
AMR string // JSON array of authentication methods
AuthTime time.Time
Expiration time.Time
CreatedAt time.Time
}
// DeviceAuth represents a device authorization request
type DeviceAuth struct {
DeviceCode string `gorm:"primaryKey"`
UserCode string `gorm:"uniqueIndex"`
ClientID string `gorm:"index"`
Scopes string // JSON array
Subject string // set after user authentication
Audience string // JSON array
Done bool // true when user has authorized
Denied bool // true when user has denied
Expiration time.Time
CreatedAt time.Time
}
// SigningKey represents a signing key for JWTs
type SigningKey struct {
ID string `gorm:"primaryKey"`
Algorithm string // RS256
PrivateKey []byte // PEM encoded
PublicKey []byte // PEM encoded
CreatedAt time.Time
Active bool
}

View File

@@ -0,0 +1,662 @@
package oidcprovider
import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"strings"
"time"
jose "github.com/go-jose/go-jose/v4"
"github.com/google/uuid"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"gorm.io/gorm"
)
// ErrInvalidRefreshToken is returned when a token is not a valid refresh token
var ErrInvalidRefreshToken = errors.New("invalid refresh token")
// OIDCStorage implements op.Storage interface for the OIDC provider
type OIDCStorage struct {
store *Store
issuer string
loginURL func(string) string
}
// NewOIDCStorage creates a new OIDCStorage
func NewOIDCStorage(store *Store, issuer string) *OIDCStorage {
return &OIDCStorage{
store: store,
issuer: issuer,
}
}
// SetLoginURL sets the login URL generator function
func (s *OIDCStorage) SetLoginURL(fn func(string) string) {
s.loginURL = fn
}
// Health checks if the storage is healthy
func (s *OIDCStorage) Health(ctx context.Context) error {
sqlDB, err := s.store.db.DB()
if err != nil {
return err
}
return sqlDB.PingContext(ctx)
}
// CreateAuthRequest creates and stores a new authorization request
func (s *OIDCStorage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, userID string) (op.AuthRequest, error) {
req := &AuthRequest{
ID: uuid.New().String(),
ClientID: authReq.ClientID,
Scopes: ToJSONArray(authReq.Scopes),
RedirectURI: authReq.RedirectURI,
State: authReq.State,
Nonce: authReq.Nonce,
ResponseType: string(authReq.ResponseType),
ResponseMode: string(authReq.ResponseMode),
CodeChallenge: authReq.CodeChallenge,
CodeMethod: string(authReq.CodeChallengeMethod),
UserID: userID,
Done: userID != "",
CreatedAt: time.Now(),
Prompt: spaceSeparated(authReq.Prompt),
UILocales: authReq.UILocales.String(),
LoginHint: authReq.LoginHint,
ACRValues: spaceSeparated(authReq.ACRValues),
}
if authReq.MaxAge != nil {
req.MaxAge = int64(*authReq.MaxAge)
}
if userID != "" {
req.AuthTime = time.Now()
}
if err := s.store.SaveAuthRequest(ctx, req); err != nil {
return nil, err
}
return &OIDCAuthRequest{req: req, storage: s}, nil
}
// AuthRequestByID retrieves an authorization request by ID
func (s *OIDCStorage) AuthRequestByID(ctx context.Context, id string) (op.AuthRequest, error) {
req, err := s.store.GetAuthRequestByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("auth request not found: %s", id)
}
return nil, err
}
return &OIDCAuthRequest{req: req, storage: s}, nil
}
// AuthRequestByCode retrieves an authorization request by code
func (s *OIDCStorage) AuthRequestByCode(ctx context.Context, code string) (op.AuthRequest, error) {
authCode, err := s.store.GetAuthCodeByCode(ctx, code)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("auth code not found: %s", code)
}
return nil, err
}
if time.Now().After(authCode.ExpiresAt) {
_ = s.store.DeleteAuthCode(ctx, code)
return nil, errors.New("auth code expired")
}
req, err := s.store.GetAuthRequestByID(ctx, authCode.AuthRequestID)
if err != nil {
return nil, err
}
return &OIDCAuthRequest{req: req, storage: s}, nil
}
// SaveAuthCode saves an authorization code linked to an auth request
func (s *OIDCStorage) SaveAuthCode(ctx context.Context, id, code string) error {
authCode := &AuthCode{
Code: code,
AuthRequestID: id,
ExpiresAt: time.Now().Add(10 * time.Minute),
}
return s.store.SaveAuthCode(ctx, authCode)
}
// DeleteAuthRequest deletes an authorization request
func (s *OIDCStorage) DeleteAuthRequest(ctx context.Context, id string) error {
return s.store.DeleteAuthRequest(ctx, id)
}
// CreateAccessToken creates and stores an access token
func (s *OIDCStorage) CreateAccessToken(ctx context.Context, request op.TokenRequest) (string, time.Time, error) {
tokenID := uuid.New().String()
expiration := time.Now().Add(5 * time.Minute)
// Get client ID from the request if possible
var clientID string
if authReq, ok := request.(op.AuthRequest); ok {
clientID = authReq.GetClientID()
} else if refreshReq, ok := request.(op.RefreshTokenRequest); ok {
clientID = refreshReq.GetClientID()
}
token := &AccessToken{
ID: tokenID,
ApplicationID: clientID,
Subject: request.GetSubject(),
Audience: ToJSONArray(request.GetAudience()),
Scopes: ToJSONArray(request.GetScopes()),
Expiration: expiration,
}
if err := s.store.SaveAccessToken(ctx, token); err != nil {
return "", time.Time{}, err
}
return tokenID, expiration, nil
}
// CreateAccessAndRefreshTokens creates both access and refresh tokens
func (s *OIDCStorage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) {
// Delete old refresh token if provided
if currentRefreshToken != "" {
_ = s.store.DeleteRefreshTokenByToken(ctx, currentRefreshToken)
}
// Create access token
accessTokenID, expiration, err = s.CreateAccessToken(ctx, request)
if err != nil {
return "", "", time.Time{}, err
}
// Get additional info from the request if possible
var clientID string
var authTime time.Time
var amr []string
if authReq, ok := request.(op.AuthRequest); ok {
clientID = authReq.GetClientID()
authTime = authReq.GetAuthTime()
amr = authReq.GetAMR()
} else if refreshReq, ok := request.(op.RefreshTokenRequest); ok {
clientID = refreshReq.GetClientID()
authTime = refreshReq.GetAuthTime()
amr = refreshReq.GetAMR()
}
// Create refresh token
refreshToken := &RefreshToken{
ID: uuid.New().String(),
Token: uuid.New().String(),
ApplicationID: clientID,
Subject: request.GetSubject(),
Audience: ToJSONArray(request.GetAudience()),
Scopes: ToJSONArray(request.GetScopes()),
AuthTime: authTime,
AMR: ToJSONArray(amr),
Expiration: time.Now().Add(5 * time.Hour), // 5 hour refresh token lifetime
}
if authReq, ok := request.(op.AuthRequest); ok {
refreshToken.AuthRequestID = authReq.GetID()
}
if err := s.store.SaveRefreshToken(ctx, refreshToken); err != nil {
return "", "", time.Time{}, err
}
return accessTokenID, refreshToken.Token, expiration, nil
}
// TokenRequestByRefreshToken retrieves token request info from refresh token
func (s *OIDCStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) {
token, err := s.store.GetRefreshToken(ctx, refreshToken)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("refresh token not found")
}
return nil, err
}
if time.Now().After(token.Expiration) {
_ = s.store.DeleteRefreshTokenByToken(ctx, refreshToken)
return nil, errors.New("refresh token expired")
}
return &OIDCRefreshToken{token: token}, nil
}
// TerminateSession terminates a user session
func (s *OIDCStorage) TerminateSession(ctx context.Context, userID, clientID string) error {
// For now, we don't track sessions separately
return nil
}
// RevokeToken revokes a token
func (s *OIDCStorage) RevokeToken(ctx context.Context, tokenOrID string, userID string, clientID string) *oidc.Error {
// Try to delete as refresh token
if err := s.store.DeleteRefreshTokenByToken(ctx, tokenOrID); err == nil {
return nil
}
// Try to delete as access token
if err := s.store.DeleteAccessToken(ctx, tokenOrID); err == nil {
return nil
}
return nil // Silently succeed even if token not found (per spec)
}
// GetRefreshTokenInfo returns info about a refresh token
func (s *OIDCStorage) GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error) {
refreshToken, err := s.store.GetRefreshToken(ctx, token)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", "", ErrInvalidRefreshToken
}
return "", "", err
}
if refreshToken.ApplicationID != clientID {
return "", "", ErrInvalidRefreshToken
}
return refreshToken.Subject, refreshToken.ID, nil
}
// GetClientByClientID retrieves a client by ID
func (s *OIDCStorage) GetClientByClientID(ctx context.Context, clientID string) (op.Client, error) {
client, err := s.store.GetClientByID(ctx, clientID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("client not found: %s", clientID)
}
return nil, err
}
return NewOIDCClient(client, s.loginURL), nil
}
// AuthorizeClientIDSecret validates client credentials
func (s *OIDCStorage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error {
_, err := s.store.ValidateClientSecret(ctx, clientID, clientSecret)
return err
}
// SetUserinfoFromScopes sets userinfo claims based on scopes
func (s *OIDCStorage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error {
return s.setUserinfo(ctx, userinfo, userID, scopes)
}
// SetUserinfoFromToken sets userinfo claims from an access token
func (s *OIDCStorage) SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error {
token, err := s.store.GetAccessTokenByID(ctx, tokenID)
if err != nil {
return err
}
return s.setUserinfo(ctx, userinfo, token.Subject, ParseJSONArray(token.Scopes))
}
// setUserinfo populates userinfo based on user data and scopes
func (s *OIDCStorage) setUserinfo(ctx context.Context, userinfo *oidc.UserInfo, userID string, scopes []string) error {
user, err := s.store.GetUserByID(ctx, userID)
if err != nil {
return err
}
for _, scope := range scopes {
switch scope {
case oidc.ScopeOpenID:
userinfo.Subject = user.ID
case oidc.ScopeProfile:
userinfo.Name = fmt.Sprintf("%s %s", user.FirstName, user.LastName)
userinfo.GivenName = user.FirstName
userinfo.FamilyName = user.LastName
userinfo.PreferredUsername = user.Username
userinfo.Locale = oidc.NewLocale(user.GetPreferredLanguage())
case oidc.ScopeEmail:
userinfo.Email = user.Email
userinfo.EmailVerified = oidc.Bool(user.EmailVerified)
case oidc.ScopePhone:
userinfo.PhoneNumber = user.Phone
userinfo.PhoneNumberVerified = user.PhoneVerified
}
}
return nil
}
// SetIntrospectionFromToken sets introspection response from token
func (s *OIDCStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) error {
token, err := s.store.GetAccessTokenByID(ctx, tokenID)
if err != nil {
return err
}
introspection.Active = true
introspection.Subject = token.Subject
introspection.ClientID = token.ApplicationID
introspection.Scope = ParseJSONArray(token.Scopes)
introspection.Expiration = oidc.FromTime(token.Expiration)
introspection.IssuedAt = oidc.FromTime(token.CreatedAt)
introspection.Audience = ParseJSONArray(token.Audience)
introspection.Issuer = s.issuer
return nil
}
// GetPrivateClaimsFromScopes returns additional claims based on scopes
func (s *OIDCStorage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]any, error) {
return nil, nil
}
// GetKeyByIDAndClientID retrieves a key by ID for a client
func (s *OIDCStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error) {
return nil, errors.New("not implemented")
}
// ValidateJWTProfileScopes validates scopes for JWT profile grant
func (s *OIDCStorage) ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error) {
return scopes, nil
}
// SigningKey returns the active signing key for token signing
func (s *OIDCStorage) SigningKey(ctx context.Context) (op.SigningKey, error) {
key, err := s.store.GetSigningKey(ctx)
if err != nil {
return nil, err
}
block, _ := pem.Decode(key.PrivateKey)
if block == nil {
return nil, errors.New("failed to decode private key PEM")
}
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
return &signingKey{
id: key.ID,
algorithm: jose.RS256,
privateKey: privateKey,
}, nil
}
// SignatureAlgorithms returns supported signature algorithms
func (s *OIDCStorage) SignatureAlgorithms(ctx context.Context) ([]jose.SignatureAlgorithm, error) {
return []jose.SignatureAlgorithm{jose.RS256}, nil
}
// KeySet returns the public key set for token verification
func (s *OIDCStorage) KeySet(ctx context.Context) ([]op.Key, error) {
key, err := s.store.GetSigningKey(ctx)
if err != nil {
return nil, err
}
block, _ := pem.Decode(key.PublicKey)
if block == nil {
return nil, errors.New("failed to decode public key PEM")
}
publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse public key: %w", err)
}
rsaKey, ok := publicKey.(*rsa.PublicKey)
if !ok {
return nil, errors.New("public key is not RSA")
}
return []op.Key{
&publicKeyInfo{
id: key.ID,
algorithm: jose.RS256,
publicKey: rsaKey,
},
}, nil
}
// Device Authorization Flow methods
// StoreDeviceAuthorization stores a device authorization request
func (s *OIDCStorage) StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) error {
auth := &DeviceAuth{
DeviceCode: deviceCode,
UserCode: userCode,
ClientID: clientID,
Scopes: ToJSONArray(scopes),
Expiration: expires,
}
return s.store.SaveDeviceAuth(ctx, auth)
}
// GetDeviceAuthorizationState retrieves the state of a device authorization
func (s *OIDCStorage) GetDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string) (*op.DeviceAuthorizationState, error) {
auth, err := s.store.GetDeviceAuthByDeviceCode(ctx, deviceCode)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("device authorization not found")
}
return nil, err
}
if auth.ClientID != clientID {
return nil, errors.New("client ID mismatch")
}
if time.Now().After(auth.Expiration) {
_ = s.store.DeleteDeviceAuth(ctx, deviceCode)
return &op.DeviceAuthorizationState{Expires: auth.Expiration}, nil
}
state := &op.DeviceAuthorizationState{
ClientID: auth.ClientID,
Scopes: ParseJSONArray(auth.Scopes),
Expires: auth.Expiration,
}
if auth.Denied {
state.Denied = true
} else if auth.Done {
state.Done = true
state.Subject = auth.Subject
}
return state, nil
}
// GetDeviceAuthorizationByUserCode retrieves device auth by user code
func (s *OIDCStorage) GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error) {
auth, err := s.store.GetDeviceAuthByUserCode(ctx, userCode)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("device authorization not found")
}
return nil, err
}
if time.Now().After(auth.Expiration) {
return nil, errors.New("device authorization expired")
}
return &op.DeviceAuthorizationState{
ClientID: auth.ClientID,
Scopes: ParseJSONArray(auth.Scopes),
Expires: auth.Expiration,
Done: auth.Done,
Denied: auth.Denied,
Subject: auth.Subject,
}, nil
}
// CompleteDeviceAuthorization marks a device authorization as complete
func (s *OIDCStorage) CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error {
auth, err := s.store.GetDeviceAuthByUserCode(ctx, userCode)
if err != nil {
return err
}
auth.Done = true
auth.Subject = subject
return s.store.UpdateDeviceAuth(ctx, auth)
}
// DenyDeviceAuthorization marks a device authorization as denied
func (s *OIDCStorage) DenyDeviceAuthorization(ctx context.Context, userCode string) error {
auth, err := s.store.GetDeviceAuthByUserCode(ctx, userCode)
if err != nil {
return err
}
auth.Denied = true
return s.store.UpdateDeviceAuth(ctx, auth)
}
// User authentication methods
// CheckUsernamePassword validates user credentials
func (s *OIDCStorage) CheckUsernamePassword(username, password, authRequestID string) error {
ctx := context.Background()
_, err := s.store.ValidateUserPassword(ctx, username, password)
if err != nil {
return err
}
return nil
}
// CheckUsernamePasswordSimple validates user credentials and returns the user ID
func (s *OIDCStorage) CheckUsernamePasswordSimple(username, password string) (string, error) {
ctx := context.Background()
user, err := s.store.ValidateUserPassword(ctx, username, password)
if err != nil {
return "", err
}
return user.ID, nil
}
// CompleteAuthRequest completes an auth request after user authentication
func (s *OIDCStorage) CompleteAuthRequest(ctx context.Context, authRequestID, userID string) error {
req, err := s.store.GetAuthRequestByID(ctx, authRequestID)
if err != nil {
return err
}
req.UserID = userID
req.Done = true
req.AuthTime = time.Now()
return s.store.UpdateAuthRequest(ctx, req)
}
// Helper types
// signingKey implements op.SigningKey
type signingKey struct {
id string
algorithm jose.SignatureAlgorithm
privateKey *rsa.PrivateKey
}
func (k *signingKey) SignatureAlgorithm() jose.SignatureAlgorithm {
return k.algorithm
}
func (k *signingKey) Key() interface{} {
return k.privateKey
}
func (k *signingKey) ID() string {
return k.id
}
// publicKeyInfo implements op.Key
type publicKeyInfo struct {
id string
algorithm jose.SignatureAlgorithm
publicKey *rsa.PublicKey
}
func (k *publicKeyInfo) ID() string {
return k.id
}
func (k *publicKeyInfo) Algorithm() jose.SignatureAlgorithm {
return k.algorithm
}
func (k *publicKeyInfo) Use() string {
return "sig"
}
func (k *publicKeyInfo) Key() interface{} {
return k.publicKey
}
// OIDCAuthRequest wraps AuthRequest for the op.AuthRequest interface
type OIDCAuthRequest struct {
req *AuthRequest
storage *OIDCStorage
}
func (r *OIDCAuthRequest) GetID() string { return r.req.ID }
func (r *OIDCAuthRequest) GetACR() string { return "" }
func (r *OIDCAuthRequest) GetAMR() []string { return []string{"pwd"} }
func (r *OIDCAuthRequest) GetAudience() []string { return []string{r.req.ClientID} }
func (r *OIDCAuthRequest) GetAuthTime() time.Time { return r.req.AuthTime }
func (r *OIDCAuthRequest) GetClientID() string { return r.req.ClientID }
func (r *OIDCAuthRequest) GetCodeChallenge() *oidc.CodeChallenge {
if r.req.CodeChallenge == "" {
return nil
}
return &oidc.CodeChallenge{
Challenge: r.req.CodeChallenge,
Method: oidc.CodeChallengeMethod(r.req.CodeMethod),
}
}
func (r *OIDCAuthRequest) GetNonce() string { return r.req.Nonce }
func (r *OIDCAuthRequest) GetRedirectURI() string { return r.req.RedirectURI }
func (r *OIDCAuthRequest) GetResponseType() oidc.ResponseType {
return oidc.ResponseType(r.req.ResponseType)
}
func (r *OIDCAuthRequest) GetResponseMode() oidc.ResponseMode {
return oidc.ResponseMode(r.req.ResponseMode)
}
func (r *OIDCAuthRequest) GetScopes() []string { return ParseJSONArray(r.req.Scopes) }
func (r *OIDCAuthRequest) GetState() string { return r.req.State }
func (r *OIDCAuthRequest) GetSubject() string { return r.req.UserID }
func (r *OIDCAuthRequest) Done() bool { return r.req.Done }
// OIDCRefreshToken wraps RefreshToken for the op.RefreshTokenRequest interface
type OIDCRefreshToken struct {
token *RefreshToken
}
func (r *OIDCRefreshToken) GetAMR() []string { return ParseJSONArray(r.token.AMR) }
func (r *OIDCRefreshToken) GetAudience() []string { return ParseJSONArray(r.token.Audience) }
func (r *OIDCRefreshToken) GetAuthTime() time.Time { return r.token.AuthTime }
func (r *OIDCRefreshToken) GetClientID() string { return r.token.ApplicationID }
func (r *OIDCRefreshToken) GetScopes() []string { return ParseJSONArray(r.token.Scopes) }
func (r *OIDCRefreshToken) GetSubject() string { return r.token.Subject }
func (r *OIDCRefreshToken) SetCurrentScopes(scopes []string) {}
// Helper functions
func spaceSeparated(items []string) string {
return strings.Join(items, " ")
}

View File

@@ -0,0 +1,265 @@
package oidcprovider
import (
"context"
"crypto/sha256"
"fmt"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
log "github.com/sirupsen/logrus"
"github.com/zitadel/oidc/v3/pkg/op"
)
// Config holds the configuration for the OIDC provider
type Config struct {
// Issuer is the OIDC issuer URL (e.g., "https://idp.example.com")
Issuer string
// Port is the port to listen on
Port int
// DataDir is the directory to store OIDC data (SQLite database)
DataDir string
// DevMode enables development mode (allows HTTP, localhost)
DevMode bool
}
// Provider represents the embedded OIDC provider
type Provider struct {
config *Config
store *Store
storage *OIDCStorage
provider op.OpenIDProvider
router chi.Router
httpServer *http.Server
}
// NewProvider creates a new OIDC provider
func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
// Create the SQLite store
store, err := NewStore(ctx, config.DataDir)
if err != nil {
return nil, fmt.Errorf("failed to create OIDC store: %w", err)
}
// Create the OIDC storage adapter
storage := NewOIDCStorage(store, config.Issuer)
p := &Provider{
config: config,
store: store,
storage: storage,
}
return p, nil
}
// Start starts the OIDC provider server
func (p *Provider) Start(ctx context.Context) error {
// Create the router
router := chi.NewRouter()
router.Use(middleware.Logger)
router.Use(middleware.Recoverer)
router.Use(middleware.RequestID)
// Create the OIDC provider
key := sha256.Sum256([]byte(p.config.Issuer + "encryption-key"))
opConfig := &op.Config{
CryptoKey: key,
DefaultLogoutRedirectURI: "/logged-out",
CodeMethodS256: true,
AuthMethodPost: true,
AuthMethodPrivateKeyJWT: true,
GrantTypeRefreshToken: true,
RequestObjectSupported: true,
DeviceAuthorization: op.DeviceAuthorizationConfig{
Lifetime: 5 * time.Minute,
PollInterval: 5 * time.Second,
UserFormPath: "/device",
UserCode: op.UserCodeBase20,
},
}
// Set the login URL generator
p.storage.SetLoginURL(func(authRequestID string) string {
return fmt.Sprintf("/login?authRequestID=%s", authRequestID)
})
// Create the provider with options
var opts []op.Option
if p.config.DevMode {
opts = append(opts, op.WithAllowInsecure())
}
provider, err := op.NewProvider(opConfig, p.storage, op.StaticIssuer(p.config.Issuer), opts...)
if err != nil {
return fmt.Errorf("failed to create OIDC provider: %w", err)
}
p.provider = provider
// Set up login handler
loginHandler, err := NewLoginHandler(p.storage, func(authRequestID string) string {
return provider.AuthorizationEndpoint().Absolute("/authorize/callback") + "?id=" + authRequestID
})
if err != nil {
return fmt.Errorf("failed to create login handler: %w", err)
}
// Set up device handler
deviceHandler, err := NewDeviceHandler(p.storage)
if err != nil {
return fmt.Errorf("failed to create device handler: %w", err)
}
// Mount routes
router.Mount("/login", loginHandler.Router())
router.Mount("/device", deviceHandler.Router())
router.Get("/logged-out", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Write([]byte(`<!DOCTYPE html><html><head><title>Logged Out</title></head><body><h1>You have been logged out</h1><p>You can close this window.</p></body></html>`))
})
// Mount the OIDC provider at root
router.Mount("/", provider)
p.router = router
// Create HTTP server
addr := fmt.Sprintf(":%d", p.config.Port)
p.httpServer = &http.Server{
Addr: addr,
Handler: router,
}
// Start server in goroutine
go func() {
log.Infof("Starting OIDC provider on %s (issuer: %s)", addr, p.config.Issuer)
if err := p.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Errorf("OIDC provider server error: %v", err)
}
}()
// Start cleanup goroutine
go p.cleanupLoop(ctx)
return nil
}
// Stop stops the OIDC provider server
func (p *Provider) Stop(ctx context.Context) error {
if p.httpServer != nil {
if err := p.httpServer.Shutdown(ctx); err != nil {
return fmt.Errorf("failed to shutdown OIDC server: %w", err)
}
}
if p.store != nil {
if err := p.store.Close(); err != nil {
return fmt.Errorf("failed to close OIDC store: %w", err)
}
}
return nil
}
// cleanupLoop periodically cleans up expired tokens
func (p *Provider) cleanupLoop(ctx context.Context) {
ticker := time.NewTicker(15 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := p.store.CleanupExpired(ctx); err != nil {
log.Warnf("OIDC cleanup error: %v", err)
}
}
}
}
// Store returns the underlying store for user/client management
func (p *Provider) Store() *Store {
return p.store
}
// GetIssuer returns the issuer URL
func (p *Provider) GetIssuer() string {
return p.config.Issuer
}
// GetDiscoveryEndpoint returns the OpenID Connect discovery endpoint
func (p *Provider) GetDiscoveryEndpoint() string {
return p.config.Issuer + "/.well-known/openid-configuration"
}
// GetTokenEndpoint returns the token endpoint
func (p *Provider) GetTokenEndpoint() string {
return p.config.Issuer + "/oauth/token"
}
// GetAuthorizationEndpoint returns the authorization endpoint
func (p *Provider) GetAuthorizationEndpoint() string {
return p.config.Issuer + "/authorize"
}
// GetDeviceAuthorizationEndpoint returns the device authorization endpoint
func (p *Provider) GetDeviceAuthorizationEndpoint() string {
return p.config.Issuer + "/device_authorization"
}
// GetJWKSEndpoint returns the JWKS endpoint
func (p *Provider) GetJWKSEndpoint() string {
return p.config.Issuer + "/keys"
}
// GetUserInfoEndpoint returns the userinfo endpoint
func (p *Provider) GetUserInfoEndpoint() string {
return p.config.Issuer + "/userinfo"
}
// EnsureDefaultClients ensures the default NetBird clients exist
func (p *Provider) EnsureDefaultClients(ctx context.Context, dashboardRedirectURIs, cliRedirectURIs []string) error {
// Check if CLI client exists
_, err := p.store.GetClientByID(ctx, "netbird-client")
if err != nil {
// Create CLI client (native, public, supports PKCE and device flow)
cliClient := CreateNativeClient("netbird-client", "NetBird CLI", cliRedirectURIs)
if err := p.store.CreateClient(ctx, cliClient); err != nil {
return fmt.Errorf("failed to create CLI client: %w", err)
}
log.Info("Created default NetBird CLI client")
}
// Check if dashboard client exists
_, err = p.store.GetClientByID(ctx, "netbird-dashboard")
if err != nil {
// Create dashboard client (SPA, public, supports PKCE)
dashboardClient := CreateSPAClient("netbird-dashboard", "NetBird Dashboard", dashboardRedirectURIs)
if err := p.store.CreateClient(ctx, dashboardClient); err != nil {
return fmt.Errorf("failed to create dashboard client: %w", err)
}
log.Info("Created default NetBird Dashboard client")
}
return nil
}
// CreateUser creates a new user (convenience method)
func (p *Provider) CreateUser(ctx context.Context, username, password, email, firstName, lastName string) (*User, error) {
user := &User{
Username: username,
Password: password, // Will be hashed by store
Email: email,
EmailVerified: false,
FirstName: firstName,
LastName: lastName,
}
if err := p.store.CreateUser(ctx, user); err != nil {
return nil, err
}
return user, nil
}

493
idp/oidcprovider/store.go Normal file
View File

@@ -0,0 +1,493 @@
package oidcprovider
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"sync"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// Store handles persistence for OIDC provider data
type Store struct {
db *gorm.DB
mu sync.RWMutex
}
// NewStore creates a new Store with SQLite backend
func NewStore(ctx context.Context, dataDir string) (*Store, error) {
dbPath := fmt.Sprintf("%s/oidc.db", dataDir)
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
return nil, fmt.Errorf("failed to open OIDC database: %w", err)
}
// Enable WAL mode for better concurrency
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
log.WithContext(ctx).Warnf("failed to enable WAL mode: %v", err)
}
// Auto-migrate tables
if err := db.AutoMigrate(
&User{},
&Client{},
&AuthRequest{},
&AuthCode{},
&AccessToken{},
&RefreshToken{},
&DeviceAuth{},
&SigningKey{},
); err != nil {
return nil, fmt.Errorf("failed to migrate OIDC database: %w", err)
}
store := &Store{db: db}
// Ensure we have a signing key
if err := store.ensureSigningKey(ctx); err != nil {
return nil, fmt.Errorf("failed to ensure signing key: %w", err)
}
return store, nil
}
// Close closes the database connection
func (s *Store) Close() error {
sqlDB, err := s.db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
// ensureSigningKey creates a signing key if one doesn't exist
func (s *Store) ensureSigningKey(ctx context.Context) error {
var key SigningKey
err := s.db.WithContext(ctx).Where("active = ?", true).First(&key).Error
if err == nil {
return nil // Key exists
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
// Generate new RSA key pair
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return fmt.Errorf("failed to generate RSA key: %w", err)
}
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
if err != nil {
return fmt.Errorf("failed to marshal public key: %w", err)
}
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: publicKeyBytes,
})
newKey := &SigningKey{
ID: uuid.New().String(),
Algorithm: "RS256",
PrivateKey: privateKeyPEM,
PublicKey: publicKeyPEM,
CreatedAt: time.Now(),
Active: true,
}
return s.db.WithContext(ctx).Create(newKey).Error
}
// GetSigningKey returns the active signing key
func (s *Store) GetSigningKey(ctx context.Context) (*SigningKey, error) {
var key SigningKey
err := s.db.WithContext(ctx).Where("active = ?", true).First(&key).Error
if err != nil {
return nil, err
}
return &key, nil
}
// User operations
// CreateUser creates a new user with bcrypt hashed password
func (s *Store) CreateUser(ctx context.Context, user *User) error {
if user.ID == "" {
user.ID = uuid.New().String()
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
user.Password = string(hashedPassword)
user.CreatedAt = time.Now()
user.UpdatedAt = time.Now()
return s.db.WithContext(ctx).Create(user).Error
}
// GetUserByID retrieves a user by ID
func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) {
var user User
err := s.db.WithContext(ctx).Where("id = ?", id).First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
// GetUserByUsername retrieves a user by username
func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) {
var user User
err := s.db.WithContext(ctx).Where("username = ?", username).First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
// ValidateUserPassword validates a user's password
func (s *Store) ValidateUserPassword(ctx context.Context, username, password string) (*User, error) {
user, err := s.GetUserByUsername(ctx, username)
if err != nil {
return nil, err
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
return nil, errors.New("invalid password")
}
return user, nil
}
// ListUsers returns all users
func (s *Store) ListUsers(ctx context.Context) ([]*User, error) {
var users []*User
err := s.db.WithContext(ctx).Find(&users).Error
return users, err
}
// UpdateUser updates a user
func (s *Store) UpdateUser(ctx context.Context, user *User) error {
user.UpdatedAt = time.Now()
return s.db.WithContext(ctx).Save(user).Error
}
// DeleteUser deletes a user
func (s *Store) DeleteUser(ctx context.Context, id string) error {
return s.db.WithContext(ctx).Delete(&User{}, "id = ?", id).Error
}
// UpdateUserPassword updates a user's password
func (s *Store) UpdateUserPassword(ctx context.Context, id, password string) error {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
return s.db.WithContext(ctx).Model(&User{}).Where("id = ?", id).Updates(map[string]interface{}{
"password": string(hashedPassword),
"updated_at": time.Now(),
}).Error
}
// Client operations
// CreateClient creates a new OIDC client
func (s *Store) CreateClient(ctx context.Context, client *Client) error {
if client.ID == "" {
client.ID = uuid.New().String()
}
// Hash secret if provided
if client.Secret != "" {
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(client.Secret), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash client secret: %w", err)
}
client.Secret = string(hashedSecret)
}
client.CreatedAt = time.Now()
client.UpdatedAt = time.Now()
return s.db.WithContext(ctx).Create(client).Error
}
// GetClientByID retrieves a client by ID
func (s *Store) GetClientByID(ctx context.Context, id string) (*Client, error) {
var client Client
err := s.db.WithContext(ctx).Where("id = ?", id).First(&client).Error
if err != nil {
return nil, err
}
return &client, nil
}
// ValidateClientSecret validates a client's secret
func (s *Store) ValidateClientSecret(ctx context.Context, clientID, secret string) (*Client, error) {
client, err := s.GetClientByID(ctx, clientID)
if err != nil {
return nil, err
}
// Public clients have no secret
if client.Secret == "" && secret == "" {
return client, nil
}
if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(secret)); err != nil {
return nil, errors.New("invalid client secret")
}
return client, nil
}
// ListClients returns all clients
func (s *Store) ListClients(ctx context.Context) ([]*Client, error) {
var clients []*Client
err := s.db.WithContext(ctx).Find(&clients).Error
return clients, err
}
// DeleteClient deletes a client
func (s *Store) DeleteClient(ctx context.Context, id string) error {
return s.db.WithContext(ctx).Delete(&Client{}, "id = ?", id).Error
}
// AuthRequest operations
// SaveAuthRequest saves an authorization request
func (s *Store) SaveAuthRequest(ctx context.Context, req *AuthRequest) error {
if req.ID == "" {
req.ID = uuid.New().String()
}
req.CreatedAt = time.Now()
return s.db.WithContext(ctx).Create(req).Error
}
// GetAuthRequestByID retrieves an auth request by ID
func (s *Store) GetAuthRequestByID(ctx context.Context, id string) (*AuthRequest, error) {
var req AuthRequest
err := s.db.WithContext(ctx).Where("id = ?", id).First(&req).Error
if err != nil {
return nil, err
}
return &req, nil
}
// UpdateAuthRequest updates an auth request
func (s *Store) UpdateAuthRequest(ctx context.Context, req *AuthRequest) error {
return s.db.WithContext(ctx).Save(req).Error
}
// DeleteAuthRequest deletes an auth request
func (s *Store) DeleteAuthRequest(ctx context.Context, id string) error {
return s.db.WithContext(ctx).Delete(&AuthRequest{}, "id = ?", id).Error
}
// AuthCode operations
// SaveAuthCode saves an authorization code
func (s *Store) SaveAuthCode(ctx context.Context, code *AuthCode) error {
code.CreatedAt = time.Now()
if code.ExpiresAt.IsZero() {
code.ExpiresAt = time.Now().Add(10 * time.Minute) // 10 minute expiry
}
return s.db.WithContext(ctx).Create(code).Error
}
// GetAuthCodeByCode retrieves an auth code
func (s *Store) GetAuthCodeByCode(ctx context.Context, code string) (*AuthCode, error) {
var authCode AuthCode
err := s.db.WithContext(ctx).Where("code = ?", code).First(&authCode).Error
if err != nil {
return nil, err
}
return &authCode, nil
}
// DeleteAuthCode deletes an auth code
func (s *Store) DeleteAuthCode(ctx context.Context, code string) error {
return s.db.WithContext(ctx).Delete(&AuthCode{}, "code = ?", code).Error
}
// Token operations
// SaveAccessToken saves an access token
func (s *Store) SaveAccessToken(ctx context.Context, token *AccessToken) error {
if token.ID == "" {
token.ID = uuid.New().String()
}
token.CreatedAt = time.Now()
return s.db.WithContext(ctx).Create(token).Error
}
// GetAccessTokenByID retrieves an access token
func (s *Store) GetAccessTokenByID(ctx context.Context, id string) (*AccessToken, error) {
var token AccessToken
err := s.db.WithContext(ctx).Where("id = ?", id).First(&token).Error
if err != nil {
return nil, err
}
return &token, nil
}
// DeleteAccessToken deletes an access token
func (s *Store) DeleteAccessToken(ctx context.Context, id string) error {
return s.db.WithContext(ctx).Delete(&AccessToken{}, "id = ?", id).Error
}
// RefreshToken operations
// SaveRefreshToken saves a refresh token
func (s *Store) SaveRefreshToken(ctx context.Context, token *RefreshToken) error {
if token.ID == "" {
token.ID = uuid.New().String()
}
if token.Token == "" {
token.Token = uuid.New().String()
}
token.CreatedAt = time.Now()
return s.db.WithContext(ctx).Create(token).Error
}
// GetRefreshToken retrieves a refresh token by token value
func (s *Store) GetRefreshToken(ctx context.Context, token string) (*RefreshToken, error) {
var rt RefreshToken
err := s.db.WithContext(ctx).Where("token = ?", token).First(&rt).Error
if err != nil {
return nil, err
}
return &rt, nil
}
// DeleteRefreshToken deletes a refresh token
func (s *Store) DeleteRefreshToken(ctx context.Context, id string) error {
return s.db.WithContext(ctx).Delete(&RefreshToken{}, "id = ?", id).Error
}
// DeleteRefreshTokenByToken deletes a refresh token by token value
func (s *Store) DeleteRefreshTokenByToken(ctx context.Context, token string) error {
return s.db.WithContext(ctx).Delete(&RefreshToken{}, "token = ?", token).Error
}
// DeviceAuth operations
// SaveDeviceAuth saves a device authorization
func (s *Store) SaveDeviceAuth(ctx context.Context, auth *DeviceAuth) error {
auth.CreatedAt = time.Now()
return s.db.WithContext(ctx).Create(auth).Error
}
// GetDeviceAuthByDeviceCode retrieves device auth by device code
func (s *Store) GetDeviceAuthByDeviceCode(ctx context.Context, deviceCode string) (*DeviceAuth, error) {
var auth DeviceAuth
err := s.db.WithContext(ctx).Where("device_code = ?", deviceCode).First(&auth).Error
if err != nil {
return nil, err
}
return &auth, nil
}
// GetDeviceAuthByUserCode retrieves device auth by user code
func (s *Store) GetDeviceAuthByUserCode(ctx context.Context, userCode string) (*DeviceAuth, error) {
var auth DeviceAuth
err := s.db.WithContext(ctx).Where("user_code = ?", userCode).First(&auth).Error
if err != nil {
return nil, err
}
return &auth, nil
}
// UpdateDeviceAuth updates a device authorization
func (s *Store) UpdateDeviceAuth(ctx context.Context, auth *DeviceAuth) error {
return s.db.WithContext(ctx).Save(auth).Error
}
// DeleteDeviceAuth deletes a device authorization
func (s *Store) DeleteDeviceAuth(ctx context.Context, deviceCode string) error {
return s.db.WithContext(ctx).Delete(&DeviceAuth{}, "device_code = ?", deviceCode).Error
}
// Cleanup operations
// CleanupExpired removes expired tokens and auth requests
func (s *Store) CleanupExpired(ctx context.Context) error {
now := time.Now()
// Delete expired auth codes
if err := s.db.WithContext(ctx).Delete(&AuthCode{}, "expires_at < ?", now).Error; err != nil {
return err
}
// Delete expired access tokens
if err := s.db.WithContext(ctx).Delete(&AccessToken{}, "expiration < ?", now).Error; err != nil {
return err
}
// Delete expired refresh tokens
if err := s.db.WithContext(ctx).Delete(&RefreshToken{}, "expiration < ?", now).Error; err != nil {
return err
}
// Delete expired device authorizations
if err := s.db.WithContext(ctx).Delete(&DeviceAuth{}, "expiration < ?", now).Error; err != nil {
return err
}
// Delete old auth requests (older than 1 hour)
oneHourAgo := now.Add(-1 * time.Hour)
if err := s.db.WithContext(ctx).Delete(&AuthRequest{}, "created_at < ?", oneHourAgo).Error; err != nil {
return err
}
return nil
}
// Helper functions for JSON serialization
// ParseJSONArray parses a JSON array string into a slice
func ParseJSONArray(jsonStr string) []string {
if jsonStr == "" {
return nil
}
var result []string
if err := json.Unmarshal([]byte(jsonStr), &result); err != nil {
return nil
}
return result
}
// ToJSONArray converts a slice to a JSON array string
func ToJSONArray(arr []string) string {
if len(arr) == 0 {
return "[]"
}
data, err := json.Marshal(arr)
if err != nil {
return "[]"
}
return string(data)
}

View File

@@ -0,0 +1,261 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Device Authorization - NetBird</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
padding: 20px;
}
.container {
background: white;
padding: 40px;
border-radius: 12px;
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
width: 100%;
max-width: 450px;
}
.logo {
text-align: center;
margin-bottom: 30px;
}
.logo h1 {
font-size: 28px;
color: #333;
font-weight: 600;
}
.logo p {
color: #666;
margin-top: 8px;
}
.form-group {
margin-bottom: 20px;
}
label {
display: block;
margin-bottom: 8px;
color: #333;
font-weight: 500;
}
input[type="text"],
input[type="password"] {
width: 100%;
padding: 14px 16px;
border: 2px solid #e1e5eb;
border-radius: 8px;
font-size: 16px;
transition: border-color 0.2s, box-shadow 0.2s;
}
input.code-input {
text-align: center;
font-size: 24px;
letter-spacing: 4px;
text-transform: uppercase;
}
input[type="text"]:focus,
input[type="password"]:focus {
outline: none;
border-color: #667eea;
box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.2);
}
button {
width: 100%;
padding: 14px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 8px;
font-size: 16px;
font-weight: 600;
cursor: pointer;
transition: transform 0.2s, box-shadow 0.2s;
margin-bottom: 10px;
}
button:hover {
transform: translateY(-2px);
box-shadow: 0 10px 20px rgba(102, 126, 234, 0.3);
}
button:active {
transform: translateY(0);
}
button.secondary {
background: #e1e5eb;
color: #333;
}
button.secondary:hover {
background: #d1d5db;
box-shadow: none;
}
button.deny {
background: #dc2626;
}
button.deny:hover {
background: #b91c1c;
}
.error {
background: #fee;
color: #c00;
padding: 12px 16px;
border-radius: 8px;
margin-bottom: 20px;
font-size: 14px;
border: 1px solid #fcc;
}
.success {
background: #d4edda;
color: #155724;
padding: 20px;
border-radius: 8px;
text-align: center;
font-size: 16px;
border: 1px solid #c3e6cb;
}
.info {
background: #e8f4fd;
color: #0c5460;
padding: 16px;
border-radius: 8px;
margin-bottom: 20px;
font-size: 14px;
border: 1px solid #bee5eb;
}
.scopes {
background: #f8f9fa;
padding: 16px;
border-radius: 8px;
margin-bottom: 20px;
}
.scopes h3 {
font-size: 14px;
color: #666;
margin-bottom: 12px;
}
.scopes ul {
list-style: none;
padding: 0;
}
.scopes li {
padding: 8px 0;
border-bottom: 1px solid #e1e5eb;
color: #333;
}
.scopes li:last-child {
border-bottom: none;
}
.button-group {
display: flex;
gap: 12px;
}
.button-group button {
flex: 1;
}
.footer {
text-align: center;
margin-top: 24px;
color: #888;
font-size: 13px;
}
</style>
</head>
<body>
<div class="container">
<div class="logo">
<h1>NetBird</h1>
<p>Device Authorization</p>
</div>
{{if .Error}}
<div class="error">{{.Error}}</div>
{{end}}
{{if eq .Step "code"}}
<!-- Step 1: Enter user code -->
<div class="info">
Enter the code shown on your device to authorize it.
</div>
<form method="GET" action="/device">
<div class="form-group">
<label for="user_code">Device Code</label>
<input type="text" id="user_code" name="user_code" class="code-input"
placeholder="XXXX-XXXX" required autofocus
pattern="[A-Za-z]{4}-?[A-Za-z]{4}">
</div>
<button type="submit">Continue</button>
</form>
{{end}}
{{if eq .Step "login"}}
<!-- Step 2: Login -->
<div class="info">
Sign in to authorize the device.
</div>
<form method="POST" action="/device/login">
<input type="hidden" name="user_code" value="{{.UserCode}}">
<div class="form-group">
<label for="username">Username</label>
<input type="text" id="username" name="username" required autofocus>
</div>
<div class="form-group">
<label for="password">Password</label>
<input type="password" id="password" name="password" required>
</div>
<button type="submit">Sign In</button>
</form>
{{end}}
{{if eq .Step "confirm"}}
<!-- Step 3: Confirm authorization -->
<div class="info">
<strong>{{.ClientID}}</strong> is requesting access to your account.
</div>
{{if .Scopes}}
<div class="scopes">
<h3>This application will have access to:</h3>
<ul>
{{range .Scopes}}
<li>{{.}}</li>
{{end}}
</ul>
</div>
{{end}}
<form method="POST" action="/device/confirm">
<div class="button-group">
<button type="submit" name="action" value="allow">Allow</button>
<button type="submit" name="action" value="deny" class="deny">Deny</button>
</div>
</form>
{{end}}
{{if eq .Step "result"}}
<!-- Result -->
{{if .Success}}
<div class="success">
{{.Message}}
</div>
{{else}}
<div class="info">
{{.Message}}
</div>
{{end}}
{{end}}
<div class="footer">
Powered by NetBird Identity Provider
</div>
</div>
</body>
</html>

View File

@@ -0,0 +1,129 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Login - NetBird</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
padding: 20px;
}
.login-container {
background: white;
padding: 40px;
border-radius: 12px;
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
width: 100%;
max-width: 400px;
}
.logo {
text-align: center;
margin-bottom: 30px;
}
.logo h1 {
font-size: 28px;
color: #333;
font-weight: 600;
}
.logo p {
color: #666;
margin-top: 8px;
}
.form-group {
margin-bottom: 20px;
}
label {
display: block;
margin-bottom: 8px;
color: #333;
font-weight: 500;
}
input[type="text"],
input[type="password"] {
width: 100%;
padding: 14px 16px;
border: 2px solid #e1e5eb;
border-radius: 8px;
font-size: 16px;
transition: border-color 0.2s, box-shadow 0.2s;
}
input[type="text"]:focus,
input[type="password"]:focus {
outline: none;
border-color: #667eea;
box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.2);
}
button {
width: 100%;
padding: 14px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 8px;
font-size: 16px;
font-weight: 600;
cursor: pointer;
transition: transform 0.2s, box-shadow 0.2s;
}
button:hover {
transform: translateY(-2px);
box-shadow: 0 10px 20px rgba(102, 126, 234, 0.3);
}
button:active {
transform: translateY(0);
}
.error {
background: #fee;
color: #c00;
padding: 12px 16px;
border-radius: 8px;
margin-bottom: 20px;
font-size: 14px;
border: 1px solid #fcc;
}
.footer {
text-align: center;
margin-top: 24px;
color: #888;
font-size: 13px;
}
</style>
</head>
<body>
<div class="login-container">
<div class="logo">
<h1>NetBird</h1>
<p>Sign in to your account</p>
</div>
{{if .Error}}
<div class="error">{{.Error}}</div>
{{end}}
<form method="POST" action="/login">
<input type="hidden" name="authRequestID" value="{{.AuthRequestID}}">
<div class="form-group">
<label for="username">Username</label>
<input type="text" id="username" name="username" required autofocus>
</div>
<div class="form-group">
<label for="password">Password</label>
<input type="password" id="password" name="password" required>
</div>
<button type="submit">Sign In</button>
</form>
<div class="footer">
Powered by NetBird Identity Provider
</div>
</div>
</body>
</html>