]> www.average.org Git - loctrkd.git/blobdiff - gps303/wsgateway.py
wsgateway properly handle write-busy websockets
[loctrkd.git] / gps303 / wsgateway.py
index cc1cb7ec08dc0db5baafbfd89545e60bdced85d2..127dbb4b002fc8c62dce6be2cd42e35b6df54e3e 100644 (file)
@@ -4,13 +4,73 @@ from logging import getLogger
 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
 from time import time
 from wsproto import ConnectionType, WSConnection
-from wsproto.events import AcceptConnection, CloseConnection, Message, Ping, Request, TextMessage
+from wsproto.events import (
+    AcceptConnection,
+    CloseConnection,
+    Message,
+    Ping,
+    Request,
+    TextMessage,
+)
+from wsproto.utilities import RemoteProtocolError
 import zmq
 
 from . import common
 from .zmsg import LocEvt
 
 log = getLogger("gps303/wsgateway")
+htmldata = None
+
+
+def try_http(data, fd, e):
+    global htmldata
+    try:
+        lines = data.decode().split("\r\n")
+        request = lines[0]
+        headers = lines[1:]
+        op, resource, proto = request.split(" ")
+        log.debug(
+            "HTTP %s for %s, proto %s from fd %d, headers: %s",
+            op,
+            resource,
+            proto,
+            fd,
+            headers,
+        )
+        try:
+            pos = resource.index("?")
+            resource = resource[:pos]
+        except ValueError:
+            pass
+        if op == "GET":
+            if htmldata is None:
+                return (
+                    f"{proto} 500 No data configured\r\n"
+                    f"Content-Type: text/plain\r\n\r\n"
+                    f"HTML data not configure on the server\r\n".encode()
+                )
+            elif resource == "/":
+                length = len(htmldata.encode("utf-8"))
+                return (
+                    f"{proto} 200 Ok\r\n"
+                    f"Content-Type: text/html; charset=utf-8\r\n"
+                    f"Content-Length: {length:d}\r\n\r\n" + htmldata
+                ).encode("utf-8")
+            else:
+                return (
+                    f"{proto} 404 File not found\r\n"
+                    f"Content-Type: text/plain\r\n\r\n"
+                    f'We can only serve "/"\r\n'.encode()
+                )
+        else:
+            return (
+                f"{proto} 400 Bad request\r\n"
+                "Content-Type: text/plain\r\n\r\n"
+                "Bad request\r\n".encode()
+            )
+    except ValueError:
+        log.warning("Unparseable data from fd %d: %s", fd, data)
+        raise e
 
 
 class Client:
@@ -44,42 +104,57 @@ class Client:
             )
             self.ws.receive_data(None)
             return None
-        self.ws.receive_data(data)
-        msgs = []
-        for event in self.ws.events():
-            if isinstance(event, Request):
-                log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
-                #self.ws_data += self.ws.send(event.response())  # Why not?!
-                self.ws_data += self.ws.send(AcceptConnection())
-            elif isinstance(event, (CloseConnection, Ping)):
-                log.debug("%s on fd %d", event, self.sock.fileno())
-                self.ws_data += self.ws.send(event.response())
-            elif isinstance(event, TextMessage):
-                # TODO: save imei "subscription"
-                log.debug("%s on fd %d", event, self.sock.fileno())
-                msgs.append(event.data)
-            else:
-                log.warning("%s on fd %d", event, self.sock.fileno())
+        try:
+            self.ws.receive_data(data)
+        except RemoteProtocolError as e:
+            log.debug(
+                "Websocket error on fd %d, try plain http (%s)",
+                self.sock.fileno(),
+                e,
+            )
+            self.ws_data = try_http(data, self.sock.fileno(), e)
+            log.debug("Sending HTTP response to %d", self.sock.fileno())
+            msgs = None
+        else:
+            msgs = []
+            for event in self.ws.events():
+                if isinstance(event, Request):
+                    log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
+                    # self.ws_data += self.ws.send(event.response())  # Why not?!
+                    self.ws_data += self.ws.send(AcceptConnection())
+                elif isinstance(event, (CloseConnection, Ping)):
+                    log.debug("%s on fd %d", event, self.sock.fileno())
+                    self.ws_data += self.ws.send(event.response())
+                elif isinstance(event, TextMessage):
+                    # TODO: save imei "subscription"
+                    log.debug("%s on fd %d", event, self.sock.fileno())
+                    msgs.append(event.data)
+                else:
+                    log.warning("%s on fd %d", event, self.sock.fileno())
         if self.ws_data:  # Temp hack
             self.write()
         return msgs
 
-    def send(self, imei, message):
+    def wants(self, imei):
+        return True  # TODO: check subscriptions
+
+    def send(self, message):
         # TODO: filter only wanted imei got from the client
-        self.ws_data += self.ws.send(Message(data=message))
+        self.ws_data += self.ws.send(Message(data=message.json))
 
     def write(self):
-        try:
-            sent = self.sock.send(self.ws_data)
-            self.ws_data = self.ws_data[sent:]
-        except OSError as e:
-            log.error(
-                "Sending to fd %d (IMEI %s): %s",
-                self.sock.fileno(),
-                self.imei,
-                e,
-            )
-            self.ws_data = b""
+        if self.ws_data:
+            try:
+                sent = self.sock.send(self.ws_data)
+                self.ws_data = self.ws_data[sent:]
+            except OSError as e:
+                log.error(
+                    "Sending to fd %d: %s",
+                    self.sock.fileno(),
+                    e,
+                )
+                self.ws_data = b""
+        return bool(self.ws_data)
 
 
 class Clients:
@@ -108,12 +183,29 @@ class Clients:
             log.debug("Received: %s", msg)
         return result
 
-    def send(self, msgs):
-        for clnt in self.by_fd.values():
-            clnt.send(msgs)
-            clnt.write()
+    def send(self, msg):
+        towrite = set()
+        for fd, clnt in self.by_fd.items():
+            if clnt.wants(msg.imei):
+                clnt.send(msg)
+                towrite.add(fd)
+        return towrite
+
+    def write(self, towrite):
+        waiting = set()
+        for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]:
+            if clnt.write():
+                waiting.add(fd)
+        return waiting
+
 
 def runserver(conf):
+    global htmldata
+    try:
+        with open(conf.get("wsgateway", "htmlfile"), encoding="utf-8") as fl:
+            htmldata = fl.read()
+    except OSError:
+        pass
     zctx = zmq.Context()
     zsub = zctx.socket(zmq.SUB)
     zsub.connect(conf.get("lookaside", "publishurl"))
@@ -129,17 +221,19 @@ def runserver(conf):
     poller.register(tcpfd, flags=zmq.POLLIN)
     clients = Clients()
     try:
+        towait = set()
         while True:
             tosend = []
             topoll = []
             tostop = []
-            events = poller.poll(1000)
+            towrite = set()
+            events = poller.poll(5000)
             for sk, fl in events:
                 if sk is zsub:
                     while True:
                         try:
-                            msg = zsub.recv(zmq.NOBLOCK)
-                            tosend.append(LocEvt(msg))
+                            zmsg = LocEvt(zsub.recv(zmq.NOBLOCK))
+                            tosend.append(zmsg)
                         except zmq.Again:
                             break
                 elif sk == tcpfd:
@@ -148,13 +242,15 @@ def runserver(conf):
                 elif fl & zmq.POLLIN:
                     received = clients.recv(sk)
                     if received is None:
-                        log.debug(
-                            "Client gone from fd %d", sk
-                        )
+                        log.debug("Client gone from fd %d", sk)
                         tostop.append(sk)
                     else:
                         for msg in received:
                             log.debug("Received from %d: %s", sk, msg)
+                elif fl & zmq.POLLOUT:
+                    log.debug("Write now open for fd %d", sk)
+                    towrite.add(sk)
+                    towait.discard(sk)
                 else:
                     log.debug("Stray event: %s on socket %s", fl, sk)
             # poll queue consumed, make changes now
@@ -162,12 +258,26 @@ def runserver(conf):
                 poller.unregister(fd)
                 clients.stop(fd)
             for zmsg in tosend:
-                log.debug("Sending to the client: %s", zmsg)
-                clients.send(zmsg)
+                log.debug("Sending to the clients: %s", zmsg)
+                towrite |= clients.send(zmsg)
             for clntsock, clntaddr in topoll:
                 fd = clients.add(clntsock, clntaddr)
                 poller.register(fd, flags=zmq.POLLIN)
-            # TODO: Handle write overruns (register for POLLOUT)
+            # Deal with actually writing the data out
+            trywrite = towrite - towait
+            morewait = clients.write(trywrite)
+            log.debug(
+                "towait %s, tried %s, still busy %s",
+                towait,
+                trywrite,
+                morewait,
+            )
+            for fd in morewait - trywrite:  # new fds waiting for write
+                poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)
+            for fd in trywrite - morewait:  # no longer waiting for write
+                poller.modify(fd, flags=zmq.POLLIN)
+            towait &= trywrite
+            towait |= morewait
     except KeyboardInterrupt:
         pass