]> www.average.org Git - loctrkd.git/blobdiff - gps303/collector.py
collector: prevent two active clients share IMEI
[loctrkd.git] / gps303 / collector.py
index f4d960eb228d7cb65c5249b90c7bb394160a438f..ce41cb57b1c91cdf7b3d1f656a1fe98057a672b7 100644 (file)
@@ -3,7 +3,14 @@
 from configparser import ConfigParser
 from logging import getLogger
 from os import umask
-from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
+from socket import (
+    socket,
+    AF_INET6,
+    SOCK_STREAM,
+    SOL_SOCKET,
+    SO_KEEPALIVE,
+    SO_REUSEADDR,
+)
 from struct import pack
 from time import time
 from typing import Dict, List, Optional, Tuple
@@ -21,6 +28,8 @@ from .zmsg import Bcast, Resp
 
 log = getLogger("gps303/collector")
 
+MAXBUFFER: int = 4096
+
 
 class Client:
     """Connected socket to the terminal plus buffer and metadata"""
@@ -39,7 +48,7 @@ class Client:
     def recv(self) -> Optional[List[Tuple[float, Tuple[str, int], bytes]]]:
         """Read from the socket and parse complete messages"""
         try:
-            segment = self.sock.recv(4096)
+            segment = self.sock.recv(MAXBUFFER)
         except OSError as e:
             log.warning(
                 "Reading from fd %d (IMEI %s): %s",
@@ -57,6 +66,10 @@ class Client:
             return None
         when = time()
         self.buffer += segment
+        if len(self.buffer) > MAXBUFFER:
+            # We are receiving junk. Let's drop it or we run out of memory.
+            log.warning("More than %d unparseable data, dropping", MAXBUFFER)
+            self.buffer = b""
         msgs = []
         while True:
             framestart = self.buffer.find(b"xx")
@@ -64,8 +77,9 @@ class Client:
                 break
             if framestart > 0:  # Should not happen, report
                 log.warning(
-                    'Undecodable data "%s" from fd %d (IMEI %s)',
-                    self.buffer[:framestart].hex(),
+                    'Undecodable data (%d) "%s" from fd %d (IMEI %s)',
+                    framestart,
+                    self.buffer[:framestart][:64].hex(),
                     self.sock.fileno(),
                     self.imei,
                 )
@@ -82,18 +96,18 @@ class Client:
             # Do this embarrassing hack to avoid accidental match
             # of some binary data in the packet against '\r\n'.
             while True:
-                frameend = self.buffer.find(b"\r\n", frameend)
-                if frameend >= (exp_end - 3):  # Found realistic match
+                frameend = self.buffer.find(b"\r\n", frameend + 1)
+                if frameend == -1 or frameend >= (
+                    exp_end - 3
+                ):  # Found realistic match or none
                     break
             if frameend == -1:  # Incomplete frame, return what we have
                 break
             packet = self.buffer[2:frameend]
             self.buffer = self.buffer[frameend + 2 :]
-            if proto_of_message(packet) == LOGIN.PROTO:
-                self.imei = parse_message(packet).imei
-                log.info(
-                    "LOGIN from fd %d (IMEI %s)", self.sock.fileno(), self.imei
-                )
+            if len(packet) < 2:  # frameend comes too early
+                log.warning("Packet too short: %s", packet)
+                break
             msgs.append((when, self.addr, packet))
         return msgs
 
@@ -103,7 +117,7 @@ class Client:
         except OSError as e:
             log.error(
                 "Sending to fd %d (IMEI %s): %s",
-                self.sock.fileno,
+                self.sock.fileno(),
                 self.imei,
                 e,
             )
@@ -128,18 +142,39 @@ class Clients:
             del self.by_imei[clnt.imei]
         del self.by_fd[fd]
 
-    def recv(self, fd: int) -> Optional[List[Tuple[Optional[str], float, Tuple[str, int], bytes]]]:
+    def recv(
+        self, fd: int
+    ) -> Optional[List[Tuple[Optional[str], float, Tuple[str, int], bytes]]]:
         clnt = self.by_fd[fd]
         msgs = clnt.recv()
         if msgs is None:
             return None
         result = []
         for when, peeraddr, packet in msgs:
-            if proto_of_message(packet) == LOGIN.PROTO:  # Could do blindly...
-                if clnt.imei:
+            if proto_of_message(packet) == LOGIN.PROTO:
+                msg = parse_message(packet)
+                if isinstance(msg, LOGIN):  # Can be unparseable
+                    if clnt.imei is None:
+                        clnt.imei = msg.imei
+                        log.info(
+                            "LOGIN from fd %d (IMEI %s)",
+                            clnt.sock.fileno(),
+                            clnt.imei,
+                        )
+                        oldclnt = self.by_imei.get(clnt.imei)
+                        if oldclnt is not None:
+                            log.info(
+                                "Orphaning fd %d with the same IMEI",
+                                oldclnt.sock.fileno(),
+                            )
+                            oldclnt.imei = None
                     self.by_imei[clnt.imei] = clnt
                 else:
-                    log.warning("Login message from %s: %s, but client imei unfilled", peeraddr, packet)
+                    log.warning(
+                        "Login message from %s: %s, but client imei unfilled",
+                        peeraddr,
+                        packet,
+                    )
             result.append((clnt.imei, when, peeraddr, packet))
             log.debug(
                 "Received from %s (IMEI %s): %s",
@@ -156,7 +191,7 @@ class Clients:
             log.info("Not connected (IMEI %s)", resp.imei)
 
 
-def runserver(conf: ConfigParser) -> None:
+def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
     # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
     zctx = zmq.Context()  # type: ignore
     zpub = zctx.socket(zmq.PUB)  # type: ignore
@@ -191,6 +226,7 @@ def runserver(conf: ConfigParser) -> None:
                             break
                 elif sk == tcpfd:
                     clntsock, clntaddr = tcpl.accept()
+                    clntsock.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1)
                     topoll.append((clntsock, clntaddr))
                 elif fl & zmq.POLLIN:
                     received = clients.recv(sk)
@@ -209,7 +245,7 @@ def runserver(conf: ConfigParser) -> None:
                                     packet=packet,
                                 ).packed
                             )
-                            if proto == HIBERNATION.PROTO:
+                            if proto == HIBERNATION.PROTO and handle_hibernate:
                                 log.debug(
                                     "HIBERNATION from fd %d (IMEI %s)",
                                     sk,
@@ -224,9 +260,6 @@ def runserver(conf: ConfigParser) -> None:
                 else:
                     log.debug("Stray event: %s on socket %s", fl, sk)
             # poll queue consumed, make changes now
-            for fd in tostop:
-                poller.unregister(fd)  # type: ignore
-                clients.stop(fd)
             for zmsg in tosend:
                 zpub.send(
                     Bcast(
@@ -239,11 +272,17 @@ def runserver(conf: ConfigParser) -> None:
                 )
                 log.debug("Sending to the client: %s", zmsg)
                 clients.response(zmsg)
+            for fd in tostop:
+                poller.unregister(fd)  # type: ignore
+                clients.stop(fd)
             for clntsock, clntaddr in topoll:
                 fd = clients.add(clntsock, clntaddr)
                 poller.register(fd, flags=zmq.POLLIN)
     except KeyboardInterrupt:
-        pass
+        zpub.close()
+        zpull.close()
+        zctx.destroy()  # type: ignore
+        tcpl.close()
 
 
 if __name__.endswith("__main__"):