Make buffer not eat all memory.
This commit is contained in:
parent
f872453f16
commit
0261886b43
41
client.go
41
client.go
@ -1,21 +1,26 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
"github.com/tardisx/netgiv/secure"
|
"github.com/tardisx/netgiv/secure"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
|
address string
|
||||||
port int
|
port int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Connect() error {
|
func (c *Client) Connect() error {
|
||||||
address := fmt.Sprintf("127.0.0.1:%d", c.port)
|
address := fmt.Sprintf("%s:%d", c.address, c.port)
|
||||||
|
|
||||||
serverAddress, _ := net.ResolveTCPAddr("tcp", address)
|
serverAddress, _ := net.ResolveTCPAddr("tcp", address)
|
||||||
|
|
||||||
conn, err := net.DialTCP("tcp", nil, serverAddress)
|
conn, err := net.DialTCP("tcp", nil, serverAddress)
|
||||||
@ -50,13 +55,45 @@ func (c *Client) Connect() error {
|
|||||||
data := secure.PacketSendDataStart{
|
data := secure.PacketSendDataStart{
|
||||||
Filename: "foobar",
|
Filename: "foobar",
|
||||||
TotalSize: 3,
|
TotalSize: 3,
|
||||||
Data: []byte{0x20, 0x21, 0x22},
|
|
||||||
}
|
}
|
||||||
err = enc.Encode(data)
|
err = enc.Encode(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
log.Print("done that")
|
log.Print("done that")
|
||||||
|
|
||||||
|
nBytes, nChunks := int64(0), int64(0)
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
buf := make([]byte, 0, 1024)
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err := reader.Read(buf[:cap(buf)])
|
||||||
|
buf = buf[:n]
|
||||||
|
if n == 0 {
|
||||||
|
if err == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
nChunks++
|
||||||
|
nBytes += int64(len(buf))
|
||||||
|
// process buf
|
||||||
|
|
||||||
|
send := secure.PacketSendDataNext{
|
||||||
|
Size: 5000,
|
||||||
|
Data: buf,
|
||||||
|
}
|
||||||
|
enc.Encode(send)
|
||||||
|
// time.Sleep(time.Second)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Println("Bytes:", nBytes, "Chunks:", nChunks)
|
||||||
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
|
||||||
break
|
break
|
||||||
|
3
main.go
3
main.go
@ -9,6 +9,7 @@ import (
|
|||||||
func main() {
|
func main() {
|
||||||
log.SetFlags(log.Lshortfile)
|
log.SetFlags(log.Lshortfile)
|
||||||
port := flag.Int("p", 9000, "Port to run server/client on.")
|
port := flag.Int("p", 9000, "Port to run server/client on.")
|
||||||
|
addr := flag.String("a", "61.245.149.58", "address to connect to.")
|
||||||
isServer := flag.Bool("s", false, "Set if running the server.")
|
isServer := flag.Bool("s", false, "Set if running the server.")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
@ -18,7 +19,7 @@ func main() {
|
|||||||
s.Run()
|
s.Run()
|
||||||
} else {
|
} else {
|
||||||
fmt.Printf("Client running on %d\n", *port)
|
fmt.Printf("Client running on %d\n", *port)
|
||||||
c := Client{port: *port}
|
c := Client{port: *port, address: *addr}
|
||||||
err := c.Connect()
|
err := c.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Print(err)
|
fmt.Print(err)
|
||||||
|
@ -29,12 +29,10 @@ func (s *SecureMessage) toByteArray() []byte {
|
|||||||
func DeterminePacketSize(data []byte) uint16 {
|
func DeterminePacketSize(data []byte) uint16 {
|
||||||
// first 24 bytes are the nonce, then the size
|
// first 24 bytes are the nonce, then the size
|
||||||
if len(data) < 26 {
|
if len(data) < 26 {
|
||||||
log.Printf("packet is too small to be complete - %d bytes", len(data))
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
size := binary.BigEndian.Uint16(data[24:26])
|
size := binary.BigEndian.Uint16(data[24:26])
|
||||||
size += 26 // add the length header and the nonce
|
size += 26 // add the length header and the nonce
|
||||||
log.Printf("size of packet inside the %d bytes is %d bytes", len(data), size)
|
|
||||||
return size
|
return size
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,11 +55,13 @@ type SecureConnection struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SecureConnection) Read(p []byte) (int, error) {
|
func (s *SecureConnection) Read(p []byte) (int, error) {
|
||||||
message := make([]byte, 20408)
|
message := make([]byte, 2048)
|
||||||
// Read the message from the buffer
|
// Read the message from the buffer
|
||||||
eof := false
|
eof := false
|
||||||
|
|
||||||
log.Printf("READ: Start, buffer contains %d bytes", s.Buffer.Len())
|
outputBytes := make([]byte, 0)
|
||||||
|
|
||||||
|
// log.Printf("READ: start, p %d/%d, buffer contains currently contains %d bytes", len(p), cap(p), s.Buffer.Len())
|
||||||
|
|
||||||
n, err := s.Conn.Read(message)
|
n, err := s.Conn.Read(message)
|
||||||
|
|
||||||
@ -82,13 +82,21 @@ func (s *SecureConnection) Read(p []byte) (int, error) {
|
|||||||
s.Buffer.Write(message[:n])
|
s.Buffer.Write(message[:n])
|
||||||
// log.Printf("read: appended them to the buffer which is now %d bytes", len(s.Buffer.Bytes()))
|
// log.Printf("read: appended them to the buffer which is now %d bytes", len(s.Buffer.Bytes()))
|
||||||
|
|
||||||
|
for {
|
||||||
|
|
||||||
actualPacketEnd := DeterminePacketSize(s.Buffer.Bytes())
|
actualPacketEnd := DeterminePacketSize(s.Buffer.Bytes())
|
||||||
if actualPacketEnd == 0 {
|
if actualPacketEnd == 0 {
|
||||||
log.Printf("packet too small?")
|
// log.Printf("packet too small?")
|
||||||
// panic("small")
|
break
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if int(actualPacketEnd) > len(s.Buffer.Bytes()) {
|
||||||
|
// we must have half a packet
|
||||||
|
// log.Print("partial packet detected")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
secureMessage := ConstructSecureMessage(s.Buffer.Bytes()[:actualPacketEnd])
|
secureMessage := ConstructSecureMessage(s.Buffer.Bytes()[:actualPacketEnd])
|
||||||
// log.Printf("Secure message from wire bytes: \n nonce: %v\n msg: %v\n size: %d\n", secureMessage.Nonce, secureMessage.Msg, secureMessage.Size)
|
// log.Printf("Secure message from wire bytes: \n nonce: %v\n msg: %v\n size: %d\n", secureMessage.Nonce, secureMessage.Msg, secureMessage.Size)
|
||||||
decryptedMessage, ok := box.OpenAfterPrecomputation(nil, secureMessage.Msg, &secureMessage.Nonce, s.SharedKey)
|
decryptedMessage, ok := box.OpenAfterPrecomputation(nil, secureMessage.Msg, &secureMessage.Nonce, s.SharedKey)
|
||||||
@ -97,7 +105,10 @@ func (s *SecureConnection) Read(p []byte) (int, error) {
|
|||||||
return 0, errors.New("problem decrypting the message")
|
return 0, errors.New("problem decrypting the message")
|
||||||
}
|
}
|
||||||
|
|
||||||
copy(p, decryptedMessage)
|
outputBytes = append(outputBytes, decryptedMessage...)
|
||||||
|
|
||||||
|
// log.Printf("OUT now: %d bytes", len(outputBytes))
|
||||||
|
// copy(p, decryptedMessage)
|
||||||
|
|
||||||
// trim what we used off the buffer
|
// trim what we used off the buffer
|
||||||
newBuffer := s.Buffer.Bytes()[actualPacketEnd:]
|
newBuffer := s.Buffer.Bytes()[actualPacketEnd:]
|
||||||
@ -105,28 +116,37 @@ func (s *SecureConnection) Read(p []byte) (int, error) {
|
|||||||
|
|
||||||
if eof && s.Buffer.Len() == 0 {
|
if eof && s.Buffer.Len() == 0 {
|
||||||
log.Printf("returning the final packet")
|
log.Printf("returning the final packet")
|
||||||
return len(decryptedMessage), io.EOF
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("successfully read %d bytes", len(decryptedMessage))
|
}
|
||||||
return len(decryptedMessage), nil
|
|
||||||
|
err = io.EOF
|
||||||
|
if !eof {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(p, outputBytes)
|
||||||
|
|
||||||
|
// log.Printf("returning %d decrypted bytes with err: %w", len(outputBytes), err)
|
||||||
|
// log.Printf("READ: end, p %d/%d, buffer contains currently contains %d bytes", len(p), cap(p), s.Buffer.Len())
|
||||||
|
|
||||||
|
return len(outputBytes), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SecureConnection) Write(p []byte) (int, error) {
|
func (s *SecureConnection) Write(p []byte) (int, error) {
|
||||||
// func (s *SecureConnection) Write(o encoding.BinaryMarshaler) (int, error) {
|
// func (s *SecureConnection) Write(o encoding.BinaryMarshaler) (int, error) {
|
||||||
var nonce [24]byte
|
var nonce [24]byte
|
||||||
|
|
||||||
log.Printf("clear bytes: %v", p)
|
|
||||||
|
|
||||||
// Create a new nonce for each message sent
|
// Create a new nonce for each message sent
|
||||||
rand.Read(nonce[:])
|
rand.Read(nonce[:])
|
||||||
log.Printf("before encryption it is %d bytes", len(p))
|
// log.Printf("before encryption it is %d bytes", len(p))
|
||||||
encryptedMessage := box.SealAfterPrecomputation(nil, p, &nonce, s.SharedKey)
|
encryptedMessage := box.SealAfterPrecomputation(nil, p, &nonce, s.SharedKey)
|
||||||
sm := SecureMessage{Msg: encryptedMessage, Nonce: nonce}
|
sm := SecureMessage{Msg: encryptedMessage, Nonce: nonce}
|
||||||
|
|
||||||
// Write it to the connection
|
// Write it to the connection
|
||||||
wireBytes := sm.toByteArray()
|
wireBytes := sm.toByteArray()
|
||||||
log.Printf("putting %d bytes on the wire\n nonce: %v\n bytes: %v", len(wireBytes), nonce, wireBytes)
|
// log.Printf("putting %d bytes on the wire\n nonce: %v\n bytes: %v", len(wireBytes), nonce, wireBytes)
|
||||||
return s.Conn.Write(wireBytes)
|
return s.Conn.Write(wireBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -167,7 +187,6 @@ type PacketStart struct {
|
|||||||
type PacketSendDataStart struct {
|
type PacketSendDataStart struct {
|
||||||
Filename string
|
Filename string
|
||||||
TotalSize uint32
|
TotalSize uint32
|
||||||
Data []byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PacketSendDataNext struct {
|
type PacketSendDataNext struct {
|
||||||
|
71
server.go
71
server.go
@ -18,7 +18,7 @@ type Server struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Run() {
|
func (s *Server) Run() {
|
||||||
address := fmt.Sprintf("127.0.0.1:%d", s.port)
|
address := fmt.Sprintf(":%d", s.port)
|
||||||
networkAddress, _ := net.ResolveTCPAddr("tcp", address)
|
networkAddress, _ := net.ResolveTCPAddr("tcp", address)
|
||||||
|
|
||||||
listener, err := net.ListenTCP("tcp", networkAddress)
|
listener, err := net.ListenTCP("tcp", networkAddress)
|
||||||
@ -41,7 +41,7 @@ func (s *Server) Run() {
|
|||||||
func handleConnection(conn *net.TCPConn) {
|
func handleConnection(conn *net.TCPConn) {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
conn.SetDeadline(time.Now().Add(time.Second))
|
conn.SetDeadline(time.Now().Add(time.Second * 5))
|
||||||
|
|
||||||
sharedKey := secure.Handshake(conn)
|
sharedKey := secure.Handshake(conn)
|
||||||
secureConnection := secure.SecureConnection{Conn: conn, SharedKey: sharedKey, Buffer: &bytes.Buffer{}}
|
secureConnection := secure.SecureConnection{Conn: conn, SharedKey: sharedKey, Buffer: &bytes.Buffer{}}
|
||||||
@ -51,35 +51,56 @@ func handleConnection(conn *net.TCPConn) {
|
|||||||
|
|
||||||
dec := gob.NewDecoder(&secureConnection)
|
dec := gob.NewDecoder(&secureConnection)
|
||||||
|
|
||||||
// At this point we are in
|
// Get the start packet
|
||||||
|
start := secure.PacketStart{}
|
||||||
|
|
||||||
|
err := dec.Decode(&start)
|
||||||
|
if err == io.EOF {
|
||||||
|
log.Printf("connection has been closed after start packet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("some error with start packet: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Decoded packet:\n%#v", start)
|
||||||
|
conn.SetDeadline(time.Now().Add(time.Second * 5))
|
||||||
|
|
||||||
|
if start.OperationType == secure.OperationTypeSend {
|
||||||
|
log.Printf("client wants to send us something, expecting a send start")
|
||||||
|
sendStart := secure.PacketSendDataStart{}
|
||||||
|
|
||||||
|
err = dec.Decode(&sendStart)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("error at send data start: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("send start looks like: %v", sendStart)
|
||||||
|
file, err := os.CreateTemp("", "netgiv_")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("got error with temp file: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("writing data to file: %s", file.Name())
|
||||||
|
sendData := secure.PacketSendDataNext{}
|
||||||
for {
|
for {
|
||||||
|
conn.SetDeadline(time.Now().Add(time.Second * 5))
|
||||||
p1 := secure.PacketStart{}
|
err = dec.Decode(&sendData)
|
||||||
|
|
||||||
log.Print("trying to decode something from wire")
|
|
||||||
err := dec.Decode(&p1)
|
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
log.Printf("connection has been closed")
|
log.Printf("WE ARE DONE writing to: %s", file.Name())
|
||||||
return
|
break
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
log.Printf("error decoding data next: %s", err)
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Decoded packet:\n%#v", p1)
|
|
||||||
|
|
||||||
p2 := secure.PacketSendDataStart{}
|
|
||||||
|
|
||||||
err = dec.Decode(&p2)
|
|
||||||
if err == io.EOF {
|
|
||||||
log.Printf("connection has been closed")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err != nil {
|
file.Write(sendData.Data)
|
||||||
panic(err)
|
}
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
log.Printf("bad operation")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Decoded packet:\n%#v", p2)
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user