]> www.average.org Git - loctrkd.git/blob - gps303/wsgateway.py
15a6517d0c61ac14535f346545b1c03607abcc1a
[loctrkd.git] / gps303 / wsgateway.py
1 """ Websocket Gateway """
2
3 from logging import getLogger
4 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
5 from time import time
6 from wsproto import ConnectionType, WSConnection
7 from wsproto.events import (
8     AcceptConnection,
9     CloseConnection,
10     Message,
11     Ping,
12     Request,
13     TextMessage,
14 )
15 from wsproto.utilities import RemoteProtocolError
16 import zmq
17
18 from . import common
19 from .zmsg import LocEvt
20
21 log = getLogger("gps303/wsgateway")
22 htmldata = None
23
24
25 def try_http(data, fd, e):
26     global htmldata
27     try:
28         lines = data.decode().split("\r\n")
29         request = lines[0]
30         headers = lines[1:]
31         op, resource, proto = request.split(" ")
32         log.debug(
33             "HTTP %s for %s, proto %s from fd %d, headers: %s",
34             op,
35             resource,
36             proto,
37             fd,
38             headers,
39         )
40         if op == "GET":
41             if htmldata is None:
42                 return (
43                     f"{proto} 500 No data configured\r\n"
44                     f"Content-Type: text/plain\r\n\r\n"
45                     f"HTML data not configure on the server\r\n".encode()
46                 )
47             elif resource == "/":
48                 length = len(htmldata.encode("utf-8"))
49                 return (
50                     f"{proto} 200 Ok\r\n"
51                     f"Content-Type: text/html; charset=utf-8\r\n"
52                     f"Content-Length: {length:d}\r\n\r\n" + htmldata
53                 ).encode("utf-8")
54             else:
55                 return (
56                     f"{proto} 404 File not found\r\n"
57                     f"Content-Type: text/plain\r\n\r\n"
58                     f"We can only serve \"/\"\r\n".encode()
59                 )
60         else:
61             return (
62                 f"{proto} 400 Bad request\r\n"
63                 "Content-Type: text/plain\r\n\r\n"
64                 "Bad request\r\n".encode()
65             )
66     except ValueError:
67         log.warning("Unparseable data from fd %d: %s", fd, data)
68         raise e
69
70
71 class Client:
72     """Websocket connection to the client"""
73
74     def __init__(self, sock, addr):
75         self.sock = sock
76         self.addr = addr
77         self.ws = WSConnection(ConnectionType.SERVER)
78         self.ws_data = b""
79
80     def close(self):
81         log.debug("Closing fd %d", self.sock.fileno())
82         self.sock.close()
83
84     def recv(self):
85         try:
86             data = self.sock.recv(4096)
87         except OSError:
88             log.warning(
89                 "Reading from fd %d: %s",
90                 self.sock.fileno(),
91                 e,
92             )
93             self.ws.receive_data(None)
94             return None
95         if not data:  # Client has closed connection
96             log.info(
97                 "EOF reading from fd %d",
98                 self.sock.fileno(),
99             )
100             self.ws.receive_data(None)
101             return None
102         try:
103             self.ws.receive_data(data)
104         except RemoteProtocolError as e:
105             log.debug(
106                 "Websocket error on fd %d, try plain http (%s)",
107                 self.sock.fileno(),
108                 e,
109             )
110             self.ws_data = try_http(data, self.sock.fileno(), e)
111             log.debug("Sending HTTP response to %d", self.sock.fileno())
112             msgs = None
113         else:
114             msgs = []
115             for event in self.ws.events():
116                 if isinstance(event, Request):
117                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
118                     # self.ws_data += self.ws.send(event.response())  # Why not?!
119                     self.ws_data += self.ws.send(AcceptConnection())
120                 elif isinstance(event, (CloseConnection, Ping)):
121                     log.debug("%s on fd %d", event, self.sock.fileno())
122                     self.ws_data += self.ws.send(event.response())
123                 elif isinstance(event, TextMessage):
124                     # TODO: save imei "subscription"
125                     log.debug("%s on fd %d", event, self.sock.fileno())
126                     msgs.append(event.data)
127                 else:
128                     log.warning("%s on fd %d", event, self.sock.fileno())
129         if self.ws_data:  # Temp hack
130             self.write()
131         return msgs
132
133     def send(self, imei, message):
134         # TODO: filter only wanted imei got from the client
135         self.ws_data += self.ws.send(Message(data=message))
136
137     def write(self):
138         try:
139             sent = self.sock.send(self.ws_data)
140             self.ws_data = self.ws_data[sent:]
141         except OSError as e:
142             log.error(
143                 "Sending to fd %d: %s",
144                 self.sock.fileno(),
145                 e,
146             )
147             self.ws_data = b""
148
149
150 class Clients:
151     def __init__(self):
152         self.by_fd = {}
153
154     def add(self, clntsock, clntaddr):
155         fd = clntsock.fileno()
156         log.info("Start serving fd %d from %s", fd, clntaddr)
157         self.by_fd[fd] = Client(clntsock, clntaddr)
158         return fd
159
160     def stop(self, fd):
161         clnt = self.by_fd[fd]
162         log.info("Stop serving fd %d", clnt.sock.fileno())
163         clnt.close()
164         del self.by_fd[fd]
165
166     def recv(self, fd):
167         clnt = self.by_fd[fd]
168         msgs = clnt.recv()
169         if msgs is None:
170             return None
171         result = []
172         for msg in msgs:
173             log.debug("Received: %s", msg)
174         return result
175
176     def send(self, msgs):
177         for clnt in self.by_fd.values():
178             clnt.send(msgs)
179             clnt.write()
180
181
182 def runserver(conf):
183     global htmldata
184     try:
185         with open(
186             conf.get("wsgateway", "htmlfile"), encoding="utf-8"
187         ) as fl:
188             htmldata = fl.read()
189     except OSError:
190         pass
191     zctx = zmq.Context()
192     zsub = zctx.socket(zmq.SUB)
193     zsub.connect(conf.get("lookaside", "publishurl"))
194     zsub.setsockopt(zmq.SUBSCRIBE, b"")
195     tcpl = socket(AF_INET6, SOCK_STREAM)
196     tcpl.setblocking(False)
197     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
198     tcpl.bind(("", conf.getint("wsgateway", "port")))
199     tcpl.listen(5)
200     tcpfd = tcpl.fileno()
201     poller = zmq.Poller()
202     poller.register(zsub, flags=zmq.POLLIN)
203     poller.register(tcpfd, flags=zmq.POLLIN)
204     clients = Clients()
205     try:
206         while True:
207             tosend = []
208             topoll = []
209             tostop = []
210             events = poller.poll(1000)
211             for sk, fl in events:
212                 if sk is zsub:
213                     while True:
214                         try:
215                             msg = zsub.recv(zmq.NOBLOCK)
216                             tosend.append(LocEvt(msg))
217                         except zmq.Again:
218                             break
219                 elif sk == tcpfd:
220                     clntsock, clntaddr = tcpl.accept()
221                     topoll.append((clntsock, clntaddr))
222                 elif fl & zmq.POLLIN:
223                     received = clients.recv(sk)
224                     if received is None:
225                         log.debug("Client gone from fd %d", sk)
226                         tostop.append(sk)
227                     else:
228                         for msg in received:
229                             log.debug("Received from %d: %s", sk, msg)
230                 else:
231                     log.debug("Stray event: %s on socket %s", fl, sk)
232             # poll queue consumed, make changes now
233             for fd in tostop:
234                 poller.unregister(fd)
235                 clients.stop(fd)
236             for zmsg in tosend:
237                 log.debug("Sending to the client: %s", zmsg)
238                 clients.send(zmsg)
239             for clntsock, clntaddr in topoll:
240                 fd = clients.add(clntsock, clntaddr)
241                 poller.register(fd, flags=zmq.POLLIN)
242             # TODO: Handle write overruns (register for POLLOUT)
243     except KeyboardInterrupt:
244         pass
245
246
247 if __name__.endswith("__main__"):
248     runserver(common.init(log))