]> www.average.org Git - loctrkd.git/blobdiff - gps303/collector.py
Rename gps303proto to zx303proto
[loctrkd.git] / gps303 / collector.py
index 9f305e55f60197f38d84e820a793192bb93abbc2..cb45d8144bb9ec132265f8bed3d18881b6dab6de 100644 (file)
 """ TCP server that communicates with terminals """
 
+from configparser import ConfigParser
+from importlib import import_module
 from logging import getLogger
 from os import umask
-from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
-from time import time
+from socket import (
+    socket,
+    AF_INET6,
+    SOCK_STREAM,
+    SOL_SOCKET,
+    SO_KEEPALIVE,
+    SO_REUSEADDR,
+)
 from struct import pack
+from time import time
+from typing import Any, cast, Dict, List, Optional, Tuple, Union
 import zmq
 
 from . import common
-from .gps303proto import (
-    HIBERNATION,
-    LOGIN,
-    inline_response,
-    parse_message,
-    proto_of_message,
-)
 from .zmsg import Bcast, Resp
 
 log = getLogger("gps303/collector")
 
+MAXBUFFER: int = 4096
+
+
+class ProtoModule:
+    class Stream:
+        @staticmethod
+        def enframe(buffer: bytes) -> bytes:
+            ...
+
+        def recv(self, segment: bytes) -> List[Union[bytes, str]]:
+            ...
+
+        def close(self) -> bytes:
+            ...
+
+    @staticmethod
+    def probe_buffer(buffer: bytes) -> bool:
+        ...
+
+    @staticmethod
+    def parse_message(packet: bytes, is_incoming: bool = True) -> Any:
+        ...
+
+    @staticmethod
+    def inline_response(packet: bytes) -> Optional[bytes]:
+        ...
+
+    @staticmethod
+    def is_goodbye_packet(packet: bytes) -> bool:
+        ...
+
+    @staticmethod
+    def imei_from_packet(packet: bytes) -> Optional[str]:
+        ...
+
+    @staticmethod
+    def proto_of_message(packet: bytes) -> str:
+        ...
+
+    @staticmethod
+    def proto_by_name(name: str) -> int:
+        ...
+
+
+pmods: List[ProtoModule] = []
+
 
 class Client:
     """Connected socket to the terminal plus buffer and metadata"""
 
-    def __init__(self, sock, addr):
+    def __init__(self, sock: socket, addr: Tuple[str, int]) -> None:
         self.sock = sock
         self.addr = addr
-        self.buffer = b""
-        self.imei = None
+        self.pmod: Optional[ProtoModule] = None
+        self.stream: Optional[ProtoModule.Stream] = None
+        self.imei: Optional[str] = None
 
-    def close(self):
+    def close(self) -> None:
         log.debug("Closing fd %d (IMEI %s)", self.sock.fileno(), self.imei)
         self.sock.close()
-        self.buffer = b""
+        if self.stream:
+            rest = self.stream.close()
+        else:
+            rest = b""
+        if rest:
+            log.warning(
+                "%d bytes in buffer on close: %s", len(rest), rest[:64].hex()
+            )
 
-    def recv(self):
+    def recv(self) -> Optional[List[Tuple[float, Tuple[str, int], bytes]]]:
         """Read from the socket and parse complete messages"""
         try:
-            segment = self.sock.recv(4096)
-        except OSError:
+            segment = self.sock.recv(MAXBUFFER)
+        except OSError as e:
             log.warning(
                 "Reading from fd %d (IMEI %s): %s",
                 self.sock.fileno(),
@@ -53,59 +110,59 @@ class Client:
                 self.imei,
             )
             return None
+        if self.stream is None:
+            for pmod in pmods:
+                if pmod.probe_buffer(segment):
+                    self.pmod = pmod
+                    self.stream = pmod.Stream()
+                    break
+        if self.stream is None:
+            log.info(
+                "unrecognizable %d bytes of data %s from fd %d",
+                len(segment),
+                segment[:32].hex(),
+                self.sock.fileno(),
+            )
+            return []
         when = time()
-        self.buffer += segment
         msgs = []
-        while True:
-            framestart = self.buffer.find(b"xx")
-            if framestart == -1:  # No frames, return whatever we have
-                break
-            if framestart > 0:  # Should not happen, report
+        for elem in self.stream.recv(segment):
+            if isinstance(elem, bytes):
+                msgs.append((when, self.addr, elem))
+            else:
                 log.warning(
-                    'Undecodable data "%s" from fd %d (IMEI %s)',
-                    self.buffer[:framestart].hex(),
+                    "%s from fd %d (IMEI %s)",
+                    elem,
                     self.sock.fileno(),
                     self.imei,
                 )
-                self.buffer = self.buffer[framestart:]
-            # At this point, buffer starts with a packet
-            frameend = self.buffer.find(b"\r\n", 4)
-            if frameend == -1:  # Incomplete frame, return what we have
-                break
-            packet = self.buffer[2:frameend]
-            self.buffer = self.buffer[frameend + 2 :]
-            if proto_of_message(packet) == LOGIN.PROTO:
-                self.imei = parse_message(packet).imei
-                log.info(
-                    "LOGIN from fd %d (IMEI %s)", self.sock.fileno(), self.imei
-                )
-            msgs.append((when, self.addr, packet))
         return msgs
 
-    def send(self, buffer):
+    def send(self, buffer: bytes) -> None:
+        assert self.stream is not None
         try:
-            self.sock.send(b"xx" + buffer + b"\r\n")
+            self.sock.send(self.stream.enframe(buffer))
         except OSError as e:
             log.error(
                 "Sending to fd %d (IMEI %s): %s",
-                self.sock.fileno,
+                self.sock.fileno(),
                 self.imei,
                 e,
             )
 
 
 class Clients:
-    def __init__(self):
-        self.by_fd = {}
-        self.by_imei = {}
+    def __init__(self) -> None:
+        self.by_fd: Dict[int, Client] = {}
+        self.by_imei: Dict[str, Client] = {}
 
-    def add(self, clntsock, clntaddr):
+    def add(self, clntsock: socket, clntaddr: Tuple[str, int]) -> int:
         fd = clntsock.fileno()
         log.info("Start serving fd %d from %s", fd, clntaddr)
         self.by_fd[fd] = Client(clntsock, clntaddr)
         return fd
 
-    def stop(self, fd):
+    def stop(self, fd: int) -> None:
         clnt = self.by_fd[fd]
         log.info("Stop serving fd %d (IMEI %s)", clnt.sock.fileno(), clnt.imei)
         clnt.close()
@@ -113,16 +170,38 @@ class Clients:
             del self.by_imei[clnt.imei]
         del self.by_fd[fd]
 
-    def recv(self, fd):
+    def recv(
+        self, fd: int
+    ) -> Optional[
+        List[Tuple[ProtoModule, Optional[str], float, Tuple[str, int], bytes]]
+    ]:
         clnt = self.by_fd[fd]
         msgs = clnt.recv()
         if msgs is None:
             return None
         result = []
         for when, peeraddr, packet in msgs:
-            if proto_of_message(packet) == LOGIN.PROTO:  # Could do blindly...
-                self.by_imei[clnt.imei] = clnt
-            result.append((clnt.imei, when, peeraddr, packet))
+            assert clnt.pmod is not None
+            if clnt.imei is None:
+                imei = clnt.pmod.imei_from_packet(packet)
+                if imei is not None:
+                    log.info("LOGIN from fd %d (IMEI %s)", fd, imei)
+                    clnt.imei = imei
+                    oldclnt = self.by_imei.get(clnt.imei)
+                    if oldclnt is not None:
+                        log.info(
+                            "Orphaning fd %d with the same IMEI",
+                            oldclnt.sock.fileno(),
+                        )
+                        oldclnt.imei = None
+                    self.by_imei[clnt.imei] = clnt
+                else:
+                    log.warning(
+                        "Login message from %s: %s, but client imei unfilled",
+                        peeraddr,
+                        packet,
+                    )
+            result.append((clnt.pmod, clnt.imei, when, peeraddr, packet))
             log.debug(
                 "Received from %s (IMEI %s): %s",
                 peeraddr,
@@ -131,17 +210,26 @@ class Clients:
             )
         return result
 
-    def response(self, resp):
+    def response(self, resp: Resp) -> Optional[ProtoModule]:
         if resp.imei in self.by_imei:
-            self.by_imei[resp.imei].send(resp.packet)
+            clnt = self.by_imei[resp.imei]
+            clnt.send(resp.packet)
+            return clnt.pmod
         else:
             log.info("Not connected (IMEI %s)", resp.imei)
+            return None
 
 
-def runserver(conf):
-    zctx = zmq.Context()
-    zpub = zctx.socket(zmq.PUB)
-    zpull = zctx.socket(zmq.PULL)
+def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
+    global pmods
+    pmods = [
+        cast(ProtoModule, import_module("." + modnm, __package__))
+        for modnm in conf.get("collector", "protocols").split(",")
+    ]
+    # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
+    zctx = zmq.Context()  # type: ignore
+    zpub = zctx.socket(zmq.PUB)  # type: ignore
+    zpull = zctx.socket(zmq.PULL)  # type: ignore
     oldmask = umask(0o117)
     zpub.bind(conf.get("collector", "publishurl"))
     zpull.bind(conf.get("collector", "listenurl"))
@@ -151,7 +239,7 @@ def runserver(conf):
     tcpl.bind(("", conf.getint("collector", "port")))
     tcpl.listen(5)
     tcpfd = tcpl.fileno()
-    poller = zmq.Poller()
+    poller = zmq.Poller()  # type: ignore
     poller.register(zpull, flags=zmq.POLLIN)
     poller.register(tcpfd, flags=zmq.POLLIN)
     clients = Clients()
@@ -166,22 +254,22 @@ def runserver(conf):
                     while True:
                         try:
                             msg = zpull.recv(zmq.NOBLOCK)
-                            tosend.append(Resp(msg))
+                            zmsg = Resp(msg)
+                            tosend.append(zmsg)
                         except zmq.Again:
                             break
                 elif sk == tcpfd:
                     clntsock, clntaddr = tcpl.accept()
+                    clntsock.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1)
                     topoll.append((clntsock, clntaddr))
                 elif fl & zmq.POLLIN:
                     received = clients.recv(sk)
                     if received is None:
-                        log.debug(
-                            "Terminal gone from fd %d (IMEI %s)", sk, imei
-                        )
+                        log.debug("Terminal gone from fd %d", sk)
                         tostop.append(sk)
                     else:
-                        for imei, when, peeraddr, packet in received:
-                            proto = proto_of_message(packet)
+                        for pmod, imei, when, peeraddr, packet in received:
+                            proto = pmod.proto_of_message(packet)
                             zpub.send(
                                 Bcast(
                                     proto=proto,
@@ -191,32 +279,48 @@ def runserver(conf):
                                     packet=packet,
                                 ).packed
                             )
-                            if proto == HIBERNATION.PROTO:
+                            if (
+                                pmod.is_goodbye_packet(packet)
+                                and handle_hibernate
+                            ):
                                 log.debug(
-                                    "HIBERNATION from fd %d (IMEI %s)",
+                                    "Goodbye from fd %d (IMEI %s)",
                                     sk,
                                     imei,
                                 )
                                 tostop.append(sk)
-                            respmsg = inline_response(packet)
+                            respmsg = pmod.inline_response(packet)
                             if respmsg is not None:
-                                clients.response(
-                                    Resp(imei=imei, packet=respmsg)
+                                tosend.append(
+                                    Resp(imei=imei, when=when, packet=respmsg)
                                 )
                 else:
                     log.debug("Stray event: %s on socket %s", fl, sk)
             # poll queue consumed, make changes now
-            for fd in tostop:
-                poller.unregister(fd)
-                clients.stop(fd)
             for zmsg in tosend:
                 log.debug("Sending to the client: %s", zmsg)
-                clients.response(zmsg)
+                rpmod = clients.response(zmsg)
+                if rpmod is not None:
+                    zpub.send(
+                        Bcast(
+                            is_incoming=False,
+                            proto=rpmod.proto_of_message(zmsg.packet),
+                            when=zmsg.when,
+                            imei=zmsg.imei,
+                            packet=zmsg.packet,
+                        ).packed
+                    )
+            for fd in tostop:
+                poller.unregister(fd)  # type: ignore
+                clients.stop(fd)
             for clntsock, clntaddr in topoll:
                 fd = clients.add(clntsock, clntaddr)
                 poller.register(fd, flags=zmq.POLLIN)
     except KeyboardInterrupt:
-        pass
+        zpub.close()
+        zpull.close()
+        zctx.destroy()  # type: ignore
+        tcpl.close()
 
 
 if __name__.endswith("__main__"):