Files
wireguard-go/conn/multipath_bind.go
T
dingfeng.wong 9f0133a5c9 add
2025-07-25 18:01:53 +08:00

178 lines
4.5 KiB
Go

/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"fmt"
"net"
"sync"
)
// MultiPathBind implements Bind interface but sends packets through multiple network paths
type MultiPathBind struct {
mu sync.RWMutex
binds []Bind
// Store the primary bind for receive operations (only one bind receives)
primaryBind Bind
}
// NewMultiPathBind creates a new multi-path bind with multiple underlying binds
func NewMultiPathBind(binds []Bind) *MultiPathBind {
if len(binds) == 0 {
panic("MultiPathBind requires at least one bind")
}
return &MultiPathBind{
binds: binds,
primaryBind: binds[0], // Use first bind as primary for receiving
}
}
// Open puts all binds into listening state
func (mpb *MultiPathBind) Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error) {
mpb.mu.Lock()
defer mpb.mu.Unlock()
// Open primary bind first to get the actual port and receive functions
fns, actualPort, err = mpb.primaryBind.Open(port)
if err != nil {
return nil, 0, fmt.Errorf("failed to open primary bind: %w", err)
}
// Open additional binds on the same port
for i, bind := range mpb.binds[1:] {
_, bindPort, bindErr := bind.Open(actualPort)
if bindErr != nil {
// If any bind fails, close already opened binds
mpb.primaryBind.Close()
for j := 0; j < i; j++ {
mpb.binds[j+1].Close()
}
return nil, 0, fmt.Errorf("failed to open bind %d: %w", i+1, bindErr)
}
// Verify all binds use the same port
if bindPort != actualPort {
mpb.primaryBind.Close()
for j := 0; j <= i; j++ {
mpb.binds[j+1].Close()
}
return nil, 0, fmt.Errorf("bind %d opened on different port %d vs %d", i+1, bindPort, actualPort)
}
}
return fns, actualPort, nil
}
// Close closes all underlying binds
func (mpb *MultiPathBind) Close() error {
mpb.mu.Lock()
defer mpb.mu.Unlock()
var firstErr error
for i, bind := range mpb.binds {
if err := bind.Close(); err != nil && firstErr == nil {
firstErr = fmt.Errorf("failed to close bind %d: %w", i, err)
}
}
return firstErr
}
// SetMark sets the mark for all underlying binds
func (mpb *MultiPathBind) SetMark(mark uint32) error {
mpb.mu.RLock()
defer mpb.mu.RUnlock()
for i, bind := range mpb.binds {
if err := bind.SetMark(mark); err != nil {
return fmt.Errorf("failed to set mark on bind %d: %w", i, err)
}
}
return nil
}
// Send sends the same packets through ALL configured network paths
func (mpb *MultiPathBind) Send(bufs [][]byte, ep Endpoint) error {
mpb.mu.RLock()
defer mpb.mu.RUnlock()
var firstErr error
successCount := 0
// Send through all binds
for i, bind := range mpb.binds {
if err := bind.Send(bufs, ep); err != nil {
if firstErr == nil {
firstErr = fmt.Errorf("bind %d failed: %w", i, err)
}
} else {
successCount++
}
}
// Consider it successful if at least one path succeeded
if successCount > 0 {
return nil
}
return firstErr
}
// ParseEndpoint uses the primary bind to parse endpoints
func (mpb *MultiPathBind) ParseEndpoint(s string) (Endpoint, error) {
mpb.mu.RLock()
defer mpb.mu.RUnlock()
return mpb.primaryBind.ParseEndpoint(s)
}
// BatchSize returns the minimum batch size among all binds
func (mpb *MultiPathBind) BatchSize() int {
mpb.mu.RLock()
defer mpb.mu.RUnlock()
if len(mpb.binds) == 0 {
return 1
}
minBatchSize := mpb.binds[0].BatchSize()
for _, bind := range mpb.binds[1:] {
if size := bind.BatchSize(); size < minBatchSize {
minBatchSize = size
}
}
return minBatchSize
}
// BindToInterface binds specific binds to specific interfaces
// This is a helper method for configuring each bind to use different interfaces
func (mpb *MultiPathBind) BindToInterface(bindIndex int, interfaceIndex uint32, blackhole bool) error {
mpb.mu.RLock()
defer mpb.mu.RUnlock()
if bindIndex >= len(mpb.binds) {
return fmt.Errorf("bind index %d out of range (have %d binds)", bindIndex, len(mpb.binds))
}
bind := mpb.binds[bindIndex]
if binder, ok := bind.(BindSocketToInterface); ok {
// Try IPv4 first
if err := binder.BindSocketToInterface4(interfaceIndex, blackhole); err != nil {
// If IPv4 fails, try IPv6
if err := binder.BindSocketToInterface6(interfaceIndex, blackhole); err != nil {
return fmt.Errorf("failed to bind to interface %d: %w", interfaceIndex, err)
}
}
return nil
}
return fmt.Errorf("bind %d does not support interface binding", bindIndex)
}
// GetBindCount returns the number of configured network paths
func (mpb *MultiPathBind) GetBindCount() int {
mpb.mu.RLock()
defer mpb.mu.RUnlock()
return len(mpb.binds)
}