mirror of
https://github.com/maxkratz/edgedb.git
synced 2024-09-16 18:59:05 +00:00
Fix DML-containing FOR when the iterator is an object set with duplicates (#6609)
DML-containing FOR loops get compiled into CTEs, and after the DML is performed, the CTEs need to get joined back with the DML CTEs in order to produce the output. Currently we do these joins on the 'identity' of the iterator set; for non-object sets, we generate a transient identity with uuid_generate_v4, and for object sets we ues the actual object id. That works ok if all the objects in the iterator set are unique, but if they aren't we produce a lot of duplicate output rows. To solve this, we need to use transient ids for object types too, but overriding object identity with something transient causes problems: if we need to actually join the object against some *actual* table, we are in trouble, since we've overriden the identity. Solve this by introducing a new 'iterator' aspect and distinguishing between whether we want to join on an iterator aspect or a real identity when making path bonds. Fixes #6608
This commit is contained in:
parent
2f88787ab3
commit
f741fc9455
7 changed files with 127 additions and 72 deletions
|
@ -148,7 +148,7 @@ class EdgeQLPathInfo(Base):
|
|||
|
||||
# Ignore the below fields in AST visitor/transformer.
|
||||
__ast_meta__ = {
|
||||
'path_id', 'path_scope', 'path_outputs', 'is_distinct',
|
||||
'path_id', 'path_bonds', 'path_outputs', 'is_distinct',
|
||||
'path_id_mask', 'path_namespace',
|
||||
'packed_path_outputs', 'packed_path_namespace',
|
||||
}
|
||||
|
@ -160,7 +160,7 @@ class EdgeQLPathInfo(Base):
|
|||
is_distinct: bool = True
|
||||
|
||||
# A subset of paths necessary to perform joining.
|
||||
path_scope: typing.Set[irast.PathId] = ast.field(factory=set)
|
||||
path_bonds: typing.Set[tuple[irast.PathId, bool]] = ast.field(factory=set)
|
||||
|
||||
# Map of res target names corresponding to paths.
|
||||
path_outputs: typing.Dict[
|
||||
|
@ -880,7 +880,7 @@ class RangeSubselect(PathRangeVar):
|
|||
subquery: Query
|
||||
|
||||
@property
|
||||
def query(self):
|
||||
def query(self) -> Query:
|
||||
return self.subquery
|
||||
|
||||
|
||||
|
@ -1056,6 +1056,11 @@ class IteratorCTE(ImmutableBase):
|
|||
# A list of other paths to *also* register the iterator rvar as
|
||||
# providing when it is merged into a statement.
|
||||
other_paths: tuple[tuple[irast.PathId, str], ...] = ()
|
||||
iterator_bond: bool = False
|
||||
|
||||
@property
|
||||
def aspect(self) -> str:
|
||||
return 'iterator' if self.iterator_bond else 'identity'
|
||||
|
||||
|
||||
class Statement(Base):
|
||||
|
|
|
@ -48,7 +48,10 @@ def get_volatility_ref(
|
|||
"""Produce an appropriate volatility_ref from a path_id."""
|
||||
|
||||
ref: Optional[pgast.BaseExpr] = relctx.maybe_get_path_var(
|
||||
stmt, path_id, aspect='identity', ctx=ctx)
|
||||
stmt, path_id, aspect='iterator', ctx=ctx)
|
||||
if not ref:
|
||||
ref = relctx.maybe_get_path_var(
|
||||
stmt, path_id, aspect='identity', ctx=ctx)
|
||||
if not ref:
|
||||
rvar = relctx.maybe_get_path_rvar(
|
||||
stmt, path_id, aspect='value', ctx=ctx)
|
||||
|
@ -179,8 +182,6 @@ def compile_iterator_expr(
|
|||
subctx.expr_exposed = False
|
||||
subctx.rel = query
|
||||
|
||||
already_existed = bool(relctx.maybe_get_path_rvar(
|
||||
query, iterator_expr.path_id, aspect='value', ctx=ctx))
|
||||
dispatch.visit(iterator_expr, ctx=subctx)
|
||||
iterator_rvar = relctx.get_path_rvar(
|
||||
query, iterator_expr.path_id, aspect='value', ctx=ctx)
|
||||
|
@ -190,6 +191,9 @@ def compile_iterator_expr(
|
|||
# makes sure that we don't spuriously produce output when
|
||||
# iterating over optional pointers.
|
||||
is_optional = ctx.scope_tree.is_optional(iterator_expr.path_id)
|
||||
if isinstance(iterator_query, pgast.SelectStmt):
|
||||
iterator_var = pathctx.get_path_value_var(
|
||||
iterator_query, path_id=iterator_expr.path_id, env=ctx.env)
|
||||
if not is_optional:
|
||||
if isinstance(iterator_query, pgast.SelectStmt):
|
||||
iterator_var = pathctx.get_path_value_var(
|
||||
|
@ -204,18 +208,19 @@ def compile_iterator_expr(
|
|||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Regardless of result type, we use transient identity,
|
||||
# for path identity of the iterator expression. This is
|
||||
# necessary to maintain correct correlation for the state
|
||||
# of iteration in DML statements.
|
||||
# The already_existed check is to avoid adding in bogus volatility refs
|
||||
# when we reprocess an iterator that was hoisted.
|
||||
if not already_existed:
|
||||
relctx.ensure_bond_for_expr(
|
||||
iterator_expr.expr.result, iterator_query, ctx=subctx)
|
||||
if is_optional:
|
||||
relctx.ensure_bond_for_expr(
|
||||
iterator_expr, iterator_query, ctx=subctx)
|
||||
# Regardless of result type, iterators need their own
|
||||
# transient identity for path identity of the iterator
|
||||
# expression in order maintain correct correlation for the
|
||||
# state of iteration in DML statements, even when there
|
||||
# are duplicates in the iterator.
|
||||
# This gets tracked as a special 'iterator' aspect in order
|
||||
# to distinguish it from actual object identity.
|
||||
relctx.create_iterator_identity_for_path(
|
||||
iterator_expr.path_id, iterator_query, ctx=subctx)
|
||||
|
||||
pathctx.put_path_rvar(
|
||||
query, iterator_expr.path_id, iterator_rvar,
|
||||
aspect='iterator')
|
||||
|
||||
return iterator_rvar
|
||||
|
||||
|
|
|
@ -160,8 +160,7 @@ def init_dml_stmt(
|
|||
dml_rvar = relctx.rvar_for_rel(dml_cte, ctx=ctx)
|
||||
else_cte = (dml_cte, dml_rvar)
|
||||
|
||||
if ctx.enclosing_cte_iterator:
|
||||
pathctx.put_path_bond(ctx.rel, ctx.enclosing_cte_iterator.path_id)
|
||||
put_iterator_bond(ctx.enclosing_cte_iterator, ctx.rel)
|
||||
|
||||
ctx.dml_stmt_stack.append(ir_stmt)
|
||||
|
||||
|
@ -210,8 +209,7 @@ def gen_dml_union(
|
|||
)
|
||||
|
||||
assert qry.larg
|
||||
if ctx.enclosing_cte_iterator:
|
||||
pathctx.put_path_bond(qry.larg, ctx.enclosing_cte_iterator.path_id)
|
||||
put_iterator_bond(ctx.enclosing_cte_iterator, qry.larg)
|
||||
|
||||
union_cte = pgast.CommonTableExpr(
|
||||
query=qry.larg,
|
||||
|
@ -350,6 +348,15 @@ def wrap_dml_cte(
|
|||
return dml_rvar
|
||||
|
||||
|
||||
def put_iterator_bond(
|
||||
iterator: Optional[pgast.IteratorCTE],
|
||||
select: pgast.Query,
|
||||
) -> None:
|
||||
if iterator:
|
||||
pathctx.put_path_bond(
|
||||
select, iterator.path_id, iterator=iterator.iterator_bond)
|
||||
|
||||
|
||||
def merge_iterator_scope(
|
||||
iterator: Optional[pgast.IteratorCTE],
|
||||
select: pgast.SelectStmt,
|
||||
|
@ -372,7 +379,7 @@ def merge_iterator(
|
|||
if iterator:
|
||||
iterator_rvar = relctx.rvar_for_rel(iterator.cte, ctx=ctx)
|
||||
|
||||
pathctx.put_path_bond(select, iterator.path_id)
|
||||
put_iterator_bond(iterator, select)
|
||||
relctx.include_rvar(
|
||||
select, iterator_rvar,
|
||||
path_id=iterator.path_id,
|
||||
|
@ -537,7 +544,7 @@ def compile_iterator_cte(
|
|||
iterator_cte = ctx.dml_stmts[iterator_set]
|
||||
return pgast.IteratorCTE(
|
||||
path_id=iterator_set.path_id, cte=iterator_cte,
|
||||
parent=last_iterator)
|
||||
parent=last_iterator, iterator_bond=True)
|
||||
|
||||
with ctx.newrel() as ictx:
|
||||
ictx.scope_tree = ctx.scope_tree
|
||||
|
@ -551,18 +558,21 @@ def compile_iterator_cte(
|
|||
if iterator_set.path_id.is_objtype_path():
|
||||
relgen.ensure_source_rvar(iterator_set, ictx.rel, ctx=ictx)
|
||||
ictx.rel.path_id = iterator_set.path_id
|
||||
pathctx.put_path_bond(ictx.rel, iterator_set.path_id)
|
||||
pathctx.put_path_bond(ictx.rel, iterator_set.path_id, iterator=True)
|
||||
iterator_cte = pgast.CommonTableExpr(
|
||||
query=ictx.rel,
|
||||
name=ctx.env.aliases.get('iter')
|
||||
name=ctx.env.aliases.get('iter'),
|
||||
)
|
||||
ictx.toplevel_stmt.append_cte(iterator_cte)
|
||||
|
||||
ctx.dml_stmts[iterator_set] = iterator_cte
|
||||
|
||||
return pgast.IteratorCTE(
|
||||
path_id=iterator_set.path_id, cte=iterator_cte,
|
||||
parent=last_iterator)
|
||||
path_id=iterator_set.path_id,
|
||||
cte=iterator_cte,
|
||||
parent=last_iterator,
|
||||
iterator_bond=True,
|
||||
)
|
||||
|
||||
|
||||
def _mk_dynamic_get_path(
|
||||
|
@ -1006,7 +1016,9 @@ def process_insert_shape(
|
|||
subctx.path_scope = ctx.path_scope.new_child()
|
||||
merge_iterator(inner_iterator, select, ctx=subctx)
|
||||
inner_iterator_id = relctx.get_path_var(
|
||||
select, inner_iterator.path_id, aspect='identity', ctx=ctx)
|
||||
select, inner_iterator.path_id,
|
||||
aspect=inner_iterator.aspect,
|
||||
ctx=ctx)
|
||||
|
||||
# Process the Insert IR and separate links that go
|
||||
# into the main table from links that are inserted into
|
||||
|
@ -1050,8 +1062,7 @@ def process_insert_shape(
|
|||
if link_ptr_info and link_ptr_info.table_type == 'link':
|
||||
external_inserts.append(element)
|
||||
|
||||
if iterator is not None:
|
||||
pathctx.put_path_bond(select, iterator.path_id)
|
||||
put_iterator_bond(iterator, select)
|
||||
|
||||
for aspect in ('value', 'identity'):
|
||||
pathctx._put_path_output_var(
|
||||
|
@ -1491,7 +1502,7 @@ def compile_insert_else_body(
|
|||
dummy_pathid = irast.PathId.new_dummy(ctx.env.aliases.get('dummy'))
|
||||
with ictx.subrel() as dctx:
|
||||
dummy_q = dctx.rel
|
||||
relctx.ensure_transient_identity_for_path(
|
||||
relctx.create_iterator_identity_for_path(
|
||||
dummy_pathid, dummy_q, ctx=dctx)
|
||||
dummy_rvar = relctx.rvar_for_rel(
|
||||
dummy_q, lateral=True, ctx=ictx)
|
||||
|
@ -1507,7 +1518,12 @@ def compile_insert_else_body(
|
|||
iter_path_id = (
|
||||
enclosing_cte_iterator.path_id if
|
||||
enclosing_cte_iterator else None)
|
||||
relctx.anti_join(ictx.rel, subrel, iter_path_id, ctx=ctx)
|
||||
aspect = (
|
||||
enclosing_cte_iterator.aspect if enclosing_cte_iterator
|
||||
else 'identity'
|
||||
)
|
||||
relctx.anti_join(ictx.rel, subrel, iter_path_id,
|
||||
aspect=aspect, ctx=ctx)
|
||||
|
||||
# Package it up as a CTE
|
||||
anti_cte = pgast.CommonTableExpr(
|
||||
|
@ -1517,7 +1533,9 @@ def compile_insert_else_body(
|
|||
ictx.toplevel_stmt.append_cte(anti_cte)
|
||||
anti_cte_iterator = pgast.IteratorCTE(
|
||||
path_id=dummy_pathid, cte=anti_cte,
|
||||
parent=ictx.enclosing_cte_iterator)
|
||||
parent=ictx.enclosing_cte_iterator,
|
||||
iterator_bond=True
|
||||
)
|
||||
|
||||
return anti_cte_iterator
|
||||
|
||||
|
@ -1602,9 +1620,7 @@ def process_update_body(
|
|||
contents_select = update_cte.query
|
||||
toplevel = ctx.toplevel_stmt
|
||||
|
||||
if ctx.enclosing_cte_iterator:
|
||||
pathctx.put_path_bond(
|
||||
contents_select, ctx.enclosing_cte_iterator.path_id)
|
||||
put_iterator_bond(ctx.enclosing_cte_iterator, contents_select)
|
||||
|
||||
assert dml_parts.range_cte
|
||||
iterator = pgast.IteratorCTE(
|
||||
|
|
|
@ -45,6 +45,7 @@ class PathAspect(s_enum.StrEnum):
|
|||
VALUE = 'value'
|
||||
SOURCE = 'source'
|
||||
SERIALIZED = 'serialized'
|
||||
ITERATOR = 'iterator'
|
||||
|
||||
|
||||
# A mapping of more specific aspect -> less specific aspect for objects
|
||||
|
@ -698,8 +699,14 @@ def put_path_serialized_var_if_not_exists(
|
|||
|
||||
|
||||
def put_path_bond(
|
||||
stmt: pgast.BaseRelation, path_id: irast.PathId) -> None:
|
||||
stmt.path_scope.add(path_id)
|
||||
stmt: pgast.BaseRelation, path_id: irast.PathId, iterator: bool=False
|
||||
) -> None:
|
||||
'''Register a path id that should be joined on when joining stmt
|
||||
|
||||
iterator indicates whether the identity or iterator aspect should
|
||||
be used.
|
||||
'''
|
||||
stmt.path_bonds.add((path_id, iterator))
|
||||
|
||||
|
||||
def put_rvar_path_bond(
|
||||
|
|
|
@ -212,7 +212,7 @@ def include_rvar(
|
|||
Compiler context.
|
||||
"""
|
||||
if aspects is None:
|
||||
aspects = ('value',)
|
||||
aspects = ('value', 'iterator')
|
||||
if path_id.is_objtype_path():
|
||||
if isinstance(rvar, pgast.RangeSubselect):
|
||||
if pathctx.has_path_aspect(
|
||||
|
@ -713,19 +713,6 @@ def semi_join(
|
|||
return set_rvar
|
||||
|
||||
|
||||
def ensure_bond_for_expr(
|
||||
ir_set: irast.Set,
|
||||
stmt: pgast.BaseRelation,
|
||||
*,
|
||||
ctx: context.CompilerContextLevel,
|
||||
) -> None:
|
||||
if ir_set.path_id.is_objtype_path():
|
||||
# ObjectTypes have inherent identity
|
||||
return
|
||||
|
||||
ensure_transient_identity_for_path(ir_set.path_id, stmt, ctx=ctx)
|
||||
|
||||
|
||||
def apply_volatility_ref(
|
||||
stmt: pgast.SelectStmt, *,
|
||||
ctx: context.CompilerContextLevel) -> None:
|
||||
|
@ -744,7 +731,7 @@ def apply_volatility_ref(
|
|||
)
|
||||
|
||||
|
||||
def ensure_transient_identity_for_path(
|
||||
def create_iterator_identity_for_path(
|
||||
path_id: irast.PathId,
|
||||
stmt: pgast.BaseRelation,
|
||||
*,
|
||||
|
@ -756,12 +743,13 @@ def ensure_transient_identity_for_path(
|
|||
args=[],
|
||||
)
|
||||
|
||||
pathctx.put_path_identity_var(stmt, path_id, id_expr, force=True)
|
||||
pathctx.put_path_bond(stmt, path_id)
|
||||
|
||||
if isinstance(stmt, pgast.SelectStmt):
|
||||
path_id = pathctx.map_path_id(path_id, stmt.view_path_id_map)
|
||||
apply_volatility_ref(stmt, ctx=ctx)
|
||||
|
||||
pathctx.put_path_var(stmt, path_id, id_expr, force=True, aspect='iterator')
|
||||
pathctx.put_path_bond(stmt, path_id, iterator=True)
|
||||
|
||||
|
||||
def get_scope(
|
||||
ir_set: irast.Set, *,
|
||||
|
@ -1289,15 +1277,19 @@ def _plain_join(
|
|||
) -> None:
|
||||
condition = None
|
||||
|
||||
for path_id in right_rvar.query.path_scope:
|
||||
lref = maybe_get_path_var(query, path_id, aspect='identity', ctx=ctx)
|
||||
if lref is None:
|
||||
for path_id, iterator_var in right_rvar.query.path_bonds:
|
||||
lref = None
|
||||
aspect = 'iterator' if iterator_var else 'identity'
|
||||
|
||||
lref = maybe_get_path_var(
|
||||
query, path_id, aspect=aspect, ctx=ctx)
|
||||
if lref is None and not iterator_var:
|
||||
lref = maybe_get_path_var(query, path_id, aspect='value', ctx=ctx)
|
||||
if lref is None:
|
||||
continue
|
||||
|
||||
rref = pathctx.get_rvar_path_identity_var(
|
||||
right_rvar, path_id, env=ctx.env)
|
||||
rref = pathctx.get_rvar_path_var(
|
||||
right_rvar, path_id, aspect=aspect, env=ctx.env)
|
||||
|
||||
assert isinstance(lref, pgast.ColumnRef)
|
||||
assert isinstance(rref, pgast.ColumnRef)
|
||||
|
@ -1332,17 +1324,18 @@ def _lateral_union_join(
|
|||
for component in astutils.each_query_in_set(right_rvar.subquery):
|
||||
condition = None
|
||||
|
||||
for path_id in right_rvar.query.path_scope:
|
||||
for path_id, iterator_var in right_rvar.query.path_bonds:
|
||||
aspect = 'iterator' if iterator_var else 'identity'
|
||||
lref = maybe_get_path_var(
|
||||
query, path_id, aspect='identity', ctx=ctx)
|
||||
if lref is None:
|
||||
query, path_id, aspect=aspect, ctx=ctx)
|
||||
if lref is None and not iterator_var:
|
||||
lref = maybe_get_path_var(
|
||||
query, path_id, aspect='value', ctx=ctx)
|
||||
if lref is None:
|
||||
continue
|
||||
|
||||
rref = pathctx.get_path_identity_var(
|
||||
component, path_id, env=ctx.env)
|
||||
rref = pathctx.get_path_var(
|
||||
component, path_id, aspect=aspect, env=ctx.env)
|
||||
|
||||
assert isinstance(lref, pgast.ColumnRef)
|
||||
assert isinstance(rref, pgast.ColumnRef)
|
||||
|
@ -1717,6 +1710,7 @@ def wrap_set_op_query(
|
|||
def anti_join(
|
||||
lhs: pgast.SelectStmt, rhs: pgast.SelectStmt,
|
||||
path_id: Optional[irast.PathId], *,
|
||||
aspect: str='identity',
|
||||
ctx: context.CompilerContextLevel,
|
||||
) -> None:
|
||||
"""Filter elements out of the LHS that appear on the RHS"""
|
||||
|
@ -1724,10 +1718,10 @@ def anti_join(
|
|||
if path_id:
|
||||
# grab the identity from the LHS and do an
|
||||
# anti-join against the RHS.
|
||||
src_ref = pathctx.get_path_identity_var(
|
||||
lhs, path_id=path_id, env=ctx.env)
|
||||
pathctx.get_path_identity_output(
|
||||
rhs, path_id=path_id, env=ctx.env)
|
||||
src_ref = pathctx.get_path_var(
|
||||
lhs, path_id=path_id, aspect=aspect, env=ctx.env)
|
||||
pathctx.get_path_output(
|
||||
rhs, path_id=path_id, aspect=aspect, env=ctx.env)
|
||||
cond_expr: pgast.BaseExpr = astutils.new_binop(
|
||||
src_ref, rhs, 'NOT IN')
|
||||
else:
|
||||
|
|
|
@ -1992,7 +1992,6 @@ def process_set_as_tuple(
|
|||
typeref=ir_set.typeref,
|
||||
)
|
||||
|
||||
relctx.ensure_bond_for_expr(ir_set, stmt, ctx=ctx)
|
||||
pathctx.put_path_value_var(stmt, ir_set.path_id, set_expr)
|
||||
|
||||
# This is an unfortunate hack. If any of those types that we
|
||||
|
|
|
@ -1737,6 +1737,35 @@ class TestInsert(tb.QueryTestCase):
|
|||
],
|
||||
)
|
||||
|
||||
async def test_edgeql_insert_for_23(self):
|
||||
await self.con.execute(r"""
|
||||
INSERT Subordinate { name := "a" }
|
||||
""")
|
||||
|
||||
await self.assert_query_result(
|
||||
"""
|
||||
for x in {Subordinate, Subordinate} union (
|
||||
(x { name }, (insert Note { name := '', subject := x }))
|
||||
);
|
||||
""",
|
||||
[
|
||||
[{'name': "a"}, {}],
|
||||
[{'name': "a"}, {}],
|
||||
],
|
||||
)
|
||||
|
||||
await self.assert_query_result(
|
||||
"""
|
||||
for x in {Subordinate, Subordinate} union (
|
||||
(x { name }, (insert InsertTest { l2 := 0, sub := x }))
|
||||
);
|
||||
""",
|
||||
[
|
||||
[{'name': "a"}, {}],
|
||||
[{'name': "a"}, {}],
|
||||
],
|
||||
)
|
||||
|
||||
async def test_edgeql_insert_for_bad_01(self):
|
||||
with self.assertRaisesRegex(
|
||||
edgedb.errors.QueryError,
|
||||
|
|
Loading…
Reference in a new issue