Fix data issues when spanning >16k bytes (#16)

* init

* Add tests for packetcontract
This commit is contained in:
Quin Lynch 2023-07-25 14:06:41 -03:00 committed by GitHub
parent abb7f8a797
commit 40eb6448a3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 193 additions and 39 deletions

View file

@ -0,0 +1,10 @@
CREATE MIGRATION m1tig3qk3mnrb2xpszwyodgurkdeyza6yt67zo7kfljc2icy3e7yma
ONTO m1vxu37wczr357ppyrbhfon2msem5oczk7mjszhx2xzp2qlxpazana
{
CREATE MODULE tests IF NOT EXISTS;
CREATE TYPE tests::TestDatastructure {
CREATE REQUIRED PROPERTY a: std::str;
CREATE REQUIRED PROPERTY b: std::str;
CREATE REQUIRED PROPERTY c: std::str;
};
};

7
dbschema/tests.esdl Normal file
View file

@ -0,0 +1,7 @@
module tests {
type TestDatastructure {
required property a -> str;
required property b -> str;
required property c -> str;
}
}

View file

@ -10,7 +10,7 @@ import java.util.UUID;
import java.util.concurrent.CompletionStage;
public final class GlobalsAndConfig implements Example {
private static final Logger logger = LoggerFactory.getLogger(AbstractTypes.class);
private static final Logger logger = LoggerFactory.getLogger(GlobalsAndConfig.class);
@Override
public CompletionStage<Void> run(EdgeDBClient client) {

View file

@ -27,9 +27,6 @@ import java.util.concurrent.Flow;
import java.util.function.Function;
import java.util.stream.Collectors;
import static com.edgedb.driver.util.BinaryProtocolUtils.BYTE_SIZE;
import static com.edgedb.driver.util.BinaryProtocolUtils.INT_SIZE;
public class PacketSerializer {
private static final Logger logger = LoggerFactory.getLogger(PacketSerializer.class);
private static final @NotNull Map<ServerMessageType, Function<PacketReader, Receivable>> deserializerMap;
@ -73,6 +70,26 @@ public class PacketSerializer {
@Override
protected void decode(@NotNull ChannelHandlerContext ctx, @NotNull ByteBuf msg, @NotNull List<Object> out) throws Exception {
var fromContract = false;
if(contracts.containsKey(ctx.channel())){
var contract = contracts.get(ctx.channel());
logger.debug("Attempting to complete contract {}", contract);
if (contract.tryComplete(msg)) {
logger.debug("Contract completed of type {} with size {}", contract.messageType, contract.length);
out.add(contract.getPacket());
contracts.remove(ctx.channel());
fromContract = true;
msg = contract.data;
} else {
logger.debug("Contract pending [{}]: {}/{}", contract.messageType, contract.getSize(), contract.length);
return;
}
}
while (msg.readableBytes() > 5) {
var type = getEnumValue(ServerMessageType.class, msg.readByte());
var length = msg.readUnsignedInt() - 4; // remove length of self.
@ -80,34 +97,31 @@ public class PacketSerializer {
// can we read this packet?
if (msg.readableBytes() >= length) {
var packet = PacketSerializer.deserialize(type, length, msg.readSlice((int) length));
if(packet == null) {
logger.error("Got null result for packet type {}", type);
throw new EdgeDBException("Failed to read message type: malformed data");
}
logger.debug("S->C: T:{}", type);
out.add(packet);
continue;
}
if (contracts.containsKey(ctx.channel())) {
var contract = contracts.get(ctx.channel());
if (contract.tryComplete(msg)) {
out.add(contract.getPacket());
}
return;
} else {
contracts.put(ctx.channel(), new PacketContract(msg, type, length));
}
// if we cannot read the full packet, create a contract for it.
msg.retain();
contracts.put(ctx.channel(), new PacketContract(msg, type, length));
return;
}
if (msg.readableBytes() > 0) {
if (contracts.containsKey(ctx.channel())) {
var contract = contracts.get(ctx.channel());
if(msg.readableBytes() > 0){
msg.retain();
contracts.put(ctx.channel(), new PacketContract(msg, null, null));
return;
}
if (contract.tryComplete(msg)) {
out.add(contract.getPacket());
}
} else {
contracts.put(ctx.channel(), new PacketContract(msg, null, null));
}
if(fromContract){
msg.release();
}
}
@ -118,6 +132,8 @@ public class PacketSerializer {
private @Nullable ServerMessageType messageType;
private @Nullable Long length;
private final List<ByteBuf> components;
public PacketContract(
ByteBuf data,
@Nullable ServerMessageType messageType,
@ -126,38 +142,47 @@ public class PacketSerializer {
this.data = data;
this.length = length;
this.messageType = messageType;
this.components = new ArrayList<>() {{
add(data);
}};
}
public long getSize() {
long size = 0;
for (var component : components) {
size += component.readableBytes();
}
return size;
}
public boolean tryComplete(@NotNull ByteBuf other) {
var orig = data.slice();
data = Unpooled.wrappedBuffer(orig, other);
if (messageType == null) {
messageType = pick(other, b -> getEnumValue(ServerMessageType.class, b.readByte()), BYTE_SIZE);
messageType = getEnumValue(ServerMessageType.class, data.readByte());
}
if (length == null) {
length = pick(other, b -> b.readUnsignedInt() - 4, INT_SIZE);
length = data.readUnsignedInt() - 4;
}
data = Unpooled.wrappedBuffer(data, other);
other.retain();
components.add(other);
if (data.readableBytes() >= length) {
// read
packet = PacketSerializer.deserialize(messageType, length, data);
packet = PacketSerializer.deserialize(messageType, length, data, false);
return true;
}
return false;
}
private <T> T pick(@NotNull ByteBuf other, @NotNull Function<ByteBuf, T> map, long sz) {
if (data.readableBytes() > sz) {
return map.apply(data);
} else if (other.readableBytes() < sz) {
throw new IndexOutOfBoundsException();
}
return map.apply(other);
}
public @NotNull Receivable getPacket() throws OperationNotSupportedException {
if (packet == null) {
throw new OperationNotSupportedException("Packet contract was incomplete");
@ -192,11 +217,20 @@ public class PacketSerializer {
};
}
public static @Nullable Receivable deserialize(ServerMessageType messageType, long length, @NotNull ByteBuf buffer) {
public static @Nullable Receivable deserialize(
ServerMessageType messageType, long length, @NotNull ByteBuf buffer
) {
var reader = new PacketReader(buffer);
return deserializeSingle(messageType, length, reader, true);
}
public static @Nullable Receivable deserialize(
ServerMessageType messageType, long length, @NotNull ByteBuf buffer, boolean verifyEmpty
) {
var reader = new PacketReader(buffer);
return deserializeSingle(messageType, length, reader, verifyEmpty);
}
public static @Nullable Receivable deserializeSingle(PacketReader reader) {
var messageType = reader.readEnum(ServerMessageType.class, Byte.TYPE);
var length = reader.readUInt32().longValue();

View file

@ -0,0 +1,103 @@
import com.edgedb.driver.EdgeDBClient;
import com.edgedb.driver.annotations.EdgeDBType;
import com.edgedb.driver.exceptions.EdgeDBException;
import com.fasterxml.jackson.databind.json.JsonMapper;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.ExecutionException;
import static org.assertj.core.api.Assertions.assertThat;
public class ProtocolTests {
private static final Logger logger = LoggerFactory.getLogger(ProtocolTests.class);
@EdgeDBType
public static class TestDatastructure {
public UUID id;
public String a;
public String b;
public String c;
}
/**
* The goal is to test the contract logic in {@linkplain com.edgedb.driver.binary.PacketSerializer}, specifically
* the decoder returned from the <b>createDecoder</b> function. To achieve this, we can query something that
* returns either multiple data packets amounting up to >16k bytes, or a single data packet that is >16k bytes.
*/
@Test
public void testPacketContract() throws EdgeDBException, IOException, ExecutionException, InterruptedException {
var client = new EdgeDBClient().withModule("tests");
// insert 1k items
logger.info("Removing old data structures...");
client.execute("DELETE TestDatastructure")
.toCompletableFuture().get();
var results = new HashMap<UUID, String[]>();
logger.info("Inserting 1000 items...");
for(int i = 0; i != 1000; i++){
var data = new String[] {
generateRandomString(),
generateRandomString(),
generateRandomString()
};
var result = client.queryRequiredSingle(TestDatastructure.class, "INSERT TestDatastructure { a := <str>$a, b := <str>$b, c := <str>$c }", new HashMap<>(){{
put("a", data[0]);
put("b", data[1]);
put("c", data[2]);
}}).toCompletableFuture().get();
results.put(result.id, data);
}
logger.info("Querying all items...");
// assert the data can be read via binary and json
var structures = client.query(TestDatastructure.class, "SELECT TestDatastructure { id, a, b, c }")
.toCompletableFuture().get();
var json = client.queryJson("SELECT TestDatastructure { id, a, b, c }")
.toCompletableFuture().get();
var structuresFromJson = List.of(new JsonMapper().readValue(json.getValue(), TestDatastructure[].class));
assertStructuresMatch(structures, results);
assertStructuresMatch(structuresFromJson, results);
}
private void assertStructuresMatch(List<TestDatastructure> source, Map<UUID, String[]> truth) {
for(var structure : source) {
assert structure != null;
var expected = truth.get(structure.id);
assertThat(structure.a).isEqualTo(expected[0]);
assertThat(structure.b).isEqualTo(expected[1]);
assertThat(structure.c).isEqualTo(expected[2]);
logger.info("{} passed [a: {}, b: {}, c: {}]", structure.id, structure.a, structure.b, structure.c);
}
}
private static String generateRandomString() {
final var chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890";
Random rand =new Random();
StringBuilder res=new StringBuilder();
for (int i = 0; i < 17; i++) {
int randIndex=rand.nextInt(chars.length());
res.append(chars.charAt(randIndex));
}
return res.toString();
}
}