]> www.average.org Git - loctrkd.git/blob - gps303/wsgateway.py
WIP on websocket gateway
[loctrkd.git] / gps303 / wsgateway.py
1 """ Websocket Gateway """
2
3 from logging import getLogger
4 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
5 from time import time
6 from wsproto import ConnectionType, WSConnection
7 from wsproto.events import AcceptConnection, CloseConnection, Message, Ping, Request, TextMessage
8 import zmq
9
10 from . import common
11 from .zmsg import LocEvt
12
13 log = getLogger("gps303/wsgateway")
14
15
16 class Client:
17     """Websocket connection to the client"""
18
19     def __init__(self, sock, addr):
20         self.sock = sock
21         self.addr = addr
22         self.ws = WSConnection(ConnectionType.SERVER)
23         self.ws_data = b""
24
25     def close(self):
26         log.debug("Closing fd %d", self.sock.fileno())
27         self.sock.close()
28
29     def recv(self):
30         try:
31             data = self.sock.recv(4096)
32         except OSError:
33             log.warning(
34                 "Reading from fd %d: %s",
35                 self.sock.fileno(),
36                 e,
37             )
38             self.ws.receive_data(None)
39             return None
40         if not data:  # Client has closed connection
41             log.info(
42                 "EOF reading from fd %d",
43                 self.sock.fileno(),
44             )
45             self.ws.receive_data(None)
46             return None
47         self.ws.receive_data(data)
48         msgs = []
49         for event in self.ws.events():
50             if isinstance(event, Request):
51                 log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
52                 #self.ws_data += self.ws.send(event.response())  # Why not?!
53                 self.ws_data += self.ws.send(AcceptConnection())
54             elif isinstance(event, (CloseConnection, Ping)):
55                 log.debug("%s on fd %d", event, self.sock.fileno())
56                 self.ws_data += self.ws.send(event.response())
57             elif isinstance(event, TextMessage):
58                 # TODO: save imei "subscription"
59                 log.debug("%s on fd %d", event, self.sock.fileno())
60                 msgs.append(event.data)
61             else:
62                 log.warning("%s on fd %d", event, self.sock.fileno())
63         if self.ws_data:  # Temp hack
64             self.write()
65         return msgs
66
67     def send(self, imei, message):
68         # TODO: filter only wanted imei got from the client
69         self.ws_data += self.ws.send(Message(data=message))
70
71     def write(self):
72         try:
73             sent = self.sock.send(self.ws_data)
74             self.ws_data = self.ws_data[sent:]
75         except OSError as e:
76             log.error(
77                 "Sending to fd %d (IMEI %s): %s",
78                 self.sock.fileno(),
79                 self.imei,
80                 e,
81             )
82             self.ws_data = b""
83
84
85 class Clients:
86     def __init__(self):
87         self.by_fd = {}
88
89     def add(self, clntsock, clntaddr):
90         fd = clntsock.fileno()
91         log.info("Start serving fd %d from %s", fd, clntaddr)
92         self.by_fd[fd] = Client(clntsock, clntaddr)
93         return fd
94
95     def stop(self, fd):
96         clnt = self.by_fd[fd]
97         log.info("Stop serving fd %d", clnt.sock.fileno())
98         clnt.close()
99         del self.by_fd[fd]
100
101     def recv(self, fd):
102         clnt = self.by_fd[fd]
103         msgs = clnt.recv()
104         if msgs is None:
105             return None
106         result = []
107         for msg in msgs:
108             log.debug("Received: %s", msg)
109         return result
110
111     def send(self, msgs):
112         for clnt in self.by_fd.values():
113             clnt.send(msgs)
114             clnt.write()
115
116 def runserver(conf):
117     zctx = zmq.Context()
118     zsub = zctx.socket(zmq.SUB)
119     zsub.connect(conf.get("lookaside", "publishurl"))
120     zsub.setsockopt(zmq.SUBSCRIBE, b"")
121     tcpl = socket(AF_INET6, SOCK_STREAM)
122     tcpl.setblocking(False)
123     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
124     tcpl.bind(("", conf.getint("wsgateway", "port")))
125     tcpl.listen(5)
126     tcpfd = tcpl.fileno()
127     poller = zmq.Poller()
128     poller.register(zsub, flags=zmq.POLLIN)
129     poller.register(tcpfd, flags=zmq.POLLIN)
130     clients = Clients()
131     try:
132         while True:
133             tosend = []
134             topoll = []
135             tostop = []
136             events = poller.poll(1000)
137             for sk, fl in events:
138                 if sk is zsub:
139                     while True:
140                         try:
141                             msg = zsub.recv(zmq.NOBLOCK)
142                             tosend.append(LocEvt(msg))
143                         except zmq.Again:
144                             break
145                 elif sk == tcpfd:
146                     clntsock, clntaddr = tcpl.accept()
147                     topoll.append((clntsock, clntaddr))
148                 elif fl & zmq.POLLIN:
149                     received = clients.recv(sk)
150                     if received is None:
151                         log.debug(
152                             "Client gone from fd %d", sk
153                         )
154                         tostop.append(sk)
155                     else:
156                         for msg in received:
157                             log.debug("Received from %d: %s", sk, msg)
158                 else:
159                     log.debug("Stray event: %s on socket %s", fl, sk)
160             # poll queue consumed, make changes now
161             for fd in tostop:
162                 poller.unregister(fd)
163                 clients.stop(fd)
164             for zmsg in tosend:
165                 log.debug("Sending to the client: %s", zmsg)
166                 clients.send(zmsg)
167             for clntsock, clntaddr in topoll:
168                 fd = clients.add(clntsock, clntaddr)
169                 poller.register(fd, flags=zmq.POLLIN)
170             # TODO: Handle write overruns (register for POLLOUT)
171     except KeyboardInterrupt:
172         pass
173
174
175 if __name__.endswith("__main__"):
176     runserver(common.init(log))