diff --git a/client/firewall/uspfilter/filter_filter_test.go b/client/firewall/uspfilter/filter_filter_test.go index beb86826d..a64c83138 100644 --- a/client/firewall/uspfilter/filter_filter_test.go +++ b/client/firewall/uspfilter/filter_filter_test.go @@ -712,8 +712,6 @@ func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcP // Detect address family isV6 := src.To4() == nil - var networkLayer gopacket.SerializableLayer - var setChecksum func(gopacket.NetworkLayer) error var err error if isV6 { @@ -723,8 +721,6 @@ func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcP SrcIP: src, DstIP: dst, } - networkLayer = ip6 - setChecksum = func(nl gopacket.NetworkLayer) error { return nil } switch proto { case fw.ProtocolTCP: @@ -747,7 +743,6 @@ func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcP default: err = gopacket.SerializeLayers(buf, opts, ip6) } - _ = setChecksum } else { ip4 := &layers.IPv4{ Version: 4, @@ -755,7 +750,6 @@ func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcP SrcIP: src, DstIP: dst, } - networkLayer = ip4 switch proto { case fw.ProtocolTCP: @@ -776,7 +770,6 @@ func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcP err = gopacket.SerializeLayers(buf, opts, ip4) } } - _ = networkLayer require.NoError(t, err) return buf.Bytes() diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go index 3f1ba6a44..bec6fb3e5 100644 --- a/client/firewall/uspfilter/forwarder/endpoint.go +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -54,8 +54,11 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) } raw := pkt.NetworkHeader().View().AsSlice() + if len(raw) == 0 { + continue + } var address tcpip.Address - if len(raw) > 0 && raw[0]>>4 == 6 { + if raw[0]>>4 == 6 { address = header.IPv6(raw).DestinationAddress() } else { address = header.IPv4(raw).DestinationAddress() diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index bbacb95f6..85c5bbc03 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -37,17 +37,18 @@ type Forwarder struct { logger *nblog.Logger flowLogger nftypes.FlowLogger // ruleIdMap is used to store the rule ID for a given connection - ruleIdMap sync.Map - stack *stack.Stack - endpoint *endpoint - udpForwarder *udpForwarder - ctx context.Context - cancel context.CancelFunc - ip tcpip.Address - ipv6 tcpip.Address - netstack bool - hasRawICMPAccess bool - pingSemaphore chan struct{} + ruleIdMap sync.Map + stack *stack.Stack + endpoint *endpoint + udpForwarder *udpForwarder + ctx context.Context + cancel context.CancelFunc + ip tcpip.Address + ipv6 tcpip.Address + netstack bool + hasRawICMPAccess bool + hasRawICMPv6Access bool + pingSemaphore chan struct{} } func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) { @@ -157,18 +158,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow udpForwarder := udp.NewForwarder(s, f.handleUDP) s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) - s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP) - // TODO: gvisor's IPv6 network layer (ipv6/icmp.go) replies to ICMPv6 echo - // requests at the network layer before our transport handler fires. Unlike - // IPv4, it has no localAddressTemporary check or DeliverTransportPacket call - // before replying. With promiscuous mode, this means gvisor replies to ALL - // ICMPv6 echo (including routed traffic) with local latency. - // Not fixed as of gvisor 20260320. - // Fix: handle ICMPv6 echo in the USP filter before passing to the forwarder, - // similar to how v4 ICMP worked before the forwarder existed. The forwarder - // is needed for TCP (full proxy) and UDP (endpoint tracking), but ICMP can - // be handled directly since it's stateless request/reply. - s.SetTransportProtocolHandler(icmp.ProtocolNumber6, f.handleICMPv6) + // ICMP is handled directly in InjectIncomingPacket, bypassing gVisor's + // network layer. This avoids duplicate echo replies (v4) and the v6 + // auto-reply bug where gVisor responds at the network layer before + // our transport handler fires. f.checkICMPCapability() @@ -187,11 +180,17 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error { if len(payload) < header.IPv4MinimumSize { return fmt.Errorf("IPv4 packet too small: %d bytes", len(payload)) } + if f.handleICMPDirect(payload) { + return nil + } protoNum = ipv4.ProtocolNumber case 6: if len(payload) < header.IPv6MinimumSize { return fmt.Errorf("IPv6 packet too small: %d bytes", len(payload)) } + if f.handleICMPDirect(payload) { + return nil + } protoNum = ipv6.ProtocolNumber default: return fmt.Errorf("unknown IP version: %d", payload[0]>>4) @@ -208,6 +207,90 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error { return nil } +// handleICMPDirect intercepts ICMP packets from raw IP payloads before they +// enter gVisor. It synthesizes the TransportEndpointID and PacketBuffer that +// the existing handlers expect, then dispatches to handleICMP/handleICMPv6. +// This bypasses gVisor's network layer which causes duplicate v4 echo replies +// and auto-replies to all v6 echo requests in promiscuous mode. +// +// Unlike gVisor's network layer, this does not validate ICMP checksums or +// reassemble IP fragments. Fragmented ICMP packets fall through to gVisor. +func parseICMPv4(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) { + ip := header.IPv4(payload) + if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) { + return 0, src, dst, false + } + if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 { + return 0, src, dst, false + } + ipHdrLen = int(ip.HeaderLength()) + if len(payload)-ipHdrLen < header.ICMPv4MinimumSize { + return 0, src, dst, false + } + return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true +} + +func parseICMPv6(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) { + ip := header.IPv6(payload) + if ip.NextHeader() != uint8(header.ICMPv6ProtocolNumber) { + return 0, src, dst, false + } + ipHdrLen = header.IPv6MinimumSize + if len(payload)-ipHdrLen < header.ICMPv6MinimumSize { + return 0, src, dst, false + } + return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true +} + +func (f *Forwarder) handleICMPDirect(payload []byte) bool { + var ( + ipHdrLen int + srcAddr tcpip.Address + dstAddr tcpip.Address + ok bool + ) + switch payload[0] >> 4 { + case 4: + ipHdrLen, srcAddr, dstAddr, ok = parseICMPv4(payload) + case 6: + ipHdrLen, srcAddr, dstAddr, ok = parseICMPv6(payload) + } + if !ok { + return false + } + + // Let gVisor handle ICMP destined for our own addresses natively. + // Its network-layer auto-reply is correct and efficient for local traffic. + if f.ip.Equal(dstAddr) || f.ipv6.Equal(dstAddr) { + return false + } + + id := stack.TransportEndpointID{ + LocalAddress: dstAddr, + RemoteAddress: srcAddr, + } + + // Build a PacketBuffer with headers consumed the same way gVisor would. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(payload), + }) + defer pkt.DecRef() + + if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok { + return false + } + + icmpPayload := payload[ipHdrLen:] + if _, ok := pkt.TransportHeader().Consume(len(icmpPayload)); !ok { + return false + } + + if payload[0]>>4 == 6 { + return f.handleICMPv6(id, pkt) + } + return f.handleICMP(id, pkt) +} + // Stop gracefully shuts down the forwarder func (f *Forwarder) Stop() { f.cancel() @@ -274,29 +357,37 @@ func addrFromNetipAddr(addr netip.Addr) tcpip.Address { // addrToNetipAddr converts a gvisor tcpip.Address to netip.Addr without allocating. func addrToNetipAddr(addr tcpip.Address) netip.Addr { - if addr.Len() == 4 { + switch addr.Len() { + case 4: return netip.AddrFrom4(addr.As4()) + case 16: + return netip.AddrFrom16(addr.As16()) + default: + return netip.Addr{} } - return netip.AddrFrom16(addr.As16()) } // checkICMPCapability tests whether we have raw ICMP socket access at startup. func (f *Forwarder) checkICMPCapability() { + f.hasRawICMPAccess = probeRawICMP("ip4:icmp", "0.0.0.0", f.logger) + f.hasRawICMPv6Access = probeRawICMP("ip6:ipv6-icmp", "::", f.logger) +} + +func probeRawICMP(network, addr string, logger *nblog.Logger) bool { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() lc := net.ListenConfig{} - conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") + conn, err := lc.ListenPacket(ctx, network, addr) if err != nil { - f.hasRawICMPAccess = false - f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback") - return + logger.Debug1("forwarder: no raw %s socket access, will use ping binary fallback", network) + return false } if err := conn.Close(); err != nil { - f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err) + logger.Debug2("forwarder: failed to close %s capability test socket: %v", network, err) } - f.hasRawICMPAccess = true - f.logger.Debug("forwarder: Raw ICMP socket access available") + logger.Debug1("forwarder: raw %s socket access available", network) + return true } diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index c60c22a38..f24ec987e 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -35,7 +35,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBu } icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice() - conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond) + conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), false, 100*time.Millisecond) if err != nil { f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err) return true @@ -72,12 +72,17 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI // forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection. // The caller is responsible for closing the returned connection. -func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) { +func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, v6 bool, timeout time.Duration) (net.PacketConn, error) { ctx, cancel := context.WithTimeout(f.ctx, timeout) defer cancel() + network, listenAddr := "ip4:icmp", "0.0.0.0" + if v6 { + network, listenAddr = "ip6:ipv6-icmp", "::" + } + lc := net.ListenConfig{} - conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") + conn, err := lc.ListenPacket(ctx, network, listenAddr) if err != nil { return nil, fmt.Errorf("create ICMP socket: %w", err) } @@ -102,7 +107,7 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) { sendTime := time.Now() - conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second) + conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, false, 5*time.Second) if err != nil { f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err) return @@ -224,8 +229,23 @@ func (f *Forwarder) handleICMPv6(id stack.TransportEndpointID, pkt *stack.Packet return f.handleICMPv6Echo(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code())) } - f.logger.Debug2("forwarder: Unhandled ICMPv6 type %v for %v", icmpHdr.Type(), epID(id)) - return false + // For non-echo types (Destination Unreachable, Packet Too Big, etc), forward without waiting + if !f.hasRawICMPv6Access { + f.logger.Debug2("forwarder: Cannot handle ICMPv6 type %v without raw socket access for %v", icmpHdr.Type(), epID(id)) + return false + } + + icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice() + conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), true, 100*time.Millisecond) + if err != nil { + f.logger.Error2("forwarder: Failed to forward ICMPv6 packet for %v: %v", epID(id), err) + return true + } + if err := conn.Close(); err != nil { + f.logger.Debug1("forwarder: Failed to close ICMPv6 socket: %v", err) + } + + return true } // handleICMPv6Echo handles ICMPv6 echo requests using the ping binary. @@ -260,9 +280,14 @@ func (f *Forwarder) handleICMPv6ViaPing(flowID uuid.UUID, id stack.TransportEndp } rtt := time.Since(pingStart).Round(10 * time.Microsecond) - f.logger.Trace4("forwarder: Forwarded ICMPv6 echo reply %v type %v code %v (rtt=%v)", epID(id), icmpType, icmpCode, rtt) + f.logger.Trace3("forwarder: Forwarded ICMPv6 echo request %v type %v code %v", + epID(id), icmpType, icmpCode) txBytes := f.synthesizeICMPv6EchoReply(id, icmpData) + + f.logger.Trace4("forwarder: Forwarded ICMPv6 echo reply %v type %v code %v (rtt=%v, ping binary)", + epID(id), icmpType, icmpCode, rtt) + f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } @@ -274,7 +299,8 @@ func (f *Forwarder) synthesizeICMPv6EchoReply(id stack.TransportEndpointID, icmp replyHdr := header.ICMPv6(replyICMP) replyHdr.SetType(header.ICMPv6EchoReply) replyHdr.SetChecksum(0) - // ICMPv6 checksum includes a pseudo-header computed internally by ICMPv6Checksum + // ICMPv6Checksum computes the pseudo-header internally from Src/Dst. + // Header contains the full ICMP message, so PayloadCsum/PayloadLen are zero. replyHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: replyHdr, Src: id.LocalAddress, diff --git a/client/firewall/uspfilter/localip_bench_test.go b/client/firewall/uspfilter/localip_bench_test.go index 11bdee7ab..14e12bd08 100644 --- a/client/firewall/uspfilter/localip_bench_test.go +++ b/client/firewall/uspfilter/localip_bench_test.go @@ -20,7 +20,9 @@ func setupManager(b *testing.B) *localIPManager { } }, } - _ = m.UpdateLocalIPs(mock) + if err := m.UpdateLocalIPs(mock); err != nil { + b.Fatalf("UpdateLocalIPs: %v", err) + } return m } diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index 0bd7d1de8..8d030aff0 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -162,7 +162,7 @@ func (p *PacketBuilder) buildTransportLayer(ipLayer gopacket.SerializableLayer) case "udp": return p.buildUDPLayer(ipLayer) case "icmp": - return p.buildICMPLayer() + return p.buildICMPLayer(ipLayer) default: return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol) } @@ -201,11 +201,14 @@ func (p *PacketBuilder) buildUDPLayer(ipLayer gopacket.SerializableLayer) ([]gop return []gopacket.SerializableLayer{udp}, nil } -func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) { +func (p *PacketBuilder) buildICMPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) { if p.SrcIP.Is6() || p.DstIP.Is6() { icmp := &layers.ICMPv6{ TypeCode: layers.CreateICMPv6TypeCode(p.ICMPType, p.ICMPCode), } + if nl, ok := ipLayer.(gopacket.NetworkLayer); ok { + _ = icmp.SetNetworkLayerForChecksum(nl) + } return []gopacket.SerializableLayer{icmp}, nil } icmp := &layers.ICMPv4{