]> www.average.org Git - loctrkd.git/blob - loctrkd/wsgateway.py
Send backlog only the the ws client that requested
[loctrkd.git] / loctrkd / wsgateway.py
1 """ Websocket Gateway """
2
3 from configparser import ConfigParser
4 from datetime import datetime, timezone
5 from importlib import import_module
6 from json import dumps, loads
7 from logging import getLogger
8 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
9 from time import time
10 from typing import Any, cast, Dict, List, Optional, Set, Tuple
11 from wsproto import ConnectionType, WSConnection
12 from wsproto.events import (
13     AcceptConnection,
14     CloseConnection,
15     Event,
16     Message,
17     Ping,
18     Request,
19     TextMessage,
20 )
21 from wsproto.utilities import RemoteProtocolError
22 import zmq
23
24 from . import common
25 from .evstore import initdb, fetch
26 from .protomodule import ProtoModule
27 from .zmsg import Rept, rtopic
28
29 log = getLogger("loctrkd/wsgateway")
30
31 htmlfile = None
32
33
34 def backlog(imei: str, numback: int) -> List[Dict[str, Any]]:
35     result = []
36     for report in fetch(imei, numback):
37         report["type"] = "location"
38         timestamp = report.pop("devtime")
39         report["timestamp"] = timestamp
40         result.append(report)
41     return result
42
43
44 def try_http(data: bytes, fd: int, e: Exception) -> bytes:
45     global htmlfile
46     try:
47         lines = data.decode().split("\r\n")
48         request = lines[0]
49         headers = lines[1:]
50         op, resource, proto = request.split(" ")
51         log.debug(
52             "HTTP %s for %s, proto %s from fd %d, headers: %s",
53             op,
54             resource,
55             proto,
56             fd,
57             headers,
58         )
59         if op == "GET":
60             if htmlfile is None:
61                 return (
62                     f"{proto} 500 No data configured\r\n"
63                     f"Content-Type: text/plain\r\n\r\n"
64                     f"HTML data not configured on the server\r\n".encode()
65                 )
66             else:
67                 try:
68                     with open(htmlfile, "rb") as fl:
69                         htmldata = fl.read()
70                     length = len(htmldata)
71                     return (
72                         f"{proto} 200 Ok\r\n"
73                         f"Content-Type: text/html; charset=utf-8\r\n"
74                         f"Content-Length: {len(htmldata):d}\r\n\r\n"
75                     ).encode("utf-8") + htmldata
76                 except OSError:
77                     return (
78                         f"{proto} 500 File not found\r\n"
79                         f"Content-Type: text/plain\r\n\r\n"
80                         f"HTML file could not be opened\r\n".encode()
81                     )
82         else:
83             return (
84                 f"{proto} 400 Bad request\r\n"
85                 "Content-Type: text/plain\r\n\r\n"
86                 "Bad request\r\n".encode()
87             )
88     except ValueError:
89         log.warning("Unparseable data from fd %d: %s", fd, data)
90         raise e
91
92
93 class Client:
94     """Websocket connection to the client"""
95
96     def __init__(self, sock: socket, addr: Tuple[str, int]) -> None:
97         self.sock = sock
98         self.addr = addr
99         self.ws = WSConnection(ConnectionType.SERVER)
100         self.ws_data = b""
101         self.ready = False
102         self.imeis: Set[str] = set()
103
104     def close(self) -> None:
105         log.debug("Closing fd %d", self.sock.fileno())
106         self.sock.close()
107
108     def recv(self) -> Optional[List[Dict[str, Any]]]:
109         try:
110             data = self.sock.recv(4096)
111         except OSError as e:
112             log.warning(
113                 "Reading from fd %d: %s",
114                 self.sock.fileno(),
115                 e,
116             )
117             self.ws.receive_data(None)
118             return None
119         if not data:  # Client has closed connection
120             log.info(
121                 "EOF reading from fd %d",
122                 self.sock.fileno(),
123             )
124             self.ws.receive_data(None)
125             return None
126         try:
127             self.ws.receive_data(data)
128         except RemoteProtocolError as e:
129             log.debug(
130                 "Websocket error on fd %d, try plain http (%s)",
131                 self.sock.fileno(),
132                 e,
133             )
134             self.ws_data = try_http(data, self.sock.fileno(), e)
135             # this `write` is a hack - writing _ought_ to be done at the
136             # stage when all other writes are performed. But I could not
137             # arrange it so in a logical way. Let it stay this way. The
138             # whole http server affair is a hack anyway.
139             self.write()
140             log.debug("Sending HTTP response to %d", self.sock.fileno())
141             msgs = None
142         else:
143             msgs = []
144             for event in self.ws.events():
145                 if isinstance(event, Request):
146                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
147                     # self.ws_data += self.ws.send(event.response())  # Why not?!
148                     self.ws_data += self.ws.send(AcceptConnection())
149                     self.ready = True
150                 elif isinstance(event, (CloseConnection, Ping)):
151                     log.debug("%s on fd %d", event, self.sock.fileno())
152                     self.ws_data += self.ws.send(event.response())
153                 elif isinstance(event, TextMessage):
154                     log.debug("%s on fd %d", event, self.sock.fileno())
155                     msg = loads(event.data)
156                     msgs.append(msg)
157                     if msg.get("type", None) == "subscribe":
158                         self.imeis = set(msg.get("imei", []))
159                         log.debug(
160                             "subs list on fd %s is %s",
161                             self.sock.fileno(),
162                             self.imeis,
163                         )
164                 else:
165                     log.warning("%s on fd %d", event, self.sock.fileno())
166         return msgs
167
168     def wants(self, imei: str) -> bool:
169         log.debug(
170             "wants %s? set is %s on fd %d",
171             imei,
172             self.imeis,
173             self.sock.fileno(),
174         )
175         return imei in self.imeis
176
177     def send(self, message: Dict[str, Any]) -> None:
178         if self.ready and message["imei"] in self.imeis:
179             self.ws_data += self.ws.send(Message(data=dumps(message)))
180
181     def write(self) -> bool:
182         if self.ws_data:
183             try:
184                 sent = self.sock.send(self.ws_data)
185                 self.ws_data = self.ws_data[sent:]
186             except OSError as e:
187                 log.error(
188                     "Sending to fd %d: %s",
189                     self.sock.fileno(),
190                     e,
191                 )
192                 self.ws_data = b""
193         return bool(self.ws_data)
194
195
196 class Clients:
197     def __init__(self) -> None:
198         self.by_fd: Dict[int, Client] = {}
199
200     def add(self, clntsock: socket, clntaddr: Tuple[str, int]) -> int:
201         fd = clntsock.fileno()
202         log.info("Start serving fd %d from %s", fd, clntaddr)
203         self.by_fd[fd] = Client(clntsock, clntaddr)
204         return fd
205
206     def stop(self, fd: int) -> None:
207         clnt = self.by_fd[fd]
208         log.info("Stop serving fd %d", clnt.sock.fileno())
209         clnt.close()
210         del self.by_fd[fd]
211
212     def recv(self, fd: int) -> Tuple[Client, Optional[List[Dict[str, Any]]]]:
213         clnt = self.by_fd[fd]
214         return (clnt, clnt.recv())
215
216     def send(self, clnt: Optional[Client], msg: Dict[str, Any]) -> Set[int]:
217         towrite = set()
218         if clnt is None:
219             for fd, cl in self.by_fd.items():
220                 if cl.wants(msg["imei"]):
221                     cl.send(msg)
222                     towrite.add(fd)
223         else:
224             fd = clnt.sock.fileno()
225             if self.by_fd.get(fd, None) == clnt:
226                 clnt.send(msg)
227                 towrite.add(fd)
228             else:
229                 log.info(
230                     "Trying to send %s to client at %d, not in service",
231                     msg,
232                     fd,
233                 )
234         return towrite
235
236     def write(self, towrite: Set[int]) -> Set[int]:
237         waiting = set()
238         for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]:
239             if clnt and clnt.write():
240                 waiting.add(fd)
241         return waiting
242
243     def subs(self) -> Set[str]:
244         result = set()
245         for clnt in self.by_fd.values():
246             result |= clnt.imeis
247         return result
248
249
250 def runserver(conf: ConfigParser) -> None:
251     global htmlfile
252     initdb(conf.get("storage", "dbfn"))
253     htmlfile = conf.get("wsgateway", "htmlfile", fallback=None)
254     # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
255     zctx = zmq.Context()  # type: ignore
256     zsub = zctx.socket(zmq.SUB)  # type: ignore
257     zsub.connect(conf.get("rectifier", "publishurl"))
258     tcpl = socket(AF_INET6, SOCK_STREAM)
259     tcpl.setblocking(False)
260     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
261     tcpl.bind(("", conf.getint("wsgateway", "port")))
262     tcpl.listen(5)
263     tcpfd = tcpl.fileno()
264     poller = zmq.Poller()  # type: ignore
265     poller.register(zsub, flags=zmq.POLLIN)
266     poller.register(tcpfd, flags=zmq.POLLIN)
267     clients = Clients()
268     activesubs: Set[str] = set()
269     try:
270         towait: Set[int] = set()
271         while True:
272             neededsubs = clients.subs()
273             for imei in neededsubs - activesubs:
274                 zsub.setsockopt(zmq.SUBSCRIBE, rtopic(imei))
275             for imei in activesubs - neededsubs:
276                 zsub.setsockopt(zmq.UNSUBSCRIBE, rtopic(imei))
277             activesubs = neededsubs
278             log.debug("Subscribed to: %s", activesubs)
279             tosend: List[Tuple[Optional[Client], Dict[str, Any]]] = []
280             topoll = []
281             tostop = []
282             towrite = set()
283             events = poller.poll()
284             for sk, fl in events:
285                 if sk is zsub:
286                     while True:
287                         try:
288                             zmsg = Rept(zsub.recv(zmq.NOBLOCK))
289                             msg = loads(zmsg.payload)
290                             msg["imei"] = zmsg.imei
291                             log.debug("Got %s, sending %s", zmsg, msg)
292                             tosend.append((None, msg))
293                         except zmq.Again:
294                             break
295                 elif sk == tcpfd:
296                     clntsock, clntaddr = tcpl.accept()
297                     topoll.append((clntsock, clntaddr))
298                 elif fl & zmq.POLLIN:
299                     clnt, received = clients.recv(sk)
300                     if received is None:
301                         log.debug("Client gone from fd %d", sk)
302                         tostop.append(sk)
303                         towait.discard(sk)
304                     else:
305                         for wsmsg in received:
306                             log.debug("Received from %d: %s", sk, wsmsg)
307                             if wsmsg.get("type", None) == "subscribe":
308                                 # Have to live w/o typeckeding from json
309                                 imeis = cast(List[str], wsmsg.get("imei"))
310                                 numback: int = wsmsg.get("backlog", 5)
311                                 for imei in imeis:
312                                     tosend.extend(
313                                         [
314                                             (clnt, msg)
315                                             for msg in backlog(imei, numback)
316                                         ]
317                                     )
318                         towrite.add(sk)
319                 elif fl & zmq.POLLOUT:
320                     log.debug("Write now open for fd %d", sk)
321                     towrite.add(sk)
322                     towait.discard(sk)
323                 else:
324                     log.debug("Stray event: %s on socket %s", fl, sk)
325             # poll queue consumed, make changes now
326             for fd in tostop:
327                 poller.unregister(fd)  # type: ignore
328                 clients.stop(fd)
329             for towhom, wsmsg in tosend:
330                 log.debug("Sending to the client %s: %s", towhom, wsmsg)
331                 towrite |= clients.send(towhom, wsmsg)
332             for clntsock, clntaddr in topoll:
333                 fd = clients.add(clntsock, clntaddr)
334                 poller.register(fd, flags=zmq.POLLIN)
335             # Deal with actually writing the data out
336             trywrite = towrite - towait
337             morewait = clients.write(trywrite)
338             log.debug(
339                 "towait %s, tried %s, still busy %s",
340                 towait,
341                 trywrite,
342                 morewait,
343             )
344             for fd in morewait - trywrite:  # new fds waiting for write
345                 poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)  # type: ignore
346             for fd in trywrite - morewait:  # no longer waiting for write
347                 poller.modify(fd, flags=zmq.POLLIN)  # type: ignore
348             towait &= trywrite
349             towait |= morewait
350     except KeyboardInterrupt:
351         zsub.close()
352         zctx.destroy()  # type: ignore
353         tcpl.close()
354
355
356 if __name__.endswith("__main__"):
357     runserver(common.init(log))