diff --git a/client.go b/client.go index 136b2ca..659e206 100644 --- a/client.go +++ b/client.go @@ -1,21 +1,26 @@ package main import ( + "bufio" "encoding/gob" "errors" "fmt" + "io" "log" "net" + "os" "github.com/tardisx/netgiv/secure" ) type Client struct { - port int + address string + port int } func (c *Client) Connect() error { - address := fmt.Sprintf("127.0.0.1:%d", c.port) + address := fmt.Sprintf("%s:%d", c.address, c.port) + serverAddress, _ := net.ResolveTCPAddr("tcp", address) conn, err := net.DialTCP("tcp", nil, serverAddress) @@ -50,13 +55,45 @@ func (c *Client) Connect() error { data := secure.PacketSendDataStart{ Filename: "foobar", TotalSize: 3, - Data: []byte{0x20, 0x21, 0x22}, } err = enc.Encode(data) if err != nil { panic(err) } log.Print("done that") + + nBytes, nChunks := int64(0), int64(0) + reader := bufio.NewReader(os.Stdin) + buf := make([]byte, 0, 1024) + + for { + n, err := reader.Read(buf[:cap(buf)]) + buf = buf[:n] + if n == 0 { + if err == nil { + continue + } + if err == io.EOF { + break + } + log.Fatal(err) + } + nChunks++ + nBytes += int64(len(buf)) + // process buf + + send := secure.PacketSendDataNext{ + Size: 5000, + Data: buf, + } + enc.Encode(send) + // time.Sleep(time.Second) + if err != nil { + log.Fatal(err) + } + } + log.Println("Bytes:", nBytes, "Chunks:", nChunks) + conn.Close() break diff --git a/main.go b/main.go index 63adf09..b037dbf 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( func main() { log.SetFlags(log.Lshortfile) port := flag.Int("p", 9000, "Port to run server/client on.") + addr := flag.String("a", "61.245.149.58", "address to connect to.") isServer := flag.Bool("s", false, "Set if running the server.") flag.Parse() @@ -18,7 +19,7 @@ func main() { s.Run() } else { fmt.Printf("Client running on %d\n", *port) - c := Client{port: *port} + c := Client{port: *port, address: *addr} err := c.Connect() if err != nil { fmt.Print(err) diff --git a/secure/secure.go b/secure/secure.go index 3985da7..84fa0f6 100644 --- a/secure/secure.go +++ b/secure/secure.go @@ -29,12 +29,10 @@ func (s *SecureMessage) toByteArray() []byte { func DeterminePacketSize(data []byte) uint16 { // first 24 bytes are the nonce, then the size if len(data) < 26 { - log.Printf("packet is too small to be complete - %d bytes", len(data)) return 0 } size := binary.BigEndian.Uint16(data[24:26]) size += 26 // add the length header and the nonce - log.Printf("size of packet inside the %d bytes is %d bytes", len(data), size) return size } @@ -57,11 +55,13 @@ type SecureConnection struct { } func (s *SecureConnection) Read(p []byte) (int, error) { - message := make([]byte, 20408) + message := make([]byte, 2048) // Read the message from the buffer eof := false - log.Printf("READ: Start, buffer contains %d bytes", s.Buffer.Len()) + outputBytes := make([]byte, 0) + + // log.Printf("READ: start, p %d/%d, buffer contains currently contains %d bytes", len(p), cap(p), s.Buffer.Len()) n, err := s.Conn.Read(message) @@ -82,51 +82,71 @@ func (s *SecureConnection) Read(p []byte) (int, error) { s.Buffer.Write(message[:n]) // log.Printf("read: appended them to the buffer which is now %d bytes", len(s.Buffer.Bytes())) - actualPacketEnd := DeterminePacketSize(s.Buffer.Bytes()) - if actualPacketEnd == 0 { - log.Printf("packet too small?") - // panic("small") - return 0, io.EOF + for { + + actualPacketEnd := DeterminePacketSize(s.Buffer.Bytes()) + if actualPacketEnd == 0 { + // log.Printf("packet too small?") + break + return 0, io.EOF + } + + if int(actualPacketEnd) > len(s.Buffer.Bytes()) { + // we must have half a packet + // log.Print("partial packet detected") + break + } + + secureMessage := ConstructSecureMessage(s.Buffer.Bytes()[:actualPacketEnd]) + // log.Printf("Secure message from wire bytes: \n nonce: %v\n msg: %v\n size: %d\n", secureMessage.Nonce, secureMessage.Msg, secureMessage.Size) + decryptedMessage, ok := box.OpenAfterPrecomputation(nil, secureMessage.Msg, &secureMessage.Nonce, s.SharedKey) + + if !ok { + return 0, errors.New("problem decrypting the message") + } + + outputBytes = append(outputBytes, decryptedMessage...) + + // log.Printf("OUT now: %d bytes", len(outputBytes)) + // copy(p, decryptedMessage) + + // trim what we used off the buffer + newBuffer := s.Buffer.Bytes()[actualPacketEnd:] + s.Buffer = bytes.NewBuffer(newBuffer) + + if eof && s.Buffer.Len() == 0 { + log.Printf("returning the final packet") + break + } + } - secureMessage := ConstructSecureMessage(s.Buffer.Bytes()[:actualPacketEnd]) - // log.Printf("Secure message from wire bytes: \n nonce: %v\n msg: %v\n size: %d\n", secureMessage.Nonce, secureMessage.Msg, secureMessage.Size) - decryptedMessage, ok := box.OpenAfterPrecomputation(nil, secureMessage.Msg, &secureMessage.Nonce, s.SharedKey) - - if !ok { - return 0, errors.New("problem decrypting the message") + err = io.EOF + if !eof { + err = nil } - copy(p, decryptedMessage) + copy(p, outputBytes) - // trim what we used off the buffer - newBuffer := s.Buffer.Bytes()[actualPacketEnd:] - s.Buffer = bytes.NewBuffer(newBuffer) + // log.Printf("returning %d decrypted bytes with err: %w", len(outputBytes), err) + // log.Printf("READ: end, p %d/%d, buffer contains currently contains %d bytes", len(p), cap(p), s.Buffer.Len()) - if eof && s.Buffer.Len() == 0 { - log.Printf("returning the final packet") - return len(decryptedMessage), io.EOF - } - - log.Printf("successfully read %d bytes", len(decryptedMessage)) - return len(decryptedMessage), nil + return len(outputBytes), err } func (s *SecureConnection) Write(p []byte) (int, error) { // func (s *SecureConnection) Write(o encoding.BinaryMarshaler) (int, error) { var nonce [24]byte - log.Printf("clear bytes: %v", p) - // Create a new nonce for each message sent rand.Read(nonce[:]) - log.Printf("before encryption it is %d bytes", len(p)) + // log.Printf("before encryption it is %d bytes", len(p)) encryptedMessage := box.SealAfterPrecomputation(nil, p, &nonce, s.SharedKey) sm := SecureMessage{Msg: encryptedMessage, Nonce: nonce} // Write it to the connection wireBytes := sm.toByteArray() - log.Printf("putting %d bytes on the wire\n nonce: %v\n bytes: %v", len(wireBytes), nonce, wireBytes) + // log.Printf("putting %d bytes on the wire\n nonce: %v\n bytes: %v", len(wireBytes), nonce, wireBytes) return s.Conn.Write(wireBytes) } @@ -167,7 +187,6 @@ type PacketStart struct { type PacketSendDataStart struct { Filename string TotalSize uint32 - Data []byte } type PacketSendDataNext struct { diff --git a/server.go b/server.go index 4f5a2ef..3d53f79 100644 --- a/server.go +++ b/server.go @@ -18,7 +18,7 @@ type Server struct { } func (s *Server) Run() { - address := fmt.Sprintf("127.0.0.1:%d", s.port) + address := fmt.Sprintf(":%d", s.port) networkAddress, _ := net.ResolveTCPAddr("tcp", address) listener, err := net.ListenTCP("tcp", networkAddress) @@ -41,7 +41,7 @@ func (s *Server) Run() { func handleConnection(conn *net.TCPConn) { defer conn.Close() - conn.SetDeadline(time.Now().Add(time.Second)) + conn.SetDeadline(time.Now().Add(time.Second * 5)) sharedKey := secure.Handshake(conn) secureConnection := secure.SecureConnection{Conn: conn, SharedKey: sharedKey, Buffer: &bytes.Buffer{}} @@ -51,35 +51,56 @@ func handleConnection(conn *net.TCPConn) { dec := gob.NewDecoder(&secureConnection) - // At this point we are in - for { - - p1 := secure.PacketStart{} - - log.Print("trying to decode something from wire") - err := dec.Decode(&p1) - if err == io.EOF { - log.Printf("connection has been closed") - return - } - if err != nil { - panic(err) - } - - log.Printf("Decoded packet:\n%#v", p1) - - p2 := secure.PacketSendDataStart{} - - err = dec.Decode(&p2) - if err == io.EOF { - log.Printf("connection has been closed") - return - } - if err != nil { - panic(err) - } - - log.Printf("Decoded packet:\n%#v", p2) + // Get the start packet + start := secure.PacketStart{} + err := dec.Decode(&start) + if err == io.EOF { + log.Printf("connection has been closed after start packet") + return } + if err != nil { + log.Printf("some error with start packet: %w", err) + return + } + + log.Printf("Decoded packet:\n%#v", start) + conn.SetDeadline(time.Now().Add(time.Second * 5)) + + if start.OperationType == secure.OperationTypeSend { + log.Printf("client wants to send us something, expecting a send start") + sendStart := secure.PacketSendDataStart{} + + err = dec.Decode(&sendStart) + if err != nil { + log.Printf("error at send data start: %w", err) + return + } + log.Printf("send start looks like: %v", sendStart) + file, err := os.CreateTemp("", "netgiv_") + if err != nil { + log.Printf("got error with temp file: %w", err) + return + } + log.Printf("writing data to file: %s", file.Name()) + sendData := secure.PacketSendDataNext{} + for { + conn.SetDeadline(time.Now().Add(time.Second * 5)) + err = dec.Decode(&sendData) + if err == io.EOF { + log.Printf("WE ARE DONE writing to: %s", file.Name()) + break + } + if err != nil { + log.Printf("error decoding data next: %s", err) + return + } + file.Write(sendData.Data) + } + return + } else { + log.Printf("bad operation") + return + } + }