diff --git a/client.py b/client.py index b4cc65f57e6ec77d7bbfea9b5f079412e49af455..1651d305d5629a36c6ea0d761049488350d39966 100644 --- a/client.py +++ b/client.py @@ -33,15 +33,35 @@ class Client(): def broadcast(self): self.send(bytes(), (HOST_ADDRESS, BROADCAST_PORT)) log("!", "Message sent to broadcast address") + log("!", "Listening for three-way handshake...") def handshake(self): # To-Do - data, address = self.recv() - log("!", f"Server address : {address}") + rcvd = Segment() + while 1: + byte, server_addr = self.recv() + rcvd.from_bytes(byte) + if not rcvd.test_checksum(): + raise ConnectionError + if rcvd.getSYN(): + break; + + self.DST_ADDR = server_addr seg = Segment() - seg.from_bytes(data) - log("!", seg.bytes[:100]) + seg.set_seq_no(300) + seg.set_ack_no(rcvd.get_seq_no()) + seg.setSYN() + seg.setACK() + self.send(seg.to_bytes(), server_addr) + log(f"Segment SEQ={rcvd.get_seq_no()}, SYN", "Received, Ack sent") + log(f"Segment SEQ={seg.get_seq_no()}, SYN", "Sent") + + byte, _ = self.recv() + rcvd.from_bytes(byte) + if (not rcvd.test_checksum()) or (not rcvd.getACK()) or (rcvd.get_ack_no() != seg.get_seq_no()): + raise ConnectionError + log(f"Segment SEQ={rcvd.get_ack_no()}", "Acked") return diff --git a/segment.py b/segment.py index 87ea3686497fa53124d96f46b8b194e23222c482..41f9ce033f2e30613e7343a2334b46c8c7c469dc 100644 --- a/segment.py +++ b/segment.py @@ -14,21 +14,21 @@ class Segment(): self.bytes[3] = val & 0xff def get_seq_no(self): - return self.bytes[0] << 24 + \ - self.bytes[1] << 16 + \ - self.bytes[2] << 8 + \ + return (self.bytes[0] << 24) + \ + (self.bytes[1] << 16) + \ + (self.bytes[2] << 8) + \ self.bytes[3] - def set_seq_no(self, val): + def set_ack_no(self, val): self.bytes[4] = val >> 24 & 0xff self.bytes[5] = val >> 16 & 0xff self.bytes[6] = val >> 8 & 0xff self.bytes[7] = val & 0xff - def get_seq_no(self): - return self.bytes[4] << 24 + \ - self.bytes[5] << 16 + \ - self.bytes[6] << 8 + \ + def get_ack_no(self): + return (self.bytes[4] << 24) + \ + (self.bytes[5] << 16) + \ + (self.bytes[6] << 8) + \ self.bytes[7] def setSYN(self): @@ -62,13 +62,18 @@ class Segment(): def count_checksum(self): # To-Do # ... - return + return 0 + + def test_checksum(self): + # return self.count_checksum == self.get_checksum + return True def to_bytes(self): return bytes(self.bytes) def from_bytes(self, byte:bytes): self.bytes = list(byte) + return self if __name__=="__main__": seg = Segment() diff --git a/server.py b/server.py index 8fc96ebf2aecb25910f59fc92e50c13c1828f518..49e7111af8b0da27525fbb78846a719023e32de8 100644 --- a/server.py +++ b/server.py @@ -46,15 +46,37 @@ class Server(): log("?", "Listen more? (y/n) ", end="") choice = input() if choice == "n": + print() break def handshake(self, client_addr): # To-DO + log("!", "Commencing three-way handshake... ") segment = Segment() + segment.set_seq_no(100) segment.setSYN() self.send(segment.to_bytes(), client_addr) + log(f"Segment SEQ={segment.get_seq_no()}, SYN", "Sent") - return + byte, _ = self.recv() + recved = Segment().from_bytes(byte) + if (not recved.test_checksum()) or (not recved.getACK()) \ + or (recved.get_ack_no() != segment.get_seq_no()): + # print(recved.getACK()) + # print(recved.get_ack_no()) + # print(segment.get_seq_no()) + # print(recved.bytes[:12]) + raise ConnectionError + log(f"Segment SEQ={segment.get_seq_no()}, SYN", "Acked") + + segment = Segment() + segment.set_seq_no(recved.get_ack_no()+1) + segment.set_ack_no(recved.get_seq_no()) + segment.setACK() + self.send(segment.to_bytes(), client_addr) + log(f"Segment SEQ={recved.get_seq_no()}, SYN", "Received, Ack sent") + + return True def send_file(self, client_addr): # To-Do