]> www.average.org Git - loctrkd.git/blob - gps303/wsgateway.py
wsgateway: reclassify http write hack as permanent
[loctrkd.git] / gps303 / wsgateway.py
1 """ Websocket Gateway """
2
3 from configparser import ConfigParser
4 from datetime import datetime, timezone
5 from json import dumps, loads
6 from logging import getLogger
7 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
8 from time import time
9 from typing import Any, cast, Dict, List, Optional, Set, Tuple
10 from wsproto import ConnectionType, WSConnection
11 from wsproto.events import (
12     AcceptConnection,
13     CloseConnection,
14     Event,
15     Message,
16     Ping,
17     Request,
18     TextMessage,
19 )
20 from wsproto.utilities import RemoteProtocolError
21 import zmq
22
23 from . import common
24 from .evstore import initdb, fetch
25 from .gps303proto import (
26     GPS_POSITIONING,
27     STATUS,
28     WIFI_POSITIONING,
29     parse_message,
30 )
31 from .zmsg import Bcast, topic
32
33 log = getLogger("gps303/wsgateway")
34 htmlfile = None
35
36
37 def backlog(imei: str, numback: int) -> List[Dict[str, Any]]:
38     result = []
39     for is_incoming, timestamp, packet in fetch(
40         imei,
41         [(True, GPS_POSITIONING.PROTO), (False, WIFI_POSITIONING.PROTO)],
42         numback,
43     ):
44         msg = parse_message(packet, is_incoming=is_incoming)
45         result.append(
46             {
47                 "type": "location",
48                 "imei": imei,
49                 "timestamp": str(
50                     datetime.fromtimestamp(timestamp).astimezone(
51                         tz=timezone.utc
52                     )
53                 ),
54                 "longitude": msg.longitude,
55                 "latitude": msg.latitude,
56                 "accuracy": "gps"
57                 if isinstance(msg, GPS_POSITIONING)
58                 else "approximate",
59             }
60         )
61     return result
62
63
64 def try_http(data: bytes, fd: int, e: Exception) -> bytes:
65     global htmlfile
66     try:
67         lines = data.decode().split("\r\n")
68         request = lines[0]
69         headers = lines[1:]
70         op, resource, proto = request.split(" ")
71         log.debug(
72             "HTTP %s for %s, proto %s from fd %d, headers: %s",
73             op,
74             resource,
75             proto,
76             fd,
77             headers,
78         )
79         if op == "GET":
80             if htmlfile is None:
81                 return (
82                     f"{proto} 500 No data configured\r\n"
83                     f"Content-Type: text/plain\r\n\r\n"
84                     f"HTML data not configured on the server\r\n".encode()
85                 )
86             else:
87                 try:
88                     with open(htmlfile, "rb") as fl:
89                         htmldata = fl.read()
90                     length = len(htmldata)
91                     return (
92                         f"{proto} 200 Ok\r\n"
93                         f"Content-Type: text/html; charset=utf-8\r\n"
94                         f"Content-Length: {len(htmldata):d}\r\n\r\n"
95                     ).encode("utf-8") + htmldata
96                 except OSError:
97                     return (
98                         f"{proto} 500 File not found\r\n"
99                         f"Content-Type: text/plain\r\n\r\n"
100                         f"HTML file could not be opened\r\n".encode()
101                     )
102         else:
103             return (
104                 f"{proto} 400 Bad request\r\n"
105                 "Content-Type: text/plain\r\n\r\n"
106                 "Bad request\r\n".encode()
107             )
108     except ValueError:
109         log.warning("Unparseable data from fd %d: %s", fd, data)
110         raise e
111
112
113 class Client:
114     """Websocket connection to the client"""
115
116     def __init__(self, sock: socket, addr: Tuple[str, int]) -> None:
117         self.sock = sock
118         self.addr = addr
119         self.ws = WSConnection(ConnectionType.SERVER)
120         self.ws_data = b""
121         self.ready = False
122         self.imeis: Set[str] = set()
123
124     def close(self) -> None:
125         log.debug("Closing fd %d", self.sock.fileno())
126         self.sock.close()
127
128     def recv(self) -> Optional[List[Dict[str, Any]]]:
129         try:
130             data = self.sock.recv(4096)
131         except OSError as e:
132             log.warning(
133                 "Reading from fd %d: %s",
134                 self.sock.fileno(),
135                 e,
136             )
137             self.ws.receive_data(None)
138             return None
139         if not data:  # Client has closed connection
140             log.info(
141                 "EOF reading from fd %d",
142                 self.sock.fileno(),
143             )
144             self.ws.receive_data(None)
145             return None
146         try:
147             self.ws.receive_data(data)
148         except RemoteProtocolError as e:
149             log.debug(
150                 "Websocket error on fd %d, try plain http (%s)",
151                 self.sock.fileno(),
152                 e,
153             )
154             self.ws_data = try_http(data, self.sock.fileno(), e)
155             # this `write` is a hack - writing _ought_ to be done at the
156             # stage when all other writes are performed. But I could not
157             # arrange it so in a logical way. Let it stay this way. The
158             # whole http server affair is a hack anyway.
159             self.write()
160             log.debug("Sending HTTP response to %d", self.sock.fileno())
161             msgs = None
162         else:
163             msgs = []
164             for event in self.ws.events():
165                 if isinstance(event, Request):
166                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
167                     # self.ws_data += self.ws.send(event.response())  # Why not?!
168                     self.ws_data += self.ws.send(AcceptConnection())
169                     self.ready = True
170                 elif isinstance(event, (CloseConnection, Ping)):
171                     log.debug("%s on fd %d", event, self.sock.fileno())
172                     self.ws_data += self.ws.send(event.response())
173                 elif isinstance(event, TextMessage):
174                     log.debug("%s on fd %d", event, self.sock.fileno())
175                     msg = loads(event.data)
176                     msgs.append(msg)
177                     if msg.get("type", None) == "subscribe":
178                         self.imeis = set(msg.get("imei", []))
179                         log.debug(
180                             "subs list on fd %s is %s",
181                             self.sock.fileno(),
182                             self.imeis,
183                         )
184                 else:
185                     log.warning("%s on fd %d", event, self.sock.fileno())
186         return msgs
187
188     def wants(self, imei: str) -> bool:
189         log.debug(
190             "wants %s? set is %s on fd %d",
191             imei,
192             self.imeis,
193             self.sock.fileno(),
194         )
195         return imei in self.imeis
196
197     def send(self, message: Dict[str, Any]) -> None:
198         if self.ready and message["imei"] in self.imeis:
199             self.ws_data += self.ws.send(Message(data=dumps(message)))
200
201     def write(self) -> bool:
202         if self.ws_data:
203             try:
204                 sent = self.sock.send(self.ws_data)
205                 self.ws_data = self.ws_data[sent:]
206             except OSError as e:
207                 log.error(
208                     "Sending to fd %d: %s",
209                     self.sock.fileno(),
210                     e,
211                 )
212                 self.ws_data = b""
213         return bool(self.ws_data)
214
215
216 class Clients:
217     def __init__(self) -> None:
218         self.by_fd: Dict[int, Client] = {}
219
220     def add(self, clntsock: socket, clntaddr: Tuple[str, int]) -> int:
221         fd = clntsock.fileno()
222         log.info("Start serving fd %d from %s", fd, clntaddr)
223         self.by_fd[fd] = Client(clntsock, clntaddr)
224         return fd
225
226     def stop(self, fd: int) -> None:
227         clnt = self.by_fd[fd]
228         log.info("Stop serving fd %d", clnt.sock.fileno())
229         clnt.close()
230         del self.by_fd[fd]
231
232     def recv(self, fd: int) -> Optional[List[Dict[str, Any]]]:
233         clnt = self.by_fd[fd]
234         return clnt.recv()
235
236     def send(self, msg: Dict[str, Any]) -> Set[int]:
237         towrite = set()
238         for fd, clnt in self.by_fd.items():
239             if clnt.wants(msg["imei"]):
240                 clnt.send(msg)
241                 towrite.add(fd)
242         return towrite
243
244     def write(self, towrite: Set[int]) -> Set[int]:
245         waiting = set()
246         for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]:
247             if clnt and clnt.write():
248                 waiting.add(fd)
249         return waiting
250
251     def subs(self) -> Set[str]:
252         result = set()
253         for clnt in self.by_fd.values():
254             result |= clnt.imeis
255         return result
256
257
258 def runserver(conf: ConfigParser) -> None:
259     global htmlfile
260
261     initdb(conf.get("storage", "dbfn"))
262     htmlfile = conf.get("wsgateway", "htmlfile")
263     # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
264     zctx = zmq.Context()  # type: ignore
265     zsub = zctx.socket(zmq.SUB)  # type: ignore
266     zsub.connect(conf.get("collector", "publishurl"))
267     tcpl = socket(AF_INET6, SOCK_STREAM)
268     tcpl.setblocking(False)
269     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
270     tcpl.bind(("", conf.getint("wsgateway", "port")))
271     tcpl.listen(5)
272     tcpfd = tcpl.fileno()
273     poller = zmq.Poller()  # type: ignore
274     poller.register(zsub, flags=zmq.POLLIN)
275     poller.register(tcpfd, flags=zmq.POLLIN)
276     clients = Clients()
277     activesubs: Set[str] = set()
278     try:
279         towait: Set[int] = set()
280         while True:
281             neededsubs = clients.subs()
282             for imei in neededsubs - activesubs:
283                 zsub.setsockopt(
284                     zmq.SUBSCRIBE,
285                     topic(GPS_POSITIONING.PROTO, True, imei),
286                 )
287                 zsub.setsockopt(
288                     zmq.SUBSCRIBE,
289                     topic(WIFI_POSITIONING.PROTO, False, imei),
290                 )
291                 zsub.setsockopt(
292                     zmq.SUBSCRIBE,
293                     topic(STATUS.PROTO, True, imei),
294                 )
295             for imei in activesubs - neededsubs:
296                 zsub.setsockopt(
297                     zmq.UNSUBSCRIBE,
298                     topic(GPS_POSITIONING.PROTO, True, imei),
299                 )
300                 zsub.setsockopt(
301                     zmq.UNSUBSCRIBE,
302                     topic(WIFI_POSITIONING.PROTO, False, imei),
303                 )
304                 zsub.setsockopt(
305                     zmq.UNSUBSCRIBE,
306                     topic(STATUS.PROTO, True, imei),
307                 )
308             activesubs = neededsubs
309             log.debug("Subscribed to: %s", activesubs)
310             tosend = []
311             topoll = []
312             tostop = []
313             towrite = set()
314             events = poller.poll()
315             for sk, fl in events:
316                 if sk is zsub:
317                     while True:
318                         try:
319                             zmsg = Bcast(zsub.recv(zmq.NOBLOCK))
320                             msg = parse_message(zmsg.packet, zmsg.is_incoming)
321                             log.debug("Got %s with %s", zmsg, msg)
322                             if isinstance(msg, STATUS):
323                                 tosend.append(
324                                     {
325                                         "type": "status",
326                                         "imei": zmsg.imei,
327                                         "timestamp": str(
328                                             datetime.fromtimestamp(
329                                                 zmsg.when
330                                             ).astimezone(tz=timezone.utc)
331                                         ),
332                                         "battery": msg.batt,
333                                     }
334                                 )
335                             else:
336                                 tosend.append(
337                                     {
338                                         "type": "location",
339                                         "imei": zmsg.imei,
340                                         "timestamp": str(
341                                             datetime.fromtimestamp(
342                                                 zmsg.when
343                                             ).astimezone(tz=timezone.utc)
344                                         ),
345                                         "longitude": msg.longitude,
346                                         "latitude": msg.latitude,
347                                         "accuracy": "gps"
348                                         if zmsg.is_incoming
349                                         else "approximate",
350                                     }
351                                 )
352                         except zmq.Again:
353                             break
354                 elif sk == tcpfd:
355                     clntsock, clntaddr = tcpl.accept()
356                     topoll.append((clntsock, clntaddr))
357                 elif fl & zmq.POLLIN:
358                     received = clients.recv(sk)
359                     if received is None:
360                         log.debug("Client gone from fd %d", sk)
361                         tostop.append(sk)
362                         towait.discard(sk)
363                     else:
364                         for wsmsg in received:
365                             log.debug("Received from %d: %s", sk, wsmsg)
366                             if wsmsg.get("type", None) == "subscribe":
367                                 # Have to live w/o typeckeding from json
368                                 imeis = cast(List[str], wsmsg.get("imei"))
369                                 numback: int = wsmsg.get("backlog", 5)
370                                 for imei in imeis:
371                                     tosend.extend(backlog(imei, numback))
372                         towrite.add(sk)
373                 elif fl & zmq.POLLOUT:
374                     log.debug("Write now open for fd %d", sk)
375                     towrite.add(sk)
376                     towait.discard(sk)
377                 else:
378                     log.debug("Stray event: %s on socket %s", fl, sk)
379             # poll queue consumed, make changes now
380             for fd in tostop:
381                 poller.unregister(fd)  # type: ignore
382                 clients.stop(fd)
383             for wsmsg in tosend:
384                 log.debug("Sending to the clients: %s", wsmsg)
385                 towrite |= clients.send(wsmsg)
386             for clntsock, clntaddr in topoll:
387                 fd = clients.add(clntsock, clntaddr)
388                 poller.register(fd, flags=zmq.POLLIN)
389             # Deal with actually writing the data out
390             trywrite = towrite - towait
391             morewait = clients.write(trywrite)
392             log.debug(
393                 "towait %s, tried %s, still busy %s",
394                 towait,
395                 trywrite,
396                 morewait,
397             )
398             for fd in morewait - trywrite:  # new fds waiting for write
399                 poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)  # type: ignore
400             for fd in trywrite - morewait:  # no longer waiting for write
401                 poller.modify(fd, flags=zmq.POLLIN)  # type: ignore
402             towait &= trywrite
403             towait |= morewait
404     except KeyboardInterrupt:
405         pass
406
407
408 if __name__.endswith("__main__"):
409     runserver(common.init(log))