]> www.average.org Git - loctrkd.git/blob - gps303/collector.py
collector: fix problems found by fuzzer test
[loctrkd.git] / gps303 / collector.py
1 """ TCP server that communicates with terminals """
2
3 from configparser import ConfigParser
4 from logging import getLogger
5 from os import umask
6 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
7 from struct import pack
8 from time import time
9 from typing import Dict, List, Optional, Tuple
10 import zmq
11
12 from . import common
13 from .gps303proto import (
14     HIBERNATION,
15     LOGIN,
16     inline_response,
17     parse_message,
18     proto_of_message,
19 )
20 from .zmsg import Bcast, Resp
21
22 log = getLogger("gps303/collector")
23
24 MAXBUFFER: int = 4096
25
26
27 class Client:
28     """Connected socket to the terminal plus buffer and metadata"""
29
30     def __init__(self, sock: socket, addr: Tuple[str, int]) -> None:
31         self.sock = sock
32         self.addr = addr
33         self.buffer = b""
34         self.imei: Optional[str] = None
35
36     def close(self) -> None:
37         log.debug("Closing fd %d (IMEI %s)", self.sock.fileno(), self.imei)
38         self.sock.close()
39         self.buffer = b""
40
41     def recv(self) -> Optional[List[Tuple[float, Tuple[str, int], bytes]]]:
42         """Read from the socket and parse complete messages"""
43         try:
44             segment = self.sock.recv(MAXBUFFER)
45         except OSError as e:
46             log.warning(
47                 "Reading from fd %d (IMEI %s): %s",
48                 self.sock.fileno(),
49                 self.imei,
50                 e,
51             )
52             return None
53         if not segment:  # Terminal has closed connection
54             log.info(
55                 "EOF reading from fd %d (IMEI %s)",
56                 self.sock.fileno(),
57                 self.imei,
58             )
59             return None
60         when = time()
61         self.buffer += segment
62         if len(self.buffer) > MAXBUFFER:
63             # We are receiving junk. Let's drop it or we run out of memory.
64             log.warning("More than %d unparseable data, dropping", MAXBUFFER)
65             self.buffer = b""
66         msgs = []
67         while True:
68             framestart = self.buffer.find(b"xx")
69             if framestart == -1:  # No frames, return whatever we have
70                 break
71             if framestart > 0:  # Should not happen, report
72                 log.warning(
73                     'Undecodable data (%d) "%s" from fd %d (IMEI %s)',
74                     framestart,
75                     self.buffer[:framestart][:64].hex(),
76                     self.sock.fileno(),
77                     self.imei,
78                 )
79                 self.buffer = self.buffer[framestart:]
80             # At this point, buffer starts with a packet
81             if len(self.buffer) < 6:  # no len and proto - cannot proceed
82                 break
83             exp_end = self.buffer[2] + 3  # Expect '\r\n' here
84             frameend = 0
85             # Length field can legitimeely be much less than the
86             # length of the packet (e.g. WiFi positioning), but
87             # it _should not_ be greater. Still sometimes it is.
88             # Luckily, not by too much: by maybe two or three bytes?
89             # Do this embarrassing hack to avoid accidental match
90             # of some binary data in the packet against '\r\n'.
91             while True:
92                 frameend = self.buffer.find(b"\r\n", frameend + 1)
93                 if frameend == -1 or frameend >= (
94                     exp_end - 3
95                 ):  # Found realistic match or none
96                     break
97             if frameend == -1:  # Incomplete frame, return what we have
98                 break
99             packet = self.buffer[2:frameend]
100             self.buffer = self.buffer[frameend + 2 :]
101             if proto_of_message(packet) == LOGIN.PROTO:
102                 self.imei = parse_message(packet).imei
103                 log.info(
104                     "LOGIN from fd %d (IMEI %s)", self.sock.fileno(), self.imei
105                 )
106             msgs.append((when, self.addr, packet))
107         return msgs
108
109     def send(self, buffer: bytes) -> None:
110         try:
111             self.sock.send(b"xx" + buffer + b"\r\n")
112         except OSError as e:
113             log.error(
114                 "Sending to fd %d (IMEI %s): %s",
115                 self.sock.fileno,
116                 self.imei,
117                 e,
118             )
119
120
121 class Clients:
122     def __init__(self) -> None:
123         self.by_fd: Dict[int, Client] = {}
124         self.by_imei: Dict[str, Client] = {}
125
126     def add(self, clntsock: socket, clntaddr: Tuple[str, int]) -> int:
127         fd = clntsock.fileno()
128         log.info("Start serving fd %d from %s", fd, clntaddr)
129         self.by_fd[fd] = Client(clntsock, clntaddr)
130         return fd
131
132     def stop(self, fd: int) -> None:
133         clnt = self.by_fd[fd]
134         log.info("Stop serving fd %d (IMEI %s)", clnt.sock.fileno(), clnt.imei)
135         clnt.close()
136         if clnt.imei:
137             del self.by_imei[clnt.imei]
138         del self.by_fd[fd]
139
140     def recv(
141         self, fd: int
142     ) -> Optional[List[Tuple[Optional[str], float, Tuple[str, int], bytes]]]:
143         clnt = self.by_fd[fd]
144         msgs = clnt.recv()
145         if msgs is None:
146             return None
147         result = []
148         for when, peeraddr, packet in msgs:
149             if proto_of_message(packet) == LOGIN.PROTO:  # Could do blindly...
150                 if clnt.imei:
151                     self.by_imei[clnt.imei] = clnt
152                 else:
153                     log.warning(
154                         "Login message from %s: %s, but client imei unfilled",
155                         peeraddr,
156                         packet,
157                     )
158             result.append((clnt.imei, when, peeraddr, packet))
159             log.debug(
160                 "Received from %s (IMEI %s): %s",
161                 peeraddr,
162                 clnt.imei,
163                 packet.hex(),
164             )
165         return result
166
167     def response(self, resp: Resp) -> None:
168         if resp.imei in self.by_imei:
169             self.by_imei[resp.imei].send(resp.packet)
170         else:
171             log.info("Not connected (IMEI %s)", resp.imei)
172
173
174 def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
175     # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
176     zctx = zmq.Context()  # type: ignore
177     zpub = zctx.socket(zmq.PUB)  # type: ignore
178     zpull = zctx.socket(zmq.PULL)  # type: ignore
179     oldmask = umask(0o117)
180     zpub.bind(conf.get("collector", "publishurl"))
181     zpull.bind(conf.get("collector", "listenurl"))
182     umask(oldmask)
183     tcpl = socket(AF_INET6, SOCK_STREAM)
184     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
185     tcpl.bind(("", conf.getint("collector", "port")))
186     tcpl.listen(5)
187     tcpfd = tcpl.fileno()
188     poller = zmq.Poller()  # type: ignore
189     poller.register(zpull, flags=zmq.POLLIN)
190     poller.register(tcpfd, flags=zmq.POLLIN)
191     clients = Clients()
192     try:
193         while True:
194             tosend = []
195             topoll = []
196             tostop = []
197             events = poller.poll(1000)
198             for sk, fl in events:
199                 if sk is zpull:
200                     while True:
201                         try:
202                             msg = zpull.recv(zmq.NOBLOCK)
203                             zmsg = Resp(msg)
204                             tosend.append(zmsg)
205                         except zmq.Again:
206                             break
207                 elif sk == tcpfd:
208                     clntsock, clntaddr = tcpl.accept()
209                     topoll.append((clntsock, clntaddr))
210                 elif fl & zmq.POLLIN:
211                     received = clients.recv(sk)
212                     if received is None:
213                         log.debug("Terminal gone from fd %d", sk)
214                         tostop.append(sk)
215                     else:
216                         for imei, when, peeraddr, packet in received:
217                             proto = proto_of_message(packet)
218                             zpub.send(
219                                 Bcast(
220                                     proto=proto,
221                                     imei=imei,
222                                     when=when,
223                                     peeraddr=peeraddr,
224                                     packet=packet,
225                                 ).packed
226                             )
227                             if proto == HIBERNATION.PROTO and handle_hibernate:
228                                 log.debug(
229                                     "HIBERNATION from fd %d (IMEI %s)",
230                                     sk,
231                                     imei,
232                                 )
233                                 tostop.append(sk)
234                             respmsg = inline_response(packet)
235                             if respmsg is not None:
236                                 tosend.append(
237                                     Resp(imei=imei, when=when, packet=respmsg)
238                                 )
239                 else:
240                     log.debug("Stray event: %s on socket %s", fl, sk)
241             # poll queue consumed, make changes now
242             for fd in tostop:
243                 poller.unregister(fd)  # type: ignore
244                 clients.stop(fd)
245             for zmsg in tosend:
246                 zpub.send(
247                     Bcast(
248                         is_incoming=False,
249                         proto=proto_of_message(zmsg.packet),
250                         when=zmsg.when,
251                         imei=zmsg.imei,
252                         packet=zmsg.packet,
253                     ).packed
254                 )
255                 log.debug("Sending to the client: %s", zmsg)
256                 clients.response(zmsg)
257             for clntsock, clntaddr in topoll:
258                 fd = clients.add(clntsock, clntaddr)
259                 poller.register(fd, flags=zmq.POLLIN)
260     except KeyboardInterrupt:
261         zpub.close()
262         zpull.close()
263         zctx.destroy()  # type: ignore
264         tcpl.close()
265
266
267 if __name__.endswith("__main__"):
268     runserver(common.init(log))