]> www.average.org Git - loctrkd.git/blob - gps303/collector.py
collector: prevent two active clients share IMEI
[loctrkd.git] / gps303 / collector.py
1 """ TCP server that communicates with terminals """
2
3 from configparser import ConfigParser
4 from logging import getLogger
5 from os import umask
6 from socket import (
7     socket,
8     AF_INET6,
9     SOCK_STREAM,
10     SOL_SOCKET,
11     SO_KEEPALIVE,
12     SO_REUSEADDR,
13 )
14 from struct import pack
15 from time import time
16 from typing import Dict, List, Optional, Tuple
17 import zmq
18
19 from . import common
20 from .gps303proto import (
21     HIBERNATION,
22     LOGIN,
23     inline_response,
24     parse_message,
25     proto_of_message,
26 )
27 from .zmsg import Bcast, Resp
28
29 log = getLogger("gps303/collector")
30
31 MAXBUFFER: int = 4096
32
33
34 class Client:
35     """Connected socket to the terminal plus buffer and metadata"""
36
37     def __init__(self, sock: socket, addr: Tuple[str, int]) -> None:
38         self.sock = sock
39         self.addr = addr
40         self.buffer = b""
41         self.imei: Optional[str] = None
42
43     def close(self) -> None:
44         log.debug("Closing fd %d (IMEI %s)", self.sock.fileno(), self.imei)
45         self.sock.close()
46         self.buffer = b""
47
48     def recv(self) -> Optional[List[Tuple[float, Tuple[str, int], bytes]]]:
49         """Read from the socket and parse complete messages"""
50         try:
51             segment = self.sock.recv(MAXBUFFER)
52         except OSError as e:
53             log.warning(
54                 "Reading from fd %d (IMEI %s): %s",
55                 self.sock.fileno(),
56                 self.imei,
57                 e,
58             )
59             return None
60         if not segment:  # Terminal has closed connection
61             log.info(
62                 "EOF reading from fd %d (IMEI %s)",
63                 self.sock.fileno(),
64                 self.imei,
65             )
66             return None
67         when = time()
68         self.buffer += segment
69         if len(self.buffer) > MAXBUFFER:
70             # We are receiving junk. Let's drop it or we run out of memory.
71             log.warning("More than %d unparseable data, dropping", MAXBUFFER)
72             self.buffer = b""
73         msgs = []
74         while True:
75             framestart = self.buffer.find(b"xx")
76             if framestart == -1:  # No frames, return whatever we have
77                 break
78             if framestart > 0:  # Should not happen, report
79                 log.warning(
80                     'Undecodable data (%d) "%s" from fd %d (IMEI %s)',
81                     framestart,
82                     self.buffer[:framestart][:64].hex(),
83                     self.sock.fileno(),
84                     self.imei,
85                 )
86                 self.buffer = self.buffer[framestart:]
87             # At this point, buffer starts with a packet
88             if len(self.buffer) < 6:  # no len and proto - cannot proceed
89                 break
90             exp_end = self.buffer[2] + 3  # Expect '\r\n' here
91             frameend = 0
92             # Length field can legitimeely be much less than the
93             # length of the packet (e.g. WiFi positioning), but
94             # it _should not_ be greater. Still sometimes it is.
95             # Luckily, not by too much: by maybe two or three bytes?
96             # Do this embarrassing hack to avoid accidental match
97             # of some binary data in the packet against '\r\n'.
98             while True:
99                 frameend = self.buffer.find(b"\r\n", frameend + 1)
100                 if frameend == -1 or frameend >= (
101                     exp_end - 3
102                 ):  # Found realistic match or none
103                     break
104             if frameend == -1:  # Incomplete frame, return what we have
105                 break
106             packet = self.buffer[2:frameend]
107             self.buffer = self.buffer[frameend + 2 :]
108             if len(packet) < 2:  # frameend comes too early
109                 log.warning("Packet too short: %s", packet)
110                 break
111             msgs.append((when, self.addr, packet))
112         return msgs
113
114     def send(self, buffer: bytes) -> None:
115         try:
116             self.sock.send(b"xx" + buffer + b"\r\n")
117         except OSError as e:
118             log.error(
119                 "Sending to fd %d (IMEI %s): %s",
120                 self.sock.fileno(),
121                 self.imei,
122                 e,
123             )
124
125
126 class Clients:
127     def __init__(self) -> None:
128         self.by_fd: Dict[int, Client] = {}
129         self.by_imei: Dict[str, Client] = {}
130
131     def add(self, clntsock: socket, clntaddr: Tuple[str, int]) -> int:
132         fd = clntsock.fileno()
133         log.info("Start serving fd %d from %s", fd, clntaddr)
134         self.by_fd[fd] = Client(clntsock, clntaddr)
135         return fd
136
137     def stop(self, fd: int) -> None:
138         clnt = self.by_fd[fd]
139         log.info("Stop serving fd %d (IMEI %s)", clnt.sock.fileno(), clnt.imei)
140         clnt.close()
141         if clnt.imei:
142             del self.by_imei[clnt.imei]
143         del self.by_fd[fd]
144
145     def recv(
146         self, fd: int
147     ) -> Optional[List[Tuple[Optional[str], float, Tuple[str, int], bytes]]]:
148         clnt = self.by_fd[fd]
149         msgs = clnt.recv()
150         if msgs is None:
151             return None
152         result = []
153         for when, peeraddr, packet in msgs:
154             if proto_of_message(packet) == LOGIN.PROTO:
155                 msg = parse_message(packet)
156                 if isinstance(msg, LOGIN):  # Can be unparseable
157                     if clnt.imei is None:
158                         clnt.imei = msg.imei
159                         log.info(
160                             "LOGIN from fd %d (IMEI %s)",
161                             clnt.sock.fileno(),
162                             clnt.imei,
163                         )
164                         oldclnt = self.by_imei.get(clnt.imei)
165                         if oldclnt is not None:
166                             log.info(
167                                 "Orphaning fd %d with the same IMEI",
168                                 oldclnt.sock.fileno(),
169                             )
170                             oldclnt.imei = None
171                     self.by_imei[clnt.imei] = clnt
172                 else:
173                     log.warning(
174                         "Login message from %s: %s, but client imei unfilled",
175                         peeraddr,
176                         packet,
177                     )
178             result.append((clnt.imei, when, peeraddr, packet))
179             log.debug(
180                 "Received from %s (IMEI %s): %s",
181                 peeraddr,
182                 clnt.imei,
183                 packet.hex(),
184             )
185         return result
186
187     def response(self, resp: Resp) -> None:
188         if resp.imei in self.by_imei:
189             self.by_imei[resp.imei].send(resp.packet)
190         else:
191             log.info("Not connected (IMEI %s)", resp.imei)
192
193
194 def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
195     # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
196     zctx = zmq.Context()  # type: ignore
197     zpub = zctx.socket(zmq.PUB)  # type: ignore
198     zpull = zctx.socket(zmq.PULL)  # type: ignore
199     oldmask = umask(0o117)
200     zpub.bind(conf.get("collector", "publishurl"))
201     zpull.bind(conf.get("collector", "listenurl"))
202     umask(oldmask)
203     tcpl = socket(AF_INET6, SOCK_STREAM)
204     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
205     tcpl.bind(("", conf.getint("collector", "port")))
206     tcpl.listen(5)
207     tcpfd = tcpl.fileno()
208     poller = zmq.Poller()  # type: ignore
209     poller.register(zpull, flags=zmq.POLLIN)
210     poller.register(tcpfd, flags=zmq.POLLIN)
211     clients = Clients()
212     try:
213         while True:
214             tosend = []
215             topoll = []
216             tostop = []
217             events = poller.poll(1000)
218             for sk, fl in events:
219                 if sk is zpull:
220                     while True:
221                         try:
222                             msg = zpull.recv(zmq.NOBLOCK)
223                             zmsg = Resp(msg)
224                             tosend.append(zmsg)
225                         except zmq.Again:
226                             break
227                 elif sk == tcpfd:
228                     clntsock, clntaddr = tcpl.accept()
229                     clntsock.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1)
230                     topoll.append((clntsock, clntaddr))
231                 elif fl & zmq.POLLIN:
232                     received = clients.recv(sk)
233                     if received is None:
234                         log.debug("Terminal gone from fd %d", sk)
235                         tostop.append(sk)
236                     else:
237                         for imei, when, peeraddr, packet in received:
238                             proto = proto_of_message(packet)
239                             zpub.send(
240                                 Bcast(
241                                     proto=proto,
242                                     imei=imei,
243                                     when=when,
244                                     peeraddr=peeraddr,
245                                     packet=packet,
246                                 ).packed
247                             )
248                             if proto == HIBERNATION.PROTO and handle_hibernate:
249                                 log.debug(
250                                     "HIBERNATION from fd %d (IMEI %s)",
251                                     sk,
252                                     imei,
253                                 )
254                                 tostop.append(sk)
255                             respmsg = inline_response(packet)
256                             if respmsg is not None:
257                                 tosend.append(
258                                     Resp(imei=imei, when=when, packet=respmsg)
259                                 )
260                 else:
261                     log.debug("Stray event: %s on socket %s", fl, sk)
262             # poll queue consumed, make changes now
263             for zmsg in tosend:
264                 zpub.send(
265                     Bcast(
266                         is_incoming=False,
267                         proto=proto_of_message(zmsg.packet),
268                         when=zmsg.when,
269                         imei=zmsg.imei,
270                         packet=zmsg.packet,
271                     ).packed
272                 )
273                 log.debug("Sending to the client: %s", zmsg)
274                 clients.response(zmsg)
275             for fd in tostop:
276                 poller.unregister(fd)  # type: ignore
277                 clients.stop(fd)
278             for clntsock, clntaddr in topoll:
279                 fd = clients.add(clntsock, clntaddr)
280                 poller.register(fd, flags=zmq.POLLIN)
281     except KeyboardInterrupt:
282         zpub.close()
283         zpull.close()
284         zctx.destroy()  # type: ignore
285         tcpl.close()
286
287
288 if __name__.endswith("__main__"):
289     runserver(common.init(log))