]> www.average.org Git - loctrkd.git/blob - gps303/wsgateway.py
wsgateway read html file every time
[loctrkd.git] / gps303 / wsgateway.py
1 """ Websocket Gateway """
2
3 from logging import getLogger
4 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
5 from time import time
6 from wsproto import ConnectionType, WSConnection
7 from wsproto.events import (
8     AcceptConnection,
9     CloseConnection,
10     Message,
11     Ping,
12     Request,
13     TextMessage,
14 )
15 from wsproto.utilities import RemoteProtocolError
16 import zmq
17
18 from . import common
19 from .zmsg import LocEvt
20
21 log = getLogger("gps303/wsgateway")
22 htmlfile = None
23
24
25 def try_http(data, fd, e):
26     global htmlfile
27     try:
28         lines = data.decode().split("\r\n")
29         request = lines[0]
30         headers = lines[1:]
31         op, resource, proto = request.split(" ")
32         log.debug(
33             "HTTP %s for %s, proto %s from fd %d, headers: %s",
34             op,
35             resource,
36             proto,
37             fd,
38             headers,
39         )
40         try:
41             pos = resource.index("?")
42             resource = resource[:pos]
43         except ValueError:
44             pass
45         if op == "GET":
46             if htmlfile is None:
47                 return (
48                     f"{proto} 500 No data configured\r\n"
49                     f"Content-Type: text/plain\r\n\r\n"
50                     f"HTML data not configure on the server\r\n".encode()
51                 )
52             elif resource == "/":
53                 try:
54                     with open(htmlfile, "rb") as fl:
55                         htmldata = fl.read()
56                     length = len(htmldata)
57                     return (
58                         f"{proto} 200 Ok\r\n"
59                         f"Content-Type: text/html; charset=utf-8\r\n"
60                         f"Content-Length: {len(htmldata):d}\r\n\r\n"
61                     ).encode("utf-8") + htmldata
62                 except OSError:
63                     return (
64                         f"{proto} 500 File not found\r\n"
65                         f"Content-Type: text/plain\r\n\r\n"
66                         f'HTML file could not be opened\r\n'.encode()
67                     )
68             else:
69                 return (
70                     f"{proto} 404 File not found\r\n"
71                     f"Content-Type: text/plain\r\n\r\n"
72                     f'We can only serve "/"\r\n'.encode()
73                 )
74         else:
75             return (
76                 f"{proto} 400 Bad request\r\n"
77                 "Content-Type: text/plain\r\n\r\n"
78                 "Bad request\r\n".encode()
79             )
80     except ValueError:
81         log.warning("Unparseable data from fd %d: %s", fd, data)
82         raise e
83
84
85 class Client:
86     """Websocket connection to the client"""
87
88     def __init__(self, sock, addr):
89         self.sock = sock
90         self.addr = addr
91         self.ws = WSConnection(ConnectionType.SERVER)
92         self.ws_data = b""
93
94     def close(self):
95         log.debug("Closing fd %d", self.sock.fileno())
96         self.sock.close()
97
98     def recv(self):
99         try:
100             data = self.sock.recv(4096)
101         except OSError:
102             log.warning(
103                 "Reading from fd %d: %s",
104                 self.sock.fileno(),
105                 e,
106             )
107             self.ws.receive_data(None)
108             return None
109         if not data:  # Client has closed connection
110             log.info(
111                 "EOF reading from fd %d",
112                 self.sock.fileno(),
113             )
114             self.ws.receive_data(None)
115             return None
116         try:
117             self.ws.receive_data(data)
118         except RemoteProtocolError as e:
119             log.debug(
120                 "Websocket error on fd %d, try plain http (%s)",
121                 self.sock.fileno(),
122                 e,
123             )
124             self.ws_data = try_http(data, self.sock.fileno(), e)
125             log.debug("Sending HTTP response to %d", self.sock.fileno())
126             msgs = None
127         else:
128             msgs = []
129             for event in self.ws.events():
130                 if isinstance(event, Request):
131                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
132                     # self.ws_data += self.ws.send(event.response())  # Why not?!
133                     self.ws_data += self.ws.send(AcceptConnection())
134                 elif isinstance(event, (CloseConnection, Ping)):
135                     log.debug("%s on fd %d", event, self.sock.fileno())
136                     self.ws_data += self.ws.send(event.response())
137                 elif isinstance(event, TextMessage):
138                     # TODO: save imei "subscription"
139                     log.debug("%s on fd %d", event, self.sock.fileno())
140                     msgs.append(event.data)
141                 else:
142                     log.warning("%s on fd %d", event, self.sock.fileno())
143         if self.ws_data:  # Temp hack
144             self.write()
145         return msgs
146
147     def wants(self, imei):
148         return True  # TODO: check subscriptions
149
150     def send(self, message):
151         # TODO: filter only wanted imei got from the client
152         self.ws_data += self.ws.send(Message(data=message.json))
153
154     def write(self):
155         if self.ws_data:
156             try:
157                 sent = self.sock.send(self.ws_data)
158                 self.ws_data = self.ws_data[sent:]
159             except OSError as e:
160                 log.error(
161                     "Sending to fd %d: %s",
162                     self.sock.fileno(),
163                     e,
164                 )
165                 self.ws_data = b""
166         return bool(self.ws_data)
167
168
169 class Clients:
170     def __init__(self):
171         self.by_fd = {}
172
173     def add(self, clntsock, clntaddr):
174         fd = clntsock.fileno()
175         log.info("Start serving fd %d from %s", fd, clntaddr)
176         self.by_fd[fd] = Client(clntsock, clntaddr)
177         return fd
178
179     def stop(self, fd):
180         clnt = self.by_fd[fd]
181         log.info("Stop serving fd %d", clnt.sock.fileno())
182         clnt.close()
183         del self.by_fd[fd]
184
185     def recv(self, fd):
186         clnt = self.by_fd[fd]
187         msgs = clnt.recv()
188         if msgs is None:
189             return None
190         result = []
191         for msg in msgs:
192             log.debug("Received: %s", msg)
193         return result
194
195     def send(self, msg):
196         towrite = set()
197         for fd, clnt in self.by_fd.items():
198             if clnt.wants(msg.imei):
199                 clnt.send(msg)
200                 towrite.add(fd)
201         return towrite
202
203     def write(self, towrite):
204         waiting = set()
205         for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]:
206             if clnt.write():
207                 waiting.add(fd)
208         return waiting
209
210
211 def runserver(conf):
212     global htmlfile
213
214     htmlfile = conf.get("wsgateway", "htmlfile")
215     zctx = zmq.Context()
216     zsub = zctx.socket(zmq.SUB)
217     zsub.connect(conf.get("lookaside", "publishurl"))
218     zsub.setsockopt(zmq.SUBSCRIBE, b"")
219     tcpl = socket(AF_INET6, SOCK_STREAM)
220     tcpl.setblocking(False)
221     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
222     tcpl.bind(("", conf.getint("wsgateway", "port")))
223     tcpl.listen(5)
224     tcpfd = tcpl.fileno()
225     poller = zmq.Poller()
226     poller.register(zsub, flags=zmq.POLLIN)
227     poller.register(tcpfd, flags=zmq.POLLIN)
228     clients = Clients()
229     try:
230         towait = set()
231         while True:
232             tosend = []
233             topoll = []
234             tostop = []
235             towrite = set()
236             events = poller.poll()
237             for sk, fl in events:
238                 if sk is zsub:
239                     while True:
240                         try:
241                             zmsg = LocEvt(zsub.recv(zmq.NOBLOCK))
242                             tosend.append(zmsg)
243                         except zmq.Again:
244                             break
245                 elif sk == tcpfd:
246                     clntsock, clntaddr = tcpl.accept()
247                     topoll.append((clntsock, clntaddr))
248                 elif fl & zmq.POLLIN:
249                     received = clients.recv(sk)
250                     if received is None:
251                         log.debug("Client gone from fd %d", sk)
252                         tostop.append(sk)
253                     else:
254                         for msg in received:
255                             log.debug("Received from %d: %s", sk, msg)
256                 elif fl & zmq.POLLOUT:
257                     log.debug("Write now open for fd %d", sk)
258                     towrite.add(sk)
259                     towait.discard(sk)
260                 else:
261                     log.debug("Stray event: %s on socket %s", fl, sk)
262             # poll queue consumed, make changes now
263             for fd in tostop:
264                 poller.unregister(fd)
265                 clients.stop(fd)
266             for zmsg in tosend:
267                 log.debug("Sending to the clients: %s", zmsg)
268                 towrite |= clients.send(zmsg)
269             for clntsock, clntaddr in topoll:
270                 fd = clients.add(clntsock, clntaddr)
271                 poller.register(fd, flags=zmq.POLLIN)
272             # Deal with actually writing the data out
273             trywrite = towrite - towait
274             morewait = clients.write(trywrite)
275             log.debug(
276                 "towait %s, tried %s, still busy %s",
277                 towait,
278                 trywrite,
279                 morewait,
280             )
281             for fd in morewait - trywrite:  # new fds waiting for write
282                 poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)
283             for fd in trywrite - morewait:  # no longer waiting for write
284                 poller.modify(fd, flags=zmq.POLLIN)
285             towait &= trywrite
286             towait |= morewait
287     except KeyboardInterrupt:
288         pass
289
290
291 if __name__.endswith("__main__"):
292     runserver(common.init(log))