]> www.average.org Git - loctrkd.git/blob - gps303/wsgateway.py
WIP on supporting multiple markers
[loctrkd.git] / gps303 / wsgateway.py
1 """ Websocket Gateway """
2
3 from json import loads
4 from logging import getLogger
5 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
6 from time import time
7 from wsproto import ConnectionType, WSConnection
8 from wsproto.events import (
9     AcceptConnection,
10     CloseConnection,
11     Message,
12     Ping,
13     Request,
14     TextMessage,
15 )
16 from wsproto.utilities import RemoteProtocolError
17 import zmq
18
19 from . import common
20 from .backlog import blinit, backlog
21 from .zmsg import LocEvt
22
23 log = getLogger("gps303/wsgateway")
24 htmlfile = None
25
26
27 def try_http(data, fd, e):
28     global htmlfile
29     try:
30         lines = data.decode().split("\r\n")
31         request = lines[0]
32         headers = lines[1:]
33         op, resource, proto = request.split(" ")
34         log.debug(
35             "HTTP %s for %s, proto %s from fd %d, headers: %s",
36             op,
37             resource,
38             proto,
39             fd,
40             headers,
41         )
42         try:
43             pos = resource.index("?")
44             resource = resource[:pos]
45         except ValueError:
46             pass
47         if op == "GET":
48             if htmlfile is None:
49                 return (
50                     f"{proto} 500 No data configured\r\n"
51                     f"Content-Type: text/plain\r\n\r\n"
52                     f"HTML data not configure on the server\r\n".encode()
53                 )
54             elif resource == "/":
55                 try:
56                     with open(htmlfile, "rb") as fl:
57                         htmldata = fl.read()
58                     length = len(htmldata)
59                     return (
60                         f"{proto} 200 Ok\r\n"
61                         f"Content-Type: text/html; charset=utf-8\r\n"
62                         f"Content-Length: {len(htmldata):d}\r\n\r\n"
63                     ).encode("utf-8") + htmldata
64                 except OSError:
65                     return (
66                         f"{proto} 500 File not found\r\n"
67                         f"Content-Type: text/plain\r\n\r\n"
68                         f"HTML file could not be opened\r\n".encode()
69                     )
70             else:
71                 return (
72                     f"{proto} 404 File not found\r\n"
73                     f"Content-Type: text/plain\r\n\r\n"
74                     f'We can only serve "/"\r\n'.encode()
75                 )
76         else:
77             return (
78                 f"{proto} 400 Bad request\r\n"
79                 "Content-Type: text/plain\r\n\r\n"
80                 "Bad request\r\n".encode()
81             )
82     except ValueError:
83         log.warning("Unparseable data from fd %d: %s", fd, data)
84         raise e
85
86
87 class Client:
88     """Websocket connection to the client"""
89
90     def __init__(self, sock, addr):
91         self.sock = sock
92         self.addr = addr
93         self.ws = WSConnection(ConnectionType.SERVER)
94         self.ws_data = b""
95         self.ready = False
96         self.imeis = set()
97
98     def close(self):
99         log.debug("Closing fd %d", self.sock.fileno())
100         self.sock.close()
101
102     def recv(self):
103         try:
104             data = self.sock.recv(4096)
105         except OSError:
106             log.warning(
107                 "Reading from fd %d: %s",
108                 self.sock.fileno(),
109                 e,
110             )
111             self.ws.receive_data(None)
112             return None
113         if not data:  # Client has closed connection
114             log.info(
115                 "EOF reading from fd %d",
116                 self.sock.fileno(),
117             )
118             self.ws.receive_data(None)
119             return None
120         try:
121             self.ws.receive_data(data)
122         except RemoteProtocolError as e:
123             log.debug(
124                 "Websocket error on fd %d, try plain http (%s)",
125                 self.sock.fileno(),
126                 e,
127             )
128             self.ws_data = try_http(data, self.sock.fileno(), e)
129             self.write()  # TODO this is a hack
130             log.debug("Sending HTTP response to %d", self.sock.fileno())
131             msgs = None
132         else:
133             msgs = []
134             for event in self.ws.events():
135                 if isinstance(event, Request):
136                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
137                     # self.ws_data += self.ws.send(event.response())  # Why not?!
138                     self.ws_data += self.ws.send(AcceptConnection())
139                     self.ready = True
140                 elif isinstance(event, (CloseConnection, Ping)):
141                     log.debug("%s on fd %d", event, self.sock.fileno())
142                     self.ws_data += self.ws.send(event.response())
143                 elif isinstance(event, TextMessage):
144                     # TODO: save imei "subscription"
145                     log.debug("%s on fd %d", event, self.sock.fileno())
146                     msg = loads(event.data)
147                     msgs.append(msg)
148                     if msg.get("type", None) == "subscribe":
149                         self.imeis = set(msg.get("imei", []))
150                         log.debug(
151                             "subs list on fd %s is %s",
152                             self.sock.fileno(),
153                             self.imeis,
154                         )
155                 else:
156                     log.warning("%s on fd %d", event, self.sock.fileno())
157         return msgs
158
159     def wants(self, imei):
160         log.debug("wants %s? set is %s on fd %d", imei, self.imeis, self.sock.fileno())
161         return True  # TODO: check subscriptions
162
163     def send(self, message):
164         if self.ready and message.imei in self.imeis:
165             self.ws_data += self.ws.send(Message(data=message.json))
166
167     def write(self):
168         if self.ws_data:
169             try:
170                 sent = self.sock.send(self.ws_data)
171                 self.ws_data = self.ws_data[sent:]
172             except OSError as e:
173                 log.error(
174                     "Sending to fd %d: %s",
175                     self.sock.fileno(),
176                     e,
177                 )
178                 self.ws_data = b""
179         return bool(self.ws_data)
180
181
182 class Clients:
183     def __init__(self):
184         self.by_fd = {}
185
186     def add(self, clntsock, clntaddr):
187         fd = clntsock.fileno()
188         log.info("Start serving fd %d from %s", fd, clntaddr)
189         self.by_fd[fd] = Client(clntsock, clntaddr)
190         return fd
191
192     def stop(self, fd):
193         clnt = self.by_fd[fd]
194         log.info("Stop serving fd %d", clnt.sock.fileno())
195         clnt.close()
196         del self.by_fd[fd]
197
198     def recv(self, fd):
199         clnt = self.by_fd[fd]
200         return clnt.recv()
201
202     def send(self, msg):
203         towrite = set()
204         for fd, clnt in self.by_fd.items():
205             if clnt.wants(msg.imei):
206                 clnt.send(msg)
207                 towrite.add(fd)
208         return towrite
209
210     def write(self, towrite):
211         waiting = set()
212         for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]:
213             if clnt.write():
214                 waiting.add(fd)
215         return waiting
216
217     def subs(self):
218         result = set()
219         for clnt in self.by_fd.values():
220             result |= clnt.imeis
221         return result
222
223
224 def runserver(conf):
225     global htmlfile
226
227     blinit(conf.get("storage", "dbfn"), conf.get("opencellid", "dbfn"))
228     htmlfile = conf.get("wsgateway", "htmlfile")
229     zctx = zmq.Context()
230     zsub = zctx.socket(zmq.SUB)
231     zsub.connect(conf.get("lookaside", "publishurl"))
232     tcpl = socket(AF_INET6, SOCK_STREAM)
233     tcpl.setblocking(False)
234     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
235     tcpl.bind(("", conf.getint("wsgateway", "port")))
236     tcpl.listen(5)
237     tcpfd = tcpl.fileno()
238     poller = zmq.Poller()
239     poller.register(zsub, flags=zmq.POLLIN)
240     poller.register(tcpfd, flags=zmq.POLLIN)
241     clients = Clients()
242     activesubs = set()
243     try:
244         towait = set()
245         while True:
246             neededsubs = clients.subs()
247             for imei in neededsubs - activesubs:
248                 zsub.setsockopt(zmq.SUBSCRIBE, imei.encode())
249             for imei in activesubs - neededsubs:
250                 zsub.setsockopt(zmq.UNSUBSCRIBE, imei.encode())
251             activesubs = neededsubs
252             log.debug("Subscribed to: %s", activesubs)
253             tosend = []
254             topoll = []
255             tostop = []
256             towrite = set()
257             events = poller.poll()
258             for sk, fl in events:
259                 if sk is zsub:
260                     while True:
261                         try:
262                             zmsg = LocEvt(zsub.recv(zmq.NOBLOCK))
263                             tosend.append(zmsg)
264                         except zmq.Again:
265                             break
266                 elif sk == tcpfd:
267                     clntsock, clntaddr = tcpl.accept()
268                     topoll.append((clntsock, clntaddr))
269                 elif fl & zmq.POLLIN:
270                     received = clients.recv(sk)
271                     if received is None:
272                         log.debug("Client gone from fd %d", sk)
273                         tostop.append(sk)
274                         towait.discard(fd)
275                     else:
276                         for msg in received:
277                             log.debug("Received from %d: %s", sk, msg)
278                             if msg.get("type", None) == "subscribe":
279                                 imei = msg.get("imei")
280                                 numback = msg.get("backlog", 5)
281                                 for elem in imei:
282                                     tosend.extend(backlog(elem, numback))
283                         towrite.add(sk)
284                 elif fl & zmq.POLLOUT:
285                     log.debug("Write now open for fd %d", sk)
286                     towrite.add(sk)
287                     towait.discard(sk)
288                 else:
289                     log.debug("Stray event: %s on socket %s", fl, sk)
290             # poll queue consumed, make changes now
291             for fd in tostop:
292                 poller.unregister(fd)
293                 clients.stop(fd)
294             for zmsg in tosend:
295                 log.debug("Sending to the clients: %s", zmsg)
296                 towrite |= clients.send(zmsg)
297             for clntsock, clntaddr in topoll:
298                 fd = clients.add(clntsock, clntaddr)
299                 poller.register(fd, flags=zmq.POLLIN)
300             # Deal with actually writing the data out
301             trywrite = towrite - towait
302             morewait = clients.write(trywrite)
303             log.debug(
304                 "towait %s, tried %s, still busy %s",
305                 towait,
306                 trywrite,
307                 morewait,
308             )
309             for fd in morewait - trywrite:  # new fds waiting for write
310                 poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)
311             for fd in trywrite - morewait:  # no longer waiting for write
312                 poller.modify(fd, flags=zmq.POLLIN)
313             towait &= trywrite
314             towait |= morewait
315     except KeyboardInterrupt:
316         pass
317
318
319 if __name__.endswith("__main__"):
320     runserver(common.init(log))