Add clean shutdown methods to miniircd for running under unit tests

This commit is contained in:
dave 2017-12-03 17:50:45 -08:00
parent 20d1b18248
commit a6de322a00
1 changed files with 22 additions and 9 deletions

View File

@ -683,6 +683,8 @@ class Client(object):
class Server(object):
def __init__(self, options):
self.alive = True
self.serversockets = []
self.ports = options.ports
self.password = options.password
self.ssl_pem_file = options.ssl_pem_file
@ -826,7 +828,6 @@ class Server(object):
del self.channels[irc_lower(channel.name)]
def start(self):
serversockets = []
for port in self.ports:
s = socket.socket(socket.AF_INET6 if self.ipv6 else socket.AF_INET,
socket.SOCK_STREAM)
@ -837,7 +838,7 @@ class Server(object):
self.print_error("Could not bind port %s: %s." % (port, e))
sys.exit(1)
s.listen(5)
serversockets.append(s)
self.serversockets.append(s)
del s
self.print_info("Listening on port %d." % port)
if self.chroot:
@ -852,12 +853,21 @@ class Server(object):
self.init_logging()
try:
self.run(serversockets)
self.run()
except:
if self.logger:
self.logger.exception("Fatal exception")
raise
def stop(self):
self.alive = False
for s in self.serversockets + [x.socket for x in self.clients.values()]:
try:
s.shutdown(socket.SHUT_RDWR)
s.close()
except OSError:
pass
def init_logging(self):
if not self.log_file:
return
@ -878,11 +888,11 @@ class Server(object):
self.logger.setLevel(log_level)
self.logger.addHandler(fh)
def run(self, serversockets):
def run(self):
last_aliveness_check = time.time()
while True:
while self.alive:
(iwtd, owtd, ewtd) = select.select(
serversockets + [x.socket for x in self.clients.values()],
self.serversockets + [x.socket for x in self.clients.values()],
[x.socket for x in self.clients.values()
if x.write_queue_size() > 0],
[],
@ -891,7 +901,10 @@ class Server(object):
if x in self.clients:
self.clients[x].socket_readable_notification()
else:
(conn, addr) = x.accept()
try:
(conn, addr) = x.accept()
except OSError: # Socket likely closed
break
if self.ssl_pem_file:
try:
conn = self.ssl.wrap_socket(
@ -1067,5 +1080,5 @@ def main(argv):
except KeyboardInterrupt:
server.print_error("Interrupted.")
main(sys.argv)
if __name__ == '__main__':
main(sys.argv)