]> www.average.org Git - loctrkd.git/blobdiff - gps303/wsgateway.py
fix zmq subscription topics
[loctrkd.git] / gps303 / wsgateway.py
index 127dbb4b002fc8c62dce6be2cd42e35b6df54e3e..00770ebd88032c35f9b9e7172dba3c63233d7434 100644 (file)
@@ -1,5 +1,6 @@
 """ Websocket Gateway """
 
+from json import loads
 from logging import getLogger
 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
 from time import time
@@ -16,14 +17,20 @@ from wsproto.utilities import RemoteProtocolError
 import zmq
 
 from . import common
-from .zmsg import LocEvt
+from .backlog import blinit, backlog
+from .gps303proto import (
+    GPS_POSITIONING,
+    WIFI_POSITIONING,
+    parse_message,
+)
+from .zmsg import Bcast, topic
 
 log = getLogger("gps303/wsgateway")
-htmldata = None
+htmlfile = None
 
 
 def try_http(data, fd, e):
-    global htmldata
+    global htmlfile
     try:
         lines = data.decode().split("\r\n")
         request = lines[0]
@@ -43,19 +50,28 @@ def try_http(data, fd, e):
         except ValueError:
             pass
         if op == "GET":
-            if htmldata is None:
+            if htmlfile is None:
                 return (
                     f"{proto} 500 No data configured\r\n"
                     f"Content-Type: text/plain\r\n\r\n"
                     f"HTML data not configure on the server\r\n".encode()
                 )
             elif resource == "/":
-                length = len(htmldata.encode("utf-8"))
-                return (
-                    f"{proto} 200 Ok\r\n"
-                    f"Content-Type: text/html; charset=utf-8\r\n"
-                    f"Content-Length: {length:d}\r\n\r\n" + htmldata
-                ).encode("utf-8")
+                try:
+                    with open(htmlfile, "rb") as fl:
+                        htmldata = fl.read()
+                    length = len(htmldata)
+                    return (
+                        f"{proto} 200 Ok\r\n"
+                        f"Content-Type: text/html; charset=utf-8\r\n"
+                        f"Content-Length: {len(htmldata):d}\r\n\r\n"
+                    ).encode("utf-8") + htmldata
+                except OSError:
+                    return (
+                        f"{proto} 500 File not found\r\n"
+                        f"Content-Type: text/plain\r\n\r\n"
+                        f"HTML file could not be opened\r\n".encode()
+                    )
             else:
                 return (
                     f"{proto} 404 File not found\r\n"
@@ -81,6 +97,8 @@ class Client:
         self.addr = addr
         self.ws = WSConnection(ConnectionType.SERVER)
         self.ws_data = b""
+        self.ready = False
+        self.imeis = set()
 
     def close(self):
         log.debug("Closing fd %d", self.sock.fileno())
@@ -113,6 +131,7 @@ class Client:
                 e,
             )
             self.ws_data = try_http(data, self.sock.fileno(), e)
+            self.write()  # TODO this is a hack
             log.debug("Sending HTTP response to %d", self.sock.fileno())
             msgs = None
         else:
@@ -122,25 +141,38 @@ class Client:
                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
                     # self.ws_data += self.ws.send(event.response())  # Why not?!
                     self.ws_data += self.ws.send(AcceptConnection())
+                    self.ready = True
                 elif isinstance(event, (CloseConnection, Ping)):
                     log.debug("%s on fd %d", event, self.sock.fileno())
                     self.ws_data += self.ws.send(event.response())
                 elif isinstance(event, TextMessage):
                     # TODO: save imei "subscription"
                     log.debug("%s on fd %d", event, self.sock.fileno())
-                    msgs.append(event.data)
+                    msg = loads(event.data)
+                    msgs.append(msg)
+                    if msg.get("type", None) == "subscribe":
+                        self.imeis = set(msg.get("imei", []))
+                        log.debug(
+                            "subs list on fd %s is %s",
+                            self.sock.fileno(),
+                            self.imeis,
+                        )
                 else:
                     log.warning("%s on fd %d", event, self.sock.fileno())
-        if self.ws_data:  # Temp hack
-            self.write()
         return msgs
 
     def wants(self, imei):
+        log.debug(
+            "wants %s? set is %s on fd %d",
+            imei,
+            self.imeis,
+            self.sock.fileno(),
+        )
         return True  # TODO: check subscriptions
 
     def send(self, message):
-        # TODO: filter only wanted imei got from the client
-        self.ws_data += self.ws.send(Message(data=message.json))
+        if self.ready and message.imei in self.imeis:
+            self.ws_data += self.ws.send(Message(data=message.json))
 
     def write(self):
         if self.ws_data:
@@ -175,13 +207,7 @@ class Clients:
 
     def recv(self, fd):
         clnt = self.by_fd[fd]
-        msgs = clnt.recv()
-        if msgs is None:
-            return None
-        result = []
-        for msg in msgs:
-            log.debug("Received: %s", msg)
-        return result
+        return clnt.recv()
 
     def send(self, msg):
         towrite = set()
@@ -198,18 +224,21 @@ class Clients:
                 waiting.add(fd)
         return waiting
 
+    def subs(self):
+        result = set()
+        for clnt in self.by_fd.values():
+            result |= clnt.imeis
+        return result
+
 
 def runserver(conf):
-    global htmldata
-    try:
-        with open(conf.get("wsgateway", "htmlfile"), encoding="utf-8") as fl:
-            htmldata = fl.read()
-    except OSError:
-        pass
+    global htmlfile
+
+    blinit(conf.get("storage", "dbfn"))
+    htmlfile = conf.get("wsgateway", "htmlfile")
     zctx = zmq.Context()
     zsub = zctx.socket(zmq.SUB)
-    zsub.connect(conf.get("lookaside", "publishurl"))
-    zsub.setsockopt(zmq.SUBSCRIBE, b"")
+    zsub.connect(conf.get("collector", "publishurl"))
     tcpl = socket(AF_INET6, SOCK_STREAM)
     tcpl.setblocking(False)
     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
@@ -220,20 +249,48 @@ def runserver(conf):
     poller.register(zsub, flags=zmq.POLLIN)
     poller.register(tcpfd, flags=zmq.POLLIN)
     clients = Clients()
+    activesubs = set()
     try:
         towait = set()
         while True:
+            neededsubs = clients.subs()
+            for imei in neededsubs - activesubs:
+                log.debug("topics: %s", [tpc.hex() for tpc in [topic(GPS_POSITIONING.PROTO, True, imei), topic(WIFI_POSITIONING.PROTO, False, imei)]])
+                zsub.setsockopt(
+                    zmq.SUBSCRIBE,
+                    topic(GPS_POSITIONING.PROTO, True, imei),
+                )
+                zsub.setsockopt(
+                    zmq.SUBSCRIBE,
+                    topic(WIFI_POSITIONING.PROTO, False, imei),
+                )
+            for imei in activesubs - neededsubs:
+                zsub.setsockopt(
+                    zmq.UNSUBSCRIBE,
+                    topic(GPS_POSITIONING.PROTO, True, imei),
+                )
+                zsub.setsockopt(
+                    zmq.UNSUBSCRIBE,
+                    topic(WIFI_POSITIONING.PROTO, False, imei),
+                )
+            activesubs = neededsubs
+            log.debug("Subscribed to: %s", activesubs)
             tosend = []
             topoll = []
             tostop = []
             towrite = set()
-            events = poller.poll(5000)
+            events = poller.poll()
             for sk, fl in events:
                 if sk is zsub:
                     while True:
                         try:
-                            zmsg = LocEvt(zsub.recv(zmq.NOBLOCK))
+                            buf = zsub.recv(zmq.NOBLOCK)
+                            zmsg = Bcast(buf)
+                            log.debug("zmq packet: %s", buf.hex())
+                            # zmsg = Bcast(zsub.recv(zmq.NOBLOCK))
+                            msg = parse_message(zmsg.packet)
                             tosend.append(zmsg)
+                            log.debug("Got %s", zmsg)
                         except zmq.Again:
                             break
                 elif sk == tcpfd:
@@ -244,9 +301,16 @@ def runserver(conf):
                     if received is None:
                         log.debug("Client gone from fd %d", sk)
                         tostop.append(sk)
+                        towait.discard(fd)
                     else:
                         for msg in received:
                             log.debug("Received from %d: %s", sk, msg)
+                            if msg.get("type", None) == "subscribe":
+                                imei = msg.get("imei")
+                                numback = msg.get("backlog", 5)
+                                for elem in imei:
+                                    tosend.extend(backlog(elem, numback))
+                        towrite.add(sk)
                 elif fl & zmq.POLLOUT:
                     log.debug("Write now open for fd %d", sk)
                     towrite.add(sk)