]> www.average.org Git - loctrkd.git/blob - test/common.py
test: opencellid downloader
[loctrkd.git] / test / common.py
1 """ Common housekeeping for tests that rely on daemons """
2
3 from configparser import ConfigParser, SectionProxy
4 from contextlib import closing, ExitStack
5 from http.server import HTTPServer, SimpleHTTPRequestHandler
6 from importlib import import_module
7 from multiprocessing import Process
8 from os import kill, unlink
9 from signal import SIGINT
10 from socket import (
11     AF_INET6,
12     MSG_DONTWAIT,
13     SOCK_DGRAM,
14     SOL_SOCKET,
15     SO_REUSEADDR,
16     socket,
17     SocketType,
18 )
19 from sys import exit
20 from tempfile import mkstemp
21 from time import sleep
22 from typing import Optional
23 from unittest import TestCase
24
25 NUMPORTS = 3
26
27
28 class TestWithServers(TestCase):
29     def setUp(self, *args: str, httpd: bool = False) -> None:
30         freeports = []
31         with ExitStack() as stack:
32             for _ in range(NUMPORTS):
33                 sk = stack.enter_context(closing(socket(AF_INET6, SOCK_DGRAM)))
34                 sk.bind(("", 0))
35                 sk.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
36                 freeports.append(sk.getsockname()[1])
37         _, self.tmpfilebase = mkstemp()
38         self.conf = ConfigParser()
39         self.conf["collector"] = {
40             "port": str(freeports[0]),
41             "publishurl": "ipc://" + self.tmpfilebase + ".pub",
42             "listenurl": "ipc://" + self.tmpfilebase + ".pul",
43         }
44         self.conf["storage"] = {
45             "dbfn": self.tmpfilebase + ".storage.sqlite",
46         }
47         self.conf["opencellid"] = {
48             "dbfn": self.tmpfilebase + ".opencellid.sqlite",
49             "downloadurl": f"http://localhost:{freeports[2]}/test/262.csv.gz",
50         }
51         self.conf["lookaside"] = {
52             "backend": "opencellid",
53         }
54         self.conf["wsgateway"] = {
55             "port": str(freeports[1]),
56         }
57         self.children = []
58         for srvname in args:
59             if srvname == "collector":
60                 kwargs = {"handle_hibernate": False}
61             else:
62                 kwargs = {}
63             cls = import_module("gps303." + srvname, package=".")
64             p = Process(target=cls.runserver, args=(self.conf,), kwargs=kwargs)
65             p.start()
66             self.children.append((srvname, p))
67         if httpd:
68             server = HTTPServer(("", freeports[2]), SimpleHTTPRequestHandler)
69
70             def run(server):
71                 try:
72                     server.serve_forever()
73                 except KeyboardInterrupt:
74                     # TODO: this still leaves unclosed socket in the server
75                     server.shutdown()
76
77             p = Process(target=run, args=(server,))
78             p.start()
79             self.children.append(("httpd", p))
80         sleep(1)
81
82     def tearDown(self) -> None:
83         for srvname, p in self.children:
84             if p.pid is not None:
85                 kill(p.pid, SIGINT)
86             p.join()
87             self.assertEqual(
88                 p.exitcode,
89                 0,
90                 srvname + " terminated with non-zero return code",
91             )
92         for sfx in (
93             "",
94             ".pub",
95             ".pul",
96             ".storage.sqlite",
97             ".opencellid.sqlite",
98         ):
99             try:
100                 unlink(self.tmpfilebase + sfx)
101             except OSError:
102                 pass
103
104
105 def send_and_drain(sock: SocketType, buf: Optional[bytes]) -> None:
106     if buf is not None:
107         sock.send(buf)
108     try:
109         sock.recv(4096, MSG_DONTWAIT)
110     except BlockingIOError:
111         pass