diff --git a/flow/client/client_test.go b/flow/client/client_test.go index 300d19b05..98ce5aa91 100644 --- a/flow/client/client_test.go +++ b/flow/client/client_test.go @@ -29,6 +29,7 @@ type testServer struct { addr string listener *connTrackListener closeStream chan struct{} // signal server to close the stream + handlerDone chan struct{} // signaled each time Events() exits } // connTrackListener wraps a net.Listener to track accepted connections @@ -102,6 +103,7 @@ func newTestServer(t *testing.T) *testServer { addr: rawListener.Addr().String(), listener: listener, closeStream: make(chan struct{}, 1), + handlerDone: make(chan struct{}, 10), } proto.RegisterFlowServiceServer(s.grpcSrv, s) @@ -120,6 +122,13 @@ func newTestServer(t *testing.T) *testServer { } func (s *testServer) Events(stream proto.FlowService_EventsServer) error { + defer func() { + select { + case s.handlerDone <- struct{}{}: + default: + } + }() + err := stream.Send(&proto.FlowEventAck{IsInitiator: true}) if err != nil { return err @@ -475,6 +484,15 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) { // This produces the exact error the client sees in production: // "stream terminated by RST_STREAM with error code: PROTOCOL_ERROR" server.listener.sendRSTStream(1) + + // Wait for the old Events() handler to fully exit so it can no longer + // drain s.acks and drop our injected ack on a broken stream. + select { + case <-server.handlerDone: + case <-time.After(5 * time.Second): + t.Fatal("old Events() handler did not exit after RST_STREAM") + } + require.Eventually(t, func() bool { return server.listener.connCount() > connsBefore }, 5*time.Second, 50*time.Millisecond, "client did not open a new TCP connection after RST_STREAM")