diff --git a/runners/kafka-streams/src/main/java/org/apache/beam/runners/kafka/streams/translation/KStreamsPayloadSerde.java b/runners/kafka-streams/src/main/java/org/apache/beam/runners/kafka/streams/translation/KStreamsPayloadSerde.java new file mode 100644 index 000000000000..dd30ed3696eb --- /dev/null +++ b/runners/kafka-streams/src/main/java/org/apache/beam/runners/kafka/streams/translation/KStreamsPayloadSerde.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.kafka.streams.translation; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serializer; + +/** + * Kafka {@link Serde} for {@link KStreamsPayload}, enabling the envelope to cross topic boundaries + * (e.g. the repartition topic a {@code GroupByKey} introduces). Until now {@link KStreamsPayload} + * only flowed in-JVM via {@code ProcessorContext#forward}, so no serialization was needed. + * + *

The wire format is a one-byte discriminator followed by the variant body: + * + *

+ * + *

A {@link KStreamsPayloadSerde} is parameterized by the {@link Coder} for the data variant + * because different topics carry different element types; the watermark variant needs no coder. + * + *

The serde assumes non-null payloads: the topics it is used on (repartition and watermark + * fan-out) are not log-compacted, so no tombstone (null-valued) records occur. + * + * @param the data element type carried by data payloads on this topic + */ +public final class KStreamsPayloadSerde implements Serde> { + + private static final byte DATA_TAG = 0; + private static final byte WATERMARK_TAG = 1; + + private final Coder> dataCoder; + + public KStreamsPayloadSerde(Coder> dataCoder) { + this.dataCoder = dataCoder; + } + + @Override + public Serializer> serializer() { + return new PayloadSerializer(); + } + + @Override + public Deserializer> deserializer() { + return new PayloadDeserializer(); + } + + private final class PayloadSerializer implements Serializer> { + @Override + public byte[] serialize(String topic, KStreamsPayload payload) { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try { + if (payload.isData()) { + out.write(DATA_TAG); + dataCoder.encode(payload.getData(), out); + } else { + WatermarkPayload watermark = payload.asWatermark(); + DataOutputStream dataOut = new DataOutputStream(out); + dataOut.writeByte(WATERMARK_TAG); + dataOut.writeLong(watermark.getWatermarkMillis()); + dataOut.writeInt(watermark.getSourcePartition()); + dataOut.writeInt(watermark.getTotalSourcePartitions()); + dataOut.flush(); + } + } catch (IOException e) { + throw new SerializationException("Failed to serialize KStreamsPayload", e); + } + return out.toByteArray(); + } + } + + private final class PayloadDeserializer implements Deserializer> { + @Override + public KStreamsPayload deserialize(String topic, byte[] bytes) { + ByteArrayInputStream in = new ByteArrayInputStream(bytes); + try { + int tag = in.read(); + if (tag == DATA_TAG) { + return KStreamsPayload.data(dataCoder.decode(in)); + } else if (tag == WATERMARK_TAG) { + DataInputStream dataIn = new DataInputStream(in); + long watermarkMillis = dataIn.readLong(); + int sourcePartition = dataIn.readInt(); + int totalSourcePartitions = dataIn.readInt(); + return KStreamsPayload.watermark(watermarkMillis, sourcePartition, totalSourcePartitions); + } else { + throw new SerializationException("Unknown KStreamsPayload tag: " + tag); + } + } catch (IOException e) { + throw new SerializationException("Failed to deserialize KStreamsPayload", e); + } + } + } +} diff --git a/runners/kafka-streams/src/test/java/org/apache/beam/runners/kafka/streams/translation/KStreamsPayloadSerdeTest.java b/runners/kafka-streams/src/test/java/org/apache/beam/runners/kafka/streams/translation/KStreamsPayloadSerdeTest.java new file mode 100644 index 000000000000..1e4c4c0263a4 --- /dev/null +++ b/runners/kafka-streams/src/test/java/org/apache/beam/runners/kafka/streams/translation/KStreamsPayloadSerdeTest.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.kafka.streams.translation; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertThrows; + +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serializer; +import org.junit.Test; + +/** Tests for {@link KStreamsPayloadSerde}. */ +public class KStreamsPayloadSerdeTest { + + private static final String TOPIC = "ks-payload-serde-test"; + + private final Coder> dataCoder = + WindowedValues.getFullCoder(VarIntCoder.of(), GlobalWindow.Coder.INSTANCE); + private final KStreamsPayloadSerde serde = new KStreamsPayloadSerde<>(dataCoder); + + private KStreamsPayload roundTrip(KStreamsPayload payload) { + Serializer> serializer = serde.serializer(); + Deserializer> deserializer = serde.deserializer(); + return deserializer.deserialize(TOPIC, serializer.serialize(TOPIC, payload)); + } + + @Test + public void roundTripsDataPayload() { + KStreamsPayload payload = KStreamsPayload.data(WindowedValues.valueInGlobalWindow(42)); + KStreamsPayload out = roundTrip(payload); + assertThat(out.isData(), is(true)); + assertThat(out.getData().getValue(), is(42)); + assertThat(out, is(payload)); + } + + @Test + public void roundTripsWatermarkPayload() { + KStreamsPayload payload = KStreamsPayload.watermark(12345L, 2, 4); + KStreamsPayload out = roundTrip(payload); + assertThat(out.isWatermark(), is(true)); + assertThat(out.asWatermark().getWatermarkMillis(), is(12345L)); + assertThat(out.asWatermark().getSourcePartition(), is(2)); + assertThat(out.asWatermark().getTotalSourcePartitions(), is(4)); + assertThat(out, is(payload)); + } + + @Test + public void roundTripsTerminalMaxWatermark() { + KStreamsPayload payload = + KStreamsPayload.watermark(BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis(), 0, 1); + assertThat( + roundTrip(payload).asWatermark().getWatermarkMillis(), + is(BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis())); + } + + @Test + public void unknownTagThrows() { + byte[] bogus = new byte[] {(byte) 0x7f}; + assertThrows( + SerializationException.class, () -> serde.deserializer().deserialize(TOPIC, bogus)); + } +}