Skip to content
Snippets Groups Projects
server.py 5.38 KiB
import socket
import sys
import os
from common import *
from buffer import FileBuffer
from typing import List

class Server():

    def __init__(self, port=10000, filepath="data/calender.pdf"):
        self.SERVER_PORT = port
        self.FILE_PATH = filepath
        self.clients = []
        self.N = 4

        self.init_socket()
        self.listen()
        for client in self.clients :
            self.handshake(client)
            self.send_file(client)
            self.close_con(client)

    def init_socket(self):
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.socket.bind((HOST_NAME, self.SERVER_PORT))

    def send(self, data, client_addr):
        if(type(data) == str):
            self.socket.sendto(str.encode(data), client_addr)
        else:
            self.socket.sendto(data, client_addr)

    def recv(self):
        return self.socket.recvfrom(SEG_SIZE)

    def recv_broadcast(self):
        return self.broadcast_socket.recvfrom(SEG_SIZE)

    def listen(self):
        # init broadcast socket for listening broadcast message
        self.broadcast_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.broadcast_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.broadcast_socket.bind((HOST_NAME, BROADCAST_PORT))
        log("!", f"Server listening on broadcast address\n")

        while True:
            message, address = self.recv_broadcast()
            log("!", f"Client {address} found")
            # self.send(bytes(), address) # <-- To-change
            self.clients.append(address)

            log("?", "Listen more? (y/n) ", end="")
            choice = input()
            if choice == "n":
                print()
                self.socket.settimeout(3)
                break

    def handshake(self, client_addr):
        log("!", "Commencing three-way handshake... ")
        segment = Segment()
        segment.set_seq_no(100)
        segment.setSYN()
        segment.set_checksum(segment.count_checksum())
        self.send(segment.to_bytes(), client_addr)
        log(f"Segment SEQ={segment.get_seq_no()}, SYN", "Sent")

        byte, _ = self.recv()
        recved = Segment().from_bytes(byte)
        if (not recved.test_checksum()) or (not recved.getACK()) \
            or (recved.get_ack_no() != segment.get_seq_no()):
            raise ConnectionError
        log(f"Segment SEQ={segment.get_seq_no()}, SYN", "Acked")

        segment = Segment()
        segment.set_seq_no(recved.get_ack_no()+1)
        segment.set_ack_no(recved.get_seq_no())
        segment.setACK()
        segment.set_checksum(segment.count_checksum())
        self.send(segment.to_bytes(), client_addr)
        log(f"Segment SEQ={recved.get_seq_no()}, SYN", "Received, Ack sent")        

        return True

    def send_file(self, client_addr):
        list_of_segments = FileBuffer().read(self.FILE_PATH)
        count = len(list_of_segments)
        log("!", f"File segmented into {count} segments")

        base = 0
        timeout = 0
        while True:
            for i in range(self.N):
                if (base+i) < count:
                    seg = list_of_segments[base+i]
                    seg.set_seq_no(base+i)
                    seg.set_checksum(seg.count_checksum())
                    self.send(seg.to_bytes(), client_addr)
                    log(f"Segment SEQ={seg.get_seq_no()}", "Sent")
            try:
                listen = self.N
                if base + listen >= count:
                    listen = count-base
                for i in range(listen):
                    if (base+i) < count:
                        byte, _ = self.recv()
                        rcvd = Segment().from_bytes(byte)
                        if rcvd.test_checksum() and rcvd.get_ack_no() == base:
                            log(f"Segment SEQ={rcvd.get_ack_no()}", "Acked")
                            base += 1
                if base >= count:
                    break
                timeout = 0
            except socket.timeout:
                log("!", "Connection timeout. ")
                timeout += 1
                if timeout == 3:
                    raise ConnectionError

    def close_con(self, client_addr):
        log("!", "Closing Connection ...")

        seg = Segment(100)
        seg.setFIN()
        seg.set_checksum(seg.count_checksum())
        self.send(seg.to_bytes(), client_addr)
        log(f"Segment SEQ={seg.get_seq_no()}, FIN", "Sent")

        byte , _ = self.recv()
        rcvd = Segment().from_bytes(byte)
        if (not rcvd.test_checksum()) or (not rcvd.getACK()) \
            or (rcvd.get_ack_no() != seg.get_seq_no()):
            raise ConnectionError
        log(f"Segment SEQ={seg.get_seq_no()}, FIN", "Acked")

        byte , _ = self.recv()
        rcvd = Segment().from_bytes(byte)
        if (not rcvd.test_checksum()) or not rcvd.getFIN():
            raise ConnectionError
        seg = Segment(seg.get_seq_no()+1, rcvd.get_seq_no())
        seg.setACK()
        seg.set_checksum(seg.count_checksum())
        self.send(seg.to_bytes(), client_addr)
        log(f"Segment SEQ={rcvd.get_seq_no()}, FIN", "Received, Ack sent.")
        return True

if __name__=="__main__":
    if len(sys.argv) > 2:
        Server(port=int(sys.argv[1]), filepath=sys.argv[2]);
    elif len(sys.argv) > 1:
        Server(port=int(sys.argv[1]))
    else :
        Server()