]> www.average.org Git - loctrkd.git/blob - gps303/wsgateway.py
f2c6596c3e6a584c1280a9d3c2025e6ee265ebae
[loctrkd.git] / gps303 / wsgateway.py
1 """ Websocket Gateway """
2
3 from json import loads
4 from logging import getLogger
5 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
6 from time import time
7 from wsproto import ConnectionType, WSConnection
8 from wsproto.events import (
9     AcceptConnection,
10     CloseConnection,
11     Message,
12     Ping,
13     Request,
14     TextMessage,
15 )
16 from wsproto.utilities import RemoteProtocolError
17 import zmq
18
19 from . import common
20 from .zmsg import LocEvt
21
22 log = getLogger("gps303/wsgateway")
23 htmlfile = None
24
25
26 def try_http(data, fd, e):
27     global htmlfile
28     try:
29         lines = data.decode().split("\r\n")
30         request = lines[0]
31         headers = lines[1:]
32         op, resource, proto = request.split(" ")
33         log.debug(
34             "HTTP %s for %s, proto %s from fd %d, headers: %s",
35             op,
36             resource,
37             proto,
38             fd,
39             headers,
40         )
41         try:
42             pos = resource.index("?")
43             resource = resource[:pos]
44         except ValueError:
45             pass
46         if op == "GET":
47             if htmlfile is None:
48                 return (
49                     f"{proto} 500 No data configured\r\n"
50                     f"Content-Type: text/plain\r\n\r\n"
51                     f"HTML data not configure on the server\r\n".encode()
52                 )
53             elif resource == "/":
54                 try:
55                     with open(htmlfile, "rb") as fl:
56                         htmldata = fl.read()
57                     length = len(htmldata)
58                     return (
59                         f"{proto} 200 Ok\r\n"
60                         f"Content-Type: text/html; charset=utf-8\r\n"
61                         f"Content-Length: {len(htmldata):d}\r\n\r\n"
62                     ).encode("utf-8") + htmldata
63                 except OSError:
64                     return (
65                         f"{proto} 500 File not found\r\n"
66                         f"Content-Type: text/plain\r\n\r\n"
67                         f"HTML file could not be opened\r\n".encode()
68                     )
69             else:
70                 return (
71                     f"{proto} 404 File not found\r\n"
72                     f"Content-Type: text/plain\r\n\r\n"
73                     f'We can only serve "/"\r\n'.encode()
74                 )
75         else:
76             return (
77                 f"{proto} 400 Bad request\r\n"
78                 "Content-Type: text/plain\r\n\r\n"
79                 "Bad request\r\n".encode()
80             )
81     except ValueError:
82         log.warning("Unparseable data from fd %d: %s", fd, data)
83         raise e
84
85
86 class Client:
87     """Websocket connection to the client"""
88
89     def __init__(self, sock, addr):
90         self.sock = sock
91         self.addr = addr
92         self.ws = WSConnection(ConnectionType.SERVER)
93         self.ws_data = b""
94         self.ready = False
95         self.imeis = set()
96
97     def close(self):
98         log.debug("Closing fd %d", self.sock.fileno())
99         self.sock.close()
100
101     def recv(self):
102         try:
103             data = self.sock.recv(4096)
104         except OSError:
105             log.warning(
106                 "Reading from fd %d: %s",
107                 self.sock.fileno(),
108                 e,
109             )
110             self.ws.receive_data(None)
111             return None
112         if not data:  # Client has closed connection
113             log.info(
114                 "EOF reading from fd %d",
115                 self.sock.fileno(),
116             )
117             self.ws.receive_data(None)
118             return None
119         try:
120             self.ws.receive_data(data)
121         except RemoteProtocolError as e:
122             log.debug(
123                 "Websocket error on fd %d, try plain http (%s)",
124                 self.sock.fileno(),
125                 e,
126             )
127             self.ws_data = try_http(data, self.sock.fileno(), e)
128             self.write()  # TODO this is a hack
129             log.debug("Sending HTTP response to %d", self.sock.fileno())
130             msgs = None
131         else:
132             msgs = []
133             for event in self.ws.events():
134                 if isinstance(event, Request):
135                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
136                     # self.ws_data += self.ws.send(event.response())  # Why not?!
137                     self.ws_data += self.ws.send(AcceptConnection())
138                     self.ready = True
139                 elif isinstance(event, (CloseConnection, Ping)):
140                     log.debug("%s on fd %d", event, self.sock.fileno())
141                     self.ws_data += self.ws.send(event.response())
142                 elif isinstance(event, TextMessage):
143                     # TODO: save imei "subscription"
144                     log.debug("%s on fd %d", event, self.sock.fileno())
145                     msg = loads(event.data)
146                     msgs.append(msg)
147                     if msg.get("type", None) == "subscribe":
148                         self.imeis = set(msg.get("imei", []))
149                         log.debug(
150                             "subs list on fd %s is %s",
151                             self.sock.fileno(),
152                             self.imeis,
153                         )
154                 else:
155                     log.warning("%s on fd %d", event, self.sock.fileno())
156         return msgs
157
158     def wants(self, imei):
159         log.debug("wants %s? set is %s on fd %d", imei, self.imeis, self.sock.fileno())
160         return True  # TODO: check subscriptions
161
162     def send(self, message):
163         if self.ready and message.imei in self.imeis:
164             self.ws_data += self.ws.send(Message(data=message.json))
165
166     def write(self):
167         if self.ws_data:
168             try:
169                 sent = self.sock.send(self.ws_data)
170                 self.ws_data = self.ws_data[sent:]
171             except OSError as e:
172                 log.error(
173                     "Sending to fd %d: %s",
174                     self.sock.fileno(),
175                     e,
176                 )
177                 self.ws_data = b""
178         return bool(self.ws_data)
179
180
181 class Clients:
182     def __init__(self):
183         self.by_fd = {}
184
185     def add(self, clntsock, clntaddr):
186         fd = clntsock.fileno()
187         log.info("Start serving fd %d from %s", fd, clntaddr)
188         self.by_fd[fd] = Client(clntsock, clntaddr)
189         return fd
190
191     def stop(self, fd):
192         clnt = self.by_fd[fd]
193         log.info("Stop serving fd %d", clnt.sock.fileno())
194         clnt.close()
195         del self.by_fd[fd]
196
197     def recv(self, fd):
198         clnt = self.by_fd[fd]
199         msgs = clnt.recv()
200         if msgs is None:
201             return None
202         result = []
203         for msg in msgs:
204             log.debug("Received: %s", msg)
205             result.append(msg)
206         return result
207
208     def send(self, msg):
209         towrite = set()
210         for fd, clnt in self.by_fd.items():
211             if clnt.wants(msg.imei):
212                 clnt.send(msg)
213                 towrite.add(fd)
214         return towrite
215
216     def write(self, towrite):
217         waiting = set()
218         for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]:
219             if clnt.write():
220                 waiting.add(fd)
221         return waiting
222
223
224 def runserver(conf):
225     global htmlfile
226
227     htmlfile = conf.get("wsgateway", "htmlfile")
228     zctx = zmq.Context()
229     zsub = zctx.socket(zmq.SUB)
230     zsub.connect(conf.get("lookaside", "publishurl"))
231     zsub.setsockopt(zmq.SUBSCRIBE, b"")
232     tcpl = socket(AF_INET6, SOCK_STREAM)
233     tcpl.setblocking(False)
234     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
235     tcpl.bind(("", conf.getint("wsgateway", "port")))
236     tcpl.listen(5)
237     tcpfd = tcpl.fileno()
238     poller = zmq.Poller()
239     poller.register(zsub, flags=zmq.POLLIN)
240     poller.register(tcpfd, flags=zmq.POLLIN)
241     clients = Clients()
242     try:
243         towait = set()
244         while True:
245             tosend = []
246             topoll = []
247             tostop = []
248             towrite = set()
249             events = poller.poll()
250             for sk, fl in events:
251                 if sk is zsub:
252                     while True:
253                         try:
254                             zmsg = LocEvt(zsub.recv(zmq.NOBLOCK))
255                             tosend.append(zmsg)
256                         except zmq.Again:
257                             break
258                 elif sk == tcpfd:
259                     clntsock, clntaddr = tcpl.accept()
260                     topoll.append((clntsock, clntaddr))
261                 elif fl & zmq.POLLIN:
262                     received = clients.recv(sk)
263                     if received is None:
264                         log.debug("Client gone from fd %d", sk)
265                         tostop.append(sk)
266                         towait.discard(fd)
267                     else:
268                         for msg in received:
269                             log.debug("Received from %d: %s", sk, msg)
270                         towrite.add(sk)
271                 elif fl & zmq.POLLOUT:
272                     log.debug("Write now open for fd %d", sk)
273                     towrite.add(sk)
274                     towait.discard(sk)
275                 else:
276                     log.debug("Stray event: %s on socket %s", fl, sk)
277             # poll queue consumed, make changes now
278             for fd in tostop:
279                 poller.unregister(fd)
280                 clients.stop(fd)
281             for zmsg in tosend:
282                 log.debug("Sending to the clients: %s", zmsg)
283                 towrite |= clients.send(zmsg)
284             for clntsock, clntaddr in topoll:
285                 fd = clients.add(clntsock, clntaddr)
286                 poller.register(fd, flags=zmq.POLLIN)
287             # Deal with actually writing the data out
288             trywrite = towrite - towait
289             morewait = clients.write(trywrite)
290             log.debug(
291                 "towait %s, tried %s, still busy %s",
292                 towait,
293                 trywrite,
294                 morewait,
295             )
296             for fd in morewait - trywrite:  # new fds waiting for write
297                 poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)
298             for fd in trywrite - morewait:  # no longer waiting for write
299                 poller.modify(fd, flags=zmq.POLLIN)
300             towait &= trywrite
301             towait |= morewait
302     except KeyboardInterrupt:
303         pass
304
305
306 if __name__.endswith("__main__"):
307     runserver(common.init(log))