mirror of
https://github.com/maxkratz/edgedb.git
synced 2024-09-16 18:59:05 +00:00
Make consul watching more robust (#6148)
* Retry if failed to start watching * Reset backoff after recovered * Add rate limit, and always retry after response * Add HA event metrics * Allow passing through connparams over HA DSN query
This commit is contained in:
parent
0dcfdd0589
commit
e6a6d69b9e
12 changed files with 199 additions and 14 deletions
|
@ -20,10 +20,14 @@ from __future__ import annotations
|
|||
from typing import *
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from . import retryloop
|
||||
|
||||
|
||||
logger = logging.getLogger("edb.server.asyncwatcher")
|
||||
|
||||
|
||||
class AsyncWatcherProtocol(asyncio.Protocol):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -37,6 +41,7 @@ class AsyncWatcherProtocol(asyncio.Protocol):
|
|||
self.request()
|
||||
|
||||
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||
self._watcher.incr_metrics_counter("watch-disconnect")
|
||||
self._watcher.on_connection_lost()
|
||||
|
||||
def request(self) -> None:
|
||||
|
@ -62,6 +67,7 @@ class AsyncWatcher:
|
|||
self._protocol = await self._start_watching()
|
||||
return True
|
||||
except BaseException:
|
||||
self.incr_metrics_counter("watch-start-err")
|
||||
self._watching = False
|
||||
raise
|
||||
return False
|
||||
|
@ -70,7 +76,15 @@ class AsyncWatcher:
|
|||
self._retry_attempt += 1
|
||||
delay = self._backoff(self._retry_attempt)
|
||||
await asyncio.sleep(delay)
|
||||
await self.start_watching()
|
||||
try:
|
||||
await self.start_watching()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"%s failed to start watching, will retry.",
|
||||
type(self).__name__,
|
||||
exc_info=True,
|
||||
)
|
||||
asyncio.create_task(self.retry_watching())
|
||||
|
||||
def stop_watching(self) -> None:
|
||||
self._watching = False
|
||||
|
@ -94,7 +108,19 @@ class AsyncWatcher:
|
|||
waiter.set_result(None)
|
||||
|
||||
def on_update(self, data: bytes) -> None:
|
||||
self._retry_attempt = 0
|
||||
self._on_update(data)
|
||||
|
||||
def _on_update(self, data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _start_watching(self) -> AsyncWatcherProtocol:
|
||||
raise NotImplementedError
|
||||
|
||||
def consume_tokens(self, tokens: int) -> float:
|
||||
# For rate limit - tries to consume the given number of tokens, returns
|
||||
# non-zero values as seconds to wait if unsuccessful
|
||||
return 0
|
||||
|
||||
def incr_metrics_counter(self, event: str, value: float = 1.0) -> None:
|
||||
pass
|
||||
|
|
|
@ -145,7 +145,7 @@ class Registry:
|
|||
desc: str,
|
||||
/,
|
||||
*,
|
||||
labels: tuple[str],
|
||||
labels: tuple[str, ...],
|
||||
unit: Unit | None = None,
|
||||
) -> LabeledCounter:
|
||||
counter = LabeledCounter(self, name, desc, unit, labels=labels)
|
||||
|
|
46
edb/common/token_bucket.py
Normal file
46
edb/common/token_bucket.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
#
|
||||
# This source file is part of the EdgeDB open source project.
|
||||
#
|
||||
# Copyright 2023-present MagicStack Inc. and the EdgeDB authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class TokenBucket:
|
||||
_capacity: float
|
||||
_token_per_sec: float
|
||||
_tokens: float
|
||||
_last_fill_time: float
|
||||
|
||||
def __init__(self, capacity: float, token_per_sec: float):
|
||||
self._capacity = capacity
|
||||
self._token_per_sec = token_per_sec
|
||||
self._tokens = capacity
|
||||
self._last_fill_time = time.monotonic()
|
||||
|
||||
def consume(self, tokens: int) -> float:
|
||||
if tokens <= 0:
|
||||
return True
|
||||
now = time.monotonic()
|
||||
tokens_to_add = (now - self._last_fill_time) * self._token_per_sec
|
||||
self._tokens = min(self._capacity, self._tokens + tokens_to_add)
|
||||
self._last_fill_time = now
|
||||
left = self._tokens - tokens
|
||||
if left >= 0:
|
||||
self._tokens -= tokens
|
||||
return 0
|
||||
else:
|
||||
return -left / (tokens * self._token_per_sec)
|
|
@ -19,8 +19,10 @@
|
|||
from __future__ import annotations
|
||||
from typing import *
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import urllib.parse
|
||||
|
||||
import httptools
|
||||
|
@ -66,17 +68,31 @@ class ConsulKVProtocol(asyncwatcher.AsyncWatcherProtocol):
|
|||
|
||||
def on_message_complete(self) -> None:
|
||||
try:
|
||||
if self._parser.get_status_code() == 200:
|
||||
code = self._parser.get_status_code()
|
||||
if code == 200:
|
||||
self._watcher.incr_metrics_counter("watch-update")
|
||||
payload = json.loads(b"".join(self._buffers))[0]
|
||||
last_modify_index = payload["ModifyIndex"]
|
||||
self._watcher.on_update(payload["Value"])
|
||||
if self._last_modify_index != last_modify_index:
|
||||
if self._last_modify_index == last_modify_index:
|
||||
self._watcher.incr_metrics_counter("watch-timeout")
|
||||
self._last_modify_index = None
|
||||
else:
|
||||
self._last_modify_index = last_modify_index
|
||||
self.request()
|
||||
else:
|
||||
self._watcher.incr_metrics_counter(f"watch-err-{code}")
|
||||
self.request()
|
||||
|
||||
finally:
|
||||
self._buffers.clear()
|
||||
|
||||
def request(self) -> None:
|
||||
delay = self._watcher.consume_tokens(1)
|
||||
if delay > 0:
|
||||
asyncio.get_running_loop().call_later(
|
||||
delay + random.random() * 0.1, self.request
|
||||
)
|
||||
return
|
||||
uri = urllib.parse.urljoin("/v1/kv/", self._key)
|
||||
if self._last_modify_index is not None:
|
||||
uri += f"?index={self._last_modify_index}"
|
||||
|
|
|
@ -23,6 +23,8 @@ import enum
|
|||
import logging
|
||||
import os
|
||||
|
||||
from edb.server import metrics
|
||||
|
||||
from . import base
|
||||
|
||||
|
||||
|
@ -95,25 +97,31 @@ class AdaptiveHASupport:
|
|||
_state: State
|
||||
_unhealthy_timer_handle: Optional[asyncio.TimerHandle]
|
||||
|
||||
def __init__(self, cluster_protocol: base.ClusterProtocol):
|
||||
def __init__(self, cluster_protocol: base.ClusterProtocol, tag: str):
|
||||
self._cluster_protocol = cluster_protocol
|
||||
self._state = State.UNHEALTHY
|
||||
self._pgcon_count = 0
|
||||
self._unexpected_disconnects = 0
|
||||
self._unhealthy_timer_handle = None
|
||||
self._sys_pgcon_healthy = False
|
||||
self._tag = tag
|
||||
|
||||
def incr_metrics_counter(self, event: str, value: float = 1.0) -> None:
|
||||
metrics.ha_events_total.inc(value, f"adaptive://{self._tag}", event)
|
||||
|
||||
def set_state_failover(self, *, call_on_switch_over=True):
|
||||
self._state = State.FAILOVER
|
||||
self._reset()
|
||||
if call_on_switch_over:
|
||||
logger.critical("adaptive: HA failover detected")
|
||||
self.incr_metrics_counter("failover")
|
||||
self._cluster_protocol.on_switch_over()
|
||||
|
||||
def on_pgcon_broken(self, is_sys_pgcon: bool):
|
||||
if is_sys_pgcon:
|
||||
self._sys_pgcon_healthy = False
|
||||
if self._state == State.HEALTHY:
|
||||
self.incr_metrics_counter("unhealthy")
|
||||
self._state = State.UNHEALTHY
|
||||
self._unexpected_disconnects = 1
|
||||
self._unhealthy_timer_handle = (
|
||||
|
@ -148,11 +156,13 @@ class AdaptiveHASupport:
|
|||
if is_sys_pgcon:
|
||||
self._sys_pgcon_healthy = True
|
||||
if self._state == State.UNHEALTHY:
|
||||
self.incr_metrics_counter("healthy")
|
||||
self._state = State.HEALTHY
|
||||
logger.info("adaptive: Backend HA cluster is healthy")
|
||||
self._reset()
|
||||
elif self._state == State.FAILOVER:
|
||||
if self._sys_pgcon_healthy:
|
||||
self.incr_metrics_counter("healthy")
|
||||
self._state = State.HEALTHY
|
||||
logger.info(
|
||||
"adaptive: Backend HA cluster has recovered from failover"
|
||||
|
|
|
@ -22,6 +22,7 @@ from typing import *
|
|||
import urllib.parse
|
||||
|
||||
from edb.common import asyncwatcher
|
||||
from edb.server import metrics
|
||||
|
||||
|
||||
class ClusterProtocol:
|
||||
|
@ -50,6 +51,9 @@ class HABackend(asyncwatcher.AsyncWatcher):
|
|||
def dsn(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def incr_metrics_counter(self, event: str, value: float = 1.0) -> None:
|
||||
metrics.ha_events_total.inc(value, self.dsn, event)
|
||||
|
||||
|
||||
def get_backend(parsed_dsn: urllib.parse.ParseResult) -> Optional[HABackend]:
|
||||
backend, _, sub_scheme = parsed_dsn.scheme.partition("+")
|
||||
|
|
|
@ -24,10 +24,12 @@ import base64
|
|||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
import urllib.parse
|
||||
|
||||
from edb.common import asyncwatcher
|
||||
from edb.common import token_bucket
|
||||
from edb.server import consul
|
||||
|
||||
from . import base
|
||||
|
@ -59,7 +61,7 @@ class StolonBackend(base.HABackend):
|
|||
def get_master_addr(self) -> Optional[Tuple[str, int]]:
|
||||
return self._master_addr
|
||||
|
||||
def on_update(self, payload: bytes) -> None:
|
||||
def _on_update(self, payload: bytes) -> None:
|
||||
try:
|
||||
data = json.loads(base64.b64decode(payload))
|
||||
except (TypeError, ValueError):
|
||||
|
@ -103,6 +105,7 @@ class StolonBackend(base.HABackend):
|
|||
)
|
||||
self._master_addr = master_addr
|
||||
if self._failover_cb is not None:
|
||||
self.incr_metrics_counter("failover")
|
||||
self._failover_cb()
|
||||
|
||||
if self._waiter is not None:
|
||||
|
@ -126,6 +129,13 @@ class StolonConsulBackend(StolonBackend):
|
|||
self._port = port
|
||||
self._ssl = ssl
|
||||
|
||||
# This means we can request for 10 consecutive requests immediately
|
||||
# after each response without delay, and then we're capped to 0.1
|
||||
# request(token) per second, or 1 request per 10 seconds.
|
||||
cap = float(os.environ.get("EDGEDB_SERVER_CONSUL_TOKEN_CAPACITY", 10))
|
||||
rate = float(os.environ.get("EDGEDB_SERVER_CONSUL_TOKEN_RATE", 0.1))
|
||||
self._token_bucket = token_bucket.TokenBucket(cap, rate)
|
||||
|
||||
async def _start_watching(self) -> asyncwatcher.AsyncWatcherProtocol:
|
||||
_, pr = await asyncio.get_running_loop().create_connection(
|
||||
functools.partial(
|
||||
|
@ -140,7 +150,7 @@ class StolonConsulBackend(StolonBackend):
|
|||
)
|
||||
return pr # type: ignore [return-value]
|
||||
|
||||
@property
|
||||
@functools.cached_property
|
||||
def dsn(self) -> str:
|
||||
proto = "http" if self._ssl is None else "https"
|
||||
return (
|
||||
|
@ -148,6 +158,9 @@ class StolonConsulBackend(StolonBackend):
|
|||
f"{self._host}:{self._port}/{self._cluster_name}"
|
||||
)
|
||||
|
||||
def consume_tokens(self, tokens: int) -> float:
|
||||
return self._token_bucket.consume(tokens)
|
||||
|
||||
|
||||
def get_backend(
|
||||
sub_scheme: str, parsed_dsn: urllib.parse.ParseResult
|
||||
|
|
|
@ -97,3 +97,9 @@ background_errors = registry.new_labeled_counter(
|
|||
'Number of unhandled errors in background server routines.',
|
||||
labels=('source',)
|
||||
)
|
||||
|
||||
ha_events_total = registry.new_labeled_counter(
|
||||
"ha_events_total",
|
||||
"Number of each high-availability watch event.",
|
||||
labels=("dsn", "event"),
|
||||
)
|
||||
|
|
|
@ -885,6 +885,24 @@ async def get_remote_pg_cluster(
|
|||
addr = await ha_backend.get_cluster_consensus()
|
||||
dsn = 'postgresql://{}:{}'.format(*addr)
|
||||
|
||||
if parsed.query:
|
||||
# Allow passing through Postgres connection parameters from the HA
|
||||
# backend DSN as "pg" prefixed query strings. For example, an HA
|
||||
# backend DSN with `?pgpassword=123` will result an actual backend
|
||||
# DSN with `?password=123`. They have higher priority than the `PG`
|
||||
# prefixed environment variables like `PGPASSWORD`.
|
||||
pq = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
|
||||
query = {}
|
||||
for k, v in pq.items():
|
||||
if k.startswith("pg") and k not in ["pghost", "pgport"]:
|
||||
if isinstance(v, list):
|
||||
val = v[-1]
|
||||
else:
|
||||
val = cast(str, v)
|
||||
query[k[2:]] = val
|
||||
if query:
|
||||
dsn += f"?{urllib.parse.urlencode(query)}"
|
||||
|
||||
addrs, params = pgconnparams.parse_dsn(dsn)
|
||||
if len(addrs) > 1:
|
||||
raise ValueError('multiple hosts in Postgres DSN are not supported')
|
||||
|
|
|
@ -135,7 +135,9 @@ class Tenant(ha_base.ClusterProtocol):
|
|||
# Increase-only counter to reject outdated attempts to connect
|
||||
self._ha_master_serial = 0
|
||||
if backend_adaptive_ha:
|
||||
self._backend_adaptive_ha = adaptive_ha.AdaptiveHASupport(self)
|
||||
self._backend_adaptive_ha = adaptive_ha.AdaptiveHASupport(
|
||||
self, self._instance_name
|
||||
)
|
||||
else:
|
||||
self._backend_adaptive_ha = None
|
||||
self._readiness_state_file = readiness_state_file
|
||||
|
|
48
tests/common/test_token_bucket.py
Normal file
48
tests/common/test_token_bucket.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
#
|
||||
# This source file is part of the EdgeDB open source project.
|
||||
#
|
||||
# Copyright 2023-present MagicStack Inc. and the EdgeDB authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
from edb.common.token_bucket import TokenBucket
|
||||
|
||||
|
||||
class ManualClock:
|
||||
def __init__(self, value: float) -> None:
|
||||
self.value = value
|
||||
|
||||
def __call__(self) -> float:
|
||||
return self.value
|
||||
|
||||
|
||||
class WindowedSumTests(unittest.TestCase):
|
||||
def test_common_token_bucket(self) -> None:
|
||||
monotonic = ManualClock(0)
|
||||
with unittest.mock.patch("time.monotonic", monotonic):
|
||||
tb = TokenBucket(10, 0.1)
|
||||
self.assertEqual(tb.consume(5), 0)
|
||||
|
||||
monotonic.value += 12
|
||||
self.assertEqual(tb.consume(6), 0)
|
||||
self.assertGreater(tb.consume(1), 0)
|
||||
self.assertGreater(tb.consume(2), tb.consume(1))
|
||||
|
||||
monotonic.value += 30
|
||||
self.assertEqual(tb.consume(2), 0)
|
||||
self.assertEqual(tb.consume(1), 0)
|
||||
self.assertGreater(tb.consume(1), 0)
|
|
@ -503,13 +503,9 @@ class TestBackendHA(tb.TestCase):
|
|||
backend_dsn=(
|
||||
f"stolon+consul+http://127.0.0.1:{consul.http_port}"
|
||||
f"/{pg1.cluster_name}"
|
||||
f"?pguser=suname&pgpassword=supass&pgdatabase=postgres"
|
||||
),
|
||||
runstate_dir=str(pathlib.Path(consul.tmp_dir.name) / "edb"),
|
||||
env=dict(
|
||||
PGUSER="suname",
|
||||
PGPASSWORD="supass",
|
||||
PGDATABASE="postgres",
|
||||
),
|
||||
reset_auth=True,
|
||||
debug=debug,
|
||||
) as sd:
|
||||
|
|
Loading…
Reference in a new issue