X-Git-Url: http://www.average.org/gitweb/?a=blobdiff_plain;f=gps303%2Fwsgateway.py;h=127dbb4b002fc8c62dce6be2cd42e35b6df54e3e;hb=7dfa156ec288ef302e945f1ad2a66b91f7ce5934;hp=cc1cb7ec08dc0db5baafbfd89545e60bdced85d2;hpb=21ab3e6ccd337ef7c69e149a4f21d6a6139557d8;p=loctrkd.git diff --git a/gps303/wsgateway.py b/gps303/wsgateway.py index cc1cb7e..127dbb4 100644 --- a/gps303/wsgateway.py +++ b/gps303/wsgateway.py @@ -4,13 +4,73 @@ from logging import getLogger from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR from time import time from wsproto import ConnectionType, WSConnection -from wsproto.events import AcceptConnection, CloseConnection, Message, Ping, Request, TextMessage +from wsproto.events import ( + AcceptConnection, + CloseConnection, + Message, + Ping, + Request, + TextMessage, +) +from wsproto.utilities import RemoteProtocolError import zmq from . import common from .zmsg import LocEvt log = getLogger("gps303/wsgateway") +htmldata = None + + +def try_http(data, fd, e): + global htmldata + try: + lines = data.decode().split("\r\n") + request = lines[0] + headers = lines[1:] + op, resource, proto = request.split(" ") + log.debug( + "HTTP %s for %s, proto %s from fd %d, headers: %s", + op, + resource, + proto, + fd, + headers, + ) + try: + pos = resource.index("?") + resource = resource[:pos] + except ValueError: + pass + if op == "GET": + if htmldata is None: + return ( + f"{proto} 500 No data configured\r\n" + f"Content-Type: text/plain\r\n\r\n" + f"HTML data not configure on the server\r\n".encode() + ) + elif resource == "/": + length = len(htmldata.encode("utf-8")) + return ( + f"{proto} 200 Ok\r\n" + f"Content-Type: text/html; charset=utf-8\r\n" + f"Content-Length: {length:d}\r\n\r\n" + htmldata + ).encode("utf-8") + else: + return ( + f"{proto} 404 File not found\r\n" + f"Content-Type: text/plain\r\n\r\n" + f'We can only serve "/"\r\n'.encode() + ) + else: + return ( + f"{proto} 400 Bad request\r\n" + "Content-Type: text/plain\r\n\r\n" + "Bad request\r\n".encode() + ) + except ValueError: + log.warning("Unparseable data from fd %d: %s", fd, data) + raise e class Client: @@ -44,42 +104,57 @@ class Client: ) self.ws.receive_data(None) return None - self.ws.receive_data(data) - msgs = [] - for event in self.ws.events(): - if isinstance(event, Request): - log.debug("WebSocket upgrade on fd %d", self.sock.fileno()) - #self.ws_data += self.ws.send(event.response()) # Why not?! - self.ws_data += self.ws.send(AcceptConnection()) - elif isinstance(event, (CloseConnection, Ping)): - log.debug("%s on fd %d", event, self.sock.fileno()) - self.ws_data += self.ws.send(event.response()) - elif isinstance(event, TextMessage): - # TODO: save imei "subscription" - log.debug("%s on fd %d", event, self.sock.fileno()) - msgs.append(event.data) - else: - log.warning("%s on fd %d", event, self.sock.fileno()) + try: + self.ws.receive_data(data) + except RemoteProtocolError as e: + log.debug( + "Websocket error on fd %d, try plain http (%s)", + self.sock.fileno(), + e, + ) + self.ws_data = try_http(data, self.sock.fileno(), e) + log.debug("Sending HTTP response to %d", self.sock.fileno()) + msgs = None + else: + msgs = [] + for event in self.ws.events(): + if isinstance(event, Request): + log.debug("WebSocket upgrade on fd %d", self.sock.fileno()) + # self.ws_data += self.ws.send(event.response()) # Why not?! + self.ws_data += self.ws.send(AcceptConnection()) + elif isinstance(event, (CloseConnection, Ping)): + log.debug("%s on fd %d", event, self.sock.fileno()) + self.ws_data += self.ws.send(event.response()) + elif isinstance(event, TextMessage): + # TODO: save imei "subscription" + log.debug("%s on fd %d", event, self.sock.fileno()) + msgs.append(event.data) + else: + log.warning("%s on fd %d", event, self.sock.fileno()) if self.ws_data: # Temp hack self.write() return msgs - def send(self, imei, message): + def wants(self, imei): + return True # TODO: check subscriptions + + def send(self, message): # TODO: filter only wanted imei got from the client - self.ws_data += self.ws.send(Message(data=message)) + self.ws_data += self.ws.send(Message(data=message.json)) def write(self): - try: - sent = self.sock.send(self.ws_data) - self.ws_data = self.ws_data[sent:] - except OSError as e: - log.error( - "Sending to fd %d (IMEI %s): %s", - self.sock.fileno(), - self.imei, - e, - ) - self.ws_data = b"" + if self.ws_data: + try: + sent = self.sock.send(self.ws_data) + self.ws_data = self.ws_data[sent:] + except OSError as e: + log.error( + "Sending to fd %d: %s", + self.sock.fileno(), + e, + ) + self.ws_data = b"" + return bool(self.ws_data) class Clients: @@ -108,12 +183,29 @@ class Clients: log.debug("Received: %s", msg) return result - def send(self, msgs): - for clnt in self.by_fd.values(): - clnt.send(msgs) - clnt.write() + def send(self, msg): + towrite = set() + for fd, clnt in self.by_fd.items(): + if clnt.wants(msg.imei): + clnt.send(msg) + towrite.add(fd) + return towrite + + def write(self, towrite): + waiting = set() + for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]: + if clnt.write(): + waiting.add(fd) + return waiting + def runserver(conf): + global htmldata + try: + with open(conf.get("wsgateway", "htmlfile"), encoding="utf-8") as fl: + htmldata = fl.read() + except OSError: + pass zctx = zmq.Context() zsub = zctx.socket(zmq.SUB) zsub.connect(conf.get("lookaside", "publishurl")) @@ -129,17 +221,19 @@ def runserver(conf): poller.register(tcpfd, flags=zmq.POLLIN) clients = Clients() try: + towait = set() while True: tosend = [] topoll = [] tostop = [] - events = poller.poll(1000) + towrite = set() + events = poller.poll(5000) for sk, fl in events: if sk is zsub: while True: try: - msg = zsub.recv(zmq.NOBLOCK) - tosend.append(LocEvt(msg)) + zmsg = LocEvt(zsub.recv(zmq.NOBLOCK)) + tosend.append(zmsg) except zmq.Again: break elif sk == tcpfd: @@ -148,13 +242,15 @@ def runserver(conf): elif fl & zmq.POLLIN: received = clients.recv(sk) if received is None: - log.debug( - "Client gone from fd %d", sk - ) + log.debug("Client gone from fd %d", sk) tostop.append(sk) else: for msg in received: log.debug("Received from %d: %s", sk, msg) + elif fl & zmq.POLLOUT: + log.debug("Write now open for fd %d", sk) + towrite.add(sk) + towait.discard(sk) else: log.debug("Stray event: %s on socket %s", fl, sk) # poll queue consumed, make changes now @@ -162,12 +258,26 @@ def runserver(conf): poller.unregister(fd) clients.stop(fd) for zmsg in tosend: - log.debug("Sending to the client: %s", zmsg) - clients.send(zmsg) + log.debug("Sending to the clients: %s", zmsg) + towrite |= clients.send(zmsg) for clntsock, clntaddr in topoll: fd = clients.add(clntsock, clntaddr) poller.register(fd, flags=zmq.POLLIN) - # TODO: Handle write overruns (register for POLLOUT) + # Deal with actually writing the data out + trywrite = towrite - towait + morewait = clients.write(trywrite) + log.debug( + "towait %s, tried %s, still busy %s", + towait, + trywrite, + morewait, + ) + for fd in morewait - trywrite: # new fds waiting for write + poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT) + for fd in trywrite - morewait: # no longer waiting for write + poller.modify(fd, flags=zmq.POLLIN) + towait &= trywrite + towait |= morewait except KeyboardInterrupt: pass