mirror of
https://github.com/maxkratz/edgedb.git
synced 2024-09-16 18:59:05 +00:00
Implement base64 encoding and decoding functions (#5963)
This adds the new `std::encoding::base64_encode` and `std::encoding::base64_decode` functions to encode to and decode from base64-encoded strings. There is support for RFC 4648 §4 standard alphabet as well as RFC 4648 §5 URL- and filename-safe alphabet. Padding requirement can be controlled by the `padding` named only argument.
This commit is contained in:
parent
9467f88e08
commit
c1781cf6bc
4 changed files with 263 additions and 13 deletions
135
edb/lib/enc.edgeql
Normal file
135
edb/lib/enc.edgeql
Normal file
|
@ -0,0 +1,135 @@
|
|||
#
|
||||
# This source file is part of the EdgeDB open source project.
|
||||
#
|
||||
# Copyright EdgeDB Inc. and the EdgeDB authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
CREATE MODULE std::enc;
|
||||
|
||||
|
||||
CREATE SCALAR TYPE
|
||||
std::enc::Base64Alphabet EXTENDING enum<standard, urlsafe>;
|
||||
|
||||
|
||||
CREATE FUNCTION
|
||||
std::enc::base64_encode(
|
||||
data: std::bytes,
|
||||
NAMED ONLY alphabet: std::enc::Base64Alphabet =
|
||||
std::enc::Base64Alphabet.standard,
|
||||
NAMED ONLY padding: std::bool = true,
|
||||
) -> std::str
|
||||
{
|
||||
CREATE ANNOTATION std::description :=
|
||||
'Encode given data as a base64 string';
|
||||
SET volatility := 'Immutable';
|
||||
USING SQL $$
|
||||
SELECT
|
||||
CASE
|
||||
WHEN "alphabet" = 'standard' AND "padding" THEN
|
||||
pg_catalog.translate(
|
||||
pg_catalog.encode("data", 'base64'),
|
||||
E'\n',
|
||||
''
|
||||
)
|
||||
WHEN "alphabet" = 'standard' AND NOT "padding" THEN
|
||||
pg_catalog.translate(
|
||||
pg_catalog.rtrim(
|
||||
pg_catalog.encode("data", 'base64'),
|
||||
'='
|
||||
),
|
||||
E'\n',
|
||||
''
|
||||
)
|
||||
WHEN "alphabet" = 'urlsafe' AND "padding" THEN
|
||||
pg_catalog.translate(
|
||||
pg_catalog.encode("data", 'base64'),
|
||||
E'+/\n',
|
||||
'-_'
|
||||
)
|
||||
WHEN "alphabet" = 'urlsafe' AND NOT "padding" THEN
|
||||
pg_catalog.translate(
|
||||
pg_catalog.rtrim(
|
||||
pg_catalog.encode("data", 'base64'),
|
||||
'='
|
||||
),
|
||||
E'+/\n',
|
||||
'-_'
|
||||
)
|
||||
ELSE
|
||||
edgedb.raise(
|
||||
NULL::text,
|
||||
'invalid_parameter_value',
|
||||
msg => (
|
||||
'invalid alphabet for std::enc::base64_encode: '
|
||||
|| pg_catalog.quote_literal("alphabet")
|
||||
),
|
||||
detail => (
|
||||
'{"hint":"Supported alphabets: standard, urlsafe."}'
|
||||
)
|
||||
)
|
||||
END
|
||||
$$;
|
||||
};
|
||||
|
||||
|
||||
CREATE FUNCTION
|
||||
std::enc::base64_decode(
|
||||
data: std::str,
|
||||
NAMED ONLY alphabet: std::enc::Base64Alphabet =
|
||||
std::enc::Base64Alphabet.standard,
|
||||
NAMED ONLY padding: std::bool = true,
|
||||
) -> std::bytes
|
||||
{
|
||||
CREATE ANNOTATION std::description :=
|
||||
'Decode the byte64-encoded byte string and return decoded bytes.';
|
||||
SET volatility := 'Immutable';
|
||||
USING SQL $$
|
||||
SELECT
|
||||
CASE
|
||||
WHEN "alphabet" = 'standard' AND "padding" THEN
|
||||
pg_catalog.decode("data", 'base64')
|
||||
WHEN "alphabet" = 'standard' AND NOT "padding" THEN
|
||||
pg_catalog.decode(
|
||||
edgedb.pad_base64_string("data"),
|
||||
'base64'
|
||||
)
|
||||
WHEN "alphabet" = 'urlsafe' AND "padding" THEN
|
||||
pg_catalog.decode(
|
||||
pg_catalog.translate("data", '-_', '+/'),
|
||||
'base64'
|
||||
)
|
||||
WHEN "alphabet" = 'urlsafe' AND NOT "padding" THEN
|
||||
pg_catalog.decode(
|
||||
edgedb.pad_base64_string(
|
||||
pg_catalog.translate("data", '-_', '+/')
|
||||
),
|
||||
'base64'
|
||||
)
|
||||
ELSE
|
||||
edgedb.raise(
|
||||
NULL::bytea,
|
||||
'invalid_parameter_value',
|
||||
msg => (
|
||||
'invalid alphabet for std::enc::base64_decode: '
|
||||
|| pg_catalog.quote_literal("alphabet")
|
||||
),
|
||||
detail => (
|
||||
'{"hint":"Supported alphabets: standard, urlsafe."}'
|
||||
)
|
||||
)
|
||||
END
|
||||
$$;
|
||||
};
|
|
@ -4187,6 +4187,38 @@ class UuidGenerateV5Function(dbops.Function):
|
|||
)
|
||||
|
||||
|
||||
class PadBase64StringFunction(dbops.Function):
|
||||
text = r"""
|
||||
WITH
|
||||
l AS (SELECT pg_catalog.length("s") % 4 AS r),
|
||||
p AS (
|
||||
SELECT
|
||||
(CASE WHEN l.r > 0 THEN repeat('=', (4 - l.r))
|
||||
ELSE '' END) AS p
|
||||
FROM
|
||||
l
|
||||
)
|
||||
SELECT
|
||||
"s" || p.p
|
||||
FROM
|
||||
p
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
name=('edgedb', 'pad_base64_string'),
|
||||
args=[
|
||||
('s', ('text',)),
|
||||
],
|
||||
returns=('text',),
|
||||
volatility='immutable',
|
||||
language='sql',
|
||||
strict=True,
|
||||
parallel_safe=True,
|
||||
text=self.text,
|
||||
)
|
||||
|
||||
|
||||
async def bootstrap(
|
||||
conn: PGConnection,
|
||||
config_spec: edbconfig.Spec,
|
||||
|
@ -4303,6 +4335,7 @@ async def bootstrap(
|
|||
dbops.CreateFunction(FTSParseQueryFunction()),
|
||||
dbops.CreateFunction(FTSNormalizeWeightFunction()),
|
||||
dbops.CreateFunction(FTSNormalizeDocFunction()),
|
||||
dbops.CreateFunction(PadBase64StringFunction()),
|
||||
]
|
||||
commands = dbops.CommandGroup()
|
||||
commands.add_commands(cmds)
|
||||
|
|
|
@ -66,6 +66,7 @@ STD_MODULES = (
|
|||
sn.UnqualName('std::_test'),
|
||||
sn.UnqualName('fts'),
|
||||
sn.UnqualName('ext'),
|
||||
sn.UnqualName('std::enc'),
|
||||
)
|
||||
|
||||
# Specifies the order of processing of files and directories in lib/
|
||||
|
@ -77,6 +78,7 @@ STD_SOURCES = (
|
|||
sn.UnqualName('cfg'),
|
||||
sn.UnqualName('cal'),
|
||||
sn.UnqualName('ext'),
|
||||
sn.UnqualName('enc'),
|
||||
sn.UnqualName('pg'),
|
||||
)
|
||||
TESTMODE_SOURCES = (
|
||||
|
|
|
@ -17,9 +17,11 @@
|
|||
#
|
||||
|
||||
|
||||
import base64
|
||||
import decimal
|
||||
import json
|
||||
import os.path
|
||||
import random
|
||||
|
||||
import edgedb
|
||||
|
||||
|
@ -1132,19 +1134,6 @@ class TestEdgeQLFunctions(tb.QueryTestCase):
|
|||
[['Ab'], ['a'], ['a'], ['aB'], ['ab']],
|
||||
)
|
||||
|
||||
async def test_edgeql_functions_re_match_all_02(self):
|
||||
await self.assert_query_result(
|
||||
r'''
|
||||
WITH
|
||||
MODULE schema,
|
||||
C2 := ScalarType
|
||||
SELECT
|
||||
count(re_match_all('(\\w+)', ScalarType.name)) =
|
||||
2 * count(C2);
|
||||
''',
|
||||
[True],
|
||||
)
|
||||
|
||||
async def test_edgeql_functions_re_test_01(self):
|
||||
await self.assert_query_result(
|
||||
r'''SELECT re_test('ac', 'AbabaB');''',
|
||||
|
@ -7505,3 +7494,94 @@ class TestEdgeQLFunctions(tb.QueryTestCase):
|
|||
],
|
||||
json_only=True,
|
||||
)
|
||||
|
||||
async def test_edgeql_functions_encoding_base64_fuzz(self):
|
||||
for _ in range(10):
|
||||
value = random.randbytes(random.randrange(0, 1000))
|
||||
await self.assert_query_result(
|
||||
r"""
|
||||
WITH
|
||||
MODULE std::enc,
|
||||
value := <bytes>$value,
|
||||
standard_encoded := base64_encode(
|
||||
value),
|
||||
standard_decoded := base64_decode(
|
||||
standard_encoded),
|
||||
standard_unpadded_encoded := base64_encode(
|
||||
value,
|
||||
padding := false),
|
||||
standard_unpadded_decoded := base64_decode(
|
||||
standard_unpadded_encoded,
|
||||
padding := false),
|
||||
urlsafe_encoded := base64_encode(
|
||||
value,
|
||||
alphabet := Base64Alphabet.urlsafe),
|
||||
urlsafe_decoded := base64_decode(
|
||||
urlsafe_encoded,
|
||||
alphabet := Base64Alphabet.urlsafe),
|
||||
urlsafe_unpadded_encoded := base64_encode(
|
||||
value,
|
||||
alphabet := Base64Alphabet.urlsafe,
|
||||
padding := false),
|
||||
urlsafe_unpadded_decoded := base64_decode(
|
||||
urlsafe_unpadded_encoded,
|
||||
alphabet := Base64Alphabet.urlsafe,
|
||||
padding := false),
|
||||
SELECT {
|
||||
standard_encoded :=
|
||||
standard_encoded,
|
||||
standard_crosscheck :=
|
||||
standard_decoded = value,
|
||||
standard_unpadded_encoded :=
|
||||
standard_unpadded_encoded,
|
||||
standard_unpadded_crosscheck :=
|
||||
standard_unpadded_decoded = value,
|
||||
urlsafe_encoded :=
|
||||
urlsafe_encoded,
|
||||
urlsafe_crosscheck :=
|
||||
urlsafe_decoded = value,
|
||||
urlsafe_unpadded_encoded :=
|
||||
urlsafe_unpadded_encoded,
|
||||
urlsafe_unpadded_crosscheck :=
|
||||
urlsafe_unpadded_decoded = value,
|
||||
}
|
||||
""",
|
||||
[{
|
||||
"standard_encoded":
|
||||
base64.b64encode(value)
|
||||
.decode("utf-8"),
|
||||
"standard_crosscheck": True,
|
||||
"standard_unpadded_encoded":
|
||||
base64.b64encode(value)
|
||||
.decode("utf-8").rstrip('='),
|
||||
"standard_unpadded_crosscheck": True,
|
||||
"urlsafe_encoded":
|
||||
base64.urlsafe_b64encode(value)
|
||||
.decode("utf-8"),
|
||||
"urlsafe_crosscheck": True,
|
||||
"urlsafe_unpadded_encoded":
|
||||
base64.urlsafe_b64encode(value)
|
||||
.decode("utf-8").rstrip('='),
|
||||
"urlsafe_unpadded_crosscheck": True,
|
||||
}],
|
||||
variables={
|
||||
"value": value,
|
||||
},
|
||||
)
|
||||
|
||||
async def test_edgeql_functions_encoding_base64_bad(self):
|
||||
async with self.assertRaisesRegexTx(
|
||||
edgedb.InvalidValueError,
|
||||
r'invalid symbol "~" found while decoding base64 sequence',
|
||||
):
|
||||
await self.con.execute(
|
||||
'select std::enc::base64_decode("~")'
|
||||
)
|
||||
|
||||
async with self.assertRaisesRegexTx(
|
||||
edgedb.InvalidValueError,
|
||||
r'invalid base64 end sequence',
|
||||
):
|
||||
await self.con.execute(
|
||||
'select std::enc::base64_decode("AA")'
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue