diff --git a/pf/function.go b/pf/function.go index be691ba..9dbb406 100644 --- a/pf/function.go +++ b/pf/function.go @@ -52,9 +52,18 @@ const ( ErrorMark = "XXX_PULSAR_ERROR_XXX:" EmptyMark = "XXX_PULSAR_EMPTY_XXX" extendedStdinMetadataMarker = 255 + binaryV2StatusOK = 0 + binaryV2StatusEmpty = 1 + binaryV2StatusError = 2 + binaryV2FrameHeaderBytes = 16 + binaryV2MaxMetadataBytes = 64 * 1024 + binaryV2MaxPayloadBytes = 128 * 1024 * 1024 + binaryV2MaxFrameBytes = binaryV2FrameHeaderBytes + binaryV2MaxMetadataBytes + binaryV2MaxPayloadBytes ) var ( + binaryV2InputMagic = [4]byte{'P', 'F', 'I', '2'} + binaryV2OutputMagic = [4]byte{'P', 'F', 'O', '2'} stdout *os.File tenant string namespace string @@ -73,6 +82,13 @@ var ( expectedHealthCheckInterval int ) +type childProtocol int + +const ( + childProtocolLineV1 childProtocol = iota + childProtocolBinaryV2 +) + type function interface { process(ctx context.Context, input []byte) ([]byte, error) } @@ -252,16 +268,12 @@ func Start(funcName interface{}) { defer cancel() for { - line, err := reader.ReadBytes('\n') + protocol, msgID, msg, err := readInput(reader) if err != nil { - if err != io.EOF { - logrus.Errorf("Error reading from stdout: %v", err) + if err == io.EOF { + break } - break - } - msgID, msg, err := readInputFrame(line) - if err != nil { - writeResult([]byte(ErrorMark + err.Error())) + writeResult(stdout, protocol, nil, err) continue } functionContext.setMessageId(&MessageId{ @@ -269,21 +281,47 @@ func Start(funcName interface{}) { }) if len(msg) == 0 { - writeResult([]byte(ErrorMark + "msg length is 0")) + writeResult(stdout, protocol, nil, fmt.Errorf("msg length is 0")) continue } valuedCtx := NewContext(ctxWithCancel, functionContext) result, err := function.process(valuedCtx, msg) if err != nil { - writeResult([]byte(ErrorMark + "handle message: " + err.Error())) + writeResult(stdout, protocol, nil, fmt.Errorf("handle message: %w", err)) continue } - writeResult(result) + writeResult(stdout, protocol, result, nil) } } +func readInput(reader *bufio.Reader) (childProtocol, string, []byte, error) { + magic, err := reader.Peek(len(binaryV2InputMagic)) + if err != nil && err != bufio.ErrBufferFull { + if err == io.EOF { + return childProtocolLineV1, "", nil, err + } + line, lineErr := reader.ReadBytes('\n') + if lineErr != nil { + return childProtocolLineV1, "", nil, lineErr + } + msgID, payload, frameErr := readInputFrame(line) + return childProtocolLineV1, msgID, payload, frameErr + } + if bytes.Equal(magic, binaryV2InputMagic[:]) { + msgID, payload, err := readBinaryV2InputFrame(reader) + return childProtocolBinaryV2, msgID, payload, err + } + + line, err := reader.ReadBytes('\n') + if err != nil { + return childProtocolLineV1, "", nil, err + } + msgID, payload, err := readInputFrame(line) + return childProtocolLineV1, msgID, payload, err +} + func readInputFrame(line []byte) (string, []byte, error) { if len(line) == 0 { return "", nil, fmt.Errorf("input frame is empty") @@ -313,14 +351,83 @@ func readInputFrame(line []byte) (string, []byte, error) { return meta[0], line[metaEnd : len(line)-1], nil } -func writeResult(result []byte) { - if len(result) > 0 { +func readBinaryV2InputFrame(reader io.Reader) (string, []byte, error) { + var magic [4]byte + if _, err := io.ReadFull(reader, magic[:]); err != nil { + return "", nil, fmt.Errorf("could not read binary v2 input magic: %w", err) + } + if magic != binaryV2InputMagic { + return "", nil, fmt.Errorf("invalid binary v2 input magic") + } + + var metadataLenBytes [4]byte + if _, err := io.ReadFull(reader, metadataLenBytes[:]); err != nil { + return "", nil, fmt.Errorf("could not read binary v2 metadata length: %w", err) + } + metadataLen := binary.BigEndian.Uint32(metadataLenBytes[:]) + if metadataLen > binaryV2MaxMetadataBytes { + return "", nil, fmt.Errorf("binary v2 metadata length %d exceeds max %d", metadataLen, binaryV2MaxMetadataBytes) + } + + var payloadLenBytes [8]byte + if _, err := io.ReadFull(reader, payloadLenBytes[:]); err != nil { + return "", nil, fmt.Errorf("could not read binary v2 payload length: %w", err) + } + payloadLen := binary.BigEndian.Uint64(payloadLenBytes[:]) + if payloadLen > binaryV2MaxPayloadBytes { + return "", nil, fmt.Errorf("binary v2 payload length %d exceeds max %d", payloadLen, binaryV2MaxPayloadBytes) + } + frameLen := uint64(binaryV2FrameHeaderBytes) + uint64(metadataLen) + payloadLen + if frameLen > binaryV2MaxFrameBytes { + return "", nil, fmt.Errorf("binary v2 frame length %d exceeds max %d", frameLen, binaryV2MaxFrameBytes) + } + + metadata := make([]byte, metadataLen) + if _, err := io.ReadFull(reader, metadata); err != nil { + return "", nil, fmt.Errorf("could not read binary v2 metadata: %w", err) + } + meta := strings.Split(string(metadata), "@") + if len(meta) != 2 { + return "", nil, fmt.Errorf("invalid metadata format: expected message id and topic separated by @") + } + + payload := make([]byte, int(payloadLen)) + if _, err := io.ReadFull(reader, payload); err != nil { + return "", nil, fmt.Errorf("could not read binary v2 payload: %w", err) + } + return meta[0], payload, nil +} + +func writeResult(writer io.Writer, protocol childProtocol, result []byte, resultErr error) { + if protocol == childProtocolBinaryV2 { + if resultErr != nil { + writeBinaryV2OutputFrame(writer, binaryV2StatusError, []byte(resultErr.Error())) + } else if len(result) > 0 { + writeBinaryV2OutputFrame(writer, binaryV2StatusOK, result) + } else { + writeBinaryV2OutputFrame(writer, binaryV2StatusEmpty, nil) + } + return + } + + if resultErr != nil { + _, _ = writer.Write([]byte(ErrorMark + resultErr.Error())) + } else if len(result) > 0 { result = bytes.ReplaceAll(result, []byte("\n"), []byte("")) - _, _ = stdout.Write(result) + _, _ = writer.Write(result) } else { - _, _ = stdout.Write([]byte(EmptyMark)) + _, _ = writer.Write([]byte(EmptyMark)) } - _, _ = stdout.Write([]byte("\n")) + _, _ = writer.Write([]byte("\n")) +} + +func writeBinaryV2OutputFrame(writer io.Writer, status byte, body []byte) { + _, _ = writer.Write(binaryV2OutputMagic[:]) + _, _ = writer.Write([]byte{status}) + var bodyLen [8]byte + binary.BigEndian.PutUint64(bodyLen[:], uint64(len(body))) + _, _ = writer.Write(bodyLen[:]) + _, _ = writer.Write(body) } func init() { diff --git a/pf/function_test.go b/pf/function_test.go index 5349a92..f9f456c 100644 --- a/pf/function_test.go +++ b/pf/function_test.go @@ -1,7 +1,10 @@ package pf import ( + "bufio" + "bytes" "encoding/binary" + "io" "strings" "testing" ) @@ -57,3 +60,134 @@ func TestReadInputFrameRejectsMalformedMetadata(t *testing.T) { t.Fatalf("error = %q, want malformed metadata error", err.Error()) } } + +func TestReadInputPreservesBinaryV2Payload(t *testing.T) { + payload := []byte{0, 'a', '\n', '\r', 255} + input := binaryV2InputFrame("1:2:3@topic", payload) + + protocol, msgID, got, err := readInput(bufioReader(input)) + if err != nil { + t.Fatalf("readInput returned error: %v", err) + } + if protocol != childProtocolBinaryV2 { + t.Fatalf("protocol = %v, want %v", protocol, childProtocolBinaryV2) + } + if msgID != "1:2:3" { + t.Fatalf("msgID = %q, want %q", msgID, "1:2:3") + } + if !bytes.Equal(got, payload) { + t.Fatalf("payload = %v, want %v", got, payload) + } +} + +func TestReadBinaryV2InputFrameRejectsOversizedMetadataBeforeAllocating(t *testing.T) { + input := binaryV2InputFrameWithLengths(binaryV2MaxMetadataBytes+1, 0) + + _, _, err := readBinaryV2InputFrame(bytes.NewReader(input)) + if err == nil { + t.Fatal("readBinaryV2InputFrame returned nil error, want oversized metadata error") + } + if !strings.Contains(err.Error(), "metadata length") { + t.Fatalf("error = %q, want metadata length error", err.Error()) + } +} + +func TestReadBinaryV2InputFrameRejectsOversizedPayloadBeforeAllocating(t *testing.T) { + metadata := "1:2:3@topic" + input := binaryV2InputFrameWithLengths(uint32(len(metadata)), binaryV2MaxPayloadBytes+1) + input = append(input, []byte(metadata)...) + + _, _, err := readBinaryV2InputFrame(bytes.NewReader(input)) + if err == nil { + t.Fatal("readBinaryV2InputFrame returned nil error, want oversized payload error") + } + if !strings.Contains(err.Error(), "payload length") { + t.Fatalf("error = %q, want payload length error", err.Error()) + } +} + +func TestWriteResultPreservesBinaryV2Payload(t *testing.T) { + payload := []byte{0, 'a', '\n', '\r', 255} + var out bytes.Buffer + + writeResult(&out, childProtocolBinaryV2, payload, nil) + + status, got, err := readBinaryV2OutputFrame(&out) + if err != nil { + t.Fatalf("readBinaryV2OutputFrame returned error: %v", err) + } + if status != binaryV2StatusOK { + t.Fatalf("status = %d, want %d", status, binaryV2StatusOK) + } + if !bytes.Equal(got, payload) { + t.Fatalf("payload = %v, want %v", got, payload) + } +} + +func TestWriteResultUsesBinaryV2EmptyAndErrorStatuses(t *testing.T) { + var emptyOut bytes.Buffer + writeResult(&emptyOut, childProtocolBinaryV2, nil, nil) + status, body, err := readBinaryV2OutputFrame(&emptyOut) + if err != nil { + t.Fatalf("readBinaryV2OutputFrame returned error: %v", err) + } + if status != binaryV2StatusEmpty { + t.Fatalf("status = %d, want %d", status, binaryV2StatusEmpty) + } + if len(body) != 0 { + t.Fatalf("empty body length = %d, want 0", len(body)) + } + + var errorOut bytes.Buffer + writeResult(&errorOut, childProtocolBinaryV2, nil, io.ErrUnexpectedEOF) + status, body, err = readBinaryV2OutputFrame(&errorOut) + if err != nil { + t.Fatalf("readBinaryV2OutputFrame returned error: %v", err) + } + if status != binaryV2StatusError { + t.Fatalf("status = %d, want %d", status, binaryV2StatusError) + } + if string(body) != io.ErrUnexpectedEOF.Error() { + t.Fatalf("error body = %q, want %q", body, io.ErrUnexpectedEOF.Error()) + } +} + +func binaryV2InputFrame(metadata string, payload []byte) []byte { + frame := append([]byte{}, binaryV2InputMagic[:]...) + frame = binary.BigEndian.AppendUint32(frame, uint32(len(metadata))) + frame = binary.BigEndian.AppendUint64(frame, uint64(len(payload))) + frame = append(frame, []byte(metadata)...) + frame = append(frame, payload...) + return frame +} + +func binaryV2InputFrameWithLengths(metadataLen uint32, payloadLen uint64) []byte { + frame := append([]byte{}, binaryV2InputMagic[:]...) + frame = binary.BigEndian.AppendUint32(frame, metadataLen) + frame = binary.BigEndian.AppendUint64(frame, payloadLen) + return frame +} + +func readBinaryV2OutputFrame(reader io.Reader) (byte, []byte, error) { + var magic [4]byte + if _, err := io.ReadFull(reader, magic[:]); err != nil { + return 0, nil, err + } + var status [1]byte + if _, err := io.ReadFull(reader, status[:]); err != nil { + return 0, nil, err + } + var bodyLen [8]byte + if _, err := io.ReadFull(reader, bodyLen[:]); err != nil { + return 0, nil, err + } + body := make([]byte, binary.BigEndian.Uint64(bodyLen[:])) + if _, err := io.ReadFull(reader, body); err != nil { + return 0, nil, err + } + return status[0], body, nil +} + +func bufioReader(data []byte) *bufio.Reader { + return bufio.NewReader(bytes.NewReader(data)) +}