]> www.average.org Git - loctrkd.git/blob - loctrkd/collector.py
17c98d5d98326af76dad80c2dc760bb3fb19c4ce
[loctrkd.git] / loctrkd / collector.py
1 """ TCP server that communicates with terminals """
2
3 from configparser import ConfigParser
4 from importlib import import_module
5 from logging import getLogger
6 from os import umask
7 from socket import (
8     socket,
9     AF_INET6,
10     SOCK_STREAM,
11     SOL_SOCKET,
12     SO_KEEPALIVE,
13     SO_REUSEADDR,
14 )
15 from struct import pack
16 from time import time
17 from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union
18 import zmq
19
20 from . import common
21 from .protomodule import ProtoModule
22 from .zmsg import Bcast, Resp
23
24 log = getLogger("loctrkd/collector")
25
26 MAXBUFFER: int = 4096
27
28
29 pmods: List[ProtoModule] = []
30
31
32 class Client:
33     """Connected socket to the terminal plus buffer and metadata"""
34
35     def __init__(self, sock: socket, addr: Any) -> None:
36         self.sock = sock
37         self.addr = addr
38         self.pmod: Optional[ProtoModule] = None
39         self.stream: Optional[ProtoModule.Stream] = None
40         self.imei: Optional[str] = None
41
42     def close(self) -> None:
43         log.debug("Closing fd %d (IMEI %s)", self.sock.fileno(), self.imei)
44         self.sock.close()
45         if self.stream:
46             rest = self.stream.close()
47         else:
48             rest = b""
49         if rest:
50             log.warning(
51                 "%d bytes in buffer on close: %s", len(rest), rest[:64].hex()
52             )
53
54     def recv(self) -> Optional[List[Tuple[float, Any, bytes]]]:
55         """Read from the socket and parse complete messages"""
56         try:
57             segment = self.sock.recv(MAXBUFFER)
58         except OSError as e:
59             log.warning(
60                 "Reading from fd %d (IMEI %s): %s",
61                 self.sock.fileno(),
62                 self.imei,
63                 e,
64             )
65             return None
66         if not segment:  # Terminal has closed connection
67             log.info(
68                 "EOF reading from fd %d (IMEI %s)",
69                 self.sock.fileno(),
70                 self.imei,
71             )
72             return None
73         if self.stream is None:
74             for pmod in pmods:
75                 if pmod.probe_buffer(segment):
76                     self.pmod = pmod
77                     self.stream = pmod.Stream()
78                     break
79         if self.stream is None:
80             log.info(
81                 "unrecognizable %d bytes of data %s from fd %d",
82                 len(segment),
83                 segment[:32].hex(),
84                 self.sock.fileno(),
85             )
86             return []
87         when = time()
88         msgs = []
89         for elem in self.stream.recv(segment):
90             if isinstance(elem, bytes):
91                 msgs.append((when, self.addr, elem))
92             else:
93                 log.warning(
94                     "%s from fd %d (IMEI %s)",
95                     elem,
96                     self.sock.fileno(),
97                     self.imei,
98                 )
99         return msgs
100
101     def send(self, buffer: bytes) -> None:
102         assert self.stream is not None and self.pmod is not None
103         try:
104             self.sock.send(self.pmod.enframe(buffer, imei=self.imei))
105         except OSError as e:
106             log.error(
107                 "Sending to fd %d (IMEI %s): %s",
108                 self.sock.fileno(),
109                 self.imei,
110                 e,
111             )
112
113
114 class Clients:
115     def __init__(self) -> None:
116         self.by_fd: Dict[int, Client] = {}
117         self.by_imei: Dict[str, Client] = {}
118
119     def fds(self) -> Set[int]:
120         return set(self.by_fd.keys())
121
122     def add(self, clntsock: socket, clntaddr: Any) -> int:
123         fd = clntsock.fileno()
124         log.info("Start serving fd %d from %s", fd, clntaddr)
125         self.by_fd[fd] = Client(clntsock, clntaddr)
126         return fd
127
128     def stop(self, fd: int) -> None:
129         if fd not in self.by_fd:
130             log.debug("Fd %d is not served, ingore stop", fd)
131             return
132         clnt = self.by_fd[fd]
133         log.info("Stop serving fd %d (IMEI %s)", clnt.sock.fileno(), clnt.imei)
134         clnt.close()
135         if clnt.imei and self.by_imei[clnt.imei] == clnt:  # could be replaced
136             del self.by_imei[clnt.imei]
137         del self.by_fd[fd]
138
139     def recv(
140         self, fd: int
141     ) -> Optional[List[Tuple[ProtoModule, Optional[str], float, Any, bytes]]]:
142         if fd not in self.by_fd:
143             log.debug("Client at fd %d gone, ingore event", fd)
144             return None
145         clnt = self.by_fd[fd]
146         msgs = clnt.recv()
147         if msgs is None:
148             return None
149         result = []
150         for when, peeraddr, packet in msgs:
151             assert clnt.pmod is not None
152             if clnt.imei is None:
153                 imei = clnt.pmod.imei_from_packet(packet)
154                 if imei is not None:
155                     log.info("LOGIN from fd %d (IMEI %s)", fd, imei)
156                     clnt.imei = imei
157                     oldclnt = self.by_imei.get(clnt.imei)
158                     if oldclnt is not None:
159                         oldfd = oldclnt.sock.fileno()
160                         log.info("Removing stale connection on fd %d", oldfd)
161                         oldclnt.imei = None
162                         self.stop(oldfd)
163                     self.by_imei[clnt.imei] = clnt
164             result.append((clnt.pmod, clnt.imei, when, peeraddr, packet))
165             log.debug(
166                 "Received from %s (IMEI %s): %s",
167                 peeraddr,
168                 clnt.imei,
169                 packet.hex(),
170             )
171         return result
172
173     def response(self, resp: Resp) -> Optional[ProtoModule]:
174         if resp.imei in self.by_imei:
175             clnt = self.by_imei[resp.imei]
176             clnt.send(resp.packet)
177             return clnt.pmod
178         else:
179             log.info("Not connected (IMEI %s)", resp.imei)
180             return None
181
182
183 def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
184     global pmods
185     pmods = [
186         cast(ProtoModule, import_module("." + modnm, __package__))
187         for modnm in conf.get("common", "protocols").split(",")
188     ]
189     # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
190     zctx = zmq.Context()  # type: ignore
191     zpub = zctx.socket(zmq.PUB)  # type: ignore
192     zpull = zctx.socket(zmq.PULL)  # type: ignore
193     oldmask = umask(0o117)
194     zpub.bind(conf.get("collector", "publishurl"))
195     zpull.bind(conf.get("collector", "listenurl"))
196     umask(oldmask)
197     tcpl = socket(AF_INET6, SOCK_STREAM)
198     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
199     tcpl.bind(("", conf.getint("collector", "port")))
200     tcpl.listen(5)
201     tcpfd = tcpl.fileno()
202     poller = zmq.Poller()  # type: ignore
203     poller.register(zpull, flags=zmq.POLLIN)
204     poller.register(tcpfd, flags=zmq.POLLIN)
205     clients = Clients()
206     pollingfds: Set[int] = set()
207     try:
208         while True:
209             tosend: List[Resp] = []
210             toadd: List[Tuple[socket, Any]] = []
211             events = poller.poll(1000)
212             for sk, fl in events:
213                 if sk is zpull:
214                     while True:
215                         try:
216                             msg = zpull.recv(zmq.NOBLOCK)
217                             zmsg = Resp(msg)
218                             tosend.append(zmsg)
219                         except zmq.Again:
220                             break
221                 elif sk == tcpfd:
222                     clntsock, clntaddr = tcpl.accept()
223                     clntsock.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1)
224                     toadd.append((clntsock, clntaddr))
225                 elif fl & zmq.POLLIN:
226                     received = clients.recv(sk)
227                     if received is None:
228                         log.debug("Terminal gone from fd %d", sk)
229                         clients.stop(sk)
230                     else:
231                         for pmod, imei, when, peeraddr, packet in received:
232                             proto = pmod.proto_of_message(packet)
233                             zpub.send(
234                                 Bcast(
235                                     proto=proto,
236                                     imei=imei,
237                                     when=when,
238                                     peeraddr=peeraddr,
239                                     packet=packet,
240                                 ).packed
241                             )
242                             if (
243                                 pmod.is_goodbye_packet(packet)
244                                 and handle_hibernate
245                             ):
246                                 log.debug(
247                                     "Goodbye from fd %d (IMEI %s)",
248                                     sk,
249                                     imei,
250                                 )
251                                 clients.stop(sk)
252                             respmsg = pmod.inline_response(packet)
253                             if respmsg is not None:
254                                 tosend.append(
255                                     Resp(imei=imei, when=when, packet=respmsg)
256                                 )
257                 else:
258                     log.debug("Stray event: %s on socket %s", fl, sk)
259             # poll queue consumed, make changes now
260             for zmsg in tosend:
261                 log.debug("Sending to the client: %s", zmsg)
262                 rpmod = clients.response(zmsg)
263                 if rpmod is not None:
264                     zpub.send(
265                         Bcast(
266                             is_incoming=False,
267                             proto=rpmod.proto_of_message(zmsg.packet),
268                             when=zmsg.when,
269                             imei=zmsg.imei,
270                             packet=zmsg.packet,
271                         ).packed
272                     )
273             for fd in pollingfds - clients.fds():
274                 poller.unregister(fd)  # type: ignore
275             for clntsock, clntaddr in toadd:
276                 fd = clients.add(clntsock, clntaddr)
277             for fd in clients.fds() - pollingfds:
278                 poller.register(fd, flags=zmq.POLLIN)
279             pollingfds = clients.fds()
280     except KeyboardInterrupt:
281         zpub.close()
282         zpull.close()
283         zctx.destroy()  # type: ignore
284         tcpl.close()
285
286
287 if __name__.endswith("__main__"):
288     runserver(common.init(log))