Fix DML-containing FOR when the iterator is an object set with duplicates (#6609)

DML-containing FOR loops get compiled into CTEs, and after the DML is
performed, the CTEs need to get joined back with the DML CTEs in order
to produce the output. Currently we do these joins on the 'identity'
of the iterator set; for non-object sets, we generate a transient
identity with uuid_generate_v4, and for object sets we ues the actual
object id.

That works ok if all the objects in the iterator set are unique, but
if they aren't we produce a lot of duplicate output rows.

To solve this, we need to use transient ids for object types too, but
overriding object identity with something transient causes problems:
if we need to actually join the object against some *actual* table, we
are in trouble, since we've overriden the identity.

Solve this by introducing a new 'iterator' aspect and distinguishing
between whether we want to join on an iterator aspect or a real identity
when making path bonds.

Fixes #6608
This commit is contained in:
Michael J. Sullivan 2023-12-15 11:31:18 -08:00 committed by GitHub
parent 2f88787ab3
commit f741fc9455
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 127 additions and 72 deletions

View file

@ -148,7 +148,7 @@ class EdgeQLPathInfo(Base):
# Ignore the below fields in AST visitor/transformer.
__ast_meta__ = {
'path_id', 'path_scope', 'path_outputs', 'is_distinct',
'path_id', 'path_bonds', 'path_outputs', 'is_distinct',
'path_id_mask', 'path_namespace',
'packed_path_outputs', 'packed_path_namespace',
}
@ -160,7 +160,7 @@ class EdgeQLPathInfo(Base):
is_distinct: bool = True
# A subset of paths necessary to perform joining.
path_scope: typing.Set[irast.PathId] = ast.field(factory=set)
path_bonds: typing.Set[tuple[irast.PathId, bool]] = ast.field(factory=set)
# Map of res target names corresponding to paths.
path_outputs: typing.Dict[
@ -880,7 +880,7 @@ class RangeSubselect(PathRangeVar):
subquery: Query
@property
def query(self):
def query(self) -> Query:
return self.subquery
@ -1056,6 +1056,11 @@ class IteratorCTE(ImmutableBase):
# A list of other paths to *also* register the iterator rvar as
# providing when it is merged into a statement.
other_paths: tuple[tuple[irast.PathId, str], ...] = ()
iterator_bond: bool = False
@property
def aspect(self) -> str:
return 'iterator' if self.iterator_bond else 'identity'
class Statement(Base):

View file

@ -48,7 +48,10 @@ def get_volatility_ref(
"""Produce an appropriate volatility_ref from a path_id."""
ref: Optional[pgast.BaseExpr] = relctx.maybe_get_path_var(
stmt, path_id, aspect='identity', ctx=ctx)
stmt, path_id, aspect='iterator', ctx=ctx)
if not ref:
ref = relctx.maybe_get_path_var(
stmt, path_id, aspect='identity', ctx=ctx)
if not ref:
rvar = relctx.maybe_get_path_rvar(
stmt, path_id, aspect='value', ctx=ctx)
@ -179,8 +182,6 @@ def compile_iterator_expr(
subctx.expr_exposed = False
subctx.rel = query
already_existed = bool(relctx.maybe_get_path_rvar(
query, iterator_expr.path_id, aspect='value', ctx=ctx))
dispatch.visit(iterator_expr, ctx=subctx)
iterator_rvar = relctx.get_path_rvar(
query, iterator_expr.path_id, aspect='value', ctx=ctx)
@ -190,6 +191,9 @@ def compile_iterator_expr(
# makes sure that we don't spuriously produce output when
# iterating over optional pointers.
is_optional = ctx.scope_tree.is_optional(iterator_expr.path_id)
if isinstance(iterator_query, pgast.SelectStmt):
iterator_var = pathctx.get_path_value_var(
iterator_query, path_id=iterator_expr.path_id, env=ctx.env)
if not is_optional:
if isinstance(iterator_query, pgast.SelectStmt):
iterator_var = pathctx.get_path_value_var(
@ -204,18 +208,19 @@ def compile_iterator_expr(
else:
raise NotImplementedError()
# Regardless of result type, we use transient identity,
# for path identity of the iterator expression. This is
# necessary to maintain correct correlation for the state
# of iteration in DML statements.
# The already_existed check is to avoid adding in bogus volatility refs
# when we reprocess an iterator that was hoisted.
if not already_existed:
relctx.ensure_bond_for_expr(
iterator_expr.expr.result, iterator_query, ctx=subctx)
if is_optional:
relctx.ensure_bond_for_expr(
iterator_expr, iterator_query, ctx=subctx)
# Regardless of result type, iterators need their own
# transient identity for path identity of the iterator
# expression in order maintain correct correlation for the
# state of iteration in DML statements, even when there
# are duplicates in the iterator.
# This gets tracked as a special 'iterator' aspect in order
# to distinguish it from actual object identity.
relctx.create_iterator_identity_for_path(
iterator_expr.path_id, iterator_query, ctx=subctx)
pathctx.put_path_rvar(
query, iterator_expr.path_id, iterator_rvar,
aspect='iterator')
return iterator_rvar

View file

@ -160,8 +160,7 @@ def init_dml_stmt(
dml_rvar = relctx.rvar_for_rel(dml_cte, ctx=ctx)
else_cte = (dml_cte, dml_rvar)
if ctx.enclosing_cte_iterator:
pathctx.put_path_bond(ctx.rel, ctx.enclosing_cte_iterator.path_id)
put_iterator_bond(ctx.enclosing_cte_iterator, ctx.rel)
ctx.dml_stmt_stack.append(ir_stmt)
@ -210,8 +209,7 @@ def gen_dml_union(
)
assert qry.larg
if ctx.enclosing_cte_iterator:
pathctx.put_path_bond(qry.larg, ctx.enclosing_cte_iterator.path_id)
put_iterator_bond(ctx.enclosing_cte_iterator, qry.larg)
union_cte = pgast.CommonTableExpr(
query=qry.larg,
@ -350,6 +348,15 @@ def wrap_dml_cte(
return dml_rvar
def put_iterator_bond(
iterator: Optional[pgast.IteratorCTE],
select: pgast.Query,
) -> None:
if iterator:
pathctx.put_path_bond(
select, iterator.path_id, iterator=iterator.iterator_bond)
def merge_iterator_scope(
iterator: Optional[pgast.IteratorCTE],
select: pgast.SelectStmt,
@ -372,7 +379,7 @@ def merge_iterator(
if iterator:
iterator_rvar = relctx.rvar_for_rel(iterator.cte, ctx=ctx)
pathctx.put_path_bond(select, iterator.path_id)
put_iterator_bond(iterator, select)
relctx.include_rvar(
select, iterator_rvar,
path_id=iterator.path_id,
@ -537,7 +544,7 @@ def compile_iterator_cte(
iterator_cte = ctx.dml_stmts[iterator_set]
return pgast.IteratorCTE(
path_id=iterator_set.path_id, cte=iterator_cte,
parent=last_iterator)
parent=last_iterator, iterator_bond=True)
with ctx.newrel() as ictx:
ictx.scope_tree = ctx.scope_tree
@ -551,18 +558,21 @@ def compile_iterator_cte(
if iterator_set.path_id.is_objtype_path():
relgen.ensure_source_rvar(iterator_set, ictx.rel, ctx=ictx)
ictx.rel.path_id = iterator_set.path_id
pathctx.put_path_bond(ictx.rel, iterator_set.path_id)
pathctx.put_path_bond(ictx.rel, iterator_set.path_id, iterator=True)
iterator_cte = pgast.CommonTableExpr(
query=ictx.rel,
name=ctx.env.aliases.get('iter')
name=ctx.env.aliases.get('iter'),
)
ictx.toplevel_stmt.append_cte(iterator_cte)
ctx.dml_stmts[iterator_set] = iterator_cte
return pgast.IteratorCTE(
path_id=iterator_set.path_id, cte=iterator_cte,
parent=last_iterator)
path_id=iterator_set.path_id,
cte=iterator_cte,
parent=last_iterator,
iterator_bond=True,
)
def _mk_dynamic_get_path(
@ -1006,7 +1016,9 @@ def process_insert_shape(
subctx.path_scope = ctx.path_scope.new_child()
merge_iterator(inner_iterator, select, ctx=subctx)
inner_iterator_id = relctx.get_path_var(
select, inner_iterator.path_id, aspect='identity', ctx=ctx)
select, inner_iterator.path_id,
aspect=inner_iterator.aspect,
ctx=ctx)
# Process the Insert IR and separate links that go
# into the main table from links that are inserted into
@ -1050,8 +1062,7 @@ def process_insert_shape(
if link_ptr_info and link_ptr_info.table_type == 'link':
external_inserts.append(element)
if iterator is not None:
pathctx.put_path_bond(select, iterator.path_id)
put_iterator_bond(iterator, select)
for aspect in ('value', 'identity'):
pathctx._put_path_output_var(
@ -1491,7 +1502,7 @@ def compile_insert_else_body(
dummy_pathid = irast.PathId.new_dummy(ctx.env.aliases.get('dummy'))
with ictx.subrel() as dctx:
dummy_q = dctx.rel
relctx.ensure_transient_identity_for_path(
relctx.create_iterator_identity_for_path(
dummy_pathid, dummy_q, ctx=dctx)
dummy_rvar = relctx.rvar_for_rel(
dummy_q, lateral=True, ctx=ictx)
@ -1507,7 +1518,12 @@ def compile_insert_else_body(
iter_path_id = (
enclosing_cte_iterator.path_id if
enclosing_cte_iterator else None)
relctx.anti_join(ictx.rel, subrel, iter_path_id, ctx=ctx)
aspect = (
enclosing_cte_iterator.aspect if enclosing_cte_iterator
else 'identity'
)
relctx.anti_join(ictx.rel, subrel, iter_path_id,
aspect=aspect, ctx=ctx)
# Package it up as a CTE
anti_cte = pgast.CommonTableExpr(
@ -1517,7 +1533,9 @@ def compile_insert_else_body(
ictx.toplevel_stmt.append_cte(anti_cte)
anti_cte_iterator = pgast.IteratorCTE(
path_id=dummy_pathid, cte=anti_cte,
parent=ictx.enclosing_cte_iterator)
parent=ictx.enclosing_cte_iterator,
iterator_bond=True
)
return anti_cte_iterator
@ -1602,9 +1620,7 @@ def process_update_body(
contents_select = update_cte.query
toplevel = ctx.toplevel_stmt
if ctx.enclosing_cte_iterator:
pathctx.put_path_bond(
contents_select, ctx.enclosing_cte_iterator.path_id)
put_iterator_bond(ctx.enclosing_cte_iterator, contents_select)
assert dml_parts.range_cte
iterator = pgast.IteratorCTE(

View file

@ -45,6 +45,7 @@ class PathAspect(s_enum.StrEnum):
VALUE = 'value'
SOURCE = 'source'
SERIALIZED = 'serialized'
ITERATOR = 'iterator'
# A mapping of more specific aspect -> less specific aspect for objects
@ -698,8 +699,14 @@ def put_path_serialized_var_if_not_exists(
def put_path_bond(
stmt: pgast.BaseRelation, path_id: irast.PathId) -> None:
stmt.path_scope.add(path_id)
stmt: pgast.BaseRelation, path_id: irast.PathId, iterator: bool=False
) -> None:
'''Register a path id that should be joined on when joining stmt
iterator indicates whether the identity or iterator aspect should
be used.
'''
stmt.path_bonds.add((path_id, iterator))
def put_rvar_path_bond(

View file

@ -212,7 +212,7 @@ def include_rvar(
Compiler context.
"""
if aspects is None:
aspects = ('value',)
aspects = ('value', 'iterator')
if path_id.is_objtype_path():
if isinstance(rvar, pgast.RangeSubselect):
if pathctx.has_path_aspect(
@ -713,19 +713,6 @@ def semi_join(
return set_rvar
def ensure_bond_for_expr(
ir_set: irast.Set,
stmt: pgast.BaseRelation,
*,
ctx: context.CompilerContextLevel,
) -> None:
if ir_set.path_id.is_objtype_path():
# ObjectTypes have inherent identity
return
ensure_transient_identity_for_path(ir_set.path_id, stmt, ctx=ctx)
def apply_volatility_ref(
stmt: pgast.SelectStmt, *,
ctx: context.CompilerContextLevel) -> None:
@ -744,7 +731,7 @@ def apply_volatility_ref(
)
def ensure_transient_identity_for_path(
def create_iterator_identity_for_path(
path_id: irast.PathId,
stmt: pgast.BaseRelation,
*,
@ -756,12 +743,13 @@ def ensure_transient_identity_for_path(
args=[],
)
pathctx.put_path_identity_var(stmt, path_id, id_expr, force=True)
pathctx.put_path_bond(stmt, path_id)
if isinstance(stmt, pgast.SelectStmt):
path_id = pathctx.map_path_id(path_id, stmt.view_path_id_map)
apply_volatility_ref(stmt, ctx=ctx)
pathctx.put_path_var(stmt, path_id, id_expr, force=True, aspect='iterator')
pathctx.put_path_bond(stmt, path_id, iterator=True)
def get_scope(
ir_set: irast.Set, *,
@ -1289,15 +1277,19 @@ def _plain_join(
) -> None:
condition = None
for path_id in right_rvar.query.path_scope:
lref = maybe_get_path_var(query, path_id, aspect='identity', ctx=ctx)
if lref is None:
for path_id, iterator_var in right_rvar.query.path_bonds:
lref = None
aspect = 'iterator' if iterator_var else 'identity'
lref = maybe_get_path_var(
query, path_id, aspect=aspect, ctx=ctx)
if lref is None and not iterator_var:
lref = maybe_get_path_var(query, path_id, aspect='value', ctx=ctx)
if lref is None:
continue
rref = pathctx.get_rvar_path_identity_var(
right_rvar, path_id, env=ctx.env)
rref = pathctx.get_rvar_path_var(
right_rvar, path_id, aspect=aspect, env=ctx.env)
assert isinstance(lref, pgast.ColumnRef)
assert isinstance(rref, pgast.ColumnRef)
@ -1332,17 +1324,18 @@ def _lateral_union_join(
for component in astutils.each_query_in_set(right_rvar.subquery):
condition = None
for path_id in right_rvar.query.path_scope:
for path_id, iterator_var in right_rvar.query.path_bonds:
aspect = 'iterator' if iterator_var else 'identity'
lref = maybe_get_path_var(
query, path_id, aspect='identity', ctx=ctx)
if lref is None:
query, path_id, aspect=aspect, ctx=ctx)
if lref is None and not iterator_var:
lref = maybe_get_path_var(
query, path_id, aspect='value', ctx=ctx)
if lref is None:
continue
rref = pathctx.get_path_identity_var(
component, path_id, env=ctx.env)
rref = pathctx.get_path_var(
component, path_id, aspect=aspect, env=ctx.env)
assert isinstance(lref, pgast.ColumnRef)
assert isinstance(rref, pgast.ColumnRef)
@ -1717,6 +1710,7 @@ def wrap_set_op_query(
def anti_join(
lhs: pgast.SelectStmt, rhs: pgast.SelectStmt,
path_id: Optional[irast.PathId], *,
aspect: str='identity',
ctx: context.CompilerContextLevel,
) -> None:
"""Filter elements out of the LHS that appear on the RHS"""
@ -1724,10 +1718,10 @@ def anti_join(
if path_id:
# grab the identity from the LHS and do an
# anti-join against the RHS.
src_ref = pathctx.get_path_identity_var(
lhs, path_id=path_id, env=ctx.env)
pathctx.get_path_identity_output(
rhs, path_id=path_id, env=ctx.env)
src_ref = pathctx.get_path_var(
lhs, path_id=path_id, aspect=aspect, env=ctx.env)
pathctx.get_path_output(
rhs, path_id=path_id, aspect=aspect, env=ctx.env)
cond_expr: pgast.BaseExpr = astutils.new_binop(
src_ref, rhs, 'NOT IN')
else:

View file

@ -1992,7 +1992,6 @@ def process_set_as_tuple(
typeref=ir_set.typeref,
)
relctx.ensure_bond_for_expr(ir_set, stmt, ctx=ctx)
pathctx.put_path_value_var(stmt, ir_set.path_id, set_expr)
# This is an unfortunate hack. If any of those types that we

View file

@ -1737,6 +1737,35 @@ class TestInsert(tb.QueryTestCase):
],
)
async def test_edgeql_insert_for_23(self):
await self.con.execute(r"""
INSERT Subordinate { name := "a" }
""")
await self.assert_query_result(
"""
for x in {Subordinate, Subordinate} union (
(x { name }, (insert Note { name := '', subject := x }))
);
""",
[
[{'name': "a"}, {}],
[{'name': "a"}, {}],
],
)
await self.assert_query_result(
"""
for x in {Subordinate, Subordinate} union (
(x { name }, (insert InsertTest { l2 := 0, sub := x }))
);
""",
[
[{'name': "a"}, {}],
[{'name': "a"}, {}],
],
)
async def test_edgeql_insert_for_bad_01(self):
with self.assertRaisesRegex(
edgedb.errors.QueryError,