Implement base64 encoding and decoding functions (#5963)

This adds the new `std::encoding::base64_encode` and
`std::encoding::base64_decode` functions to encode to and decode from
base64-encoded strings.  There is support for RFC 4648 §4 standard
alphabet as well as RFC 4648 §5 URL- and filename-safe alphabet.
Padding requirement can be controlled by the `padding` named only
argument.
This commit is contained in:
Elvis Pranskevichus 2023-08-28 16:21:56 -07:00 committed by GitHub
parent 9467f88e08
commit c1781cf6bc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 263 additions and 13 deletions

135
edb/lib/enc.edgeql Normal file
View file

@ -0,0 +1,135 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright EdgeDB Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
CREATE MODULE std::enc;
CREATE SCALAR TYPE
std::enc::Base64Alphabet EXTENDING enum<standard, urlsafe>;
CREATE FUNCTION
std::enc::base64_encode(
data: std::bytes,
NAMED ONLY alphabet: std::enc::Base64Alphabet =
std::enc::Base64Alphabet.standard,
NAMED ONLY padding: std::bool = true,
) -> std::str
{
CREATE ANNOTATION std::description :=
'Encode given data as a base64 string';
SET volatility := 'Immutable';
USING SQL $$
SELECT
CASE
WHEN "alphabet" = 'standard' AND "padding" THEN
pg_catalog.translate(
pg_catalog.encode("data", 'base64'),
E'\n',
''
)
WHEN "alphabet" = 'standard' AND NOT "padding" THEN
pg_catalog.translate(
pg_catalog.rtrim(
pg_catalog.encode("data", 'base64'),
'='
),
E'\n',
''
)
WHEN "alphabet" = 'urlsafe' AND "padding" THEN
pg_catalog.translate(
pg_catalog.encode("data", 'base64'),
E'+/\n',
'-_'
)
WHEN "alphabet" = 'urlsafe' AND NOT "padding" THEN
pg_catalog.translate(
pg_catalog.rtrim(
pg_catalog.encode("data", 'base64'),
'='
),
E'+/\n',
'-_'
)
ELSE
edgedb.raise(
NULL::text,
'invalid_parameter_value',
msg => (
'invalid alphabet for std::enc::base64_encode: '
|| pg_catalog.quote_literal("alphabet")
),
detail => (
'{"hint":"Supported alphabets: standard, urlsafe."}'
)
)
END
$$;
};
CREATE FUNCTION
std::enc::base64_decode(
data: std::str,
NAMED ONLY alphabet: std::enc::Base64Alphabet =
std::enc::Base64Alphabet.standard,
NAMED ONLY padding: std::bool = true,
) -> std::bytes
{
CREATE ANNOTATION std::description :=
'Decode the byte64-encoded byte string and return decoded bytes.';
SET volatility := 'Immutable';
USING SQL $$
SELECT
CASE
WHEN "alphabet" = 'standard' AND "padding" THEN
pg_catalog.decode("data", 'base64')
WHEN "alphabet" = 'standard' AND NOT "padding" THEN
pg_catalog.decode(
edgedb.pad_base64_string("data"),
'base64'
)
WHEN "alphabet" = 'urlsafe' AND "padding" THEN
pg_catalog.decode(
pg_catalog.translate("data", '-_', '+/'),
'base64'
)
WHEN "alphabet" = 'urlsafe' AND NOT "padding" THEN
pg_catalog.decode(
edgedb.pad_base64_string(
pg_catalog.translate("data", '-_', '+/')
),
'base64'
)
ELSE
edgedb.raise(
NULL::bytea,
'invalid_parameter_value',
msg => (
'invalid alphabet for std::enc::base64_decode: '
|| pg_catalog.quote_literal("alphabet")
),
detail => (
'{"hint":"Supported alphabets: standard, urlsafe."}'
)
)
END
$$;
};

View file

@ -4187,6 +4187,38 @@ class UuidGenerateV5Function(dbops.Function):
)
class PadBase64StringFunction(dbops.Function):
text = r"""
WITH
l AS (SELECT pg_catalog.length("s") % 4 AS r),
p AS (
SELECT
(CASE WHEN l.r > 0 THEN repeat('=', (4 - l.r))
ELSE '' END) AS p
FROM
l
)
SELECT
"s" || p.p
FROM
p
"""
def __init__(self) -> None:
super().__init__(
name=('edgedb', 'pad_base64_string'),
args=[
('s', ('text',)),
],
returns=('text',),
volatility='immutable',
language='sql',
strict=True,
parallel_safe=True,
text=self.text,
)
async def bootstrap(
conn: PGConnection,
config_spec: edbconfig.Spec,
@ -4303,6 +4335,7 @@ async def bootstrap(
dbops.CreateFunction(FTSParseQueryFunction()),
dbops.CreateFunction(FTSNormalizeWeightFunction()),
dbops.CreateFunction(FTSNormalizeDocFunction()),
dbops.CreateFunction(PadBase64StringFunction()),
]
commands = dbops.CommandGroup()
commands.add_commands(cmds)

View file

@ -66,6 +66,7 @@ STD_MODULES = (
sn.UnqualName('std::_test'),
sn.UnqualName('fts'),
sn.UnqualName('ext'),
sn.UnqualName('std::enc'),
)
# Specifies the order of processing of files and directories in lib/
@ -77,6 +78,7 @@ STD_SOURCES = (
sn.UnqualName('cfg'),
sn.UnqualName('cal'),
sn.UnqualName('ext'),
sn.UnqualName('enc'),
sn.UnqualName('pg'),
)
TESTMODE_SOURCES = (

View file

@ -17,9 +17,11 @@
#
import base64
import decimal
import json
import os.path
import random
import edgedb
@ -1132,19 +1134,6 @@ class TestEdgeQLFunctions(tb.QueryTestCase):
[['Ab'], ['a'], ['a'], ['aB'], ['ab']],
)
async def test_edgeql_functions_re_match_all_02(self):
await self.assert_query_result(
r'''
WITH
MODULE schema,
C2 := ScalarType
SELECT
count(re_match_all('(\\w+)', ScalarType.name)) =
2 * count(C2);
''',
[True],
)
async def test_edgeql_functions_re_test_01(self):
await self.assert_query_result(
r'''SELECT re_test('ac', 'AbabaB');''',
@ -7505,3 +7494,94 @@ class TestEdgeQLFunctions(tb.QueryTestCase):
],
json_only=True,
)
async def test_edgeql_functions_encoding_base64_fuzz(self):
for _ in range(10):
value = random.randbytes(random.randrange(0, 1000))
await self.assert_query_result(
r"""
WITH
MODULE std::enc,
value := <bytes>$value,
standard_encoded := base64_encode(
value),
standard_decoded := base64_decode(
standard_encoded),
standard_unpadded_encoded := base64_encode(
value,
padding := false),
standard_unpadded_decoded := base64_decode(
standard_unpadded_encoded,
padding := false),
urlsafe_encoded := base64_encode(
value,
alphabet := Base64Alphabet.urlsafe),
urlsafe_decoded := base64_decode(
urlsafe_encoded,
alphabet := Base64Alphabet.urlsafe),
urlsafe_unpadded_encoded := base64_encode(
value,
alphabet := Base64Alphabet.urlsafe,
padding := false),
urlsafe_unpadded_decoded := base64_decode(
urlsafe_unpadded_encoded,
alphabet := Base64Alphabet.urlsafe,
padding := false),
SELECT {
standard_encoded :=
standard_encoded,
standard_crosscheck :=
standard_decoded = value,
standard_unpadded_encoded :=
standard_unpadded_encoded,
standard_unpadded_crosscheck :=
standard_unpadded_decoded = value,
urlsafe_encoded :=
urlsafe_encoded,
urlsafe_crosscheck :=
urlsafe_decoded = value,
urlsafe_unpadded_encoded :=
urlsafe_unpadded_encoded,
urlsafe_unpadded_crosscheck :=
urlsafe_unpadded_decoded = value,
}
""",
[{
"standard_encoded":
base64.b64encode(value)
.decode("utf-8"),
"standard_crosscheck": True,
"standard_unpadded_encoded":
base64.b64encode(value)
.decode("utf-8").rstrip('='),
"standard_unpadded_crosscheck": True,
"urlsafe_encoded":
base64.urlsafe_b64encode(value)
.decode("utf-8"),
"urlsafe_crosscheck": True,
"urlsafe_unpadded_encoded":
base64.urlsafe_b64encode(value)
.decode("utf-8").rstrip('='),
"urlsafe_unpadded_crosscheck": True,
}],
variables={
"value": value,
},
)
async def test_edgeql_functions_encoding_base64_bad(self):
async with self.assertRaisesRegexTx(
edgedb.InvalidValueError,
r'invalid symbol "~" found while decoding base64 sequence',
):
await self.con.execute(
'select std::enc::base64_decode("~")'
)
async with self.assertRaisesRegexTx(
edgedb.InvalidValueError,
r'invalid base64 end sequence',
):
await self.con.execute(
'select std::enc::base64_decode("AA")'
)