/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn import ( "fmt" "net" "sync" ) // MultiPathBind implements Bind interface but sends packets through multiple network paths type MultiPathBind struct { mu sync.RWMutex binds []Bind // Store the primary bind for receive operations (only one bind receives) primaryBind Bind } // NewMultiPathBind creates a new multi-path bind with multiple underlying binds func NewMultiPathBind(binds []Bind) *MultiPathBind { if len(binds) == 0 { panic("MultiPathBind requires at least one bind") } return &MultiPathBind{ binds: binds, primaryBind: binds[0], // Use first bind as primary for receiving } } // Open puts all binds into listening state func (mpb *MultiPathBind) Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error) { mpb.mu.Lock() defer mpb.mu.Unlock() // Open primary bind first to get the actual port and receive functions fns, actualPort, err = mpb.primaryBind.Open(port) if err != nil { return nil, 0, fmt.Errorf("failed to open primary bind: %w", err) } // Open additional binds on the same port for i, bind := range mpb.binds[1:] { _, bindPort, bindErr := bind.Open(actualPort) if bindErr != nil { // If any bind fails, close already opened binds mpb.primaryBind.Close() for j := 0; j < i; j++ { mpb.binds[j+1].Close() } return nil, 0, fmt.Errorf("failed to open bind %d: %w", i+1, bindErr) } // Verify all binds use the same port if bindPort != actualPort { mpb.primaryBind.Close() for j := 0; j <= i; j++ { mpb.binds[j+1].Close() } return nil, 0, fmt.Errorf("bind %d opened on different port %d vs %d", i+1, bindPort, actualPort) } } return fns, actualPort, nil } // Close closes all underlying binds func (mpb *MultiPathBind) Close() error { mpb.mu.Lock() defer mpb.mu.Unlock() var firstErr error for i, bind := range mpb.binds { if err := bind.Close(); err != nil && firstErr == nil { firstErr = fmt.Errorf("failed to close bind %d: %w", i, err) } } return firstErr } // SetMark sets the mark for all underlying binds func (mpb *MultiPathBind) SetMark(mark uint32) error { mpb.mu.RLock() defer mpb.mu.RUnlock() for i, bind := range mpb.binds { if err := bind.SetMark(mark); err != nil { return fmt.Errorf("failed to set mark on bind %d: %w", i, err) } } return nil } // Send sends the same packets through ALL configured network paths func (mpb *MultiPathBind) Send(bufs [][]byte, ep Endpoint) error { mpb.mu.RLock() defer mpb.mu.RUnlock() var firstErr error successCount := 0 // Send through all binds for i, bind := range mpb.binds { if err := bind.Send(bufs, ep); err != nil { if firstErr == nil { firstErr = fmt.Errorf("bind %d failed: %w", i, err) } } else { successCount++ } } // Consider it successful if at least one path succeeded if successCount > 0 { return nil } return firstErr } // ParseEndpoint uses the primary bind to parse endpoints func (mpb *MultiPathBind) ParseEndpoint(s string) (Endpoint, error) { mpb.mu.RLock() defer mpb.mu.RUnlock() return mpb.primaryBind.ParseEndpoint(s) } // BatchSize returns the minimum batch size among all binds func (mpb *MultiPathBind) BatchSize() int { mpb.mu.RLock() defer mpb.mu.RUnlock() if len(mpb.binds) == 0 { return 1 } minBatchSize := mpb.binds[0].BatchSize() for _, bind := range mpb.binds[1:] { if size := bind.BatchSize(); size < minBatchSize { minBatchSize = size } } return minBatchSize } // BindToInterface binds specific binds to specific interfaces // This is a helper method for configuring each bind to use different interfaces func (mpb *MultiPathBind) BindToInterface(bindIndex int, interfaceIndex uint32, blackhole bool) error { mpb.mu.RLock() defer mpb.mu.RUnlock() if bindIndex >= len(mpb.binds) { return fmt.Errorf("bind index %d out of range (have %d binds)", bindIndex, len(mpb.binds)) } bind := mpb.binds[bindIndex] if binder, ok := bind.(BindSocketToInterface); ok { // Try IPv4 first if err := binder.BindSocketToInterface4(interfaceIndex, blackhole); err != nil { // If IPv4 fails, try IPv6 if err := binder.BindSocketToInterface6(interfaceIndex, blackhole); err != nil { return fmt.Errorf("failed to bind to interface %d: %w", interfaceIndex, err) } } return nil } return fmt.Errorf("bind %d does not support interface binding", bindIndex) } // GetBindCount returns the number of configured network paths func (mpb *MultiPathBind) GetBindCount() int { mpb.mu.RLock() defer mpb.mu.RUnlock() return len(mpb.binds) }