Parse SQL with libpg_query and translate to our pgsql.ast (#4480)

This commit is contained in:
Aljaž Mur Eržen 2022-10-13 20:57:53 +02:00 committed by GitHub
parent b0a9f230dd
commit 59d37f9aa0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 1802 additions and 39 deletions

View file

@ -167,6 +167,12 @@
fi
rsync -av ./build/rust_extensions/edb/ ./edb/
# Build libpg_query
- name: Build libpg_query
run: |
python setup.py build_libpg_query
# Build extensions
- name: Handle Cython extensions build cache

View file

@ -190,6 +190,12 @@ jobs:
fi
rsync -av ./build/rust_extensions/edb/ ./edb/
# Build libpg_query
- name: Build libpg_query
run: |
python setup.py build_libpg_query
# Build extensions
- name: Handle Cython extensions build cache

View file

@ -188,6 +188,12 @@ jobs:
fi
rsync -av ./build/rust_extensions/edb/ ./edb/
# Build libpg_query
- name: Build libpg_query
run: |
python setup.py build_libpg_query
# Build extensions
- name: Handle Cython extensions build cache

View file

@ -190,6 +190,12 @@ jobs:
fi
rsync -av ./build/rust_extensions/edb/ ./edb/
# Build libpg_query
- name: Build libpg_query
run: |
python setup.py build_libpg_query
# Build extensions
- name: Handle Cython extensions build cache

View file

@ -188,6 +188,12 @@ jobs:
fi
rsync -av ./build/rust_extensions/edb/ ./edb/
# Build libpg_query
- name: Build libpg_query
run: |
python setup.py build_libpg_query
# Build extensions
- name: Handle Cython extensions build cache

View file

@ -201,6 +201,12 @@ jobs:
fi
rsync -av ./build/rust_extensions/edb/ ./edb/
# Build libpg_query
- name: Build libpg_query
run: |
python setup.py build_libpg_query
# Build extensions
- name: Handle Cython extensions build cache

4
.gitmodules vendored
View file

@ -5,3 +5,7 @@
[submodule "edb/server/pgproto"]
path = edb/server/pgproto
url = https://github.com/MagicStack/py-pgproto.git
[submodule "edb/pgsql/parser/libpg_query"]
path = edb/pgsql/parser/libpg_query
url = https://github.com/pganalyze/libpg_query.git
branch = 13-latest

View file

@ -24,7 +24,7 @@ import dataclasses
import typing
import uuid
from edb.common import ast
from edb.common import ast, parsing
from edb.common import typeutils
from edb.edgeql import ast as qlast
from edb.ir import ast as irast
@ -39,6 +39,13 @@ from edb.ir import ast as irast
class Base(ast.AST):
__ast_hidden__ = {'context'}
context: typing.Optional[parsing.ParserContext] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __repr__(self):
return f'<pg.{self.__class__.__name__} at 0x{id(self):x}>'
@ -413,6 +420,13 @@ class ResTarget(ImmutableBaseExpr):
val: BaseExpr
class InsertTarget(ImmutableBaseExpr):
"""Column reference in INSERT."""
# Column name
name: str
class UpdateTarget(ImmutableBaseExpr):
"""Query update target."""
@ -420,6 +434,8 @@ class UpdateTarget(ImmutableBaseExpr):
name: str | typing.List[str]
# value expression to assign
val: BaseExpr
# subscripts, field names and '*'
indirection: typing.Optional[typing.List[IndirectionOp]] = None
class InferClause(ImmutableBaseExpr):
@ -436,7 +452,9 @@ class OnConflictClause(ImmutableBaseExpr):
action: str
infer: typing.Optional[InferClause]
target_list: typing.Optional[list] = None
target_list: typing.Optional[
typing.List[UpdateTarget | MultiAssignRef]
] = None
where: typing.Optional[BaseExpr] = None
@ -533,7 +551,7 @@ class DMLQuery(Query):
class InsertStmt(DMLQuery):
# (optional) list of target column names
cols: typing.Optional[typing.List[ColumnRef]] = None
cols: typing.Optional[typing.List[InsertTarget]] = None
# source SELECT/VALUES or None
select_stmt: typing.Optional[Query] = None
# ON CONFLICT clause
@ -543,7 +561,9 @@ class InsertStmt(DMLQuery):
class UpdateStmt(DMLQuery):
# The UPDATE target list
targets: typing.List[UpdateTarget] = ast.field(factory=list)
targets: typing.List[UpdateTarget | MultiAssignRef] = ast.field(
factory=list
)
# WHERE clause
where_clause: typing.Optional[BaseExpr] = None
# optional FROM clause
@ -746,13 +766,16 @@ class Slice(ImmutableBaseExpr):
ridx: typing.Optional[BaseExpr]
IndirectionOp = Slice | Index | ColumnRef | Star
class Indirection(ImmutableBaseExpr):
"""Field and/or array element indirection."""
# Indirection subject
arg: BaseExpr
# Subscripts and/or field names and/or '*'
indirection: list
indirection: typing.List[IndirectionOp]
class ArrayExpr(ImmutableBaseExpr):
@ -837,7 +860,7 @@ class JoinExpr(BaseRangeVar):
# Right subtree
rarg: BaseExpr
# USING clause, if any
using_clause: typing.Optional[typing.List[BaseExpr]] = None
using_clause: typing.Optional[typing.List[ColumnRef]] = None
# Qualifiers on join, if any
quals: typing.Optional[BaseExpr] = None
@ -858,6 +881,7 @@ class SubLinkType(enum.IntEnum):
NOT_EXISTS = enum.auto()
ALL = enum.auto()
ANY = enum.auto()
EXPR = enum.auto()
class SubLink(ImmutableBaseExpr):
@ -867,6 +891,8 @@ class SubLink(ImmutableBaseExpr):
type: SubLinkType
# Sublink expression
expr: BaseExpr
# Sublink expression
test_expr: typing.Optional[BaseExpr] = None
# Sublink is never NULL
nullable: bool = False
@ -907,6 +933,17 @@ class NullTest(ImmutableBaseExpr):
nullable: bool = False
class BooleanTest(ImmutableBaseExpr):
"""IS [NOT] {TRUE,FALSE}"""
# Input expression,
arg: BaseExpr
negated: bool = False
is_true: bool = False
# NullTest is never NULL
nullable: bool = False
class CaseWhen(ImmutableBase):
# Condition expression

View file

@ -18,6 +18,7 @@
from __future__ import annotations
from typing import Sequence
from edb import errors
@ -305,16 +306,17 @@ class SQLSourceGenerator(codegen.SourceGenerator):
self.indentation += 1
self.new_lines = 1
if node.select_stmt.values:
self.write('VALUES ')
self.new_lines = 1
self.indentation += 1
self.visit_list(node.select_stmt.values)
self.indentation -= 1
else:
self.write('(')
self.visit(node.select_stmt)
self.write(')')
if node.select_stmt:
if node.select_stmt.values:
self.write('VALUES ')
self.new_lines = 1
self.indentation += 1
self.visit_list(node.select_stmt.values)
self.indentation -= 1
else:
self.write('(')
self.visit(node.select_stmt)
self.write(')')
if node.on_conflict:
self.new_lines = 1
@ -441,9 +443,14 @@ class SQLSourceGenerator(codegen.SourceGenerator):
def visit_ResTarget(self, node):
self.visit(node.val)
if node.indirection:
self._visit_indirection_ops(node.indirection)
if node.name:
self.write(' AS ' + common.quote_ident(node.name))
def visit_InsertTarget(self, node: pgast.InsertTarget):
self.write(common.quote_ident(node.name))
def visit_UpdateTarget(self, node):
if isinstance(node.name, list):
self.write('(')
@ -451,6 +458,8 @@ class SQLSourceGenerator(codegen.SourceGenerator):
self.write(')')
else:
self.write(common.quote_ident(node.name))
if node.indirection:
self._visit_indirection_ops(node.indirection)
self.write(' = ')
self.visit(node.val)
@ -550,7 +559,11 @@ class SQLSourceGenerator(codegen.SourceGenerator):
self.visit(node.larg)
if node.rarg is not None:
self.new_lines = 1
self.write(node.type.upper() + ' JOIN ')
join_type = node.type.upper()
if join_type == 'INNER':
self.write('JOIN ')
else:
self.write(join_type + ' JOIN ')
nested_join = (
isinstance(node.rarg, pgast.JoinExpr) and
node.rarg.rarg is not None
@ -574,6 +587,10 @@ class SQLSourceGenerator(codegen.SourceGenerator):
self.visit(node.quals)
if not nested_join:
self.indentation -= 1
elif node.using_clause:
self.write(" USING (")
self.visit_list(node.using_clause)
self.write(")")
def visit_Expr(self, node):
self.write('(')
@ -685,19 +702,24 @@ class SQLSourceGenerator(codegen.SourceGenerator):
self.visit(node.val)
def visit_SubLink(self, node):
if node.type == pgast.SubLinkType.EXISTS:
self.write('EXISTS')
if node.test_expr and node.type == pgast.SubLinkType.ANY:
self.visit(node.test_expr)
self.write(' IN ')
elif node.type == pgast.SubLinkType.EXISTS:
self.write('EXISTS ')
elif node.type == pgast.SubLinkType.NOT_EXISTS:
self.write('NOT EXISTS')
self.write('NOT EXISTS ')
elif node.type == pgast.SubLinkType.ALL:
self.write('ALL')
self.write('ALL ')
elif node.type == pgast.SubLinkType.ANY:
self.write('ANY')
self.write('ANY ')
elif node.type == pgast.SubLinkType.EXPR:
pass
else:
raise SQLSourceGeneratorError(
'unexpected SubLinkType: {!r}'.format(node.type))
self.write(' (')
self.write('(')
self.new_lines = 1
self.indentation += 1
self.visit(node.expr)
@ -773,14 +795,30 @@ class SQLSourceGenerator(codegen.SourceGenerator):
self.write(' IS NULL')
self.write(')')
def visit_Indirection(self, node):
def visit_BooleanTest(self, node):
self.write("(")
self.visit(node.arg)
op = " IS"
if node.negated:
op += " NOT"
if node.is_true:
op += " TRUE"
else:
op += " FALSE"
self.write(op)
self.write(")")
def visit_Indirection(self, node: pgast.Indirection):
self.write('(')
self.visit(node.arg)
self.write(')')
for indirection in node.indirection:
if isinstance(indirection, (pgast.Star, pgast.ColumnRef)):
self._visit_indirection_ops(node.indirection)
def _visit_indirection_ops(self, ops: Sequence[pgast.IndirectionOp]):
for op in ops:
if isinstance(op, (pgast.Star, pgast.ColumnRef)):
self.write('.')
self.visit(indirection)
self.visit(op)
def visit_Index(self, node):
self.write('[')

View file

@ -145,9 +145,9 @@ def compile_ConfigSet(
]
),
cols=[
pgast.ColumnRef(name=['name']),
pgast.ColumnRef(name=['value']),
pgast.ColumnRef(name=['type']),
pgast.InsertTarget(name='name'),
pgast.InsertTarget(name='value'),
pgast.InsertTarget(name='type'),
],
on_conflict=pgast.OnConflictClause(
action='update',
@ -212,8 +212,8 @@ def compile_ConfigSet(
]
),
cols=[
pgast.ColumnRef(name=['name']),
pgast.ColumnRef(name=['value']),
pgast.InsertTarget(name='name'),
pgast.InsertTarget(name='value'),
],
on_conflict=pgast.OnConflictClause(
action='update',

View file

@ -35,7 +35,7 @@ from __future__ import annotations
from typing import *
from edb.common import uuidgen
from edb.common.typeutils import not_none
from edb.common.typeutils import downcast, not_none
from edb.edgeql import ast as qlast
from edb.edgeql import qltypes
@ -731,8 +731,7 @@ def process_insert_body(
# Populate the real insert statement based on the select we generated
insert_stmt.cols = [
pgast.ColumnRef(name=[not_none(value.name)])
for value in values
pgast.InsertTarget(name=not_none(value.name)) for value in values
]
insert_stmt.select_stmt = pgast.SelectStmt(
target_list=[
@ -2138,7 +2137,10 @@ def process_link_update(
query=pgast.InsertStmt(
relation=target_rvar,
select_stmt=data_select,
cols=cols,
cols=[
pgast.InsertTarget(name=downcast(col.name[0], str))
for col in cols
],
on_conflict=conflict_clause,
returning_list=[
pgast.ResTarget(

1
edb/pgsql/parser/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
*.c

View file

@ -0,0 +1,36 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2010-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import *
import json
from edb.pgsql import ast as pgast
from .exceptions import PSqlUnsupportedError
from .parser import pg_parse
from .ast_builder import build_queries
def parse(sql_query: str) -> List[pgast.Query]:
ast_json = pg_parse(bytes(sql_query, encoding="UTF8"))
try:
return build_queries(json.loads(ast_json), sql_query)
except IndexError:
raise PSqlUnsupportedError()

View file

@ -0,0 +1,677 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2010-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import *
from edb.common.parsing import ParserContext
from edb.pgsql import ast as pgast
from edb.edgeql import ast as qlast
from edb.pgsql.parser.exceptions import PSqlUnsupportedError
# Node = bool | str | int | float | List[Any] | dict[str, Any]
Node = Any
Context = Tuple[str]
T = TypeVar("T")
U = TypeVar("U")
Builder = Callable[[Node, Context], T]
def build_queries(node: Node, source_sql: str) -> List[pgast.Query]:
ctx = (source_sql,)
return [_build_query(node["stmt"], ctx) for node in node["stmts"]]
def _maybe(
node: Node, ctx: Context, name: str, builder: Builder
) -> Optional[T]:
if name in node:
return builder(node[name], ctx)
return None
def _ident(t: T) -> U:
return t # type: ignore
def _list(
node: Node,
ctx: Context,
name: str,
builder: Builder,
mapper: Callable[[T], U] = _ident,
) -> List[U]:
return [mapper(builder(n, ctx)) for n in node[name]]
def _maybe_list(
node: Node,
ctx: Context,
name: str,
builder: Builder,
mapper: Callable[[T], U] = _ident,
) -> Optional[List[U]]:
return _list(node, ctx, name, builder, mapper) if name in node else None
def _enum(
node: Node,
ctx: Context,
builders: dict[str, Builder],
fallbacks: Sequence[Builder] = (),
) -> T:
for name in builders:
if name in node:
builder = builders[name]
return builder(node[name], ctx)
for fallback in fallbacks:
try:
return fallback(node, ctx)
except PSqlUnsupportedError:
pass
raise PSqlUnsupportedError(f"unknown enum: {node}")
def _build_any(node: Node, _: Context) -> Any:
return node
def _build_str(node: Node, _: Context) -> str:
node = _unwrap(node, "String")
node = _unwrap(node, "str")
return str(node)
def _build_bool(node: Node, _: Context) -> bool:
assert isinstance(node, bool)
return node
def _bool_or_false(node: Node, name: str) -> bool:
return node[name] if name in node else False
def _unwrap(node: Node, name: str) -> pgast.Query:
if isinstance(node, dict) and name in node:
return node[name]
return node
def _probe(n: Node, keys: List[str | int]) -> bool:
for key in keys:
contained = key in n if isinstance(key, str) else key < len(n)
if contained:
n = n[key]
else:
return False
return True
def _as_column_ref(name: str) -> pgast.ColumnRef:
return pgast.ColumnRef(
name=(name,),
)
def _build_context(n: Node, c: Context) -> Optional[ParserContext]:
if 'location' not in n:
return None
return ParserContext(
name='<string>', buffer=c[0], start=n['location'], end=n['location']
)
def _build_query(node: Node, c: Context) -> pgast.Query:
return _enum(
node,
c,
{
"SelectStmt": _build_select_stmt,
"InsertStmt": _build_insert_stmt,
"UpdateStmt": _build_update_stmt,
"DeleteStmt": _build_delete_stmt,
},
)
def _build_select_stmt(n: Node, c: Context) -> pgast.SelectStmt:
op = _maybe(n, c, "op", _build_str)
if op:
op = op[6:]
if op == "NONE":
op = None
return pgast.SelectStmt(
distinct_clause=_maybe_list(n, c, "distinct_clause", _build_any),
target_list=_maybe_list(n, c, "targetList", _build_res_target) or [],
from_clause=_maybe_list(n, c, "fromClause", _build_base_range_var)
or [],
where_clause=_maybe(n, c, "whereClause", _build_base_expr),
group_clause=_maybe_list(n, c, "groupClause", _build_base),
having=_maybe(n, c, "having", _build_base_expr),
window_clause=_maybe_list(n, c, "windowClause", _build_base),
values=_maybe_list(n, c, "valuesLists", _build_base_expr),
sort_clause=_maybe_list(n, c, "sortClause", _build_sort_by),
limit_offset=_maybe(n, c, "limitOffset", _build_base_expr),
limit_count=_maybe(n, c, "limitCount", _build_base_expr),
locking_clause=_maybe_list(n, c, "sortClause", _build_any),
op=op,
all=n["all"] if "all" in n else False,
larg=_maybe(n, c, "larg", _build_select_stmt),
rarg=_maybe(n, c, "rarg", _build_select_stmt),
ctes=_maybe(n, c, "withClause", _build_ctes),
)
def _build_insert_stmt(n: Node, c: Context) -> pgast.InsertStmt:
return pgast.InsertStmt(
relation=_maybe(n, c, "relation", _build_rel_range_var),
returning_list=_maybe_list(n, c, "returningList", _build_res_target)
or [],
cols=_maybe_list(n, c, "cols", _build_insert_target),
select_stmt=_maybe(n, c, "selectStmt", _build_query),
on_conflict=_maybe(n, c, "on_conflict", _build_on_conflict),
ctes=_maybe(n, c, "withClause", _build_ctes),
)
def _build_update_stmt(n: Node, c: Context) -> pgast.UpdateStmt:
return pgast.UpdateStmt(
relation=_maybe(n, c, "relation", _build_rel_range_var),
targets=_build_targets(n, c, "targetList") or [],
where_clause=_maybe(n, c, "whereClause", _build_base_expr),
from_clause=_maybe_list(n, c, "fromClause", _build_base_range_var)
or [],
)
def _build_delete_stmt(n: Node, c: Context) -> pgast.DeleteStmt:
return pgast.DeleteStmt(
relation=_maybe(n, c, "relation", _build_rel_range_var),
returning_list=_maybe_list(n, c, "returningList", _build_res_target)
or [],
where_clause=_maybe(n, c, "whereClause", _build_base_expr),
using_clause=_maybe_list(n, c, "usingClause", _build_base_range_var)
or [],
)
def _build_base(n: Node, c: Context) -> pgast.Base:
return _enum(
n,
c,
{
"CommonTableExpr": _build_cte,
},
[_build_base_expr], # type: ignore
)
def _build_base_expr(node: Node, c: Context) -> pgast.BaseExpr:
return _enum(
node,
c,
{
"ResTarget": _build_res_target,
"FuncCall": _build_func_call,
"List": _build_implicit_row,
"A_Expr": _build_a_expr,
"A_ArrayExpr": _build_array_expr,
"A_Const": _build_const,
"BoolExpr": _build_bool_expr,
"CaseExpr": _build_case_expr,
"TypeCast": _build_type_cast,
"NullTest": _build_null_test,
"BooleanTest": _build_boolean_test,
"RowExpr": _build_row_expr,
"SubLink": _build_sub_link,
"ParamRef": _build_param_ref,
"SetToDefault": _build_keyword("DEFAULT"),
},
[_build_base_range_var, _build_indirection_op], # type: ignore
)
def _build_indirection_op(n: Node, c: Context) -> pgast.IndirectionOp:
return _enum(
n,
c,
{
'A_Indices': _build_index_or_slice,
'Star': _build_star,
'ColumnRef': _build_column_ref,
},
)
def _build_ctes(n: Node, c: Context) -> List[pgast.CommonTableExpr]:
return _list(n, c, "ctes", _build_cte)
def _build_cte(n: Node, c: Context) -> pgast.CommonTableExpr:
n = _unwrap(n, "CommonTableExpr")
materialized = None
if n["ctematerialized"] == "CTEMaterializeAlways":
materialized = True
elif n["ctematerialized"] == "CTEMaterializeNever":
materialized = False
return pgast.CommonTableExpr(
name=n["ctename"],
query=_build_query(n["ctequery"], c),
recursive=_bool_or_false(n, "cterecursive"),
aliascolnames=_maybe_list(
n, c, "aliascolnames", _build_str # type: ignore
),
materialized=materialized,
context=_build_context(n, c),
)
def _build_keyword(name: str) -> Builder[pgast.Keyword]:
return lambda n, c: pgast.Keyword(name=name, context=_build_context(n, c))
def _build_param_ref(n: Node, c: Context) -> pgast.BaseParamRef:
return pgast.ParamRef(number=n["number"], context=_build_context(n, c))
def _build_sub_link(n: Node, c: Context) -> pgast.SubLink:
typ = n["subLinkType"]
if typ == "EXISTS_SUBLINK":
type = pgast.SubLinkType.EXISTS
elif typ == "NOT_EXISTS_SUBLINK":
type = pgast.SubLinkType.NOT_EXISTS
elif typ == "ALL_SUBLINK":
type = pgast.SubLinkType.ALL
elif typ == "ANY_SUBLINK":
type = pgast.SubLinkType.ANY
elif typ == "EXPR_SUBLINK":
type = pgast.SubLinkType.EXPR
else:
raise PSqlUnsupportedError(f"unknown SubLink type: `{typ}`")
return pgast.SubLink(
type=type,
expr=_build_query(n["subselect"], c),
test_expr=_maybe(n, c, 'testexpr', _build_base_expr),
context=_build_context(n, c),
)
def _build_row_expr(n: Node, c: Context) -> pgast.ImplicitRowExpr:
return pgast.ImplicitRowExpr(
args=_list(n, c, "args", _build_base_expr),
context=_build_context(n, c),
)
def _build_boolean_test(n: Node, c: Context) -> pgast.BooleanTest:
return pgast.BooleanTest(
arg=_build_base_expr(n["arg"], c),
negated=n["booltesttype"].startswith("IS_NOT"),
is_true=n["booltesttype"].endswith("TRUE"),
context=_build_context(n, c),
)
def _build_null_test(n: Node, c: Context) -> pgast.NullTest:
return pgast.NullTest(
arg=_build_base_expr(n["arg"], c),
negated=n["nulltesttype"] == "IS_NOT_NULL",
context=_build_context(n, c),
)
def _build_type_cast(n: Node, c: Context) -> pgast.TypeCast:
return pgast.TypeCast(
arg=_build_base_expr(n["arg"], c),
type_name=_build_type_name(n["typeName"], c),
context=_build_context(n, c),
)
def _build_type_name(n: Node, c: Context) -> pgast.TypeName:
return pgast.TypeName(
name=tuple(_list(n, c, "names", _build_str)),
setof=_bool_or_false(n, "setof"),
typmods=None,
array_bounds=None,
context=_build_context(n, c),
)
def _build_case_expr(n: Node, c: Context) -> pgast.CaseExpr:
return pgast.CaseExpr(
arg=_maybe(n, c, "arg", _build_base_expr),
args=_list(n, c, "args", _build_case_when),
defresult=_maybe(n, c, "defresult", _build_base_expr),
context=_build_context(n, c),
)
def _build_case_when(n: Node, c: Context) -> pgast.CaseWhen:
n = _unwrap(n, "CaseWhen")
return pgast.CaseWhen(
expr=_build_base_expr(n["expr"], c),
result=_build_base_expr(n["result"], c),
context=_build_context(n, c),
)
def _build_bool_expr(n: Node, c: Context) -> pgast.Expr:
name = _build_str(n["boolop"], c)[0:-5]
res = pgast.Expr(
kind=pgast.ExprKind.OP,
name=name,
lexpr=_build_base_expr(n["args"].pop(0), c),
rexpr=_build_base_expr(n["args"].pop(0), c),
context=_build_context(n, c),
)
while len(n["args"]) > 0:
res = pgast.Expr(
kind=pgast.ExprKind.OP,
name=_build_str(n["boolop"], c)[0:-5],
lexpr=res,
rexpr=_build_base_expr(n["args"].pop(0), c),
context=_build_context(n, c),
)
return res
def _build_base_range_var(n: Node, c: Context) -> pgast.BaseRangeVar:
return _enum(
n,
c,
{
"RangeVar": _build_rel_range_var,
"JoinExpr": _build_join_expr,
"RangeFunction": _build_range_function,
"RangeSubselect": _build_range_subselect,
},
)
def _build_const(n: Node, c: Context) -> pgast.BaseConstant:
val = n["val"]
context = _build_context(n, c)
if "Integer" in val:
return pgast.NumericConstant(
val=str(val["Integer"]["ival"]), context=context
)
if "Float" in val:
return pgast.NumericConstant(val=val["Float"]["str"], context=context)
if "Null" in val:
return pgast.NullConstant(context=context)
if "String" in val:
return pgast.StringConstant(val=_build_str(val, c), context=context)
raise PSqlUnsupportedError(f'unknown Const: {val}')
def _build_range_subselect(n: Node, c: Context) -> pgast.RangeSubselect:
return pgast.RangeSubselect(
alias=_maybe(n, c, "alias", _build_alias) or pgast.Alias(aliasname=""),
lateral=_bool_or_false(n, "lateral"),
subquery=_build_query(n["subquery"], c),
)
def _build_range_function(n: Node, c: Context) -> pgast.RangeFunction:
return pgast.RangeFunction(
alias=_maybe(n, c, "alias", _build_alias) or pgast.Alias(aliasname=""),
lateral=_bool_or_false(n, "lateral"),
with_ordinality=_bool_or_false(n, "with_ordinality"),
is_rowsfrom=_bool_or_false(n, "is_rowsfrom"),
functions=_build_implicit_row(n["functions"], c).args, # type: ignore
)
def _build_join_expr(n: Node, c: Context) -> pgast.JoinExpr:
return pgast.JoinExpr(
alias=_maybe(n, c, "alias", _build_alias) or pgast.Alias(aliasname=""),
type=n["jointype"][5:],
larg=_build_base_expr(n["larg"], c),
rarg=_build_base_expr(n["rarg"], c),
using_clause=_maybe_list(
n, c, "usingClause", _build_str, _as_column_ref
),
quals=_maybe(n, c, "quals", _build_base_expr),
)
def _build_rel_range_var(n: Node, c: Context) -> pgast.RelRangeVar:
return pgast.RelRangeVar(
alias=_maybe(n, c, "alias", _build_alias) or pgast.Alias(aliasname=""),
relation=_build_base_relation(n, c),
include_inherited=_bool_or_false(n, "inh"),
context=_build_context(n, c),
)
def _build_alias(n: Node, c: Context) -> pgast.Alias:
return pgast.Alias(
aliasname=_build_str(n["aliasname"], c),
colnames=_maybe_list(n, c, "colnames", _build_str),
)
def _build_base_relation(n: Node, c: Context) -> pgast.BaseRelation:
return pgast.Relation(
name=_maybe(n, c, "relname", _build_str),
context=_build_context(n, c),
)
def _build_implicit_row(n: Node, c: Context) -> pgast.ImplicitRowExpr:
if isinstance(n, list):
n = n[0]
n = _unwrap(n, "List")
return pgast.ImplicitRowExpr(
args=[_build_base_expr(e, c) for e in n["items"] if len(e) > 0],
)
def _build_array_expr(n: Node, c: Context) -> pgast.ArrayExpr:
return pgast.ArrayExpr(elements=_list(n, c, "elements", _build_base_expr))
def _build_a_expr(n: Node, c: Context) -> pgast.Expr:
if n["kind"] == "AEXPR_OP":
return pgast.Expr(
kind=pgast.ExprKind.OP,
name=_build_str(n["name"][0], c),
lexpr=_maybe(n, c, "lexpr", _build_base_expr),
rexpr=_maybe(n, c, "rexpr", _build_base_expr),
context=_build_context(n, c),
)
elif n["kind"] == "AEXPR_LIKE":
return pgast.Expr(
kind=pgast.ExprKind.OP,
name="LIKE",
lexpr=_maybe(n, c, "lexpr", _build_base_expr),
rexpr=_maybe(n, c, "rexpr", _build_base_expr),
context=_build_context(n, c),
)
elif n["kind"] == "AEXPR_IN":
return pgast.Expr(
kind=pgast.ExprKind.OP,
name="IN",
lexpr=_maybe(n, c, "lexpr", _build_base_expr),
rexpr=_maybe(n, c, "rexpr", _build_base_expr),
context=_build_context(n, c),
)
else:
raise PSqlUnsupportedError(f'unknown ExprKind: {n["kind"]}')
def _build_func_call(n: Node, c: Context) -> pgast.FuncCall:
n = _unwrap(n, "FuncCall")
return pgast.FuncCall(
name=tuple(_list(n, c, "funcname", _build_str)),
args=_maybe_list(n, c, "args", _build_base_expr) or [],
agg_order=_maybe_list(n, c, "aggOrder", _build_sort_by),
agg_filter=_maybe(n, c, "aggFilter", _build_base_expr),
agg_star=_bool_or_false(n, "agg_star"),
agg_distinct=_bool_or_false(n, "agg_distinct"),
over=_maybe(n, c, "over", _build_window_def),
with_ordinality=_bool_or_false(n, "withOrdinality"),
context=_build_context(n, c),
)
def _build_index_or_slice(n: Node, c: Context) -> pgast.Slice | pgast.Index:
if n['is_slice']:
return pgast.Slice(
lidx=_build_base_expr(n['lidx'], c),
ridx=_build_base_expr(n['uidx'], c),
)
else:
return pgast.Index(
idx=_build_base_expr(n['lidx'], c),
)
def _build_res_target(n: Node, c: Context) -> pgast.ResTarget:
n = _unwrap(n, "ResTarget")
return pgast.ResTarget(
name=_maybe(n, c, "name", _build_str),
indirection=_maybe_list(n, c, "indirection", _build_indirection_op),
val=_build_base_expr(n["val"], c),
context=_build_context(n, c),
)
def _build_insert_target(n: Node, c: Context) -> pgast.InsertTarget:
n = _unwrap(n, "ResTarget")
return pgast.InsertTarget(
name=_build_str(n['name'], c),
context=_build_context(n, c),
)
def _build_update_target(n: Node, c: Context) -> pgast.UpdateTarget:
n = _unwrap(n, "ResTarget")
return pgast.UpdateTarget(
name=_build_str(n['name'], c),
val=_build_base_expr(n['val'], c),
indirection=_maybe_list(n, c, "indirection", _build_indirection_op),
context=_build_context(n, c),
)
def _build_window_def(n: Node, c: Context) -> pgast.WindowDef:
return pgast.WindowDef(
name=_maybe(n, c, "name", _build_str),
refname=_maybe(n, c, "refname", _build_str),
partition_clause=_maybe_list(
n, c, "partitionClause", _build_base_expr
),
order_clause=_maybe_list(n, c, "orderClause", _build_sort_by),
frame_options=None,
start_offset=_maybe(n, c, "startOffset", _build_base_expr),
end_offset=_maybe(n, c, "endOffset", _build_base_expr),
context=_build_context(n, c),
)
def _build_sort_by(n: Node, c: Context) -> pgast.SortBy:
n = _unwrap(n, "SortBy")
return pgast.SortBy(
node=_build_base_expr(n["node"], c),
dir=_maybe(n, c, "sortby_dir", _build_sort_order),
nulls=_maybe(n, c, "sortby_nulls", _build_nones_order),
context=_build_context(n, c),
)
def _build_nones_order(n: Node, _c: Context) -> qlast.NonesOrder:
if n == "SORTBY_NULLS_FIRST":
return qlast.NonesFirst
return qlast.NonesLast
def _build_sort_order(n: Node, _c: Context) -> qlast.SortOrder:
if n == "SORTBY_DESC":
return qlast.SortOrder.Desc
return qlast.SortOrder.Asc
def _build_targets(
n: Node, c: Context, key: str
) -> Optional[List[pgast.UpdateTarget | pgast.MultiAssignRef]]:
if _probe(n, [key, 0, "ResTarget", "val", "MultiAssignRef"]):
return [_build_multi_assign_ref(n[key], c)]
else:
return _maybe_list(n, c, key, _build_update_target)
def _build_multi_assign_ref(
targets: List[Node], c: Context
) -> pgast.MultiAssignRef:
mar = targets[0]['ResTarget']['val']['MultiAssignRef']
return pgast.MultiAssignRef(
source=_build_base_expr(mar['source'], c),
columns=[
_as_column_ref(target['ResTarget']['name']) for target in targets
],
context=_build_context(targets[0]['ResTarget'], c),
)
def _build_column_ref(n: Node, c: Context) -> pgast.ColumnRef:
return pgast.ColumnRef(
name=_list(n, c, "fields", _build_string_or_star),
optional=_maybe(n, c, "optional", _build_bool),
context=_build_context(n, c),
)
def _build_infer_clause(n: Node, c: Context) -> pgast.InferClause:
return pgast.InferClause(
index_elems=_maybe_list(n, c, "indexElems", _build_str),
where_clause=_maybe(n, c, "whereClause", _build_base_expr),
conname=_maybe(n, c, "conname", _build_str),
context=_build_context(n, c),
)
def _build_on_conflict(n: Node, c: Context) -> pgast.OnConflictClause:
return pgast.OnConflictClause(
action=_build_str(n["action"], c),
infer=_maybe(n, c, "infer", _build_infer_clause),
target_list=_build_targets(n, c, "targetList"),
where=_maybe(n, c, "where", _build_base_expr),
context=_build_context(n, c),
)
def _build_star(_n: Node, _c: Context) -> pgast.Star | str:
return pgast.Star()
def _build_string_or_star(node: Node, c: Context) -> pgast.Star | str:
return _enum(node, c, {"String": _build_str, "A_Star": _build_star})

View file

@ -0,0 +1,40 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2010-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
class PSqlParseError(Exception):
def __init__(self, message, lineno, cursorpos):
self.message = message
self.lineno = lineno
self.cursorpos = cursorpos
def __str__(self):
return self.message
class PSqlUnsupportedError(Exception):
def __init__(self, construct: Optional[str] = None):
self.construct = construct
def __str__(self):
if self.construct is not None:
return f"unsupported SQL construct: {self.construct}"
return "unsupported SQL construct"

@ -0,0 +1 @@
Subproject commit 4b30b03cb3944f01d4807ee89532549ccf115a44

View file

@ -0,0 +1,51 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2010-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .exceptions import PSqlParseError
cdef extern from "pg_query.h":
ctypedef struct PgQueryError:
char *message
int lineno
int cursorpos
ctypedef struct PgQueryParseResult:
char *parse_tree
PgQueryError *error
PgQueryParseResult pg_query_parse(const char* input)
void pg_query_free_parse_result(PgQueryParseResult result);
def pg_parse(query) -> str:
cdef PgQueryParseResult result
result = pg_query_parse(query)
if result.error:
error = PSqlParseError(
result.error.message.decode('utf8'),
result.error.lineno, result.error.cursorpos
)
pg_query_free_parse_result(result)
raise error
result_utf8 = result.parse_tree.decode('utf8')
pg_query_free_parse_result(result)
return result_utf8

View file

@ -58,7 +58,14 @@ EXT_LDFLAGS: list[str] = []
ROOT_PATH = pathlib.Path(__file__).parent.resolve()
EXT_INC_DIRS = [(ROOT_PATH / 'edb' / 'server' / 'pgproto').as_posix()]
EXT_INC_DIRS = [
(ROOT_PATH / 'edb' / 'server' / 'pgproto').as_posix(),
(ROOT_PATH / 'edb' / 'pgsql' / 'parser' / 'libpg_query').as_posix()
]
EXT_LIB_DIRS = [
(ROOT_PATH / 'edb' / 'pgsql' / 'parser' / 'libpg_query').as_posix()
]
if platform.uname().system != 'Windows':
@ -278,6 +285,27 @@ def _compile_postgres(build_base, *,
)
def _compile_libpg_query():
proc = subprocess.run(
['git', 'submodule', 'status', 'edb/pgsql/parser/libpg_query'],
stdout=subprocess.PIPE,
universal_newlines=True,
check=True,
cwd=ROOT_PATH,
)
status = proc.stdout
if status[0] == '-':
print('libpg_query submodule not initialized, '
'run `git submodule init; git submodule update`')
exit(1)
dir = (ROOT_PATH / 'edb' / 'pgsql' / 'parser' / 'libpg_query').resolve()
subprocess.run(
['make'] + ['build', '-j', str(max(os.cpu_count() - 1, 1))],
cwd=str(dir), check=True)
def _check_rust():
import packaging.version
@ -374,8 +402,9 @@ class build(setuptools_build.build):
user_options = setuptools_build.build.user_options
sub_commands = (
setuptools_build.build.sub_commands
+ [
[
("build_libpg_query", lambda self: True),
*setuptools_build.build.sub_commands,
("build_metadata", lambda self: True),
("build_parsers", lambda self: True),
("build_postgres", lambda self: True),
@ -559,6 +588,20 @@ class build_postgres(setuptools.Command):
produce_compile_commands_json=self.compile_commands,
)
class build_libpg_query(setuptools.Command):
description = "build libpg_query"
user_options = []
def initialize_options(self):
pass
def finalize_options(self):
pass
def run(self):
_compile_libpg_query()
class build_ext(setuptools_build_ext.build_ext):
@ -865,6 +908,7 @@ setuptools.setup(
'build_cli': build_cli,
'build_parsers': build_parsers,
'build_ui': build_ui,
'build_libpg_query': build_libpg_query,
'ci_helper': ci_helper,
},
ext_modules=[
@ -987,6 +1031,16 @@ setuptools.setup(
extra_link_args=EXT_LDFLAGS,
include_dirs=EXT_INC_DIRS,
),
setuptools_extension.Extension(
"edb.pgsql.parser.parser",
["edb/pgsql/parser/parser.pyx"],
extra_compile_args=EXT_CFLAGS,
extra_link_args=EXT_LDFLAGS,
include_dirs=EXT_INC_DIRS,
library_dirs=EXT_LIB_DIRS,
libraries=['pg_query']
),
],
rust_extensions=[
setuptools_rust.RustExtension(

786
tests/test_pgsql_parse.py Normal file
View file

@ -0,0 +1,786 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2012-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from edb.pgsql import codegen, parser
from edb.testbase import lang as tb
from edb.tools import test
class TestEdgeQLSelect(tb.BaseDocTest):
def run_test(self, *, source, spec, expected):
def inline(text):
lines = (line.strip() for line in text.split('\n'))
return ' '.join((line for line in lines if len(line) > 0))
def normalize(s):
return s.replace(" ", " ").replace("( ", "(").replace(" )", ")")
source = normalize(inline(source))
can_omit_expected = False
if expected:
expected = normalize(inline(expected))
can_omit_expected = source == expected
else:
expected = source
ast = parser.parse(source)
sql_stmts = [
codegen.generate_source(stmt, pretty=False) for stmt in ast
]
sql = normalize("; ".join(sql_stmts))
self.assertEqual(expected, sql)
if can_omit_expected:
raise BaseException(
'Warning: test''s `source` is same as `expected`. '
'You can omit `expected`.'
)
def test_pgsql_parse_select_00(self):
"""
SELECT * FROM my_table
"""
def test_pgsql_parse_select_01(self):
"""
SELECT col1 FROM my_table WHERE
my_attribute LIKE 'condition' AND other = 5.6 AND extra > 5
% OK %
SELECT col1 FROM my_table WHERE
(((my_attribute LIKE 'condition') AND
(other = 5.6)) AND (extra > 5))
"""
def test_pgsql_parse_select_02(self):
"""
SELECT * FROM table_one JOIN table_two USING (common)
"""
def test_pgsql_parse_select_03(self):
"""
WITH fake_table AS (
SELECT SUM(countable) AS total FROM inner_table
GROUP BY groupable
) SELECT * FROM fake_table
% OK %
WITH fake_table AS ((
SELECT sum(countable) AS total FROM inner_table
GROUP BY groupable
)) SELECT * FROM fake_table
"""
def test_pgsql_parse_select_04(self):
"""
SELECT * FROM (SELECT something FROM dataset) AS other
"""
def test_pgsql_parse_select_05(self):
"""
SELECT a, CASE WHEN a=1 THEN 'one' WHEN a=2
THEN 'two' ELSE 'other' END FROM test
% OK %
SELECT a, (CASE WHEN (a = 1) THEN 'one' WHEN (a = 2)
THEN 'two' ELSE 'other' END) FROM test
"""
def test_pgsql_parse_select_06(self):
"""
SELECT CASE a.value WHEN 0 THEN '1' ELSE '2' END
FROM sometable a
% OK %
SELECT (CASE a.value WHEN 0 THEN '1' ELSE '2' END)
FROM sometable AS a
"""
def test_pgsql_parse_select_07(self):
"""
SELECT * FROM table_one UNION select * FROM table_two
% OK %
SELECT * FROM table_one UNION (SELECT * FROM table_two)
"""
def test_pgsql_parse_select_08(self):
"""
SELECT * FROM my_table WHERE ST_Intersects(geo1, geo2)
% OK %
SELECT * FROM my_table WHERE st_intersects(geo1, geo2)
"""
def test_pgsql_parse_select_09(self):
"""
SELECT 'accbf276-705b-11e7-b8e4-0242ac120002'::UUID
% OK %
SELECT ('accbf276-705b-11e7-b8e4-0242ac120002')::uuid
"""
def test_pgsql_parse_select_10(self):
"""
SELECT * FROM my_table ORDER BY field DESC NULLS FIRST
"""
def test_pgsql_parse_select_11(self):
"""
SELECT * FROM my_table ORDER BY field
% OK %
SELECT * FROM my_table ORDER BY field ASC NULLS LAST
"""
def test_pgsql_parse_select_12(self):
"""
SELECT salary, sum(salary) OVER () FROM empsalary
"""
def test_pgsql_parse_select_13(self):
"""
SELECT salary, sum(salary)
OVER (ORDER BY salary) FROM empsalary
% OK %
SELECT salary, sum(salary)
OVER (ORDER BY salary ASC NULLS LAST) FROM empsalary
"""
def test_pgsql_parse_select_14(self):
"""
SELECT salary, avg(salary)
OVER (PARTITION BY depname) FROM empsalary
"""
def test_pgsql_parse_select_15(self):
"""
SELECT m.* FROM mytable m WHERE m.foo IS NULL
% OK %
SELECT m.* FROM mytable AS m WHERE (m.foo IS NULL)
"""
def test_pgsql_parse_select_16(self):
"""
SELECT m.* FROM mytable m WHERE m.foo IS NOT NULL
% OK %
SELECT m.* FROM mytable AS m WHERE (m.foo IS NOT NULL)
"""
def test_pgsql_parse_select_17(self):
"""
SELECT m.* FROM mytable m WHERE m.foo IS TRUE
% OK %
SELECT m.* FROM mytable AS m WHERE (m.foo IS TRUE)
"""
def test_pgsql_parse_select_18(self):
"""
SELECT m.name AS mname, pname FROM manufacturers m,
LATERAL get_product_names(m.id) pname
% OK %
SELECT m.name AS mname, pname FROM manufacturers AS m,
LATERAL get_product_names(m.id) AS pname
"""
def test_pgsql_parse_select_19(self):
"""
SELECT * FROM unnest(ARRAY['a','b','c','d','e','f'])
% OK %
SELECT * FROM unnest(ARRAY['a', 'b', 'c', 'd', 'e', 'f'])
"""
def test_pgsql_parse_select_20(self):
"""
SELECT * FROM my_table
WHERE (a, b) in (('a', 'b'), ('c', 'd'))
% OK %
SELECT * FROM my_table
WHERE ((a, b) IN (('a', 'b'), ('c', 'd')))
"""
@test.xerror('bad FRO keyword')
def test_pgsql_parse_select_21(self):
"""
SELECT * FRO my_table
"""
@test.xerror('missing expression after THEN')
def test_pgsql_parse_select_22(self):
"""
SELECT a, CASE WHEN a=1 THEN 'one'
WHEN a=2 THEN ELSE 'other' END FROM test
"""
def test_pgsql_parse_select_23(self):
"""
SELECT * FROM table_one, table_two
"""
def test_pgsql_parse_select_24(self):
"""
SELECT * FROM table_one, public.table_one
% OK %
SELECT * FROM table_one, table_one
"""
def test_pgsql_parse_select_25(self):
"""
WITH fake_table AS (SELECT * FROM inner_table)
SELECT * FROM fake_table
% OK %
WITH fake_table AS ((SELECT * FROM inner_table))
SELECT * FROM fake_table
"""
def test_pgsql_parse_select_26(self):
"""
SELECT * FROM table_one JOIN table_two USING (common_1)
JOIN table_three USING (common_2)
"""
def test_pgsql_parse_select_27(self):
"""
select * FROM table_one UNION select * FROM table_two
% OK %
SELECT * FROM table_one UNION (SELECT * FROM table_two)
"""
def test_pgsql_parse_select_28(self):
"""
SELECT * FROM my_table WHERE (a, b) in ('a', 'b')
% OK %
SELECT * FROM my_table WHERE ((a, b) IN ('a', 'b'))
"""
def test_pgsql_parse_select_29(self):
"""
SELECT * FROM my_table
WHERE (a, b) in (('a', 'b'), ('c', 'd'))
% OK %
SELECT * FROM my_table
WHERE ((a, b) IN (('a', 'b'), ('c', 'd')))
"""
def test_pgsql_parse_select_30(self):
"""
SELECT (SELECT * FROM table_one)
% OK %
SELECT ((SELECT * FROM table_one))
"""
def test_pgsql_parse_select_31(self):
"""
SELECT my_func((select * from table_one))
% OK %
SELECT my_func(((SELECT * FROM table_one)))
"""
def test_pgsql_parse_select_32(self):
"""
SELECT 1
"""
def test_pgsql_parse_select_33(self):
"""
SELECT 2
"""
def test_pgsql_parse_select_34(self):
"""
SELECT $1
"""
def test_pgsql_parse_select_35(self):
"""
SELECT 1; SELECT a FROM b
"""
def test_pgsql_parse_select_36(self):
"""
SELECT COUNT(DISTINCT id), * FROM targets
WHERE something IS NOT NULL
AND elsewhere::interval < now()
% OK %
SELECT count(DISTINCT id), * FROM targets
WHERE ((something IS NOT NULL)
AND ((elsewhere)::pg_catalog.interval < now()))
"""
def test_pgsql_parse_select_37(self):
"""
SELECT b AS x, a AS y FROM z
"""
def test_pgsql_parse_select_38(self):
"""
WITH a AS (SELECT * FROM x WHERE x.y = $1 AND x.z = 1)
SELECT * FROM a
% OK %
WITH a AS ((SELECT * FROM x WHERE ((x.y = $1) AND (x.z = 1))))
SELECT * FROM a
"""
def test_pgsql_parse_select_39(self):
"""
SELECT * FROM x WHERE y IN ($1)
% OK %
SELECT * FROM x WHERE (y IN ($1))
"""
def test_pgsql_parse_select_40(self):
"""
SELECT * FROM x WHERE y IN ($1, $2, $3)
% OK %
SELECT * FROM x WHERE (y IN ($1, $2, $3))
"""
def test_pgsql_parse_select_41(self):
"""
SELECT * FROM x WHERE y IN ( $1::uuid )
% OK %
SELECT * FROM x WHERE (y IN (($1)::uuid))
"""
def test_pgsql_parse_select_42(self):
"""
SELECT * FROM x
WHERE y IN ( $1::uuid, $2::uuid, $3::uuid )
% OK %
SELECT * FROM x
WHERE (y IN (($1)::uuid, ($2)::uuid, ($3)::uuid))
"""
def test_pgsql_parse_select_43(self):
"""
SELECT * FROM x AS a, y AS b
"""
def test_pgsql_parse_select_44(self):
"""
SELECT * FROM y AS a, x AS b
"""
def test_pgsql_parse_select_45(self):
"""
SELECT x AS a, y AS b FROM x
"""
def test_pgsql_parse_select_46(self):
"""
SELECT x, y FROM z
"""
def test_pgsql_parse_select_47(self):
"""
SELECT y, x FROM z
"""
def test_pgsql_parse_select_48(self):
"""
SELECT * FROM a
"""
def test_pgsql_parse_select_49(self):
"""
SELECT * FROM a AS b
"""
def test_pgsql_parse_select_50(self):
"""
-- nothing
% OK %
"""
# TODO: is this ok? What is `(0)`?
def test_pgsql_parse_select_51(self):
"""
SELECT INTERVAL (0) $2
% OK %
SELECT ($2)::pg_catalog.interval
"""
def test_pgsql_parse_select_52(self):
"""
SELECT INTERVAL (2) $2
% OK %
SELECT ($2)::pg_catalog.interval
"""
def test_pgsql_parse_select_53(self):
"""
SELECT * FROM t WHERE t.a IN (1, 2) AND t.b = 3
% OK %
SELECT * FROM t WHERE ((t.a IN (1, 2)) AND (t.b = 3))
"""
def test_pgsql_parse_select_54(self):
"""
SELECT * FROM t WHERE t.b = 3 AND t.a IN (1, 2)
% OK %
SELECT * FROM t WHERE ((t.b = 3) AND (t.a IN (1, 2)))
"""
def test_pgsql_parse_select_55(self):
"""
SELECT * FROM t WHERE a && '[1,2]'
% OK %
SELECT * FROM t WHERE (a && '[1,2]')
"""
def test_pgsql_parse_select_56(self):
"""
SELECT * FROM t WHERE a && '[1,2]'::int4range
% OK %
SELECT * FROM t WHERE (a && ('[1,2]')::int4range)
"""
def test_pgsql_parse_select_57(self):
"""
SELECT * FROM t_20210301_x
"""
def test_pgsql_parse_insert_00(self):
"""
INSERT INTO my_table (id, name) VALUES (1, 'some')
"""
def test_pgsql_parse_insert_01(self):
"""
INSERT INTO my_table (id, name) SELECT 1, 'some'
% OK %
INSERT INTO my_table (id, name) ((SELECT 1, 'some'))
"""
def test_pgsql_parse_insert_02(self):
"""
INSERT INTO my_table (id) VALUES (5) RETURNING id, date
"""
def test_pgsql_parse_insert_03(self):
"""
INSERT INTO my_table (id) VALUES (5) RETURNING id, "date"
% OK %
INSERT INTO my_table (id) VALUES (5) RETURNING id, date
"""
def test_pgsql_parse_insert_04(self):
"""
INSERT INTO my_table (id) VALUES(1); SELECT * FROM my_table
% OK %
INSERT INTO my_table (id) VALUES (1); SELECT * FROM my_table
"""
@test.xerror('missing VALUES or SELECT')
def test_pgsql_parse_insert_05(self):
"""
INSERT INTO my_table
"""
def test_pgsql_parse_insert_06(self):
"""
INSERT INTO table_one (id, name) SELECT * from table_two
% OK %
INSERT INTO table_one (id, name) ((SELECT * FROM table_two))
"""
def test_pgsql_parse_insert_07(self):
"""
WITH fake as (SELECT * FROM inner_table)
INSERT INTO dataset SELECT * FROM fake
% OK %
WITH fake AS ((SELECT * FROM inner_table))
INSERT INTO dataset ((SELECT * FROM fake))
"""
def test_pgsql_parse_insert_08(self):
"""
INSERT INTO test (a, b) VALUES
(ARRAY[$1, $1, $2, $3], $4::timestamptz),
(ARRAY[$1, $1, $2, $3], $4::timestamptz),
($5, $6::timestamptz)
% OK %
INSERT INTO test (a, b) VALUES
(ARRAY[$1, $1, $2, $3], ($4)::timestamptz),
(ARRAY[$1, $1, $2, $3], ($4)::timestamptz),
($5, ($6)::timestamptz)
"""
def test_pgsql_parse_insert_09(self):
"""
INSERT INTO films (code, title, did) VALUES
('UA502', 'Bananas', 105), ('T_601', 'Yojimbo', DEFAULT)
"""
def test_pgsql_parse_insert_10(self):
"""
INSERT INTO films (code, title, did) VALUES ($1, $2, $3)
"""
def test_pgsql_parse_update_00(self):
"""
UPDATE my_table SET the_value = DEFAULT
"""
def test_pgsql_parse_update_01(self):
"""
UPDATE tictactoe SET board[1:3][1:3] = '{{,,},{,,},{,,}}'
WHERE game = 1
% OK %
UPDATE tictactoe SET board[1:3][1:3] = '{{,,},{,,},{,,}}'
WHERE (game = 1)
"""
def test_pgsql_parse_update_02(self):
"""
UPDATE accounts SET
(contact_first_name, contact_last_name) =
(SELECT first_name, last_name
FROM salesmen WHERE salesmen.id = accounts.sales_id)
% OK %
UPDATE accounts SET
(contact_first_name, contact_last_name) =
((SELECT first_name, last_name
FROM salesmen WHERE (salesmen.id = accounts.sales_id)))
"""
def test_pgsql_parse_update_03(self):
"""
UPDATE my_table SET id = 5; DELETE FROM my_table
"""
def test_pgsql_parse_update_04(self):
"""
UPDATE dataset SET a = 5
WHERE id IN (SELECT * from table_one)
OR age IN (select * from table_two)
% OK %
UPDATE dataset SET a = 5
WHERE (id IN ((SELECT * FROM table_one))
OR age IN ((SELECT * FROM table_two)))
"""
def test_pgsql_parse_update_05(self):
"""
UPDATE dataset SET a = 5 FROM extra WHERE b = c
% OK %
UPDATE dataset SET a = 5 FROM extra WHERE (b = c)
"""
def test_pgsql_parse_update_06(self):
"""
UPDATE users SET one_thing = $1, second_thing = $2
WHERE users.id = $1
% OK %
UPDATE users SET one_thing = $1, second_thing = $2
WHERE (users.id = $1)
"""
def test_pgsql_parse_update_07(self):
"""
UPDATE users SET something_else = $1 WHERE users.id = $1
% OK %
UPDATE users SET something_else = $1 WHERE (users.id = $1)
"""
def test_pgsql_parse_update_08(self):
"""
UPDATE users SET something_else =
(SELECT a FROM x WHERE uid = users.id LIMIT 1)
WHERE users.id = $1
% OK %
UPDATE users SET something_else =
((SELECT a FROM x WHERE (uid = users.id) LIMIT 1))
WHERE (users.id = $1)
"""
def test_pgsql_parse_update_09(self):
"""
UPDATE x SET a = 1, b = 2, c = 3
"""
def test_pgsql_parse_update_10(self):
"""
UPDATE x SET z = now()
"""
def test_pgsql_parse_delete(self):
"""
DELETE FROM dataset USING table_one
WHERE x = y OR x IN (SELECT * from table_two)
% OK %
DELETE FROM dataset USING table_one
WHERE ((x = y) OR x IN ((SELECT * FROM table_two)))
"""
def test_pgsql_parse_query_00(self):
"""
SELECT * FROM
(VALUES (1, 'one'), (2, 'two')) AS t(num, letter)
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_01(self):
"""
SELECT * FROM my_table ORDER BY field ASC NULLS LAST USING @>
"""
@test.xfail("unsupported")
def test_pgsql_parse_query_02(self):
"""
SELECT m.* FROM mytable AS m FOR UPDATE
"""
@test.xfail("unsupported")
def test_pgsql_parse_query_03(self):
"""
SELECT m.* FROM mytable m FOR SHARE of m nowait
"""
def test_pgsql_parse_query_04(self):
"""
SELECT * FROM unnest(ARRAY['a', 'b', 'c', 'd', 'e', 'f'])
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_06(self):
"""
SELECT ?
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_07(self):
"""
SELECT * FROM x WHERE y = ?
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_08(self):
"""
SELECT * FROM x WHERE y = ANY ($1)
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_09(self):
"""
PREPARE a123 AS SELECT a
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_10(self):
"""
EXECUTE a123
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_11(self):
"""
DEALLOCATE a123
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_12(self):
"""
DEALLOCATE ALL
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_13(self):
"""
EXPLAIN ANALYZE SELECT a
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_14(self):
"""
VACUUM FULL my_table
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_15(self):
"""
SAVEPOINT some_id
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_16(self):
"""
RELEASE some_id
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_17(self):
"""
PREPARE TRANSACTION 'some_id'
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_18(self):
"""
START TRANSACTION READ WRITE
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_19(self):
"""
DECLARE cursor_123 CURSOR FOR
SELECT * FROM test WHERE id = 123
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_20(self):
"""
FETCH 1000 FROM cursor_123
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_21(self):
"""
CLOSE cursor_123
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_22(self):
"""
CREATE VIEW view_a (a, b) AS WITH RECURSIVE view_a (a, b) AS
(SELECT * FROM a(1)) SELECT "a", "b" FROM "view_a"
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_23(self):
"""
CREATE FOREIGN TABLE ft1 () SERVER no_server
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_24(self):
"""
CREATE TEMPORARY TABLE my_temp_table
(test_id integer NOT NULL) ON COMMIT DROP
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_25(self):
"""
CREATE TEMPORARY TABLE my_temp_table AS SELECT 1
"""
@test.xerror("unsupported")
def test_pgsql_parse_query_26(self):
"""
CREATE TABLE types (
a float(2), b float(49),
c NUMERIC(2, 3), d character(4), e char(5),
f varchar(6), g character varying(7))
"""