From fcf8c4b30ee700e072999af75fad1efc88296595 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 27 Mar 2026 14:24:04 +0100 Subject: [PATCH 1/2] Handle ICMP directly in forwarder, bypassing gVisor network layer --- .../firewall/uspfilter/filter_filter_test.go | 7 -- .../firewall/uspfilter/forwarder/endpoint.go | 5 +- .../firewall/uspfilter/forwarder/forwarder.go | 88 ++++++++++++++++++- client/firewall/uspfilter/forwarder/icmp.go | 44 +++++++--- .../firewall/uspfilter/localip_bench_test.go | 4 +- client/firewall/uspfilter/tracer.go | 7 +- 6 files changed, 127 insertions(+), 28 deletions(-) diff --git a/client/firewall/uspfilter/filter_filter_test.go b/client/firewall/uspfilter/filter_filter_test.go index beb86826d..a64c83138 100644 --- a/client/firewall/uspfilter/filter_filter_test.go +++ b/client/firewall/uspfilter/filter_filter_test.go @@ -712,8 +712,6 @@ func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcP // Detect address family isV6 := src.To4() == nil - var networkLayer gopacket.SerializableLayer - var setChecksum func(gopacket.NetworkLayer) error var err error if isV6 { @@ -723,8 +721,6 @@ func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcP SrcIP: src, DstIP: dst, } - networkLayer = ip6 - setChecksum = func(nl gopacket.NetworkLayer) error { return nil } switch proto { case fw.ProtocolTCP: @@ -747,7 +743,6 @@ func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcP default: err = gopacket.SerializeLayers(buf, opts, ip6) } - _ = setChecksum } else { ip4 := &layers.IPv4{ Version: 4, @@ -755,7 +750,6 @@ func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcP SrcIP: src, DstIP: dst, } - networkLayer = ip4 switch proto { case fw.ProtocolTCP: @@ -776,7 +770,6 @@ func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcP err = gopacket.SerializeLayers(buf, opts, ip4) } } - _ = networkLayer require.NoError(t, err) return buf.Bytes() diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go index 6dd2e8e84..45117f3a2 100644 --- a/client/firewall/uspfilter/forwarder/endpoint.go +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -53,8 +53,11 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) } raw := pkt.NetworkHeader().View().AsSlice() + if len(raw) == 0 { + continue + } var address tcpip.Address - if len(raw) > 0 && raw[0]>>4 == 6 { + if raw[0]>>4 == 6 { address = header.IPv6(raw).DestinationAddress() } else { address = header.IPv4(raw).DestinationAddress() diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 84cc7ec44..770f3fbea 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -157,8 +157,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow udpForwarder := udp.NewForwarder(s, f.handleUDP) s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) - s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP) - s.SetTransportProtocolHandler(icmp.ProtocolNumber6, f.handleICMPv6) + // ICMP is handled directly in InjectIncomingPacket, bypassing gVisor's + // network layer. This avoids duplicate echo replies (v4) and the v6 + // auto-reply bug where gVisor responds at the network layer before + // our transport handler fires. f.checkICMPCapability() @@ -177,11 +179,17 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error { if len(payload) < header.IPv4MinimumSize { return fmt.Errorf("IPv4 packet too small: %d bytes", len(payload)) } + if f.handleICMPDirect(payload) { + return nil + } protoNum = ipv4.ProtocolNumber case 6: if len(payload) < header.IPv6MinimumSize { return fmt.Errorf("IPv6 packet too small: %d bytes", len(payload)) } + if f.handleICMPDirect(payload) { + return nil + } protoNum = ipv6.ProtocolNumber default: return fmt.Errorf("unknown IP version: %d", payload[0]>>4) @@ -198,6 +206,74 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error { return nil } +// handleICMPDirect intercepts ICMP packets from raw IP payloads before they +// enter gVisor. It synthesizes the TransportEndpointID and PacketBuffer that +// the existing handlers expect, then dispatches to handleICMP/handleICMPv6. +// This bypasses gVisor's network layer which causes duplicate v4 echo replies +// and auto-replies to all v6 echo requests in promiscuous mode. +// +// Unlike gVisor's network layer, this does not validate ICMP checksums or +// reassemble IP fragments. Fragmented ICMP packets fall through to gVisor. +func (f *Forwarder) handleICMPDirect(payload []byte) bool { + var ( + ipHdrLen int + srcAddr tcpip.Address + dstAddr tcpip.Address + ) + + switch payload[0] >> 4 { + case 4: + ip := header.IPv4(payload) + if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) { + return false + } + ipHdrLen = int(ip.HeaderLength()) + srcAddr = ip.SourceAddress() + dstAddr = ip.DestinationAddress() + case 6: + ip := header.IPv6(payload) + if ip.NextHeader() != uint8(header.ICMPv6ProtocolNumber) { + return false + } + ipHdrLen = header.IPv6MinimumSize + srcAddr = ip.SourceAddress() + dstAddr = ip.DestinationAddress() + default: + return false + } + + // Let gVisor handle ICMP destined for our own addresses natively. + // Its network-layer auto-reply is correct and efficient for local traffic. + if f.ip.Equal(dstAddr) || f.ipv6.Equal(dstAddr) { + return false + } + + id := stack.TransportEndpointID{ + LocalAddress: dstAddr, + RemoteAddress: srcAddr, + } + + // Build a PacketBuffer with headers consumed the same way gVisor would. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(payload), + }) + defer pkt.DecRef() + + if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok { + return false + } + + icmpPayload := payload[ipHdrLen:] + if _, ok := pkt.TransportHeader().Consume(len(icmpPayload)); !ok { + return false + } + + if payload[0]>>4 == 6 { + return f.handleICMPv6(id, pkt) + } + return f.handleICMP(id, pkt) +} + // Stop gracefully shuts down the forwarder func (f *Forwarder) Stop() { f.cancel() @@ -264,10 +340,14 @@ func addrFromNetipAddr(addr netip.Addr) tcpip.Address { // addrToNetipAddr converts a gvisor tcpip.Address to netip.Addr without allocating. func addrToNetipAddr(addr tcpip.Address) netip.Addr { - if addr.Len() == 4 { + switch addr.Len() { + case 4: return netip.AddrFrom4(addr.As4()) + case 16: + return netip.AddrFrom16(addr.As16()) + default: + return netip.Addr{} } - return netip.AddrFrom16(addr.As16()) } // checkICMPCapability tests whether we have raw ICMP socket access at startup. diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index 7df2fec82..4b3d01779 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -34,7 +34,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBu } icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice() - conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond) + conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), false, 100*time.Millisecond) if err != nil { f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err) return true @@ -71,12 +71,17 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI // forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection. // The caller is responsible for closing the returned connection. -func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) { +func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, v6 bool, timeout time.Duration) (net.PacketConn, error) { ctx, cancel := context.WithTimeout(f.ctx, timeout) defer cancel() + network, listenAddr := "ip4:icmp", "0.0.0.0" + if v6 { + network, listenAddr = "ip6:ipv6-icmp", "::" + } + lc := net.ListenConfig{} - conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") + conn, err := lc.ListenPacket(ctx, network, listenAddr) if err != nil { return nil, fmt.Errorf("create ICMP socket: %w", err) } @@ -101,7 +106,7 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) { sendTime := time.Now() - conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second) + conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, false, 5*time.Second) if err != nil { f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err) return @@ -223,8 +228,23 @@ func (f *Forwarder) handleICMPv6(id stack.TransportEndpointID, pkt *stack.Packet return f.handleICMPv6Echo(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code())) } - f.logger.Debug2("forwarder: Unhandled ICMPv6 type %v for %v", icmpHdr.Type(), epID(id)) - return false + // For non-echo types (Destination Unreachable, Packet Too Big, etc), forward without waiting + if !f.hasRawICMPAccess { + f.logger.Debug2("forwarder: Cannot handle ICMPv6 type %v without raw socket access for %v", icmpHdr.Type(), epID(id)) + return false + } + + icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice() + conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), true, 100*time.Millisecond) + if err != nil { + f.logger.Error2("forwarder: Failed to forward ICMPv6 packet for %v: %v", epID(id), err) + return true + } + if err := conn.Close(); err != nil { + f.logger.Debug1("forwarder: Failed to close ICMPv6 socket: %v", err) + } + + return true } // handleICMPv6Echo handles ICMPv6 echo requests using the ping binary. @@ -273,14 +293,12 @@ func (f *Forwarder) synthesizeICMPv6EchoReply(id stack.TransportEndpointID, icmp replyHdr := header.ICMPv6(replyICMP) replyHdr.SetType(header.ICMPv6EchoReply) replyHdr.SetChecksum(0) - // ICMPv6 checksum requires a pseudo-header - psum := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, id.LocalAddress, id.RemoteAddress, uint16(len(replyICMP))) + // ICMPv6Checksum computes the pseudo-header internally from Src/Dst. + // Header contains the full ICMP message, so PayloadCsum/PayloadLen are zero. replyHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: replyHdr, - Src: id.LocalAddress, - Dst: id.RemoteAddress, - PayloadCsum: psum, - PayloadLen: len(replyICMP) - header.ICMPv6MinimumSize, + Header: replyHdr, + Src: id.LocalAddress, + Dst: id.RemoteAddress, })) return f.injectICMPv6Reply(id, replyICMP) diff --git a/client/firewall/uspfilter/localip_bench_test.go b/client/firewall/uspfilter/localip_bench_test.go index 11bdee7ab..14e12bd08 100644 --- a/client/firewall/uspfilter/localip_bench_test.go +++ b/client/firewall/uspfilter/localip_bench_test.go @@ -20,7 +20,9 @@ func setupManager(b *testing.B) *localIPManager { } }, } - _ = m.UpdateLocalIPs(mock) + if err := m.UpdateLocalIPs(mock); err != nil { + b.Fatalf("UpdateLocalIPs: %v", err) + } return m } diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index 0bd7d1de8..8d030aff0 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -162,7 +162,7 @@ func (p *PacketBuilder) buildTransportLayer(ipLayer gopacket.SerializableLayer) case "udp": return p.buildUDPLayer(ipLayer) case "icmp": - return p.buildICMPLayer() + return p.buildICMPLayer(ipLayer) default: return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol) } @@ -201,11 +201,14 @@ func (p *PacketBuilder) buildUDPLayer(ipLayer gopacket.SerializableLayer) ([]gop return []gopacket.SerializableLayer{udp}, nil } -func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) { +func (p *PacketBuilder) buildICMPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) { if p.SrcIP.Is6() || p.DstIP.Is6() { icmp := &layers.ICMPv6{ TypeCode: layers.CreateICMPv6TypeCode(p.ICMPType, p.ICMPCode), } + if nl, ok := ipLayer.(gopacket.NetworkLayer); ok { + _ = icmp.SetNetworkLayerForChecksum(nl) + } return []gopacket.SerializableLayer{icmp}, nil } icmp := &layers.ICMPv4{ From ed5cfa6dc57177cb60a85527592f919516b54481 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 27 Mar 2026 15:37:32 +0100 Subject: [PATCH 2/2] Fix CodeRabbit findings: fragment guard, v6 raw socket probe, v6 echo logging --- .../firewall/uspfilter/forwarder/forwarder.go | 89 ++++++++++++------- client/firewall/uspfilter/forwarder/icmp.go | 9 +- 2 files changed, 62 insertions(+), 36 deletions(-) diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 770f3fbea..8290e3d5c 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -37,17 +37,18 @@ type Forwarder struct { logger *nblog.Logger flowLogger nftypes.FlowLogger // ruleIdMap is used to store the rule ID for a given connection - ruleIdMap sync.Map - stack *stack.Stack - endpoint *endpoint - udpForwarder *udpForwarder - ctx context.Context - cancel context.CancelFunc - ip tcpip.Address - ipv6 tcpip.Address - netstack bool - hasRawICMPAccess bool - pingSemaphore chan struct{} + ruleIdMap sync.Map + stack *stack.Stack + endpoint *endpoint + udpForwarder *udpForwarder + ctx context.Context + cancel context.CancelFunc + ip tcpip.Address + ipv6 tcpip.Address + netstack bool + hasRawICMPAccess bool + hasRawICMPv6Access bool + pingSemaphore chan struct{} } func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) { @@ -214,31 +215,47 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error { // // Unlike gVisor's network layer, this does not validate ICMP checksums or // reassemble IP fragments. Fragmented ICMP packets fall through to gVisor. +func parseICMPv4(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) { + ip := header.IPv4(payload) + if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) { + return 0, src, dst, false + } + if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 { + return 0, src, dst, false + } + ipHdrLen = int(ip.HeaderLength()) + if len(payload)-ipHdrLen < header.ICMPv4MinimumSize { + return 0, src, dst, false + } + return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true +} + +func parseICMPv6(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) { + ip := header.IPv6(payload) + if ip.NextHeader() != uint8(header.ICMPv6ProtocolNumber) { + return 0, src, dst, false + } + ipHdrLen = header.IPv6MinimumSize + if len(payload)-ipHdrLen < header.ICMPv6MinimumSize { + return 0, src, dst, false + } + return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true +} + func (f *Forwarder) handleICMPDirect(payload []byte) bool { var ( ipHdrLen int srcAddr tcpip.Address dstAddr tcpip.Address + ok bool ) - switch payload[0] >> 4 { case 4: - ip := header.IPv4(payload) - if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) { - return false - } - ipHdrLen = int(ip.HeaderLength()) - srcAddr = ip.SourceAddress() - dstAddr = ip.DestinationAddress() + ipHdrLen, srcAddr, dstAddr, ok = parseICMPv4(payload) case 6: - ip := header.IPv6(payload) - if ip.NextHeader() != uint8(header.ICMPv6ProtocolNumber) { - return false - } - ipHdrLen = header.IPv6MinimumSize - srcAddr = ip.SourceAddress() - dstAddr = ip.DestinationAddress() - default: + ipHdrLen, srcAddr, dstAddr, ok = parseICMPv6(payload) + } + if !ok { return false } @@ -352,21 +369,25 @@ func addrToNetipAddr(addr tcpip.Address) netip.Addr { // checkICMPCapability tests whether we have raw ICMP socket access at startup. func (f *Forwarder) checkICMPCapability() { + f.hasRawICMPAccess = probeRawICMP("ip4:icmp", "0.0.0.0", f.logger) + f.hasRawICMPv6Access = probeRawICMP("ip6:ipv6-icmp", "::", f.logger) +} + +func probeRawICMP(network, addr string, logger *nblog.Logger) bool { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() lc := net.ListenConfig{} - conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") + conn, err := lc.ListenPacket(ctx, network, addr) if err != nil { - f.hasRawICMPAccess = false - f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback") - return + logger.Debug1("forwarder: no raw %s socket access, will use ping binary fallback", network) + return false } if err := conn.Close(); err != nil { - f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err) + logger.Debug2("forwarder: failed to close %s capability test socket: %v", network, err) } - f.hasRawICMPAccess = true - f.logger.Debug("forwarder: Raw ICMP socket access available") + logger.Debug1("forwarder: raw %s socket access available", network) + return true } diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index 4b3d01779..2c35aa748 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -229,7 +229,7 @@ func (f *Forwarder) handleICMPv6(id stack.TransportEndpointID, pkt *stack.Packet } // For non-echo types (Destination Unreachable, Packet Too Big, etc), forward without waiting - if !f.hasRawICMPAccess { + if !f.hasRawICMPv6Access { f.logger.Debug2("forwarder: Cannot handle ICMPv6 type %v without raw socket access for %v", icmpHdr.Type(), epID(id)) return false } @@ -279,9 +279,14 @@ func (f *Forwarder) handleICMPv6ViaPing(flowID uuid.UUID, id stack.TransportEndp } rtt := time.Since(pingStart).Round(10 * time.Microsecond) - f.logger.Trace4("forwarder: Forwarded ICMPv6 echo reply %v type %v code %v (rtt=%v)", epID(id), icmpType, icmpCode, rtt) + f.logger.Trace3("forwarder: Forwarded ICMPv6 echo request %v type %v code %v", + epID(id), icmpType, icmpCode) txBytes := f.synthesizeICMPv6EchoReply(id, icmpData) + + f.logger.Trace4("forwarder: Forwarded ICMPv6 echo reply %v type %v code %v (rtt=%v, ping binary)", + epID(id), icmpType, icmpCode, rtt) + f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) }