-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathendpoint.go
74 lines (66 loc) · 1.94 KB
/
endpoint.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
package netstack
import (
"context"
"github.com/clarkmcc/remotenetstack/utils"
"go.uber.org/zap"
"gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// MemoryPipe is used to join two endpoints together, allowing them to communicate.
func MemoryPipe(c1, c2 *channel.Endpoint) {
e1 := WrapChannel(c1)
e2 := WrapChannel(c2)
utils.Join(e1, e2)
}
// WrapChannel wraps the provided netstack channel-based Endpoint and returns a wrapper
// that implements io.Reader and io.Writer on the channel. This allows callers to read
// and write packets as raw []byte directly to the channel.
func WrapChannel(channel *channel.Endpoint) *Endpoint {
return &Endpoint{
Endpoint: channel,
Logger: zap.NewNop(),
}
}
// Endpoint is a wrapper around a channel.Endpoint that implements
// the io.Reader and io.Writer interfaces.
type Endpoint struct {
*channel.Endpoint
Logger *zap.Logger
}
func (e *Endpoint) Read(p []byte) (n int, err error) {
pkt := e.ReadContext(context.Background())
b := pkt.ToBuffer()
n = copy(p, b.Flatten())
pkt.DecRef()
e.Logger.Debug("read packet", zap.Int("bytes", n))
return n, nil
}
func (e *Endpoint) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
// NewPacketBuffer takes ownership of the data, so making a copy is necessary
data := make([]byte, len(p))
copy(data, p)
pb := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: bufferv2.MakeWithData(data),
})
var ipv tcpip.NetworkProtocolNumber
switch header.IPVersion(p) {
case header.IPv4Version:
ipv = ipv4.ProtocolNumber
case header.IPv6Version:
ipv = ipv6.ProtocolNumber
default:
// todo: log this
return
}
e.InjectInbound(ipv, pb)
e.Logger.Debug("wrote packet", zap.Int("bytes", len(p)))
return len(p), nil
}