]> www.average.org Git - loctrkd.git/blob - gps303/wsgateway.py
typeckecking: annotate wsgateway.py
[loctrkd.git] / gps303 / wsgateway.py
1 """ Websocket Gateway """
2
3 from configparser import ConfigParser
4 from datetime import datetime, timezone
5 from json import dumps, loads
6 from logging import getLogger
7 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
8 from time import time
9 from typing import Any, cast, Dict, List, Optional, Set, Tuple
10 from wsproto import ConnectionType, WSConnection
11 from wsproto.events import (
12     AcceptConnection,
13     CloseConnection,
14     Event,
15     Message,
16     Ping,
17     Request,
18     TextMessage,
19 )
20 from wsproto.utilities import RemoteProtocolError
21 import zmq
22
23 from . import common
24 from .evstore import initdb, fetch
25 from .gps303proto import (
26     GPS_POSITIONING,
27     STATUS,
28     WIFI_POSITIONING,
29     parse_message,
30 )
31 from .zmsg import Bcast, topic
32
33 log = getLogger("gps303/wsgateway")
34 htmlfile = None
35
36
37 def backlog(imei: str, numback: int) -> List[Dict[str, Any]]:
38     result = []
39     for is_incoming, timestamp, packet in fetch(
40         imei,
41         [(True, GPS_POSITIONING.PROTO), (False, WIFI_POSITIONING.PROTO)],
42         numback,
43     ):
44         msg = parse_message(packet, is_incoming=is_incoming)
45         result.append(
46             {
47                 "type": "location",
48                 "imei": imei,
49                 "timestamp": str(
50                     datetime.fromtimestamp(timestamp).astimezone(
51                         tz=timezone.utc
52                     )
53                 ),
54                 "longitude": msg.longitude,
55                 "latitude": msg.latitude,
56                 "accuracy": "gps"
57                 if isinstance(msg, GPS_POSITIONING)
58                 else "approximate",
59             }
60         )
61     return result
62
63
64 def try_http(data: bytes, fd: int, e: Exception) -> bytes:
65     global htmlfile
66     try:
67         lines = data.decode().split("\r\n")
68         request = lines[0]
69         headers = lines[1:]
70         op, resource, proto = request.split(" ")
71         log.debug(
72             "HTTP %s for %s, proto %s from fd %d, headers: %s",
73             op,
74             resource,
75             proto,
76             fd,
77             headers,
78         )
79         if op == "GET":
80             if htmlfile is None:
81                 return (
82                     f"{proto} 500 No data configured\r\n"
83                     f"Content-Type: text/plain\r\n\r\n"
84                     f"HTML data not configured on the server\r\n".encode()
85                 )
86             else:
87                 try:
88                     with open(htmlfile, "rb") as fl:
89                         htmldata = fl.read()
90                     length = len(htmldata)
91                     return (
92                         f"{proto} 200 Ok\r\n"
93                         f"Content-Type: text/html; charset=utf-8\r\n"
94                         f"Content-Length: {len(htmldata):d}\r\n\r\n"
95                     ).encode("utf-8") + htmldata
96                 except OSError:
97                     return (
98                         f"{proto} 500 File not found\r\n"
99                         f"Content-Type: text/plain\r\n\r\n"
100                         f"HTML file could not be opened\r\n".encode()
101                     )
102         else:
103             return (
104                 f"{proto} 400 Bad request\r\n"
105                 "Content-Type: text/plain\r\n\r\n"
106                 "Bad request\r\n".encode()
107             )
108     except ValueError:
109         log.warning("Unparseable data from fd %d: %s", fd, data)
110         raise e
111
112
113 class Client:
114     """Websocket connection to the client"""
115
116     def __init__(self, sock: socket, addr: Tuple[str, int]) -> None:
117         self.sock = sock
118         self.addr = addr
119         self.ws = WSConnection(ConnectionType.SERVER)
120         self.ws_data = b""
121         self.ready = False
122         self.imeis: Set[str] = set()
123
124     def close(self) -> None:
125         log.debug("Closing fd %d", self.sock.fileno())
126         self.sock.close()
127
128     def recv(self) -> Optional[List[Dict[str, Any]]]:
129         try:
130             data = self.sock.recv(4096)
131         except OSError as e:
132             log.warning(
133                 "Reading from fd %d: %s",
134                 self.sock.fileno(),
135                 e,
136             )
137             self.ws.receive_data(None)
138             return None
139         if not data:  # Client has closed connection
140             log.info(
141                 "EOF reading from fd %d",
142                 self.sock.fileno(),
143             )
144             self.ws.receive_data(None)
145             return None
146         try:
147             self.ws.receive_data(data)
148         except RemoteProtocolError as e:
149             log.debug(
150                 "Websocket error on fd %d, try plain http (%s)",
151                 self.sock.fileno(),
152                 e,
153             )
154             self.ws_data = try_http(data, self.sock.fileno(), e)
155             self.write()  # TODO this is a hack
156             log.debug("Sending HTTP response to %d", self.sock.fileno())
157             msgs = None
158         else:
159             msgs = []
160             for event in self.ws.events():
161                 if isinstance(event, Request):
162                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
163                     # self.ws_data += self.ws.send(event.response())  # Why not?!
164                     self.ws_data += self.ws.send(AcceptConnection())
165                     self.ready = True
166                 elif isinstance(event, (CloseConnection, Ping)):
167                     log.debug("%s on fd %d", event, self.sock.fileno())
168                     self.ws_data += self.ws.send(event.response())
169                 elif isinstance(event, TextMessage):
170                     log.debug("%s on fd %d", event, self.sock.fileno())
171                     msg = loads(event.data)
172                     msgs.append(msg)
173                     if msg.get("type", None) == "subscribe":
174                         self.imeis = set(msg.get("imei", []))
175                         log.debug(
176                             "subs list on fd %s is %s",
177                             self.sock.fileno(),
178                             self.imeis,
179                         )
180                 else:
181                     log.warning("%s on fd %d", event, self.sock.fileno())
182         return msgs
183
184     def wants(self, imei: str) -> bool:
185         log.debug(
186             "wants %s? set is %s on fd %d",
187             imei,
188             self.imeis,
189             self.sock.fileno(),
190         )
191         return imei in self.imeis
192
193     def send(self, message: Dict[str, Any]) -> None:
194         if self.ready and message["imei"] in self.imeis:
195             self.ws_data += self.ws.send(Message(data=dumps(message)))
196
197     def write(self) -> bool:
198         if self.ws_data:
199             try:
200                 sent = self.sock.send(self.ws_data)
201                 self.ws_data = self.ws_data[sent:]
202             except OSError as e:
203                 log.error(
204                     "Sending to fd %d: %s",
205                     self.sock.fileno(),
206                     e,
207                 )
208                 self.ws_data = b""
209         return bool(self.ws_data)
210
211
212 class Clients:
213     def __init__(self) -> None:
214         self.by_fd: Dict[int, Client] = {}
215
216     def add(self, clntsock: socket, clntaddr: Tuple[str, int]) -> int:
217         fd = clntsock.fileno()
218         log.info("Start serving fd %d from %s", fd, clntaddr)
219         self.by_fd[fd] = Client(clntsock, clntaddr)
220         return fd
221
222     def stop(self, fd: int) -> None:
223         clnt = self.by_fd[fd]
224         log.info("Stop serving fd %d", clnt.sock.fileno())
225         clnt.close()
226         del self.by_fd[fd]
227
228     def recv(self, fd: int) -> Optional[List[Dict[str, Any]]]:
229         clnt = self.by_fd[fd]
230         return clnt.recv()
231
232     def send(self, msg: Dict[str, Any]) -> Set[int]:
233         towrite = set()
234         for fd, clnt in self.by_fd.items():
235             if clnt.wants(msg["imei"]):
236                 clnt.send(msg)
237                 towrite.add(fd)
238         return towrite
239
240     def write(self, towrite: Set[int]) -> Set[int]:
241         waiting = set()
242         for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]:
243             if clnt and clnt.write():
244                 waiting.add(fd)
245         return waiting
246
247     def subs(self) -> Set[str]:
248         result = set()
249         for clnt in self.by_fd.values():
250             result |= clnt.imeis
251         return result
252
253
254 def runserver(conf: ConfigParser) -> None:
255     global htmlfile
256
257     initdb(conf.get("storage", "dbfn"))
258     htmlfile = conf.get("wsgateway", "htmlfile")
259     # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
260     zctx = zmq.Context()  # type: ignore
261     zsub = zctx.socket(zmq.SUB)  # type: ignore
262     zsub.connect(conf.get("collector", "publishurl"))
263     tcpl = socket(AF_INET6, SOCK_STREAM)
264     tcpl.setblocking(False)
265     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
266     tcpl.bind(("", conf.getint("wsgateway", "port")))
267     tcpl.listen(5)
268     tcpfd = tcpl.fileno()
269     poller = zmq.Poller()  # type: ignore
270     poller.register(zsub, flags=zmq.POLLIN)
271     poller.register(tcpfd, flags=zmq.POLLIN)
272     clients = Clients()
273     activesubs: Set[str] = set()
274     try:
275         towait: Set[int] = set()
276         while True:
277             neededsubs = clients.subs()
278             for imei in neededsubs - activesubs:
279                 zsub.setsockopt(
280                     zmq.SUBSCRIBE,
281                     topic(GPS_POSITIONING.PROTO, True, imei),
282                 )
283                 zsub.setsockopt(
284                     zmq.SUBSCRIBE,
285                     topic(WIFI_POSITIONING.PROTO, False, imei),
286                 )
287                 zsub.setsockopt(
288                     zmq.SUBSCRIBE,
289                     topic(STATUS.PROTO, True, imei),
290                 )
291             for imei in activesubs - neededsubs:
292                 zsub.setsockopt(
293                     zmq.UNSUBSCRIBE,
294                     topic(GPS_POSITIONING.PROTO, True, imei),
295                 )
296                 zsub.setsockopt(
297                     zmq.UNSUBSCRIBE,
298                     topic(WIFI_POSITIONING.PROTO, False, imei),
299                 )
300                 zsub.setsockopt(
301                     zmq.UNSUBSCRIBE,
302                     topic(STATUS.PROTO, True, imei),
303                 )
304             activesubs = neededsubs
305             log.debug("Subscribed to: %s", activesubs)
306             tosend = []
307             topoll = []
308             tostop = []
309             towrite = set()
310             events = poller.poll()
311             for sk, fl in events:
312                 if sk is zsub:
313                     while True:
314                         try:
315                             zmsg = Bcast(zsub.recv(zmq.NOBLOCK))
316                             msg = parse_message(zmsg.packet, zmsg.is_incoming)
317                             log.debug("Got %s with %s", zmsg, msg)
318                             if isinstance(msg, STATUS):
319                                 tosend.append(
320                                     {
321                                         "type": "status",
322                                         "imei": zmsg.imei,
323                                         "timestamp": str(
324                                             datetime.fromtimestamp(
325                                                 zmsg.when
326                                             ).astimezone(tz=timezone.utc)
327                                         ),
328                                         "battery": msg.batt,
329                                     }
330                                 )
331                             else:
332                                 tosend.append(
333                                     {
334                                         "type": "location",
335                                         "imei": zmsg.imei,
336                                         "timestamp": str(
337                                             datetime.fromtimestamp(
338                                                 zmsg.when
339                                             ).astimezone(tz=timezone.utc)
340                                         ),
341                                         "longitude": msg.longitude,
342                                         "latitude": msg.latitude,
343                                         "accuracy": "gps"
344                                         if zmsg.is_incoming
345                                         else "approximate",
346                                     }
347                                 )
348                         except zmq.Again:
349                             break
350                 elif sk == tcpfd:
351                     clntsock, clntaddr = tcpl.accept()
352                     topoll.append((clntsock, clntaddr))
353                 elif fl & zmq.POLLIN:
354                     received = clients.recv(sk)
355                     if received is None:
356                         log.debug("Client gone from fd %d", sk)
357                         tostop.append(sk)
358                         towait.discard(sk)
359                     else:
360                         for wsmsg in received:
361                             log.debug("Received from %d: %s", sk, wsmsg)
362                             if wsmsg.get("type", None) == "subscribe":
363                                 # Have to live w/o typeckeding from json
364                                 imeis = cast(List[str], wsmsg.get("imei"))
365                                 numback: int = wsmsg.get("backlog", 5)
366                                 for imei in imeis:
367                                     tosend.extend(backlog(imei, numback))
368                         towrite.add(sk)
369                 elif fl & zmq.POLLOUT:
370                     log.debug("Write now open for fd %d", sk)
371                     towrite.add(sk)
372                     towait.discard(sk)
373                 else:
374                     log.debug("Stray event: %s on socket %s", fl, sk)
375             # poll queue consumed, make changes now
376             for fd in tostop:
377                 poller.unregister(fd)  # type: ignore
378                 clients.stop(fd)
379             for wsmsg in tosend:
380                 log.debug("Sending to the clients: %s", wsmsg)
381                 towrite |= clients.send(wsmsg)
382             for clntsock, clntaddr in topoll:
383                 fd = clients.add(clntsock, clntaddr)
384                 poller.register(fd, flags=zmq.POLLIN)
385             # Deal with actually writing the data out
386             trywrite = towrite - towait
387             morewait = clients.write(trywrite)
388             log.debug(
389                 "towait %s, tried %s, still busy %s",
390                 towait,
391                 trywrite,
392                 morewait,
393             )
394             for fd in morewait - trywrite:  # new fds waiting for write
395                 poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)  # type: ignore
396             for fd in trywrite - morewait:  # no longer waiting for write
397                 poller.modify(fd, flags=zmq.POLLIN)  # type: ignore
398             towait &= trywrite
399             towait |= morewait
400     except KeyboardInterrupt:
401         pass
402
403
404 if __name__.endswith("__main__"):
405     runserver(common.init(log))