Add clean shutdown methods to miniircd for running under unit tests

This commit is contained in:
dave 2017-12-03 17:50:45 -08:00
parent 20d1b18248
commit a6de322a00
1 changed files with 22 additions and 9 deletions

View File

@ -683,6 +683,8 @@ class Client(object):
class Server(object): class Server(object):
def __init__(self, options): def __init__(self, options):
self.alive = True
self.serversockets = []
self.ports = options.ports self.ports = options.ports
self.password = options.password self.password = options.password
self.ssl_pem_file = options.ssl_pem_file self.ssl_pem_file = options.ssl_pem_file
@ -826,7 +828,6 @@ class Server(object):
del self.channels[irc_lower(channel.name)] del self.channels[irc_lower(channel.name)]
def start(self): def start(self):
serversockets = []
for port in self.ports: for port in self.ports:
s = socket.socket(socket.AF_INET6 if self.ipv6 else socket.AF_INET, s = socket.socket(socket.AF_INET6 if self.ipv6 else socket.AF_INET,
socket.SOCK_STREAM) socket.SOCK_STREAM)
@ -837,7 +838,7 @@ class Server(object):
self.print_error("Could not bind port %s: %s." % (port, e)) self.print_error("Could not bind port %s: %s." % (port, e))
sys.exit(1) sys.exit(1)
s.listen(5) s.listen(5)
serversockets.append(s) self.serversockets.append(s)
del s del s
self.print_info("Listening on port %d." % port) self.print_info("Listening on port %d." % port)
if self.chroot: if self.chroot:
@ -852,12 +853,21 @@ class Server(object):
self.init_logging() self.init_logging()
try: try:
self.run(serversockets) self.run()
except: except:
if self.logger: if self.logger:
self.logger.exception("Fatal exception") self.logger.exception("Fatal exception")
raise raise
def stop(self):
self.alive = False
for s in self.serversockets + [x.socket for x in self.clients.values()]:
try:
s.shutdown(socket.SHUT_RDWR)
s.close()
except OSError:
pass
def init_logging(self): def init_logging(self):
if not self.log_file: if not self.log_file:
return return
@ -878,11 +888,11 @@ class Server(object):
self.logger.setLevel(log_level) self.logger.setLevel(log_level)
self.logger.addHandler(fh) self.logger.addHandler(fh)
def run(self, serversockets): def run(self):
last_aliveness_check = time.time() last_aliveness_check = time.time()
while True: while self.alive:
(iwtd, owtd, ewtd) = select.select( (iwtd, owtd, ewtd) = select.select(
serversockets + [x.socket for x in self.clients.values()], self.serversockets + [x.socket for x in self.clients.values()],
[x.socket for x in self.clients.values() [x.socket for x in self.clients.values()
if x.write_queue_size() > 0], if x.write_queue_size() > 0],
[], [],
@ -891,7 +901,10 @@ class Server(object):
if x in self.clients: if x in self.clients:
self.clients[x].socket_readable_notification() self.clients[x].socket_readable_notification()
else: else:
(conn, addr) = x.accept() try:
(conn, addr) = x.accept()
except OSError: # Socket likely closed
break
if self.ssl_pem_file: if self.ssl_pem_file:
try: try:
conn = self.ssl.wrap_socket( conn = self.ssl.wrap_socket(
@ -1067,5 +1080,5 @@ def main(argv):
except KeyboardInterrupt: except KeyboardInterrupt:
server.print_error("Interrupted.") server.print_error("Interrupted.")
if __name__ == '__main__':
main(sys.argv) main(sys.argv)