From: Eugene Crosser Date: Fri, 6 May 2022 10:50:28 +0000 (+0200) Subject: wsgateway properly handle write-busy websockets X-Git-Tag: 0.01~9 X-Git-Url: http://www.average.org/gitweb/?p=loctrkd.git;a=commitdiff_plain;h=7dfa156ec288ef302e945f1ad2a66b91f7ce5934 wsgateway properly handle write-busy websockets --- diff --git a/gps303/wsgateway.py b/gps303/wsgateway.py index a4a0b6d..127dbb4 100644 --- a/gps303/wsgateway.py +++ b/gps303/wsgateway.py @@ -60,7 +60,7 @@ def try_http(data, fd, e): return ( f"{proto} 404 File not found\r\n" f"Content-Type: text/plain\r\n\r\n" - f"We can only serve \"/\"\r\n".encode() + f'We can only serve "/"\r\n'.encode() ) else: return ( @@ -143,16 +143,18 @@ class Client: self.ws_data += self.ws.send(Message(data=message.json)) def write(self): - try: - sent = self.sock.send(self.ws_data) - self.ws_data = self.ws_data[sent:] - except OSError as e: - log.error( - "Sending to fd %d: %s", - self.sock.fileno(), - e, - ) - self.ws_data = b"" + if self.ws_data: + try: + sent = self.sock.send(self.ws_data) + self.ws_data = self.ws_data[sent:] + except OSError as e: + log.error( + "Sending to fd %d: %s", + self.sock.fileno(), + e, + ) + self.ws_data = b"" + return bool(self.ws_data) class Clients: @@ -182,18 +184,25 @@ class Clients: return result def send(self, msg): - for clnt in self.by_fd.values(): + towrite = set() + for fd, clnt in self.by_fd.items(): if clnt.wants(msg.imei): clnt.send(msg) - clnt.write() + towrite.add(fd) + return towrite + + def write(self, towrite): + waiting = set() + for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]: + if clnt.write(): + waiting.add(fd) + return waiting def runserver(conf): global htmldata try: - with open( - conf.get("wsgateway", "htmlfile"), encoding="utf-8" - ) as fl: + with open(conf.get("wsgateway", "htmlfile"), encoding="utf-8") as fl: htmldata = fl.read() except OSError: pass @@ -212,10 +221,12 @@ def runserver(conf): poller.register(tcpfd, flags=zmq.POLLIN) clients = Clients() try: + towait = set() while True: tosend = [] topoll = [] tostop = [] + towrite = set() events = poller.poll(5000) for sk, fl in events: if sk is zsub: @@ -236,6 +247,10 @@ def runserver(conf): else: for msg in received: log.debug("Received from %d: %s", sk, msg) + elif fl & zmq.POLLOUT: + log.debug("Write now open for fd %d", sk) + towrite.add(sk) + towait.discard(sk) else: log.debug("Stray event: %s on socket %s", fl, sk) # poll queue consumed, make changes now @@ -243,12 +258,26 @@ def runserver(conf): poller.unregister(fd) clients.stop(fd) for zmsg in tosend: - log.debug("Sending to the client: %s", zmsg) - clients.send(zmsg) + log.debug("Sending to the clients: %s", zmsg) + towrite |= clients.send(zmsg) for clntsock, clntaddr in topoll: fd = clients.add(clntsock, clntaddr) poller.register(fd, flags=zmq.POLLIN) - # TODO: Handle write overruns (register for POLLOUT) + # Deal with actually writing the data out + trywrite = towrite - towait + morewait = clients.write(trywrite) + log.debug( + "towait %s, tried %s, still busy %s", + towait, + trywrite, + morewait, + ) + for fd in morewait - trywrite: # new fds waiting for write + poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT) + for fd in trywrite - morewait: # no longer waiting for write + poller.modify(fd, flags=zmq.POLLIN) + towait &= trywrite + towait |= morewait except KeyboardInterrupt: pass