Apply listen_addresses correctly with wildcard (#4137)

* Shutdown servers first if old/new addresses has wildcard
* Fix ISE with IPv6 addresses
* Raise a proper error if the newly-configured listen_* doesn't work

Fixes #3971
This commit is contained in:
Fantix King 2022-07-25 17:25:05 -04:00 committed by GitHub
parent 52734404e0
commit af09a24315
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 98 additions and 37 deletions

View file

@ -896,8 +896,11 @@ class Server(ha_base.ClusterProtocol):
async def _restart_servers_new_addr(self, nethosts, netport):
if not netport:
raise RuntimeError('cannot restart without network port specified')
nethosts = await _resolve_interfaces(nethosts)
nethosts, has_ipv4_wc, has_ipv6_wc = await _resolve_interfaces(
nethosts
)
servers_to_stop = []
servers_to_stop_early = []
servers = {}
if self._listen_port == netport:
hosts_to_start = [
@ -906,7 +909,25 @@ class Server(ha_base.ClusterProtocol):
for host, srv in self._servers.items():
if host == ADMIN_PLACEHOLDER or host in nethosts:
servers[host] = srv
elif host in ['::', '0.0.0.0']:
servers_to_stop_early.append(srv)
else:
if has_ipv4_wc:
try:
ipaddress.IPv4Address(host)
except ValueError:
pass
else:
servers_to_stop_early.append(srv)
continue
if has_ipv6_wc:
try:
ipaddress.IPv6Address(host)
except ValueError:
pass
else:
servers_to_stop_early.append(srv)
continue
servers_to_stop.append(srv)
admin = False
else:
@ -914,17 +935,29 @@ class Server(ha_base.ClusterProtocol):
servers_to_stop = self._servers.values()
admin = True
if servers_to_stop_early:
await self._stop_servers_with_logging(servers_to_stop_early)
if hosts_to_start:
new_servers, *_ = await self._start_servers(
hosts_to_start,
netport,
admin=admin,
)
servers.update(new_servers)
try:
new_servers, *_ = await self._start_servers(
hosts_to_start,
netport,
admin=admin,
)
servers.update(new_servers)
except StartupError:
raise errors.ConfigurationError(
'Server updated its config but cannot serve on requested '
'address/port, please see server log for more information.'
)
self._servers = servers
self._listen_hosts = nethosts
self._listen_port = netport
await self._stop_servers_with_logging(servers_to_stop)
async def _stop_servers_with_logging(self, servers_to_stop):
addrs = []
unix_addr = None
port = None
@ -932,7 +965,7 @@ class Server(ha_base.ClusterProtocol):
for s in srv.sockets:
addr = s.getsockname()
if isinstance(addr, tuple):
addrs.append(addr)
addrs.append(addr[:2])
if port is None:
port = addr[1]
elif port != addr[1]:
@ -1634,7 +1667,7 @@ class Server(ha_base.ClusterProtocol):
)
self._servers, actual_port, listen_addrs = await self._start_servers(
await _resolve_interfaces(self._listen_hosts),
(await _resolve_interfaces(self._listen_hosts))[0],
self._listen_port,
sockets=self._listen_sockets,
)
@ -1851,7 +1884,7 @@ class Server(ha_base.ClusterProtocol):
def _cleanup_wildcard_addrs(
hosts: Sequence[str]
) -> tuple[list[str], list[str]]:
) -> tuple[list[str], list[str], bool, bool]:
"""Filter out conflicting addresses in presence of INADDR_ANY wildcards.
Attempting to bind to 0.0.0.0 (or ::) _and_ a non-wildcard address will
@ -1894,28 +1927,36 @@ def _cleanup_wildcard_addrs(
named_hosts.add(host)
if not ipv4_hosts and not ipv6_hosts:
return (list(hosts), [])
return (list(hosts), [], False, False)
if ipv4_wc not in ipv4_hosts and ipv6_wc not in ipv6_hosts:
return (list(hosts), [])
return (list(hosts), [], False, False)
if ipv4_wc in ipv4_hosts and ipv6_wc in ipv6_hosts:
return (
['0.0.0.0', '::'],
[str(a) for a in
((named_hosts | ipv4_hosts | ipv6_hosts) - {ipv4_wc, ipv6_wc})]
[
str(a) for a in
((named_hosts | ipv4_hosts | ipv6_hosts) - {ipv4_wc, ipv6_wc})
],
True,
True,
)
if ipv4_wc in ipv4_hosts:
return (
[str(a) for a in ({ipv4_wc} | ipv6_hosts)],
[str(a) for a in ((named_hosts | ipv4_hosts) - {ipv4_wc})]
[str(a) for a in ((named_hosts | ipv4_hosts) - {ipv4_wc})],
True,
False,
)
if ipv6_wc in ipv6_hosts:
return (
[str(a) for a in ({ipv6_wc} | ipv4_hosts)],
[str(a) for a in ((named_hosts | ipv6_hosts) - {ipv6_wc})]
[str(a) for a in ((named_hosts | ipv6_hosts) - {ipv6_wc})],
False,
True,
)
raise AssertionError('unreachable')
@ -1937,7 +1978,9 @@ async def _resolve_host(host: str) -> list[str] | Exception:
return [addr[4][0] for addr in addrinfo]
async def _resolve_interfaces(hosts: Sequence[str]) -> Sequence[str]:
async def _resolve_interfaces(
hosts: Sequence[str]
) -> Tuple[Sequence[str], bool, bool]:
async with taskgroup.TaskGroup() as g:
resolve_tasks = {
@ -1954,7 +1997,9 @@ async def _resolve_interfaces(hosts: Sequence[str]) -> Sequence[str]:
else:
addrs.extend(result)
clean_addrs, rejected_addrs = _cleanup_wildcard_addrs(addrs)
(
clean_addrs, rejected_addrs, has_ipv4_wc, has_ipv6_wc
) = _cleanup_wildcard_addrs(addrs)
if rejected_addrs:
logger.warning(
@ -1963,4 +2008,4 @@ async def _resolve_interfaces(hosts: Sequence[str]) -> Sequence[str]:
", ".join(repr(h) for h in rejected_addrs)
)
return clean_addrs
return clean_addrs, has_ipv4_wc, has_ipv6_wc

View file

@ -1184,9 +1184,6 @@ class TestSeparateCluster(tb.TestCase):
for i, con in enumerate((con1, con2, con3)):
self.assertEqual(await con.query_single(f"SELECT {i}"), i)
await con1.execute("""
CONFIGURE INSTANCE SET listen_addresses := <str>{};
""")
await con1.execute("""
CONFIGURE INSTANCE SET listen_addresses := {
'0.0.0.0',

View file

@ -28,64 +28,83 @@ class TestServerUnittests(unittest.TestCase):
CASES = [
(
['*'],
(['0.0.0.0', '::'], [])
(['0.0.0.0', '::'], []),
(True, True),
),
(
['0.0.0.0', '::0'],
(['0.0.0.0', '::'], [])
(['0.0.0.0', '::'], []),
(True, True),
),
(
['0.0.0.0', '127.0.0.1', '2001:db8::8a2e:370:7334', '::0'],
(['0.0.0.0', '::'], ['127.0.0.1', '2001:db8::8a2e:370:7334'])
(['0.0.0.0', '::'], ['127.0.0.1', '2001:db8::8a2e:370:7334']),
(True, True),
),
(
['0.0.0.0', 'example.com', '2001:db8::8a2e:370:7334', '::0'],
(['0.0.0.0', '::'], ['example.com', '2001:db8::8a2e:370:7334'])
(
['0.0.0.0', '::'],
['example.com', '2001:db8::8a2e:370:7334'],
),
(True, True),
),
(
['127.0.0.1', 'example.com', '2001:db8::8a2e:370:7334', '::0'],
(
['127.0.0.1', '::'],
['example.com', '2001:db8::8a2e:370:7334']
)
),
(False, True),
),
(
['example.com', '2001:db8::8a2e:370:7334', '::0'],
(['::'], ['example.com', '2001:db8::8a2e:370:7334'])
(['::'], ['example.com', '2001:db8::8a2e:370:7334']),
(False, True),
),
(
['example.com', 'sub.example.com'],
(['example.com', 'sub.example.com'], [])
(['example.com', 'sub.example.com'], []),
(False, False),
),
(
['example.com', '127.0.0.1'],
(['example.com', '127.0.0.1'], [])
(['example.com', '127.0.0.1'], []),
(False, False),
),
(
['example.com', '::1'],
(['::1', 'example.com'], [])
(['::1', 'example.com'], []),
(False, False),
),
(
['example.com', '::'],
(['::'], ['example.com'])
(['::'], ['example.com']),
(False, True),
),
(
['example.com', '::1', '127.0.0.1'],
(['example.com', '::1', '127.0.0.1'], [])
(['example.com', '::1', '127.0.0.1'], []),
(False, False),
),
(
['0.0.0.0', '2001:db8::8a2e:370:7334'],
(['0.0.0.0', '2001:db8::8a2e:370:7334'], [])
(['0.0.0.0', '2001:db8::8a2e:370:7334'], []),
(True, False),
),
(
['127.0.0.1', '2001:db8::8a2e:370:7334', '::'],
(['127.0.0.1', '::'], ['2001:db8::8a2e:370:7334'])
(['127.0.0.1', '::'], ['2001:db8::8a2e:370:7334']),
(False, True),
),
]
for hosts, expected in CASES:
new_hosts, rej_hosts = server._cleanup_wildcard_addrs(hosts)
for hosts, expected, expected_wildcard in CASES:
(
new_hosts, rej_hosts, *has_wildcards
) = server._cleanup_wildcard_addrs(hosts)
self.assertEqual(
(set(new_hosts), set(rej_hosts)),
(set(expected[0]), set(expected[1]))
)
self.assertEqual(tuple(has_wildcards), expected_wildcard)