184 lines
4.7 KiB
Go
184 lines
4.7 KiB
Go
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package conn
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
)
|
|
|
|
// MultiPathBind implements Bind interface and sends/receives packets through multiple network paths
|
|
type MultiPathBind struct {
|
|
mu sync.RWMutex
|
|
binds []Bind
|
|
}
|
|
|
|
// NewMultiPathBind creates a new multi-path bind with multiple underlying binds
|
|
func NewMultiPathBind(binds []Bind) *MultiPathBind {
|
|
if len(binds) == 0 {
|
|
panic("MultiPathBind requires at least one bind")
|
|
}
|
|
|
|
return &MultiPathBind{
|
|
binds: binds,
|
|
}
|
|
}
|
|
|
|
// Open puts all binds into listening state and collects receive functions from all binds
|
|
func (mpb *MultiPathBind) Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error) {
|
|
mpb.mu.Lock()
|
|
defer mpb.mu.Unlock()
|
|
|
|
// Open first bind to get the actual port
|
|
var firstBindFns []ReceiveFunc
|
|
firstBindFns, actualPort, err = mpb.binds[0].Open(port)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("failed to open bind 0: %w", err)
|
|
}
|
|
|
|
// Collect receive functions from the first bind
|
|
fns = append(fns, firstBindFns...)
|
|
|
|
// Open additional binds on the same port and collect their receive functions
|
|
for i, bind := range mpb.binds[1:] {
|
|
var bindFns []ReceiveFunc
|
|
var bindPort uint16
|
|
bindFns, bindPort, err = bind.Open(actualPort)
|
|
if err != nil {
|
|
// If any bind fails, close already opened binds
|
|
mpb.binds[0].Close()
|
|
for j := 0; j < i; j++ {
|
|
mpb.binds[j+1].Close()
|
|
}
|
|
return nil, 0, fmt.Errorf("failed to open bind %d: %w", i+1, err)
|
|
}
|
|
|
|
// Verify all binds use the same port
|
|
if bindPort != actualPort {
|
|
mpb.binds[0].Close()
|
|
for j := 0; j <= i; j++ {
|
|
mpb.binds[j+1].Close()
|
|
}
|
|
return nil, 0, fmt.Errorf("bind %d opened on different port %d vs %d", i+1, bindPort, actualPort)
|
|
}
|
|
|
|
// Collect receive functions from this bind
|
|
fns = append(fns, bindFns...)
|
|
}
|
|
|
|
return fns, actualPort, nil
|
|
}
|
|
|
|
// Close closes all underlying binds
|
|
func (mpb *MultiPathBind) Close() error {
|
|
mpb.mu.Lock()
|
|
defer mpb.mu.Unlock()
|
|
|
|
var firstErr error
|
|
for i, bind := range mpb.binds {
|
|
if err := bind.Close(); err != nil && firstErr == nil {
|
|
firstErr = fmt.Errorf("failed to close bind %d: %w", i, err)
|
|
}
|
|
}
|
|
return firstErr
|
|
}
|
|
|
|
// SetMark sets the mark for all underlying binds
|
|
func (mpb *MultiPathBind) SetMark(mark uint32) error {
|
|
mpb.mu.RLock()
|
|
defer mpb.mu.RUnlock()
|
|
|
|
for i, bind := range mpb.binds {
|
|
if err := bind.SetMark(mark); err != nil {
|
|
return fmt.Errorf("failed to set mark on bind %d: %w", i, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Send sends the same packets through ALL configured network paths
|
|
func (mpb *MultiPathBind) Send(bufs [][]byte, ep Endpoint) error {
|
|
mpb.mu.RLock()
|
|
defer mpb.mu.RUnlock()
|
|
|
|
var firstErr error
|
|
successCount := 0
|
|
|
|
// Send through all binds
|
|
for i, bind := range mpb.binds {
|
|
if err := bind.Send(bufs, ep); err != nil {
|
|
if firstErr == nil {
|
|
firstErr = fmt.Errorf("bind %d failed: %w", i, err)
|
|
}
|
|
} else {
|
|
successCount++
|
|
}
|
|
}
|
|
|
|
// Consider it successful if at least one path succeeded
|
|
if successCount > 0 {
|
|
return nil
|
|
}
|
|
|
|
return firstErr
|
|
}
|
|
|
|
// ParseEndpoint uses the first bind to parse endpoints
|
|
func (mpb *MultiPathBind) ParseEndpoint(s string) (Endpoint, error) {
|
|
mpb.mu.RLock()
|
|
defer mpb.mu.RUnlock()
|
|
return mpb.binds[0].ParseEndpoint(s)
|
|
}
|
|
|
|
// BatchSize returns the minimum batch size among all binds
|
|
func (mpb *MultiPathBind) BatchSize() int {
|
|
mpb.mu.RLock()
|
|
defer mpb.mu.RUnlock()
|
|
|
|
if len(mpb.binds) == 0 {
|
|
return 1
|
|
}
|
|
|
|
minBatchSize := mpb.binds[0].BatchSize()
|
|
for _, bind := range mpb.binds[1:] {
|
|
if size := bind.BatchSize(); size < minBatchSize {
|
|
minBatchSize = size
|
|
}
|
|
}
|
|
return minBatchSize
|
|
}
|
|
|
|
// BindToInterface binds specific binds to specific interfaces
|
|
// This is a helper method for configuring each bind to use different interfaces
|
|
func (mpb *MultiPathBind) BindToInterface(bindIndex int, interfaceIndex uint32, blackhole bool) error {
|
|
mpb.mu.RLock()
|
|
defer mpb.mu.RUnlock()
|
|
|
|
if bindIndex >= len(mpb.binds) {
|
|
return fmt.Errorf("bind index %d out of range (have %d binds)", bindIndex, len(mpb.binds))
|
|
}
|
|
|
|
bind := mpb.binds[bindIndex]
|
|
if binder, ok := bind.(BindSocketToInterface); ok {
|
|
// Try IPv4 first
|
|
if err := binder.BindSocketToInterface4(interfaceIndex, blackhole); err != nil {
|
|
// If IPv4 fails, try IPv6
|
|
if err := binder.BindSocketToInterface6(interfaceIndex, blackhole); err != nil {
|
|
return fmt.Errorf("failed to bind to interface %d: %w", interfaceIndex, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("bind %d does not support interface binding", bindIndex)
|
|
}
|
|
|
|
// GetBindCount returns the number of configured network paths
|
|
func (mpb *MultiPathBind) GetBindCount() int {
|
|
mpb.mu.RLock()
|
|
defer mpb.mu.RUnlock()
|
|
return len(mpb.binds)
|
|
} |