]> www.average.org Git - loctrkd.git/blob - test/common.py
test: better aquisition of free ports
[loctrkd.git] / test / common.py
1 """ Common housekeeping for tests that rely on daemons """
2
3 from configparser import ConfigParser, SectionProxy
4 from contextlib import closing, ExitStack
5 from importlib import import_module
6 from multiprocessing import Process
7 from os import kill, unlink
8 from signal import SIGINT
9 from socket import (
10     AF_INET6,
11     MSG_DONTWAIT,
12     SOCK_DGRAM,
13     SOL_SOCKET,
14     SO_REUSEADDR,
15     socket,
16     SocketType,
17 )
18 from tempfile import mkstemp
19 from time import sleep
20 from typing import Optional
21 from unittest import TestCase
22
23 NUMPORTS = 3
24
25
26 class TestWithServers(TestCase):
27     def setUp(self, *args: str, httpd: bool = False) -> None:
28         freeports = []
29         with ExitStack() as stack:
30             for _ in range(NUMPORTS):
31                 sk = stack.enter_context(closing(socket(AF_INET6, SOCK_DGRAM)))
32                 sk.bind(("", 0))
33                 sk.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
34                 freeports.append(sk.getsockname()[1])
35         _, self.tmpfilebase = mkstemp()
36         self.conf = ConfigParser()
37         self.conf["collector"] = {
38             "port": str(freeports[0]),
39             "publishurl": "ipc://" + self.tmpfilebase + ".pub",
40             "listenurl": "ipc://" + self.tmpfilebase + ".pul",
41         }
42         self.conf["storage"] = {
43             "dbfn": self.tmpfilebase + ".storage.sqlite",
44         }
45         self.conf["opencellid"] = {
46             "dbfn": self.tmpfilebase + ".opencellid.sqlite",
47         }
48         self.conf["lookaside"] = {
49             "backend": "opencellid",
50         }
51         self.conf["wsgateway"] = {
52             "port": str(freeports[1]),
53         }
54         self.children = []
55         for srvname in args:
56             if srvname == "collector":
57                 kwargs = {"handle_hibernate": False}
58             else:
59                 kwargs = {}
60             cls = import_module("gps303." + srvname, package=".")
61             p = Process(target=cls.runserver, args=(self.conf,), kwargs=kwargs)
62             p.start()
63             self.children.append((srvname, p))
64         if httpd:
65             pass
66         sleep(1)
67
68     def tearDown(self) -> None:
69         for srvname, p in self.children:
70             if p.pid is not None:
71                 kill(p.pid, SIGINT)
72             p.join()
73             self.assertEqual(
74                 p.exitcode,
75                 0,
76                 srvname + " terminated with non-zero return code",
77             )
78         for sfx in (
79             "",
80             ".pub",
81             ".pul",
82             ".storage.sqlite",
83             ".opencellid.sqlite",
84         ):
85             try:
86                 unlink(self.tmpfilebase + sfx)
87             except OSError:
88                 pass
89
90
91 def send_and_drain(sock: SocketType, buf: Optional[bytes]) -> None:
92     if buf is not None:
93         sock.send(buf)
94     try:
95         sock.recv(4096, MSG_DONTWAIT)
96     except BlockingIOError:
97         pass