mirror of
https://github.com/maxkratz/edgedb-java.git
synced 2024-09-16 16:27:58 +00:00
Fix data issues when spanning >16k bytes (#16)
* init * Add tests for packetcontract
This commit is contained in:
parent
abb7f8a797
commit
40eb6448a3
5 changed files with 193 additions and 39 deletions
10
dbschema/migrations/00007.edgeql
Normal file
10
dbschema/migrations/00007.edgeql
Normal file
|
@ -0,0 +1,10 @@
|
|||
CREATE MIGRATION m1tig3qk3mnrb2xpszwyodgurkdeyza6yt67zo7kfljc2icy3e7yma
|
||||
ONTO m1vxu37wczr357ppyrbhfon2msem5oczk7mjszhx2xzp2qlxpazana
|
||||
{
|
||||
CREATE MODULE tests IF NOT EXISTS;
|
||||
CREATE TYPE tests::TestDatastructure {
|
||||
CREATE REQUIRED PROPERTY a: std::str;
|
||||
CREATE REQUIRED PROPERTY b: std::str;
|
||||
CREATE REQUIRED PROPERTY c: std::str;
|
||||
};
|
||||
};
|
7
dbschema/tests.esdl
Normal file
7
dbschema/tests.esdl
Normal file
|
@ -0,0 +1,7 @@
|
|||
module tests {
|
||||
type TestDatastructure {
|
||||
required property a -> str;
|
||||
required property b -> str;
|
||||
required property c -> str;
|
||||
}
|
||||
}
|
|
@ -10,7 +10,7 @@ import java.util.UUID;
|
|||
import java.util.concurrent.CompletionStage;
|
||||
|
||||
public final class GlobalsAndConfig implements Example {
|
||||
private static final Logger logger = LoggerFactory.getLogger(AbstractTypes.class);
|
||||
private static final Logger logger = LoggerFactory.getLogger(GlobalsAndConfig.class);
|
||||
|
||||
@Override
|
||||
public CompletionStage<Void> run(EdgeDBClient client) {
|
||||
|
|
|
@ -27,9 +27,6 @@ import java.util.concurrent.Flow;
|
|||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.edgedb.driver.util.BinaryProtocolUtils.BYTE_SIZE;
|
||||
import static com.edgedb.driver.util.BinaryProtocolUtils.INT_SIZE;
|
||||
|
||||
public class PacketSerializer {
|
||||
private static final Logger logger = LoggerFactory.getLogger(PacketSerializer.class);
|
||||
private static final @NotNull Map<ServerMessageType, Function<PacketReader, Receivable>> deserializerMap;
|
||||
|
@ -73,6 +70,26 @@ public class PacketSerializer {
|
|||
|
||||
@Override
|
||||
protected void decode(@NotNull ChannelHandlerContext ctx, @NotNull ByteBuf msg, @NotNull List<Object> out) throws Exception {
|
||||
var fromContract = false;
|
||||
|
||||
if(contracts.containsKey(ctx.channel())){
|
||||
var contract = contracts.get(ctx.channel());
|
||||
|
||||
logger.debug("Attempting to complete contract {}", contract);
|
||||
|
||||
if (contract.tryComplete(msg)) {
|
||||
logger.debug("Contract completed of type {} with size {}", contract.messageType, contract.length);
|
||||
|
||||
out.add(contract.getPacket());
|
||||
contracts.remove(ctx.channel());
|
||||
fromContract = true;
|
||||
msg = contract.data;
|
||||
} else {
|
||||
logger.debug("Contract pending [{}]: {}/{}", contract.messageType, contract.getSize(), contract.length);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
while (msg.readableBytes() > 5) {
|
||||
var type = getEnumValue(ServerMessageType.class, msg.readByte());
|
||||
var length = msg.readUnsignedInt() - 4; // remove length of self.
|
||||
|
@ -80,34 +97,31 @@ public class PacketSerializer {
|
|||
// can we read this packet?
|
||||
if (msg.readableBytes() >= length) {
|
||||
var packet = PacketSerializer.deserialize(type, length, msg.readSlice((int) length));
|
||||
|
||||
if(packet == null) {
|
||||
logger.error("Got null result for packet type {}", type);
|
||||
throw new EdgeDBException("Failed to read message type: malformed data");
|
||||
}
|
||||
|
||||
logger.debug("S->C: T:{}", type);
|
||||
out.add(packet);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (contracts.containsKey(ctx.channel())) {
|
||||
var contract = contracts.get(ctx.channel());
|
||||
|
||||
if (contract.tryComplete(msg)) {
|
||||
out.add(contract.getPacket());
|
||||
}
|
||||
|
||||
return;
|
||||
} else {
|
||||
contracts.put(ctx.channel(), new PacketContract(msg, type, length));
|
||||
}
|
||||
// if we cannot read the full packet, create a contract for it.
|
||||
msg.retain();
|
||||
contracts.put(ctx.channel(), new PacketContract(msg, type, length));
|
||||
return;
|
||||
}
|
||||
|
||||
if (msg.readableBytes() > 0) {
|
||||
if (contracts.containsKey(ctx.channel())) {
|
||||
var contract = contracts.get(ctx.channel());
|
||||
if(msg.readableBytes() > 0){
|
||||
msg.retain();
|
||||
contracts.put(ctx.channel(), new PacketContract(msg, null, null));
|
||||
return;
|
||||
}
|
||||
|
||||
if (contract.tryComplete(msg)) {
|
||||
out.add(contract.getPacket());
|
||||
}
|
||||
} else {
|
||||
contracts.put(ctx.channel(), new PacketContract(msg, null, null));
|
||||
}
|
||||
if(fromContract){
|
||||
msg.release();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -118,6 +132,8 @@ public class PacketSerializer {
|
|||
private @Nullable ServerMessageType messageType;
|
||||
private @Nullable Long length;
|
||||
|
||||
private final List<ByteBuf> components;
|
||||
|
||||
public PacketContract(
|
||||
ByteBuf data,
|
||||
@Nullable ServerMessageType messageType,
|
||||
|
@ -126,38 +142,47 @@ public class PacketSerializer {
|
|||
this.data = data;
|
||||
this.length = length;
|
||||
this.messageType = messageType;
|
||||
|
||||
this.components = new ArrayList<>() {{
|
||||
add(data);
|
||||
}};
|
||||
}
|
||||
|
||||
public long getSize() {
|
||||
long size = 0;
|
||||
|
||||
for (var component : components) {
|
||||
size += component.readableBytes();
|
||||
}
|
||||
|
||||
return size;
|
||||
}
|
||||
|
||||
public boolean tryComplete(@NotNull ByteBuf other) {
|
||||
var orig = data.slice();
|
||||
data = Unpooled.wrappedBuffer(orig, other);
|
||||
|
||||
if (messageType == null) {
|
||||
messageType = pick(other, b -> getEnumValue(ServerMessageType.class, b.readByte()), BYTE_SIZE);
|
||||
messageType = getEnumValue(ServerMessageType.class, data.readByte());
|
||||
}
|
||||
|
||||
if (length == null) {
|
||||
length = pick(other, b -> b.readUnsignedInt() - 4, INT_SIZE);
|
||||
length = data.readUnsignedInt() - 4;
|
||||
}
|
||||
|
||||
data = Unpooled.wrappedBuffer(data, other);
|
||||
other.retain();
|
||||
components.add(other);
|
||||
|
||||
if (data.readableBytes() >= length) {
|
||||
// read
|
||||
packet = PacketSerializer.deserialize(messageType, length, data);
|
||||
packet = PacketSerializer.deserialize(messageType, length, data, false);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private <T> T pick(@NotNull ByteBuf other, @NotNull Function<ByteBuf, T> map, long sz) {
|
||||
if (data.readableBytes() > sz) {
|
||||
return map.apply(data);
|
||||
} else if (other.readableBytes() < sz) {
|
||||
throw new IndexOutOfBoundsException();
|
||||
}
|
||||
|
||||
return map.apply(other);
|
||||
}
|
||||
|
||||
public @NotNull Receivable getPacket() throws OperationNotSupportedException {
|
||||
if (packet == null) {
|
||||
throw new OperationNotSupportedException("Packet contract was incomplete");
|
||||
|
@ -192,11 +217,20 @@ public class PacketSerializer {
|
|||
};
|
||||
}
|
||||
|
||||
public static @Nullable Receivable deserialize(ServerMessageType messageType, long length, @NotNull ByteBuf buffer) {
|
||||
public static @Nullable Receivable deserialize(
|
||||
ServerMessageType messageType, long length, @NotNull ByteBuf buffer
|
||||
) {
|
||||
var reader = new PacketReader(buffer);
|
||||
return deserializeSingle(messageType, length, reader, true);
|
||||
}
|
||||
|
||||
public static @Nullable Receivable deserialize(
|
||||
ServerMessageType messageType, long length, @NotNull ByteBuf buffer, boolean verifyEmpty
|
||||
) {
|
||||
var reader = new PacketReader(buffer);
|
||||
return deserializeSingle(messageType, length, reader, verifyEmpty);
|
||||
}
|
||||
|
||||
public static @Nullable Receivable deserializeSingle(PacketReader reader) {
|
||||
var messageType = reader.readEnum(ServerMessageType.class, Byte.TYPE);
|
||||
var length = reader.readUInt32().longValue();
|
||||
|
|
103
src/driver/src/test/java/ProtocolTests.java
Normal file
103
src/driver/src/test/java/ProtocolTests.java
Normal file
|
@ -0,0 +1,103 @@
|
|||
import com.edgedb.driver.EdgeDBClient;
|
||||
import com.edgedb.driver.annotations.EdgeDBType;
|
||||
import com.edgedb.driver.exceptions.EdgeDBException;
|
||||
import com.fasterxml.jackson.databind.json.JsonMapper;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class ProtocolTests {
|
||||
private static final Logger logger = LoggerFactory.getLogger(ProtocolTests.class);
|
||||
|
||||
@EdgeDBType
|
||||
public static class TestDatastructure {
|
||||
public UUID id;
|
||||
|
||||
public String a;
|
||||
|
||||
public String b;
|
||||
|
||||
public String c;
|
||||
}
|
||||
|
||||
/**
|
||||
* The goal is to test the contract logic in {@linkplain com.edgedb.driver.binary.PacketSerializer}, specifically
|
||||
* the decoder returned from the <b>createDecoder</b> function. To achieve this, we can query something that
|
||||
* returns either multiple data packets amounting up to >16k bytes, or a single data packet that is >16k bytes.
|
||||
*/
|
||||
@Test
|
||||
public void testPacketContract() throws EdgeDBException, IOException, ExecutionException, InterruptedException {
|
||||
var client = new EdgeDBClient().withModule("tests");
|
||||
|
||||
// insert 1k items
|
||||
logger.info("Removing old data structures...");
|
||||
client.execute("DELETE TestDatastructure")
|
||||
.toCompletableFuture().get();
|
||||
|
||||
var results = new HashMap<UUID, String[]>();
|
||||
|
||||
logger.info("Inserting 1000 items...");
|
||||
|
||||
for(int i = 0; i != 1000; i++){
|
||||
var data = new String[] {
|
||||
generateRandomString(),
|
||||
generateRandomString(),
|
||||
generateRandomString()
|
||||
};
|
||||
|
||||
var result = client.queryRequiredSingle(TestDatastructure.class, "INSERT TestDatastructure { a := <str>$a, b := <str>$b, c := <str>$c }", new HashMap<>(){{
|
||||
put("a", data[0]);
|
||||
put("b", data[1]);
|
||||
put("c", data[2]);
|
||||
}}).toCompletableFuture().get();
|
||||
|
||||
results.put(result.id, data);
|
||||
}
|
||||
|
||||
logger.info("Querying all items...");
|
||||
|
||||
// assert the data can be read via binary and json
|
||||
var structures = client.query(TestDatastructure.class, "SELECT TestDatastructure { id, a, b, c }")
|
||||
.toCompletableFuture().get();
|
||||
|
||||
var json = client.queryJson("SELECT TestDatastructure { id, a, b, c }")
|
||||
.toCompletableFuture().get();
|
||||
|
||||
var structuresFromJson = List.of(new JsonMapper().readValue(json.getValue(), TestDatastructure[].class));
|
||||
|
||||
assertStructuresMatch(structures, results);
|
||||
assertStructuresMatch(structuresFromJson, results);
|
||||
}
|
||||
|
||||
private void assertStructuresMatch(List<TestDatastructure> source, Map<UUID, String[]> truth) {
|
||||
for(var structure : source) {
|
||||
assert structure != null;
|
||||
|
||||
var expected = truth.get(structure.id);
|
||||
|
||||
assertThat(structure.a).isEqualTo(expected[0]);
|
||||
assertThat(structure.b).isEqualTo(expected[1]);
|
||||
assertThat(structure.c).isEqualTo(expected[2]);
|
||||
|
||||
logger.info("{} passed [a: {}, b: {}, c: {}]", structure.id, structure.a, structure.b, structure.c);
|
||||
}
|
||||
}
|
||||
|
||||
private static String generateRandomString() {
|
||||
final var chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890";
|
||||
|
||||
Random rand =new Random();
|
||||
StringBuilder res=new StringBuilder();
|
||||
for (int i = 0; i < 17; i++) {
|
||||
int randIndex=rand.nextInt(chars.length());
|
||||
res.append(chars.charAt(randIndex));
|
||||
}
|
||||
return res.toString();
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue