Make consul watching more robust (#6148)

* Retry if failed to start watching
* Reset backoff after recovered
* Add rate limit, and always retry after response
* Add HA event metrics
* Allow passing through connparams over HA DSN query
This commit is contained in:
Fantix King 2023-10-05 02:25:17 +09:00 committed by GitHub
parent 0dcfdd0589
commit e6a6d69b9e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 199 additions and 14 deletions

View file

@ -20,10 +20,14 @@ from __future__ import annotations
from typing import *
import asyncio
import logging
from . import retryloop
logger = logging.getLogger("edb.server.asyncwatcher")
class AsyncWatcherProtocol(asyncio.Protocol):
def __init__(
self,
@ -37,6 +41,7 @@ class AsyncWatcherProtocol(asyncio.Protocol):
self.request()
def connection_lost(self, exc: Optional[Exception]) -> None:
self._watcher.incr_metrics_counter("watch-disconnect")
self._watcher.on_connection_lost()
def request(self) -> None:
@ -62,6 +67,7 @@ class AsyncWatcher:
self._protocol = await self._start_watching()
return True
except BaseException:
self.incr_metrics_counter("watch-start-err")
self._watching = False
raise
return False
@ -70,7 +76,15 @@ class AsyncWatcher:
self._retry_attempt += 1
delay = self._backoff(self._retry_attempt)
await asyncio.sleep(delay)
await self.start_watching()
try:
await self.start_watching()
except Exception:
logger.warning(
"%s failed to start watching, will retry.",
type(self).__name__,
exc_info=True,
)
asyncio.create_task(self.retry_watching())
def stop_watching(self) -> None:
self._watching = False
@ -94,7 +108,19 @@ class AsyncWatcher:
waiter.set_result(None)
def on_update(self, data: bytes) -> None:
self._retry_attempt = 0
self._on_update(data)
def _on_update(self, data: bytes) -> None:
raise NotImplementedError
async def _start_watching(self) -> AsyncWatcherProtocol:
raise NotImplementedError
def consume_tokens(self, tokens: int) -> float:
# For rate limit - tries to consume the given number of tokens, returns
# non-zero values as seconds to wait if unsuccessful
return 0
def incr_metrics_counter(self, event: str, value: float = 1.0) -> None:
pass

View file

@ -145,7 +145,7 @@ class Registry:
desc: str,
/,
*,
labels: tuple[str],
labels: tuple[str, ...],
unit: Unit | None = None,
) -> LabeledCounter:
counter = LabeledCounter(self, name, desc, unit, labels=labels)

View file

@ -0,0 +1,46 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2023-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
class TokenBucket:
_capacity: float
_token_per_sec: float
_tokens: float
_last_fill_time: float
def __init__(self, capacity: float, token_per_sec: float):
self._capacity = capacity
self._token_per_sec = token_per_sec
self._tokens = capacity
self._last_fill_time = time.monotonic()
def consume(self, tokens: int) -> float:
if tokens <= 0:
return True
now = time.monotonic()
tokens_to_add = (now - self._last_fill_time) * self._token_per_sec
self._tokens = min(self._capacity, self._tokens + tokens_to_add)
self._last_fill_time = now
left = self._tokens - tokens
if left >= 0:
self._tokens -= tokens
return 0
else:
return -left / (tokens * self._token_per_sec)

View file

@ -19,8 +19,10 @@
from __future__ import annotations
from typing import *
import asyncio
import json
import logging
import random
import urllib.parse
import httptools
@ -66,17 +68,31 @@ class ConsulKVProtocol(asyncwatcher.AsyncWatcherProtocol):
def on_message_complete(self) -> None:
try:
if self._parser.get_status_code() == 200:
code = self._parser.get_status_code()
if code == 200:
self._watcher.incr_metrics_counter("watch-update")
payload = json.loads(b"".join(self._buffers))[0]
last_modify_index = payload["ModifyIndex"]
self._watcher.on_update(payload["Value"])
if self._last_modify_index != last_modify_index:
if self._last_modify_index == last_modify_index:
self._watcher.incr_metrics_counter("watch-timeout")
self._last_modify_index = None
else:
self._last_modify_index = last_modify_index
self.request()
else:
self._watcher.incr_metrics_counter(f"watch-err-{code}")
self.request()
finally:
self._buffers.clear()
def request(self) -> None:
delay = self._watcher.consume_tokens(1)
if delay > 0:
asyncio.get_running_loop().call_later(
delay + random.random() * 0.1, self.request
)
return
uri = urllib.parse.urljoin("/v1/kv/", self._key)
if self._last_modify_index is not None:
uri += f"?index={self._last_modify_index}"

View file

@ -23,6 +23,8 @@ import enum
import logging
import os
from edb.server import metrics
from . import base
@ -95,25 +97,31 @@ class AdaptiveHASupport:
_state: State
_unhealthy_timer_handle: Optional[asyncio.TimerHandle]
def __init__(self, cluster_protocol: base.ClusterProtocol):
def __init__(self, cluster_protocol: base.ClusterProtocol, tag: str):
self._cluster_protocol = cluster_protocol
self._state = State.UNHEALTHY
self._pgcon_count = 0
self._unexpected_disconnects = 0
self._unhealthy_timer_handle = None
self._sys_pgcon_healthy = False
self._tag = tag
def incr_metrics_counter(self, event: str, value: float = 1.0) -> None:
metrics.ha_events_total.inc(value, f"adaptive://{self._tag}", event)
def set_state_failover(self, *, call_on_switch_over=True):
self._state = State.FAILOVER
self._reset()
if call_on_switch_over:
logger.critical("adaptive: HA failover detected")
self.incr_metrics_counter("failover")
self._cluster_protocol.on_switch_over()
def on_pgcon_broken(self, is_sys_pgcon: bool):
if is_sys_pgcon:
self._sys_pgcon_healthy = False
if self._state == State.HEALTHY:
self.incr_metrics_counter("unhealthy")
self._state = State.UNHEALTHY
self._unexpected_disconnects = 1
self._unhealthy_timer_handle = (
@ -148,11 +156,13 @@ class AdaptiveHASupport:
if is_sys_pgcon:
self._sys_pgcon_healthy = True
if self._state == State.UNHEALTHY:
self.incr_metrics_counter("healthy")
self._state = State.HEALTHY
logger.info("adaptive: Backend HA cluster is healthy")
self._reset()
elif self._state == State.FAILOVER:
if self._sys_pgcon_healthy:
self.incr_metrics_counter("healthy")
self._state = State.HEALTHY
logger.info(
"adaptive: Backend HA cluster has recovered from failover"

View file

@ -22,6 +22,7 @@ from typing import *
import urllib.parse
from edb.common import asyncwatcher
from edb.server import metrics
class ClusterProtocol:
@ -50,6 +51,9 @@ class HABackend(asyncwatcher.AsyncWatcher):
def dsn(self) -> str:
raise NotImplementedError
def incr_metrics_counter(self, event: str, value: float = 1.0) -> None:
metrics.ha_events_total.inc(value, self.dsn, event)
def get_backend(parsed_dsn: urllib.parse.ParseResult) -> Optional[HABackend]:
backend, _, sub_scheme = parsed_dsn.scheme.partition("+")

View file

@ -24,10 +24,12 @@ import base64
import functools
import json
import logging
import os
import ssl
import urllib.parse
from edb.common import asyncwatcher
from edb.common import token_bucket
from edb.server import consul
from . import base
@ -59,7 +61,7 @@ class StolonBackend(base.HABackend):
def get_master_addr(self) -> Optional[Tuple[str, int]]:
return self._master_addr
def on_update(self, payload: bytes) -> None:
def _on_update(self, payload: bytes) -> None:
try:
data = json.loads(base64.b64decode(payload))
except (TypeError, ValueError):
@ -103,6 +105,7 @@ class StolonBackend(base.HABackend):
)
self._master_addr = master_addr
if self._failover_cb is not None:
self.incr_metrics_counter("failover")
self._failover_cb()
if self._waiter is not None:
@ -126,6 +129,13 @@ class StolonConsulBackend(StolonBackend):
self._port = port
self._ssl = ssl
# This means we can request for 10 consecutive requests immediately
# after each response without delay, and then we're capped to 0.1
# request(token) per second, or 1 request per 10 seconds.
cap = float(os.environ.get("EDGEDB_SERVER_CONSUL_TOKEN_CAPACITY", 10))
rate = float(os.environ.get("EDGEDB_SERVER_CONSUL_TOKEN_RATE", 0.1))
self._token_bucket = token_bucket.TokenBucket(cap, rate)
async def _start_watching(self) -> asyncwatcher.AsyncWatcherProtocol:
_, pr = await asyncio.get_running_loop().create_connection(
functools.partial(
@ -140,7 +150,7 @@ class StolonConsulBackend(StolonBackend):
)
return pr # type: ignore [return-value]
@property
@functools.cached_property
def dsn(self) -> str:
proto = "http" if self._ssl is None else "https"
return (
@ -148,6 +158,9 @@ class StolonConsulBackend(StolonBackend):
f"{self._host}:{self._port}/{self._cluster_name}"
)
def consume_tokens(self, tokens: int) -> float:
return self._token_bucket.consume(tokens)
def get_backend(
sub_scheme: str, parsed_dsn: urllib.parse.ParseResult

View file

@ -97,3 +97,9 @@ background_errors = registry.new_labeled_counter(
'Number of unhandled errors in background server routines.',
labels=('source',)
)
ha_events_total = registry.new_labeled_counter(
"ha_events_total",
"Number of each high-availability watch event.",
labels=("dsn", "event"),
)

View file

@ -885,6 +885,24 @@ async def get_remote_pg_cluster(
addr = await ha_backend.get_cluster_consensus()
dsn = 'postgresql://{}:{}'.format(*addr)
if parsed.query:
# Allow passing through Postgres connection parameters from the HA
# backend DSN as "pg" prefixed query strings. For example, an HA
# backend DSN with `?pgpassword=123` will result an actual backend
# DSN with `?password=123`. They have higher priority than the `PG`
# prefixed environment variables like `PGPASSWORD`.
pq = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
query = {}
for k, v in pq.items():
if k.startswith("pg") and k not in ["pghost", "pgport"]:
if isinstance(v, list):
val = v[-1]
else:
val = cast(str, v)
query[k[2:]] = val
if query:
dsn += f"?{urllib.parse.urlencode(query)}"
addrs, params = pgconnparams.parse_dsn(dsn)
if len(addrs) > 1:
raise ValueError('multiple hosts in Postgres DSN are not supported')

View file

@ -135,7 +135,9 @@ class Tenant(ha_base.ClusterProtocol):
# Increase-only counter to reject outdated attempts to connect
self._ha_master_serial = 0
if backend_adaptive_ha:
self._backend_adaptive_ha = adaptive_ha.AdaptiveHASupport(self)
self._backend_adaptive_ha = adaptive_ha.AdaptiveHASupport(
self, self._instance_name
)
else:
self._backend_adaptive_ha = None
self._readiness_state_file = readiness_state_file

View file

@ -0,0 +1,48 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2023-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
import unittest.mock
from edb.common.token_bucket import TokenBucket
class ManualClock:
def __init__(self, value: float) -> None:
self.value = value
def __call__(self) -> float:
return self.value
class WindowedSumTests(unittest.TestCase):
def test_common_token_bucket(self) -> None:
monotonic = ManualClock(0)
with unittest.mock.patch("time.monotonic", monotonic):
tb = TokenBucket(10, 0.1)
self.assertEqual(tb.consume(5), 0)
monotonic.value += 12
self.assertEqual(tb.consume(6), 0)
self.assertGreater(tb.consume(1), 0)
self.assertGreater(tb.consume(2), tb.consume(1))
monotonic.value += 30
self.assertEqual(tb.consume(2), 0)
self.assertEqual(tb.consume(1), 0)
self.assertGreater(tb.consume(1), 0)

View file

@ -503,13 +503,9 @@ class TestBackendHA(tb.TestCase):
backend_dsn=(
f"stolon+consul+http://127.0.0.1:{consul.http_port}"
f"/{pg1.cluster_name}"
f"?pguser=suname&pgpassword=supass&pgdatabase=postgres"
),
runstate_dir=str(pathlib.Path(consul.tmp_dir.name) / "edb"),
env=dict(
PGUSER="suname",
PGPASSWORD="supass",
PGDATABASE="postgres",
),
reset_auth=True,
debug=debug,
) as sd: