Compare commits

...

33 Commits

Author SHA1 Message Date
Zoltan Papp
38f2a59d1b Add comment 2024-06-12 10:56:21 +02:00
Zoltan Papp
9504012920 Set the proper buffer size in the client code 2024-06-09 21:10:57 +02:00
Zoltan Papp
5e93d117cf Use buf pool
- eliminate reader function generation
- fix write to closed channel panic
2024-06-09 20:33:35 +02:00
Zoltan Papp
8c70b7d7ff Replace ws lib on client side 2024-06-09 12:41:52 +02:00
Zoltan Papp
ed8def4d9b Protect ws writing in Gorilla ws 2024-06-07 16:07:35 +02:00
Zoltán Papp
1e115e3893 Merge branch 'main' into feature/relay 2024-06-06 13:38:40 +02:00
Zoltan Papp
fed9e587af Add close message type 2024-06-05 19:49:30 +02:00
Zoltan Papp
a40d4d2f32 - add comments
- avoid double closing messages
- add cleanup routine for relay manager
2024-06-04 14:40:35 +02:00
Zoltán Papp
15818b72c6 Add alternative ws server implementation 2024-06-03 21:38:37 +02:00
Zoltán Papp
0556dc1860 Avoid nil pointer exception in test in case of err 2024-06-03 21:36:46 +02:00
Zoltán Papp
2b369cd28f Add quic transporter 2024-06-03 20:17:43 +02:00
Zoltán Papp
9d44a476c6 Fix double unlock in client.go 2024-06-03 20:14:39 +02:00
Zoltán Papp
57ddb5f262 Add comment 2024-06-03 11:22:16 +02:00
Zoltan Papp
4ced07dd8d Fix close conn threading issue 2024-06-03 01:37:56 +02:00
Zoltán Papp
3430b81622 Add relay server tracking 2024-06-01 11:48:15 +02:00
Zoltán Papp
fd4ad15c83 Move reconnection logic to separated struct 2024-06-01 11:25:00 +02:00
Zoltán Papp
4ff069a102 Support multiple server 2024-05-29 16:40:26 +02:00
Zoltán Papp
7cc3964a4d Use mux for http server
Without it can not start multiple http
server instances for unit tests
2024-05-29 16:11:58 +02:00
Zoltan Papp
6d627f1923 Code cleaning 2024-05-28 01:27:53 +02:00
Zoltan Papp
076ce69a24 Add reconnect logic 2024-05-28 01:00:25 +02:00
Zoltán Papp
645a1f31a7 Fix writing/reading to a closed conn 2024-05-27 10:25:08 +02:00
Zoltán Papp
b4aa7e50f9 Close sockets on server cmd 2024-05-27 09:42:27 +02:00
Zoltán Papp
173ca25dac Fix in client the close event 2024-05-26 22:14:33 +02:00
Zoltán Papp
36b2cd16cc Remove channel binding logic 2024-05-23 13:24:02 +02:00
Zoltán Papp
0a05f8b4d4 Use buffer pool and protect exported functions 2024-05-22 00:38:41 +02:00
Zoltán Papp
e82c0a55a3 Set to blocking the message queue 2024-05-21 16:21:29 +02:00
Zoltán Papp
13eb457132 Add registration response message to the communication 2024-05-21 15:51:37 +02:00
Zoltan Papp
1c9c9ae47e Remove sync.pool 2024-05-20 11:38:23 +02:00
Zoltan Papp
9ac5a1ed3f Add udp listener and did some change for debug purpose. 2024-05-19 12:41:06 +02:00
Zoltan Papp
d4eaec5cbd Followup messages modification 2024-05-17 23:41:47 +02:00
Zoltan Papp
6ae7a790f2 Fix buffer handling 2024-05-17 23:29:47 +02:00
Zoltan Papp
49dfbc82d9 Add relay cmd 2024-05-17 20:24:06 +02:00
Zoltan Papp
57a89cf0cc Add initial relay code 2024-05-17 17:43:28 +02:00
33 changed files with 3152 additions and 5 deletions

11
go.mod
View File

@@ -12,7 +12,7 @@ require (
github.com/gorilla/mux v1.8.0
github.com/kardianos/service v1.2.1-0.20210728001519-a323c3813bc7
github.com/onsi/ginkgo v1.16.5
github.com/onsi/gomega v1.23.0
github.com/onsi/gomega v1.27.6
github.com/pion/ice/v3 v3.0.2
github.com/rs/cors v1.8.0
github.com/sirupsen/logrus v1.9.3
@@ -47,6 +47,7 @@ require (
github.com/google/martian/v3 v3.0.0
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/gopacket/gopacket v1.1.1
github.com/gorilla/websocket v1.5.1
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
@@ -67,6 +68,7 @@ require (
github.com/pion/transport/v3 v3.0.1
github.com/pion/turn/v3 v3.0.1
github.com/prometheus/client_golang v1.19.1
github.com/quic-go/quic-go v0.45.0
github.com/rs/xid v1.3.0
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/stretchr/testify v1.9.0
@@ -91,6 +93,7 @@ require (
gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.3
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
nhooyr.io/websocket v1.8.11
)
require (
@@ -132,10 +135,12 @@ require (
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-stack/stack v1.8.0 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/btree v1.0.1 // indirect
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
@@ -161,6 +166,7 @@ require (
github.com/moby/term v0.5.0 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/nxadm/tail v1.4.8 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
@@ -188,9 +194,12 @@ require (
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect
go.uber.org/mock v0.4.0 // indirect
golang.org/x/image v0.10.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/text v0.15.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.21.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect

30
go.sum
View File

@@ -50,6 +50,9 @@ github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
@@ -154,6 +157,8 @@ github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
@@ -213,6 +218,8 @@ github.com/google/martian/v3 v3.0.0 h1:pMen7vLs8nvgEYhywH3KDWJIJTeEr2ULsVWHWYHQy
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -230,6 +237,8 @@ github.com/gopherjs/gopherjs v0.0.0-20220410123724-9e86199038b0 h1:fWY+zXdWhvWnd
github.com/gopherjs/gopherjs v0.0.0-20220410123724-9e86199038b0/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 h1:Fkzd8ktnpOR9h47SXHe2AYPwelXLH2GjGsjlAloiWfo=
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357/go.mod h1:w9Y7gY31krpLmrVU5ZPG9H7l9fZuRu5/3R3S3FMtVQ4=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 h1:YBftPWNWd4WwGqtY2yeZL2ef8rHAxPBD8KFhJpmcqms=
@@ -247,6 +256,7 @@ github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mO
github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
@@ -356,14 +366,14 @@ github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+
github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/ginkgo/v2 v2.4.0 h1:+Ig9nvqgS5OBSACXNk15PLdp0U9XPYROt9CFzVdFGIs=
github.com/onsi/ginkgo/v2 v2.4.0/go.mod h1:iHkDK1fKGcBoEHT5W7YBq4RFWaQulw+caOMkAt4OrFo=
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA=
github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/onsi/gomega v1.23.0 h1:/oxKu9c2HVap+F3PfKort2Hw5DEU+HGlW8n+tguWsys=
github.com/onsi/gomega v1.23.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
@@ -415,6 +425,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
github.com/quic-go/quic-go v0.45.0 h1:OHmkQGM37luZITyTSu6ff03HP/2IrwDX1ZFiNEhSFUE=
github.com/quic-go/quic-go v0.45.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so=
@@ -462,6 +474,7 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
@@ -522,6 +535,8 @@ go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2L
go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4=
goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
@@ -554,6 +569,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20170114055629-f2499483f923/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -608,6 +625,7 @@ golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -676,6 +694,8 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -770,5 +790,7 @@ k8s.io/gengo v0.0.0-20190128074634-0689ccc1d7d6/go.mod h1:ezvh/TsK7cY6rbqRK0oQQ8
k8s.io/klog v0.0.0-20181102134211-b9b56d5dfc92/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk=
k8s.io/klog v1.0.0/go.mod h1:4Bi6QPql/J/LkTDqv7R/cd3hPo4k2DG6Ptcz060Ez5I=
k8s.io/kube-openapi v0.0.0-20191107075043-30be4d16710a/go.mod h1:1TqjTSzOxsLGIKfj0lK8EeCP7K1iUG65v09OM0/WG5E=
nhooyr.io/websocket v1.8.11 h1:f/qXNc2/3DpoSZkHt1DQu6rj4zGC8JmkkLkWss0MgN0=
nhooyr.io/websocket v1.8.11/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c=
sigs.k8s.io/structured-merge-diff v0.0.0-20190525122527-15d366b2352e/go.mod h1:wWxsB5ozmmv/SG7nM11ayaAW51xMvak/t1r0CSlcokI=
sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o=

419
relay/client/client.go Normal file
View File

@@ -0,0 +1,419 @@
package client
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
ws "github.com/netbirdio/netbird/relay/client/dialer/wsnhooyr"
"github.com/netbirdio/netbird/relay/messages"
)
const (
bufferSize = 8820
serverResponseTimeout = 8 * time.Second
)
// Msg carry the payload from the server to the client. With this sturct, the net.Conn can free the buffer.
type Msg struct {
Payload []byte
bufPool *sync.Pool
bufPtr *[]byte
}
func (m *Msg) Free() {
m.bufPool.Put(m.bufPtr)
}
type connContainer struct {
conn *Conn
messages chan Msg
msgChanLock sync.Mutex
closed bool // flag to check if channel is closed
}
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
return &connContainer{
conn: conn,
messages: messages,
}
}
func (cc *connContainer) writeMsg(msg Msg) {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
cc.messages <- msg
}
func (cc *connContainer) close() {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
close(cc.messages)
cc.closed = true
}
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection,
// the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
// While the Connect is in progress, the OpenConn function will block until the connection is established.
type Client struct {
log *log.Entry
parentCtx context.Context
ctxCancel context.CancelFunc
serverAddress string
hashedID []byte
bufPool *sync.Pool
relayConn net.Conn
conns map[string]*connContainer
serviceIsRunning bool
mu sync.Mutex
readLoopMutex sync.Mutex
wgReadLoop sync.WaitGroup
remoteAddr net.Addr
onDisconnectListener func()
listenerMutex sync.Mutex
}
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverAddress, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID)
return &Client{
log: log.WithField("client_id", hashedStringId),
parentCtx: ctx,
ctxCancel: func() {},
serverAddress: serverAddress,
hashedID: hashedID,
bufPool: &sync.Pool{
New: func() any {
buf := make([]byte, bufferSize)
return &buf
},
},
conns: make(map[string]*connContainer),
}
}
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect() error {
c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
if c.serviceIsRunning {
return nil
}
err := c.connect()
if err != nil {
return err
}
c.serviceIsRunning = true
var ctx context.Context
ctx, c.ctxCancel = context.WithCancel(c.parentCtx)
context.AfterFunc(ctx, func() {
cErr := c.close(false)
if cErr != nil {
log.Errorf("failed to close relay connection: %s", cErr)
}
})
c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn)
return nil
}
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
// to the relay server, the function will block until the connection is established or timed out. Otherwise,
// it will return immediately.
// todo: what should happen if call with the same peerID with multiple times?
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
c.mu.Lock()
defer c.mu.Unlock()
if !c.serviceIsRunning {
return nil, fmt.Errorf("relay connection is not established")
}
hashedID, hashedStringID := messages.HashID(dstPeerID)
log.Infof("open connection to peer: %s", hashedStringID)
msgChannel := make(chan Msg, 2)
conn := NewConn(c, hashedID, hashedStringID, msgChannel)
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
return conn, nil
}
// RelayRemoteAddress returns the IP address of the relay server. It could change after the close and reopen the connection.
func (c *Client) RelayRemoteAddress() (net.Addr, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.remoteAddr == nil {
return nil, fmt.Errorf("relay connection is not established")
}
return c.remoteAddr, nil
}
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
func (c *Client) SetOnDisconnectListener(fn func()) {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
c.onDisconnectListener = fn
}
func (c *Client) HasConns() bool {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.conns) > 0
}
// Close closes the connection to the relay server and all connections to other peers.
func (c *Client) Close() error {
return c.close(false)
}
func (c *Client) close(byServer bool) error {
c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock()
c.mu.Lock()
var err error
if !c.serviceIsRunning {
c.mu.Unlock()
return nil
}
c.serviceIsRunning = false
c.closeAllConns()
if !byServer {
c.writeCloseMsg()
err = c.relayConn.Close()
}
c.mu.Unlock()
c.wgReadLoop.Wait()
c.log.Infof("relay connection closed with: %s", c.serverAddress)
c.ctxCancel()
return err
}
func (c *Client) connect() error {
conn, err := ws.Dial(c.serverAddress)
if err != nil {
return err
}
c.relayConn = conn
err = c.handShake()
if err != nil {
cErr := conn.Close()
if cErr != nil {
log.Errorf("failed to close connection: %s", cErr)
}
c.relayConn = nil
return err
}
c.remoteAddr = conn.RemoteAddr()
return nil
}
func (c *Client) handShake() error {
defer func() {
err := c.relayConn.SetReadDeadline(time.Time{})
if err != nil {
log.Errorf("failed to reset read deadline: %s", err)
}
}()
msg, err := messages.MarshalHelloMsg(c.hashedID)
if err != nil {
log.Errorf("failed to marshal hello message: %s", err)
return err
}
_, err = c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to send hello message: %s", err)
return err
}
err = c.relayConn.SetReadDeadline(time.Now().Add(serverResponseTimeout))
if err != nil {
log.Errorf("failed to set read deadline: %s", err)
return err
}
buf := make([]byte, 1500) // todo: optimise buffer size
n, err := c.relayConn.Read(buf)
if err != nil {
log.Errorf("failed to read hello response: %s", err)
return err
}
msgType, err := messages.DetermineServerMsgType(buf[:n])
if err != nil {
log.Errorf("failed to determine message type: %s", err)
return err
}
if msgType != messages.MsgTypeHelloResponse {
log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type")
}
return nil
}
func (c *Client) readLoop(relayConn net.Conn) {
var (
errExit error
n int
closedByServer bool
)
for {
bufPtr := c.bufPool.Get().(*[]byte)
buf := *bufPtr
n, errExit = relayConn.Read(buf)
if errExit != nil {
c.mu.Lock()
if c.serviceIsRunning {
c.log.Debugf("failed to read message from relay server: %s", errExit)
}
c.mu.Unlock()
goto Exit
}
msgType, err := messages.DetermineServerMsgType(buf[:n])
if err != nil {
c.log.Errorf("failed to determine message type: %s", err)
continue
}
switch msgType {
case messages.MsgTypeTransport:
peerID, payload, err := messages.UnmarshalTransportMsg(buf[:n])
if err != nil {
c.log.Errorf("failed to parse transport message: %v", err)
continue
}
stringID := messages.HashIDToString(peerID)
c.mu.Lock()
if !c.serviceIsRunning {
c.mu.Unlock()
goto Exit
}
container, ok := c.conns[stringID]
c.mu.Unlock()
if !ok {
c.log.Errorf("peer not found: %s", stringID)
continue
}
container.writeMsg(Msg{
bufPool: c.bufPool,
bufPtr: bufPtr,
Payload: payload})
case messages.MsgClose:
closedByServer = true
log.Debugf("relay connection close by server")
goto Exit
}
}
Exit:
c.notifyDisconnected()
c.wgReadLoop.Done()
_ = c.close(closedByServer)
}
// todo check by reference too, the id is not enought because the id come from the outer conn
func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
c.mu.Lock()
// conn, ok := c.conns[id]
_, ok := c.conns[id]
c.mu.Unlock()
if !ok {
return 0, io.EOF
}
/*
if conn != clientRef {
return 0, io.EOF
}
*/
msg := messages.MarshalTransportMsg(dstID, payload)
n, err := c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to write transport message: %s", err)
}
return n, err
}
func (c *Client) closeAllConns() {
for _, container := range c.conns {
container.close()
}
c.conns = make(map[string]*connContainer)
}
// todo check by reference too, the id is not enought because the id come from the outer conn
func (c *Client) closeConn(id string) error {
c.mu.Lock()
defer c.mu.Unlock()
container, ok := c.conns[id]
if !ok {
return fmt.Errorf("connection already closed")
}
container.close()
delete(c.conns, id)
return nil
}
func (c *Client) onDisconnect() {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
if c.onDisconnectListener == nil {
return
}
c.onDisconnectListener()
}
func (c *Client) notifyDisconnected() {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
if c.onDisconnectListener == nil {
return
}
go c.onDisconnectListener()
}
func (c *Client) writeCloseMsg() {
msg := messages.MarshalCloseMsg()
_, err := c.relayConn.Write(msg)
if err != nil {
c.log.Errorf("failed to send close message: %s", err)
}
}

523
relay/client/client_test.go Normal file
View File

@@ -0,0 +1,523 @@
package client
import (
"context"
"net"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/relay/server"
)
func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console")
code := m.Run()
os.Exit(code)
}
func TestClient(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, "alice")
err := clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientAlice.Close()
clientPlaceHolder := NewClient(ctx, addr, "clientPlaceHolder")
err = clientPlaceHolder.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientPlaceHolder.Close()
clientBob := NewClient(ctx, addr, "bob")
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientBob.Close()
connAliceToBob, err := clientAlice.OpenConn("bob")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn("alice")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
log.Debugf("alice sent message to bob")
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
log.Debugf("on new message from alice to bob")
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestRegistration(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, "alice")
err := clientAlice.Connect()
if err != nil {
_ = srv.Close()
t.Fatalf("failed to connect to server: %s", err)
}
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close conn: %s", err)
}
err = srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}
func TestRegistrationTimeout(t *testing.T) {
ctx := context.Background()
fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{
Port: 1234,
IP: net.ParseIP("0.0.0.0"),
})
if err != nil {
t.Fatalf("failed to bind UDP server: %s", err)
}
defer func(fakeUDPListener *net.UDPConn) {
_ = fakeUDPListener.Close()
}(fakeUDPListener)
fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{
Port: 1234,
IP: net.ParseIP("0.0.0.0"),
})
if err != nil {
t.Fatalf("failed to bind TCP server: %s", err)
}
defer func(fakeTCPListener *net.TCPListener) {
_ = fakeTCPListener.Close()
}(fakeTCPListener)
clientAlice := NewClient(ctx, "127.0.0.1:1234", "alice")
err = clientAlice.Connect()
if err == nil {
t.Errorf("failed to connect to server: %s", err)
}
log.Debugf("%s", err)
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close conn: %s", err)
}
}
func TestEcho(t *testing.T) {
ctx := context.Background()
idAlice := "alice"
idBob := "bob"
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, idAlice)
err := clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientAlice.Close()
if err != nil {
t.Errorf("failed to close Alice client: %s", err)
}
}()
clientBob := NewClient(ctx, addr, idBob)
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientBob.Close()
if err != nil {
t.Errorf("failed to close Bob client: %s", err)
}
}()
connAliceToBob, err := clientAlice.OpenConn(idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
_, err = connBobToAlice.Write(buf[:n])
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
n, err = connAliceToBob.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestBindToUnavailabePeer(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
log.Infof("closing server")
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, "alice")
err := clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
_, err = clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing client")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
}
func TestBindReconnect(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Errorf("failed to bind server: %s", err)
}
}()
defer func() {
log.Infof("closing server")
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, "alice")
err := clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
_, err = clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
clientBob := NewClient(ctx, addr, "bob")
err = clientBob.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
chBob, err := clientBob.OpenConn("alice")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing client Alice")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
clientAlice = NewClient(ctx, addr, "alice")
err = clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
chAlice, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
testString := "hello alice, I am bob"
_, err = chBob.Write([]byte(testString))
if err != nil {
t.Errorf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := chAlice.Read(buf)
if err != nil {
t.Errorf("failed to read from channel: %s", err)
}
if testString != string(buf[:n]) {
t.Errorf("expected %s, got %s", testString, string(buf[:n]))
}
log.Infof("closing client")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
}
func TestCloseConn(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Errorf("failed to bind server: %s", err)
}
}()
defer func() {
log.Infof("closing server")
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, "alice")
err := clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing connection")
err = conn.Close()
if err != nil {
t.Errorf("failed to close connection: %s", err)
}
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
_, err = conn.Write([]byte("hello"))
if err == nil {
t.Errorf("unexpected writing from closed connection")
}
}
func TestCloseRelayConn(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Errorf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv.Close()
if err != nil {
log.Errorf("failed to close server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, "alice")
err := clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
_ = clientAlice.relayConn.Close()
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
_, err = clientAlice.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
}
func TestCloseByServer(t *testing.T) {
ctx := context.Background()
addr1 := "localhost:1234"
srv1 := server.NewServer()
go func() {
err := srv1.Listen(addr1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(ctx, addr1, idAlice)
err := relayClient.Connect()
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
disconnected := make(chan struct{})
relayClient.SetOnDisconnectListener(func() {
log.Infof("client disconnected")
close(disconnected)
})
err = srv1.Close()
if err != nil {
t.Fatalf("failed to close server: %s", err)
}
select {
case <-disconnected:
case <-time.After(3 * time.Second):
log.Fatalf("timeout waiting for client to disconnect")
}
_, err = relayClient.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
}
func TestCloseByClient(t *testing.T) {
ctx := context.Background()
addr1 := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(ctx, addr1, idAlice)
err := relayClient.Connect()
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
err = relayClient.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
_, err = relayClient.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
err = srv.Close()
if err != nil {
t.Fatalf("failed to close server: %s", err)
}
}

67
relay/client/conn.go Normal file
View File

@@ -0,0 +1,67 @@
package client
import (
"io"
"net"
"time"
)
type Conn struct {
client *Client
dstID []byte
dstStringID string
messageChan chan Msg
}
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg) *Conn {
c := &Conn{
client: client,
dstID: dstID,
dstStringID: dstStringID,
messageChan: messageChan,
}
return c
}
func (c *Conn) Write(p []byte) (n int, err error) {
return c.client.writeTo(c.dstStringID, c.dstID, p)
}
func (c *Conn) Read(b []byte) (n int, err error) {
msg, ok := <-c.messageChan
if !ok {
return 0, io.EOF
}
n = copy(b, msg.Payload)
msg.Free()
return n, nil
}
func (c *Conn) Close() error {
return c.client.closeConn(c.dstStringID)
}
func (c *Conn) LocalAddr() net.Addr {
return c.client.relayConn.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.client.relayConn.RemoteAddr()
}
func (c *Conn) SetDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (c *Conn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}

View File

@@ -0,0 +1,52 @@
package quic
import (
"net"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
)
type Conn struct {
quic.Stream
qConn quic.Connection
}
func NewConn(stream quic.Stream, qConn quic.Connection) net.Conn {
return &Conn{
Stream: stream,
qConn: qConn,
}
}
func (q *Conn) Write(b []byte) (n int, err error) {
log.Debugf("writing: %d, %x\n", len(b), b)
n, err = q.Stream.Write(b)
if n != len(b) {
log.Errorf("failed to write out the full message")
}
return
}
func (q *Conn) Close() error {
err := q.Stream.Close()
if err != nil {
log.Errorf("failed to close stream: %s", err)
return err
}
err = q.qConn.CloseWithError(0, "")
if err != nil {
log.Errorf("failed to close connection: %s", err)
return err
}
return err
}
func (c *Conn) LocalAddr() net.Addr {
return c.qConn.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.qConn.RemoteAddr()
}

View File

@@ -0,0 +1,32 @@
package quic
import (
"context"
"crypto/tls"
"net"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
)
func Dial(address string) (net.Conn, error) {
tlsConf := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"quic-echo-example"},
}
qConn, err := quic.DialAddr(context.Background(), address, tlsConf, &quic.Config{
EnableDatagrams: true,
})
if err != nil {
log.Errorf("dial quic address %s failed: %s", address, err)
return nil, err
}
stream, err := qConn.OpenStreamSync(context.Background())
if err != nil {
return nil, err
}
conn := NewConn(stream, qConn)
return conn, nil
}

View File

@@ -0,0 +1,7 @@
package tcp
import "net"
func Dial(address string) (net.Conn, error) {
return net.Dial("tcp", address)
}

View File

@@ -0,0 +1,14 @@
package udp
import (
"net"
)
func Dial(address string) (net.Conn, error) {
udpAddr, err := net.ResolveUDPAddr("udp", address)
if err != nil {
return nil, err
}
return net.DialUDP("udp", nil, udpAddr)
}

View File

@@ -0,0 +1,59 @@
package ws
import (
"fmt"
"net"
"sync"
"time"
"github.com/gorilla/websocket"
)
type Conn struct {
*websocket.Conn
mu sync.Mutex
}
func NewConn(wsConn *websocket.Conn) net.Conn {
return &Conn{
Conn: wsConn,
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
t, r, err := c.NextReader()
if err != nil {
return 0, err
}
if t != websocket.BinaryMessage {
return 0, fmt.Errorf("unexpected message type")
}
return r.Read(b)
}
func (c *Conn) Write(b []byte) (int, error) {
c.mu.Lock()
err := c.WriteMessage(websocket.BinaryMessage, b)
c.mu.Unlock()
return len(b), err
}
func (c *Conn) SetDeadline(t time.Time) error {
errR := c.SetReadDeadline(t)
errW := c.SetWriteDeadline(t)
if errR != nil {
return errR
}
if errW != nil {
return errW
}
return nil
}
func (c *Conn) Close() error {
return c.Conn.Close()
}

View File

@@ -0,0 +1,22 @@
package ws
import (
"fmt"
"net"
"time"
"github.com/gorilla/websocket"
)
func Dial(address string) (net.Conn, error) {
addr := fmt.Sprintf("ws://" + address)
wsDialer := websocket.Dialer{
HandshakeTimeout: 3 * time.Second,
}
wsConn, _, err := wsDialer.Dial(addr, nil)
if err != nil {
return nil, err
}
conn := NewConn(wsConn)
return conn, nil
}

View File

@@ -0,0 +1,79 @@
package wsnhooyr
import (
"context"
"fmt"
"net"
"time"
"nhooyr.io/websocket"
)
type Conn struct {
*websocket.Conn
ctx context.Context
}
func NewConn(wsConn *websocket.Conn) net.Conn {
return &Conn{
Conn: wsConn,
ctx: context.Background(),
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
t, ioReader, err := c.Conn.Reader(c.ctx)
if err != nil {
return 0, err
}
if t != websocket.MessageBinary {
return 0, fmt.Errorf("unexpected message type")
}
return ioReader.Read(b)
}
func (c *Conn) Write(b []byte) (n int, err error) {
err = c.Conn.Write(c.ctx, websocket.MessageBinary, b)
return len(b), err
}
func (c *Conn) RemoteAddr() net.Addr {
// todo: implement me
return nil
}
func (c *Conn) LocalAddr() net.Addr {
// todo: implement me
return nil
}
func (c *Conn) SetReadDeadline(t time.Time) error {
// todo: implement me
return nil
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
// todo: implement me
return nil
}
func (c *Conn) SetDeadline(t time.Time) error {
// todo: implement me
errR := c.SetReadDeadline(t)
errW := c.SetWriteDeadline(t)
if errR != nil {
return errR
}
if errW != nil {
return errW
}
return nil
}
func (c *Conn) Close() error {
return c.Conn.CloseNow()
}

View File

@@ -0,0 +1,22 @@
package wsnhooyr
import (
"context"
"fmt"
"net"
"nhooyr.io/websocket"
)
func Dial(address string) (net.Conn, error) {
addr := fmt.Sprintf("ws://" + address)
wsConn, _, err := websocket.Dial(context.Background(), addr, nil)
if err != nil {
return nil, err
}
conn := NewConn(wsConn)
return conn, nil
}

33
relay/client/guard.go Normal file
View File

@@ -0,0 +1,33 @@
package client
import (
"context"
"time"
)
var (
reconnectingTimeout = 5 * time.Second
)
type Guard struct {
ctx context.Context
relayClient *Client
}
func NewGuard(context context.Context, relayClient *Client) *Guard {
g := &Guard{
ctx: context,
relayClient: relayClient,
}
return g
}
func (g *Guard) OnDisconnected() {
select {
case <-time.After(time.Second):
_ = g.relayClient.Connect()
case <-g.ctx.Done():
return
}
}

208
relay/client/manager.go Normal file
View File

@@ -0,0 +1,208 @@
package client
import (
"context"
"fmt"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
var (
relayCleanupInterval = 60 * time.Second
)
// RelayTrack hold the relay clients for the foregin relay servers.
// With the mutex can ensure we can open new connection in case the relay connection has been established with
// the relay server.
type RelayTrack struct {
sync.RWMutex
relayClient *Client
}
func NewRelayTrack() *RelayTrack {
return &RelayTrack{}
}
// Manager is a manager for the relay client. It establish one persistent connection to the given relay server. In case
// of network error the manager will try to reconnect to the server.
// The manager also manage temproary relay connection. If a client wants to communicate with an another client on a
// different relay server, the manager will establish a new connection to the relay server. The connection with these
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
// unused relay connection and close it.
type Manager struct {
ctx context.Context
srvAddress string
peerID string
relayClient *Client
reconnectGuard *Guard
relayClients map[string]*RelayTrack
relayClientsMutex sync.RWMutex
}
func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager {
return &Manager{
ctx: ctx,
srvAddress: serverAddress,
peerID: peerID,
relayClients: make(map[string]*RelayTrack),
}
}
// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop.
// todo: consider to return an error if the initial connection to the relay server is not established.
func (m *Manager) Serve() {
m.relayClient = NewClient(m.ctx, m.srvAddress, m.peerID)
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
m.relayClient.SetOnDisconnectListener(m.reconnectGuard.OnDisconnected)
err := m.relayClient.Connect()
if err != nil {
log.Errorf("failed to connect to relay server, keep try to reconnect: %s", err)
return
}
m.startCleanupLoop()
return
}
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server.
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
if m.relayClient == nil {
return nil, fmt.Errorf("relay client not connected")
}
foreign, err := m.isForeignServer(serverAddress)
if err != nil {
return nil, err
}
if !foreign {
log.Debugf("open connection to permanent server: %s", peerKey)
return m.relayClient.OpenConn(peerKey)
} else {
log.Debugf("open connection to foreign server: %s", serverAddress)
return m.openConnVia(serverAddress, peerKey)
}
}
// RelayAddress returns the address of the permanent relay server. It could change if the network connection is lost.
// This address will be sent to the target peer to choose the common relay server for the communication.
func (m *Manager) RelayAddress() (net.Addr, error) {
if m.relayClient == nil {
return nil, fmt.Errorf("relay client not connected")
}
return m.relayClient.RelayRemoteAddress()
}
func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
// check if already has a connection to the desired relay server
m.relayClientsMutex.RLock()
rt, ok := m.relayClients[serverAddress]
if ok {
rt.RLock()
m.relayClientsMutex.RUnlock()
defer rt.RUnlock()
return rt.relayClient.OpenConn(peerKey)
}
m.relayClientsMutex.RUnlock()
// if not, establish a new connection but check it again (because changed the lock type) before starting the
// connection
m.relayClientsMutex.Lock()
rt, ok = m.relayClients[serverAddress]
if ok {
rt.RLock()
m.relayClientsMutex.Unlock()
defer rt.RUnlock()
return rt.relayClient.OpenConn(peerKey)
}
// create a new relay client and store it in the relayClients map
rt = NewRelayTrack()
rt.Lock()
m.relayClients[serverAddress] = rt
m.relayClientsMutex.Unlock()
relayClient := NewClient(m.ctx, serverAddress, m.peerID)
err := relayClient.Connect()
if err != nil {
rt.Unlock()
m.relayClientsMutex.Lock()
delete(m.relayClients, serverAddress)
m.relayClientsMutex.Unlock()
return nil, err
}
// if connection closed then delete the relay client from the list
relayClient.SetOnDisconnectListener(func() {
m.deleteRelayConn(serverAddress)
})
rt.relayClient = relayClient
rt.Unlock()
conn, err := relayClient.OpenConn(peerKey)
if err != nil {
return nil, err
}
return conn, nil
}
func (m *Manager) deleteRelayConn(address string) {
log.Infof("deleting relay client for %s", address)
m.relayClientsMutex.Lock()
delete(m.relayClients, address)
m.relayClientsMutex.Unlock()
}
func (m *Manager) isForeignServer(address string) (bool, error) {
rAddr, err := m.relayClient.RelayRemoteAddress()
if err != nil {
return false, fmt.Errorf("relay client not connected")
}
return rAddr.String() != address, nil
}
func (m *Manager) startCleanupLoop() {
if m.ctx.Err() != nil {
return
}
ticker := time.NewTicker(relayCleanupInterval)
go func() {
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
m.cleanUpUnusedRelays()
}
}
}()
}
func (m *Manager) cleanUpUnusedRelays() {
m.relayClientsMutex.Lock()
defer m.relayClientsMutex.Unlock()
for addr, rt := range m.relayClients {
rt.Lock()
if rt.relayClient.HasConns() {
rt.Unlock()
continue
}
rt.relayClient.SetOnDisconnectListener(nil)
go func() {
_ = rt.relayClient.Close()
}()
log.Debugf("clean up relay client: %s", addr)
delete(m.relayClients, addr)
rt.Unlock()
}
}

View File

@@ -0,0 +1,271 @@
package client
import (
"context"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server"
)
func TestForeignConn(t *testing.T) {
ctx := context.Background()
addr1 := "localhost:1234"
srv1 := server.NewServer()
go func() {
err := srv1.Listen(addr1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv1.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
addr2 := "localhost:2234"
srv2 := server.NewServer()
go func() {
err := srv2.Listen(addr2)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv2.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, addr1, idAlice)
clientAlice.Serve()
idBob := "bob"
log.Debugf("connect by bob")
clientBob := NewManager(mCtx, addr2, idBob)
clientBob.Serve()
bobsSrvAddr, err := clientBob.RelayAddress()
if err != nil {
t.Fatalf("failed to get relay address: %s", err)
}
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr.String(), idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr.String(), idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
_, err = connBobToAlice.Write(buf[:n])
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
n, err = connAliceToBob.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestForeginConnClose(t *testing.T) {
ctx := context.Background()
addr1 := "localhost:1234"
srv1 := server.NewServer()
go func() {
err := srv1.Listen(addr1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv1.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
addr2 := "localhost:2234"
srv2 := server.NewServer()
go func() {
err := srv2.Listen(addr2)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv2.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgr := NewManager(mCtx, addr1, idAlice)
mgr.Serve()
conn, err := mgr.OpenConn(addr2, "anotherpeer")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
}
}
func TestForeginAutoClose(t *testing.T) {
ctx := context.Background()
relayCleanupInterval = 1 * time.Second
addr1 := "localhost:1234"
srv1 := server.NewServer()
go func() {
t.Log("binding server 1.")
err := srv1.Listen(addr1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
t.Logf("closing server 1.")
err := srv1.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
t.Logf("server 1. closed")
}()
addr2 := "localhost:2234"
srv2 := server.NewServer()
go func() {
t.Log("binding server 2.")
err := srv2.Listen(addr2)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
t.Logf("closing server 2.")
err := srv2.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
t.Logf("server 2 closed.")
}()
idAlice := "alice"
t.Log("connect to server 1.")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgr := NewManager(mCtx, addr1, idAlice)
mgr.Serve()
t.Log("open connection to another peer")
conn, err := mgr.OpenConn(addr2, "anotherpeer")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
t.Log("close conn")
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
}
t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
time.Sleep(relayCleanupInterval + 1*time.Second)
if len(mgr.relayClients) != 0 {
t.Errorf("expected 0, got %d", len(mgr.relayClients))
}
t.Logf("closing manager")
}
func TestAutoReconnect(t *testing.T) {
ctx := context.Background()
reconnectingTimeout = 2 * time.Second
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Errorf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv.Close()
if err != nil {
log.Errorf("failed to close server: %s", err)
}
}()
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, addr, "alice")
clientAlice.Serve()
ra, err := clientAlice.RelayAddress()
if err != nil {
t.Errorf("failed to get relay address: %s", err)
}
conn, err := clientAlice.OpenConn(ra.String(), "bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
t.Log("closing client relay connection")
// todo figure out moc server
_ = clientAlice.relayClient.relayConn.Close()
t.Log("start test reading")
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
log.Infof("waiting for reconnection")
time.Sleep(reconnectingTimeout + 1*time.Second)
log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ra.String(), "bob")
if err != nil {
t.Errorf("failed to open channel: %s", err)
}
}

40
relay/cmd/main.go Normal file
View File

@@ -0,0 +1,40 @@
package main
import (
"os"
"os/signal"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/util"
)
func init() {
util.InitLog("trace", "console")
}
func waitForExitSignal() {
osSigs := make(chan os.Signal, 1)
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
_ = <-osSigs
}
func main() {
address := "10.145.236.1:1235"
srv := server.NewServer()
err := srv.Listen(address)
if err != nil {
log.Errorf("failed to bind server: %s", err)
os.Exit(1)
}
waitForExitSignal()
err = srv.Close()
if err != nil {
log.Errorf("failed to close server: %s", err)
os.Exit(1)
}
}

20
relay/messages/id.go Normal file
View File

@@ -0,0 +1,20 @@
package messages
import (
"crypto/sha256"
"encoding/base64"
)
const (
IDSize = sha256.Size
)
func HashID(peerID string) ([]byte, string) {
idHash := sha256.Sum256([]byte(peerID))
idHashString := base64.StdEncoding.EncodeToString(idHash[:])
return idHash[:], idHashString
}
func HashIDToString(idHash []byte) string {
return base64.StdEncoding.EncodeToString(idHash[:])
}

137
relay/messages/message.go Normal file
View File

@@ -0,0 +1,137 @@
package messages
import (
"fmt"
log "github.com/sirupsen/logrus"
)
const (
MsgTypeHello MsgType = 0
MsgTypeHelloResponse MsgType = 1
MsgTypeTransport MsgType = 2
MsgClose MsgType = 3
)
var (
ErrInvalidMessageLength = fmt.Errorf("invalid message length")
)
type MsgType byte
func (m MsgType) String() string {
switch m {
case MsgTypeHello:
return "hello"
case MsgTypeHelloResponse:
return "hello response"
case MsgTypeTransport:
return "transport"
case MsgClose:
return "close"
default:
return "unknown"
}
}
func DetermineClientMsgType(msg []byte) (MsgType, error) {
// todo: validate magic byte
msgType := MsgType(msg[0])
switch msgType {
case MsgTypeHello:
return msgType, nil
case MsgTypeTransport:
return msgType, nil
case MsgClose:
return msgType, nil
default:
return 0, fmt.Errorf("invalid msg type, len: %d", len(msg))
}
}
func DetermineServerMsgType(msg []byte) (MsgType, error) {
// todo: validate magic byte
msgType := MsgType(msg[0])
switch msgType {
case MsgTypeHelloResponse:
return msgType, nil
case MsgTypeTransport:
return msgType, nil
case MsgClose:
return msgType, nil
default:
return 0, fmt.Errorf("invalid msg type (len: %d)", len(msg))
}
}
// MarshalHelloMsg initial hello message
func MarshalHelloMsg(peerID []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length")
}
msg := make([]byte, 1, 1+len(peerID))
msg[0] = byte(MsgTypeHello)
msg = append(msg, peerID...)
return msg, nil
}
func UnmarshalHelloMsg(msg []byte) ([]byte, error) {
if len(msg) < 2 {
return nil, fmt.Errorf("invalid 'hello' messge")
}
return msg[1:], nil
}
func MarshalHelloResponse() []byte {
msg := make([]byte, 1)
msg[0] = byte(MsgTypeHelloResponse)
return msg
}
// Close message
func MarshalCloseMsg() []byte {
msg := make([]byte, 1)
msg[0] = byte(MsgClose)
return msg
}
// Transport message
func MarshalTransportMsg(peerID []byte, payload []byte) []byte {
if len(peerID) != IDSize {
return nil
}
msg := make([]byte, 1+IDSize, 1+IDSize+len(payload))
msg[0] = byte(MsgTypeTransport)
copy(msg[1:], peerID)
msg = append(msg, payload...)
return msg
}
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
headerSize := 1 + IDSize
if len(buf) < headerSize {
return nil, nil, ErrInvalidMessageLength
}
return buf[1:headerSize], buf[headerSize:], nil
}
func UnmarshalTransportID(buf []byte) ([]byte, error) {
headerSize := 1 + IDSize
if len(buf) < headerSize {
log.Debugf("invalid message length: %d, expected: %d, %x", len(buf), headerSize, buf)
return nil, ErrInvalidMessageLength
}
return buf[1:headerSize], nil
}
func UpdateTransportMsg(msg []byte, peerID []byte) error {
if len(msg) < 1+len(peerID) {
return ErrInvalidMessageLength
}
copy(msg[1:], peerID)
return nil
}

View File

@@ -0,0 +1,9 @@
package listener
import "net"
type Listener interface {
Listen(func(conn net.Conn)) error
Close() error
WaitForExitAcceptedConns()
}

View File

@@ -0,0 +1,36 @@
package quic
import (
"net"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
)
type QuicConn struct {
quic.Stream
qConn quic.Connection
}
func NewConn(stream quic.Stream, qConn quic.Connection) net.Conn {
return &QuicConn{
Stream: stream,
qConn: qConn,
}
}
func (q QuicConn) Write(b []byte) (n int, err error) {
n, err = q.Stream.Write(b)
if n != len(b) {
log.Errorf("failed to write out the full message")
}
return
}
func (q QuicConn) LocalAddr() net.Addr {
return q.qConn.LocalAddr()
}
func (q QuicConn) RemoteAddr() net.Addr {
return q.qConn.RemoteAddr()
}

View File

@@ -0,0 +1,111 @@
package quic
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"math/big"
"net"
"sync"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server/listener"
)
type Listener struct {
address string
onAcceptFn func(conn net.Conn)
listener *quic.Listener
quit chan struct{}
wg sync.WaitGroup
}
func NewListener(address string) listener.Listener {
return &Listener{
address: address,
}
}
func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
ql, err := quic.ListenAddr(l.address, generateTLSConfig(), &quic.Config{
EnableDatagrams: true,
})
if err != nil {
return err
}
l.listener = ql
l.quit = make(chan struct{})
log.Infof("quic server is listening on address: %s", l.address)
l.wg.Add(1)
go l.acceptLoop(onAcceptFn)
<-l.quit
return nil
}
func (l *Listener) Close() error {
close(l.quit)
err := l.listener.Close()
l.wg.Wait()
return err
}
func (l *Listener) acceptLoop(acceptFn func(conn net.Conn)) {
defer l.wg.Done()
for {
qConn, err := l.listener.Accept(context.Background())
if err != nil {
select {
case <-l.quit:
return
default:
log.Errorf("failed to accept connection: %s", err)
continue
}
}
log.Infof("new connection from: %s", qConn.RemoteAddr())
stream, err := qConn.AcceptStream(context.Background())
if err != nil {
log.Errorf("failed to open stream: %s", err)
continue
}
conn := NewConn(stream, qConn)
go acceptFn(conn)
}
}
// Setup a bare-bones TLS config for the server
func generateTLSConfig() *tls.Config {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}
template := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
panic(err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
panic(err)
}
return &tls.Config{
Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{"quic-echo-example"},
}
}

View File

@@ -0,0 +1,80 @@
package tcp
import (
"net"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server/listener"
)
// Listener
// Is it just demo code. It does not work in real life environment because the TCP is a streaming protocol, adn
// it does not handle framing.
type Listener struct {
address string
onAcceptFn func(conn net.Conn)
wg sync.WaitGroup
quit chan struct{}
listener net.Listener
lock sync.Mutex
}
func NewListener(address string) listener.Listener {
return &Listener{
address: address,
}
}
func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
l.lock.Lock()
l.onAcceptFn = onAcceptFn
l.quit = make(chan struct{})
li, err := net.Listen("tcp", l.address)
if err != nil {
log.Errorf("failed to listen on address: %s, %s", l.address, err)
l.lock.Unlock()
return err
}
log.Debugf("TCP server is listening on address: %s", l.address)
l.listener = li
l.wg.Add(1)
go l.acceptLoop()
l.lock.Unlock()
<-l.quit
return nil
}
// Close todo: prevent multiple call (do not close two times the channel)
func (l *Listener) Close() error {
l.lock.Lock()
defer l.lock.Unlock()
close(l.quit)
err := l.listener.Close()
l.wg.Wait()
return err
}
func (l *Listener) acceptLoop() {
defer l.wg.Done()
for {
conn, err := l.listener.Accept()
if err != nil {
select {
case <-l.quit:
return
default:
log.Errorf("failed to accept connection: %s", err)
continue
}
}
go l.onAcceptFn(conn)
}
}

View File

@@ -0,0 +1,68 @@
package udp
import (
"io"
"net"
"time"
)
type UDPConn struct {
*net.UDPConn
addr *net.UDPAddr
msgChannel chan []byte
}
func NewConn(conn *net.UDPConn, addr *net.UDPAddr) *UDPConn {
return &UDPConn{
UDPConn: conn,
addr: addr,
msgChannel: make(chan []byte),
}
}
func (u *UDPConn) Read(b []byte) (n int, err error) {
msg, ok := <-u.msgChannel
if !ok {
return 0, io.EOF
}
n = copy(b, msg)
return n, nil
}
func (u *UDPConn) Write(b []byte) (n int, err error) {
return u.UDPConn.WriteTo(b, u.addr)
}
func (u *UDPConn) Close() error {
//TODO implement me
//panic("implement me")
return nil
}
func (u *UDPConn) LocalAddr() net.Addr {
return u.UDPConn.LocalAddr()
}
func (u *UDPConn) RemoteAddr() net.Addr {
return u.addr
}
func (u *UDPConn) SetDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (u *UDPConn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (u *UDPConn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (u *UDPConn) onNewMsg(b []byte) {
u.msgChannel <- b
}

View File

@@ -0,0 +1,109 @@
package udp
import (
"net"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server/listener"
)
type Listener struct {
address string
conns map[string]*UDPConn
onAcceptFn func(conn net.Conn)
listener *net.UDPConn
wg sync.WaitGroup
quit chan struct{}
lock sync.Mutex
}
func (l *Listener) WaitForExitAcceptedConns() {
l.wg.Wait()
return
}
func NewListener(address string) listener.Listener {
return &Listener{
address: address,
conns: make(map[string]*UDPConn),
}
}
func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
l.lock.Lock()
l.onAcceptFn = onAcceptFn
l.quit = make(chan struct{})
addr, err := net.ResolveUDPAddr("udp", l.address)
if err != nil {
log.Errorf("invalid listen address '%s': %s", l.address, err)
l.lock.Unlock()
return err
}
li, err := net.ListenUDP("udp", addr)
if err != nil {
log.Fatalf("%s", err)
l.lock.Unlock()
return err
}
log.Debugf("udp server is listening on address: %s", addr.String())
l.listener = li
l.wg.Add(1)
go l.readLoop()
l.lock.Unlock()
<-l.quit
return nil
}
func (l *Listener) Close() error {
l.lock.Lock()
defer l.lock.Unlock()
if l.listener == nil {
return nil
}
log.Infof("closing UDP listener")
close(l.quit)
err := l.listener.Close()
l.wg.Wait()
l.listener = nil
return err
}
func (l *Listener) readLoop() {
defer l.wg.Done()
for {
buf := make([]byte, 1500)
n, addr, err := l.listener.ReadFromUDP(buf)
if err != nil {
select {
case <-l.quit:
return
default:
log.Errorf("failed to accept connection: %s", err)
continue
}
}
pConn, ok := l.conns[addr.String()]
if ok {
pConn.onNewMsg(buf[:n])
continue
}
pConn = NewConn(l.listener, addr)
log.Infof("new connection from: %s", pConn.RemoteAddr())
l.conns[addr.String()] = pConn
go l.onAcceptFn(pConn)
pConn.onNewMsg(buf[:n])
}
}

View File

@@ -0,0 +1,74 @@
package ws
import (
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
)
type Conn struct {
*websocket.Conn
mu sync.Mutex
}
func NewConn(wsConn *websocket.Conn) *Conn {
return &Conn{
Conn: wsConn,
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
t, r, err := c.NextReader()
if err != nil {
return 0, ioErrHandling(err)
}
if t != websocket.BinaryMessage {
log.Errorf("unexpected message type: %d", t)
return 0, fmt.Errorf("unexpected message type")
}
n, err = r.Read(b)
if err != nil {
return 0, ioErrHandling(err)
}
return n, err
}
func (c *Conn) Write(b []byte) (int, error) {
c.mu.Lock()
err := c.WriteMessage(websocket.BinaryMessage, b)
c.mu.Unlock()
return len(b), err
}
func (c *Conn) SetDeadline(t time.Time) error {
errR := c.SetReadDeadline(t)
errW := c.SetWriteDeadline(t)
if errR != nil {
return errR
}
if errW != nil {
return errW
}
return nil
}
func ioErrHandling(err error) error {
var wErr *websocket.CloseError
if !errors.As(err, &wErr) {
return err
}
if wErr.Code == websocket.CloseNormalClosure {
return io.EOF
}
return err
}

View File

@@ -0,0 +1,92 @@
package ws
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server/listener"
)
var (
upgrader = websocket.Upgrader{} // use default options
)
type Listener struct {
address string
wg sync.WaitGroup
server *http.Server
acceptFn func(conn net.Conn)
}
func NewListener(address string) listener.Listener {
return &Listener{
address: address,
}
}
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
if l.server != nil {
return errors.New("server is already running")
}
l.acceptFn = acceptFn
mux := http.NewServeMux()
mux.HandleFunc("/", l.onAccept)
l.server = &http.Server{
Addr: l.address,
Handler: mux,
}
log.Infof("WS server is listening on address: %s", l.address)
err := l.server.ListenAndServe()
if errors.Is(err, http.ErrServerClosed) {
return nil
}
return err
}
func (l *Listener) Close() error {
if l.server == nil {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
log.Debugf("closing WS server")
if err := l.server.Shutdown(ctx); err != nil {
return fmt.Errorf("server shutdown failed: %v", err)
}
l.wg.Wait()
return nil
}
func (l *Listener) WaitForExitAcceptedConns() {
l.wg.Wait()
}
func (l *Listener) onAccept(writer http.ResponseWriter, request *http.Request) {
l.wg.Add(1)
defer l.wg.Done()
wsConn, err := upgrader.Upgrade(writer, request, nil)
if err != nil {
log.Errorf("failed to upgrade connection: %s", err)
return
}
conn := NewConn(wsConn)
log.Infof("new connection from: %s", conn.RemoteAddr())
l.acceptFn(conn)
return
}

View File

@@ -0,0 +1,107 @@
package ws
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
"nhooyr.io/websocket"
)
type Conn struct {
*websocket.Conn
lAddr *net.TCPAddr
rAddr *net.TCPAddr
closed bool
closedMu sync.Mutex
ctx context.Context
}
func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn {
return &Conn{
Conn: wsConn,
lAddr: lAddr,
rAddr: rAddr,
ctx: context.Background(),
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
t, r, err := c.Reader(c.ctx)
if err != nil {
return 0, c.ioErrHandling(err)
}
if t != websocket.MessageBinary {
log.Errorf("unexpected message type: %d", t)
return 0, fmt.Errorf("unexpected message type")
}
n, err = r.Read(b)
if err != nil {
return 0, c.ioErrHandling(err)
}
return n, err
}
func (c *Conn) Write(b []byte) (int, error) {
err := c.Conn.Write(c.ctx, websocket.MessageBinary, b)
return len(b), err
}
func (c *Conn) LocalAddr() net.Addr {
return c.lAddr
}
func (c *Conn) RemoteAddr() net.Addr {
return c.rAddr
}
func (c *Conn) SetReadDeadline(t time.Time) error {
// todo: implement me
return nil
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
// todo: implement me
return nil
}
func (c *Conn) SetDeadline(t time.Time) error {
// todo: implement me
return nil
}
func (c *Conn) Close() error {
c.closedMu.Lock()
c.closed = true
c.closedMu.Unlock()
return c.Conn.CloseNow()
}
func (c *Conn) isClosed() bool {
c.closedMu.Lock()
defer c.closedMu.Unlock()
return c.closed
}
func (c *Conn) ioErrHandling(err error) error {
if c.isClosed() {
return io.EOF
}
var wErr *websocket.CloseError
if !errors.As(err, &wErr) {
return err
}
if wErr.Code == websocket.StatusNormalClosure {
return io.EOF
}
return err
}

View File

@@ -0,0 +1,96 @@
package ws
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"sync"
"time"
log "github.com/sirupsen/logrus"
"nhooyr.io/websocket"
"github.com/netbirdio/netbird/relay/server/listener"
)
type Listener struct {
address string
wg sync.WaitGroup
server *http.Server
acceptFn func(conn net.Conn)
}
func NewListener(address string) listener.Listener {
return &Listener{
address: address,
}
}
// Listen todo: prevent multiple call
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
l.acceptFn = acceptFn
mux := http.NewServeMux()
mux.HandleFunc("/", l.onAccept)
l.server = &http.Server{
Addr: l.address,
Handler: mux,
}
log.Infof("WS server is listening on address: %s", l.address)
err := l.server.ListenAndServe()
if errors.Is(err, http.ErrServerClosed) {
return nil
}
return err
}
func (l *Listener) Close() error {
if l.server == nil {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
log.Infof("stop WS listener")
if err := l.server.Shutdown(ctx); err != nil {
return fmt.Errorf("server shutdown failed: %v", err)
}
log.Infof("WS listener stopped")
return nil
}
func (l *Listener) WaitForExitAcceptedConns() {
l.wg.Wait()
}
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
l.wg.Add(1)
defer l.wg.Done()
wsConn, err := websocket.Accept(w, r, nil)
if err != nil {
log.Errorf("failed to accept ws connection: %s", err)
return
}
rAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
conn := NewConn(wsConn, lAddr, rAddr)
l.acceptFn(conn)
return
}

33
relay/server/peer.go Normal file
View File

@@ -0,0 +1,33 @@
package server
import (
"net"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/messages"
)
type Peer struct {
Log *log.Entry
idS string
idB []byte
conn net.Conn
}
func NewPeer(id []byte, conn net.Conn) *Peer {
stringID := messages.HashIDToString(id)
return &Peer{
Log: log.WithField("peer_id", stringID),
idB: id,
idS: stringID,
conn: conn,
}
}
func (p *Peer) ID() []byte {
return p.idB
}
func (p *Peer) String() string {
return p.idS
}

192
relay/server/server.go Normal file
View File

@@ -0,0 +1,192 @@
package server
import (
"errors"
"fmt"
"io"
"net"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/udp"
ws "github.com/netbirdio/netbird/relay/server/listener/wsnhooyr"
)
type Server struct {
store *Store
storeMu sync.RWMutex
UDPListener listener.Listener
WSListener listener.Listener
}
func NewServer() *Server {
return &Server{
store: NewStore(),
storeMu: sync.RWMutex{},
}
}
func (r *Server) Listen(address string) error {
wg := sync.WaitGroup{}
wg.Add(2)
r.WSListener = ws.NewListener(address)
var wslErr error
go func() {
defer wg.Done()
wslErr = r.WSListener.Listen(r.accept)
if wslErr != nil {
log.Errorf("failed to bind ws server: %s", wslErr)
}
}()
r.UDPListener = udp.NewListener(address)
var udpLErr error
go func() {
defer wg.Done()
udpLErr = r.UDPListener.Listen(r.accept)
if udpLErr != nil {
log.Errorf("failed to bind ws server: %s", udpLErr)
}
}()
err := errors.Join(wslErr, udpLErr)
return err
}
func (r *Server) Close() error {
var wErr error
if r.WSListener != nil {
wErr = r.WSListener.Close()
}
var uErr error
if r.UDPListener != nil {
uErr = r.UDPListener.Close()
}
r.sendCloseMsgs()
r.WSListener.WaitForExitAcceptedConns()
err := errors.Join(wErr, uErr)
return err
}
func (r *Server) accept(conn net.Conn) {
peer, err := handShake(conn)
if err != nil {
log.Errorf("failed to handshake wiht %s: %s", conn.RemoteAddr(), err)
cErr := conn.Close()
if cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
}
return
}
peer.Log.Infof("peer connected from: %s", conn.RemoteAddr())
r.store.AddPeer(peer)
defer func() {
r.store.DeletePeer(peer)
peer.Log.Infof("relay connection closed")
}()
for {
buf := make([]byte, 1500) // todo: optimize buffer size
n, err := conn.Read(buf)
if err != nil {
if err != io.EOF {
peer.Log.Errorf("failed to read message: %s", err)
}
return
}
msgType, err := messages.DetermineClientMsgType(buf[:n])
if err != nil {
peer.Log.Errorf("failed to determine message type: %s", err)
return
}
switch msgType {
case messages.MsgTypeTransport:
msg := buf[:n]
peerID, err := messages.UnmarshalTransportID(msg)
if err != nil {
peer.Log.Errorf("failed to unmarshal transport message: %s", err)
continue
}
go func() {
stringPeerID := messages.HashIDToString(peerID)
dp, ok := r.store.Peer(stringPeerID)
if !ok {
peer.Log.Errorf("peer not found: %s", stringPeerID)
return
}
err := messages.UpdateTransportMsg(msg, peer.ID())
if err != nil {
peer.Log.Errorf("failed to update transport message: %s", err)
return
}
_, err = dp.conn.Write(msg)
if err != nil {
peer.Log.Errorf("failed to write transport message to: %s", dp.String())
}
return
}()
case messages.MsgClose:
peer.Log.Infof("peer disconnected gracefully")
_ = conn.Close()
return
}
}
}
func (r *Server) sendCloseMsgs() {
msg := messages.MarshalCloseMsg()
r.storeMu.Lock()
log.Debugf("sending close messages to %d peers", len(r.store.peers))
for _, p := range r.store.peers {
_, err := p.conn.Write(msg)
if err != nil {
log.Errorf("failed to send close message to peer: %s", p.String())
}
err = p.conn.Close()
if err != nil {
log.Errorf("failed to close connection to peer: %s", err)
}
}
r.storeMu.Unlock()
}
func handShake(conn net.Conn) (*Peer, error) {
buf := make([]byte, 1500)
n, err := conn.Read(buf)
if err != nil {
log.Errorf("failed to read message: %s", err)
return nil, err
}
msgType, err := messages.DetermineClientMsgType(buf[:n])
if err != nil {
return nil, err
}
if msgType != messages.MsgTypeHello {
tErr := fmt.Errorf("invalid message type")
log.Errorf("failed to handshake: %s", tErr)
return nil, tErr
}
peerId, err := messages.UnmarshalHelloMsg(buf[:n])
if err != nil {
log.Errorf("failed to handshake: %s", err)
return nil, err
}
p := NewPeer(peerId, conn)
msg := messages.MarshalHelloResponse()
_, err = conn.Write(msg)
return p, err
}

47
relay/server/store.go Normal file
View File

@@ -0,0 +1,47 @@
package server
import (
"sync"
)
type Store struct {
peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster
peersLock sync.RWMutex
}
func NewStore() *Store {
return &Store{
peers: make(map[string]*Peer),
}
}
func (s *Store) AddPeer(peer *Peer) {
s.peersLock.Lock()
defer s.peersLock.Unlock()
s.peers[peer.String()] = peer
}
func (s *Store) DeletePeer(peer *Peer) {
s.peersLock.Lock()
defer s.peersLock.Unlock()
delete(s.peers, peer.String())
}
func (s *Store) Peer(id string) (*Peer, bool) {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
p, ok := s.peers[id]
return p, ok
}
func (s *Store) Peers() []*Peer {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
peers := make([]*Peer, 0, len(s.peers))
for _, p := range s.peers {
peers = append(peers, p)
}
return peers
}

View File

@@ -0,0 +1,57 @@
package test
import (
"context"
"testing"
"time"
"github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/relay/server"
)
func TestManager(t *testing.T) {
addr := "localhost:1239"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cm := client.NewManager(ctx, addr, "me")
cm.Serve()
// wait for the relay handshake to complete
time.Sleep(1 * time.Second)
conn, err := cm.OpenConn("remotepeer")
if err != nil {
t.Errorf("failed to open connection: %s", err)
}
readCtx, readCancel := context.WithCancel(context.Background())
defer readCancel()
go func() {
_, _ = conn.Read(make([]byte, 1))
readCancel()
}()
cancel()
select {
case <-time.After(2 * time.Second):
t.Errorf("client peer conn did not close automatically")
case <-readCtx.Done():
// conn exited well
}
}