]> www.average.org Git - loctrkd.git/blob - loctrkd/collector.py
collector: streamline tracking of polled fd-s
[loctrkd.git] / loctrkd / collector.py
1 """ TCP server that communicates with terminals """
2
3 from configparser import ConfigParser
4 from importlib import import_module
5 from logging import getLogger
6 from os import umask
7 from socket import (
8     socket,
9     AF_INET6,
10     SOCK_STREAM,
11     SOL_SOCKET,
12     SO_KEEPALIVE,
13     SO_REUSEADDR,
14 )
15 from struct import pack
16 from time import time
17 from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union
18 import zmq
19
20 from . import common
21 from .zmsg import Bcast, Resp
22
23 log = getLogger("loctrkd/collector")
24
25 MAXBUFFER: int = 4096
26
27
28 class ProtoModule:
29     class Stream:
30         def recv(self, segment: bytes) -> List[Union[bytes, str]]:
31             ...
32
33         def close(self) -> bytes:
34             ...
35
36     @staticmethod
37     def enframe(buffer: bytes, imei: Optional[str] = None) -> bytes:
38         ...
39
40     @staticmethod
41     def probe_buffer(buffer: bytes) -> bool:
42         ...
43
44     @staticmethod
45     def parse_message(packet: bytes, is_incoming: bool = True) -> Any:
46         ...
47
48     @staticmethod
49     def inline_response(packet: bytes) -> Optional[bytes]:
50         ...
51
52     @staticmethod
53     def is_goodbye_packet(packet: bytes) -> bool:
54         ...
55
56     @staticmethod
57     def imei_from_packet(packet: bytes) -> Optional[str]:
58         ...
59
60     @staticmethod
61     def proto_of_message(packet: bytes) -> str:
62         ...
63
64
65 pmods: List[ProtoModule] = []
66
67
68 class Client:
69     """Connected socket to the terminal plus buffer and metadata"""
70
71     def __init__(self, sock: socket, addr: Any) -> None:
72         self.sock = sock
73         self.addr = addr
74         self.pmod: Optional[ProtoModule] = None
75         self.stream: Optional[ProtoModule.Stream] = None
76         self.imei: Optional[str] = None
77
78     def close(self) -> None:
79         log.debug("Closing fd %d (IMEI %s)", self.sock.fileno(), self.imei)
80         self.sock.close()
81         if self.stream:
82             rest = self.stream.close()
83         else:
84             rest = b""
85         if rest:
86             log.warning(
87                 "%d bytes in buffer on close: %s", len(rest), rest[:64].hex()
88             )
89
90     def recv(self) -> Optional[List[Tuple[float, Any, bytes]]]:
91         """Read from the socket and parse complete messages"""
92         try:
93             segment = self.sock.recv(MAXBUFFER)
94         except OSError as e:
95             log.warning(
96                 "Reading from fd %d (IMEI %s): %s",
97                 self.sock.fileno(),
98                 self.imei,
99                 e,
100             )
101             return None
102         if not segment:  # Terminal has closed connection
103             log.info(
104                 "EOF reading from fd %d (IMEI %s)",
105                 self.sock.fileno(),
106                 self.imei,
107             )
108             return None
109         if self.stream is None:
110             for pmod in pmods:
111                 if pmod.probe_buffer(segment):
112                     self.pmod = pmod
113                     self.stream = pmod.Stream()
114                     break
115         if self.stream is None:
116             log.info(
117                 "unrecognizable %d bytes of data %s from fd %d",
118                 len(segment),
119                 segment[:32].hex(),
120                 self.sock.fileno(),
121             )
122             return []
123         when = time()
124         msgs = []
125         for elem in self.stream.recv(segment):
126             if isinstance(elem, bytes):
127                 msgs.append((when, self.addr, elem))
128             else:
129                 log.warning(
130                     "%s from fd %d (IMEI %s)",
131                     elem,
132                     self.sock.fileno(),
133                     self.imei,
134                 )
135         return msgs
136
137     def send(self, buffer: bytes) -> None:
138         assert self.stream is not None and self.pmod is not None
139         try:
140             self.sock.send(self.pmod.enframe(buffer, imei=self.imei))
141         except OSError as e:
142             log.error(
143                 "Sending to fd %d (IMEI %s): %s",
144                 self.sock.fileno(),
145                 self.imei,
146                 e,
147             )
148
149
150 class Clients:
151     def __init__(self) -> None:
152         self.by_fd: Dict[int, Client] = {}
153         self.by_imei: Dict[str, Client] = {}
154
155     def fds(self) -> Set[int]:
156         return set(self.by_fd.keys())
157
158     def add(self, clntsock: socket, clntaddr: Any) -> int:
159         fd = clntsock.fileno()
160         log.info("Start serving fd %d from %s", fd, clntaddr)
161         self.by_fd[fd] = Client(clntsock, clntaddr)
162         return fd
163
164     def stop(self, fd: int) -> None:
165         if fd not in self.by_fd:
166             log.debug("Fd %d is not served, ingore stop", fd)
167             return
168         clnt = self.by_fd[fd]
169         log.info("Stop serving fd %d (IMEI %s)", clnt.sock.fileno(), clnt.imei)
170         clnt.close()
171         if clnt.imei and self.by_imei[clnt.imei] == clnt:  # could be replaced
172             del self.by_imei[clnt.imei]
173         del self.by_fd[fd]
174
175     def recv(
176         self, fd: int
177     ) -> Optional[List[Tuple[ProtoModule, Optional[str], float, Any, bytes]]]:
178         if fd not in self.by_fd:
179             log.debug("Client at fd %d gone, ingore event", fd)
180             return None
181         clnt = self.by_fd[fd]
182         msgs = clnt.recv()
183         if msgs is None:
184             return None
185         result = []
186         for when, peeraddr, packet in msgs:
187             assert clnt.pmod is not None
188             if clnt.imei is None:
189                 imei = clnt.pmod.imei_from_packet(packet)
190                 if imei is not None:
191                     log.info("LOGIN from fd %d (IMEI %s)", fd, imei)
192                     clnt.imei = imei
193                     oldclnt = self.by_imei.get(clnt.imei)
194                     if oldclnt is not None:
195                         oldfd = oldclnt.sock.fileno()
196                         log.info("Removing stale connection on fd %d", oldfd)
197                         oldclnt.imei = None
198                         self.stop(oldfd)
199                     self.by_imei[clnt.imei] = clnt
200             result.append((clnt.pmod, clnt.imei, when, peeraddr, packet))
201             log.debug(
202                 "Received from %s (IMEI %s): %s",
203                 peeraddr,
204                 clnt.imei,
205                 packet.hex(),
206             )
207         return result
208
209     def response(self, resp: Resp) -> Optional[ProtoModule]:
210         if resp.imei in self.by_imei:
211             clnt = self.by_imei[resp.imei]
212             clnt.send(resp.packet)
213             return clnt.pmod
214         else:
215             log.info("Not connected (IMEI %s)", resp.imei)
216             return None
217
218
219 def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
220     global pmods
221     pmods = [
222         cast(ProtoModule, import_module("." + modnm, __package__))
223         for modnm in conf.get("collector", "protocols").split(",")
224     ]
225     # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
226     zctx = zmq.Context()  # type: ignore
227     zpub = zctx.socket(zmq.PUB)  # type: ignore
228     zpull = zctx.socket(zmq.PULL)  # type: ignore
229     oldmask = umask(0o117)
230     zpub.bind(conf.get("collector", "publishurl"))
231     zpull.bind(conf.get("collector", "listenurl"))
232     umask(oldmask)
233     tcpl = socket(AF_INET6, SOCK_STREAM)
234     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
235     tcpl.bind(("", conf.getint("collector", "port")))
236     tcpl.listen(5)
237     tcpfd = tcpl.fileno()
238     poller = zmq.Poller()  # type: ignore
239     poller.register(zpull, flags=zmq.POLLIN)
240     poller.register(tcpfd, flags=zmq.POLLIN)
241     clients = Clients()
242     pollingfds: Set[int] = set()
243     try:
244         while True:
245             tosend: List[Resp] = []
246             toadd: List[Tuple[socket, Any]] = []
247             events = poller.poll(1000)
248             for sk, fl in events:
249                 if sk is zpull:
250                     while True:
251                         try:
252                             msg = zpull.recv(zmq.NOBLOCK)
253                             zmsg = Resp(msg)
254                             tosend.append(zmsg)
255                         except zmq.Again:
256                             break
257                 elif sk == tcpfd:
258                     clntsock, clntaddr = tcpl.accept()
259                     clntsock.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1)
260                     toadd.append((clntsock, clntaddr))
261                 elif fl & zmq.POLLIN:
262                     received = clients.recv(sk)
263                     if received is None:
264                         log.debug("Terminal gone from fd %d", sk)
265                         clients.stop(sk)
266                     else:
267                         for pmod, imei, when, peeraddr, packet in received:
268                             proto = pmod.proto_of_message(packet)
269                             zpub.send(
270                                 Bcast(
271                                     proto=proto,
272                                     imei=imei,
273                                     when=when,
274                                     peeraddr=peeraddr,
275                                     packet=packet,
276                                 ).packed
277                             )
278                             if (
279                                 pmod.is_goodbye_packet(packet)
280                                 and handle_hibernate
281                             ):
282                                 log.debug(
283                                     "Goodbye from fd %d (IMEI %s)",
284                                     sk,
285                                     imei,
286                                 )
287                                 clients.stop(sk)
288                             respmsg = pmod.inline_response(packet)
289                             if respmsg is not None:
290                                 tosend.append(
291                                     Resp(imei=imei, when=when, packet=respmsg)
292                                 )
293                 else:
294                     log.debug("Stray event: %s on socket %s", fl, sk)
295             # poll queue consumed, make changes now
296             for zmsg in tosend:
297                 log.debug("Sending to the client: %s", zmsg)
298                 rpmod = clients.response(zmsg)
299                 if rpmod is not None:
300                     zpub.send(
301                         Bcast(
302                             is_incoming=False,
303                             proto=rpmod.proto_of_message(zmsg.packet),
304                             when=zmsg.when,
305                             imei=zmsg.imei,
306                             packet=zmsg.packet,
307                         ).packed
308                     )
309             for fd in pollingfds - clients.fds():
310                 poller.unregister(fd)  # type: ignore
311             for clntsock, clntaddr in toadd:
312                 fd = clients.add(clntsock, clntaddr)
313             for fd in clients.fds() - pollingfds:
314                 poller.register(fd, flags=zmq.POLLIN)
315             pollingfds = clients.fds()
316     except KeyboardInterrupt:
317         zpub.close()
318         zpull.close()
319         zctx.destroy()  # type: ignore
320         tcpl.close()
321
322
323 if __name__.endswith("__main__"):
324     runserver(common.init(log))