Files
netbird/client/internal/routeselector/routeselector.go

291 lines
6.8 KiB
Go

package routeselector
import (
"encoding/json"
"fmt"
"slices"
"sync"
"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/route"
)
const (
exitNodeCIDR = "0.0.0.0/0"
)
type RouteSelector struct {
mu sync.RWMutex
deselectedRoutes map[route.NetID]struct{}
selectedRoutes map[route.NetID]struct{}
deselectAll bool
}
func NewRouteSelector() *RouteSelector {
return &RouteSelector{
deselectedRoutes: map[route.NetID]struct{}{},
selectedRoutes: map[route.NetID]struct{}{},
deselectAll: false,
}
}
// SelectRoutes updates the selected routes based on the provided route IDs.
func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error {
rs.mu.Lock()
defer rs.mu.Unlock()
if !appendRoute || rs.deselectAll {
if rs.deselectedRoutes == nil {
rs.deselectedRoutes = map[route.NetID]struct{}{}
}
if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
maps.Clear(rs.deselectedRoutes)
maps.Clear(rs.selectedRoutes)
for _, r := range allRoutes {
rs.deselectedRoutes[r] = struct{}{}
}
}
var err *multierror.Error
for _, route := range routes {
if !slices.Contains(allRoutes, route) {
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
continue
}
delete(rs.deselectedRoutes, route)
rs.selectedRoutes[route] = struct{}{}
}
rs.deselectAll = false
return errors.FormatErrorOrNil(err)
}
// SelectAllRoutes sets the selector to select all routes.
func (rs *RouteSelector) SelectAllRoutes() {
rs.mu.Lock()
defer rs.mu.Unlock()
rs.deselectAll = false
if rs.deselectedRoutes == nil {
rs.deselectedRoutes = map[route.NetID]struct{}{}
}
if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
maps.Clear(rs.deselectedRoutes)
maps.Clear(rs.selectedRoutes)
}
// DeselectRoutes removes specific routes from the selection.
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.deselectAll {
return nil
}
var err *multierror.Error
for _, route := range routes {
if !slices.Contains(allRoutes, route) {
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
continue
}
rs.deselectedRoutes[route] = struct{}{}
delete(rs.selectedRoutes, route)
}
return errors.FormatErrorOrNil(err)
}
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
func (rs *RouteSelector) DeselectAllRoutes() {
rs.mu.Lock()
defer rs.mu.Unlock()
rs.deselectAll = true
if rs.deselectedRoutes == nil {
rs.deselectedRoutes = map[route.NetID]struct{}{}
}
if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
maps.Clear(rs.deselectedRoutes)
maps.Clear(rs.selectedRoutes)
}
// IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
isSelected := !deselected
return isSelected
}
// FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
rs.mu.RLock()
defer rs.mu.RUnlock()
if rs.deselectAll {
return route.HAMap{}
}
filtered := route.HAMap{}
for id, rt := range routes {
netID := id.NetID()
_, deselected := rs.deselectedRoutes[netID]
if !deselected {
filtered[id] = rt
}
}
return filtered
}
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this specific route
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
rs.mu.RLock()
defer rs.mu.RUnlock()
if rs.deselectAll {
return route.HAMap{}
}
filtered := make(route.HAMap, len(routes))
for id, rt := range routes {
netID := id.NetID()
if rs.isDeselected(netID) {
continue
}
if !isExitNode(rt) {
filtered[id] = rt
continue
}
rs.applyExitNodeFilter(id, netID, rt, filtered)
}
return filtered
}
func (rs *RouteSelector) isDeselected(netID route.NetID) bool {
_, deselected := rs.deselectedRoutes[netID]
return deselected || rs.deselectAll
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && rt[0].Network.String() == exitNodeCIDR
}
func (rs *RouteSelector) applyExitNodeFilter(
id route.HAUniqueID,
netID route.NetID,
rt []*route.Route,
out route.HAMap,
) {
if rs.hasUserSelections() {
// user made explicit selects/deselects
if rs.IsSelected(netID) {
out[id] = rt
}
return
}
// no explicit selections: only include routes marked !SkipAutoApply (=AutoApply)
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func (rs *RouteSelector) hasUserSelections() bool {
return len(rs.selectedRoutes) > 0 || len(rs.deselectedRoutes) > 0
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {
if !r.SkipAutoApply {
sel = append(sel, r)
}
}
return sel
}
// MarshalJSON implements the json.Marshaler interface
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
rs.mu.RLock()
defer rs.mu.RUnlock()
return json.Marshal(struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"`
DeselectAll bool `json:"deselect_all"`
}{
SelectedRoutes: rs.selectedRoutes,
DeselectedRoutes: rs.deselectedRoutes,
DeselectAll: rs.deselectAll,
})
}
// UnmarshalJSON implements the json.Unmarshaler interface
// If the JSON is empty or null, it will initialize like a NewRouteSelector.
func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
rs.mu.Lock()
defer rs.mu.Unlock()
// Check for null or empty JSON
if len(data) == 0 || string(data) == "null" {
rs.deselectedRoutes = map[route.NetID]struct{}{}
rs.selectedRoutes = map[route.NetID]struct{}{}
rs.deselectAll = false
return nil
}
var temp struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"`
DeselectAll bool `json:"deselect_all"`
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
rs.selectedRoutes = temp.SelectedRoutes
rs.deselectedRoutes = temp.DeselectedRoutes
rs.deselectAll = temp.DeselectAll
if rs.deselectedRoutes == nil {
rs.deselectedRoutes = map[route.NetID]struct{}{}
}
if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
return nil
}