[Iterator] batch pre-decoding (#852)

* feat(backend): batch pre-decoding implementation
This commit is contained in:
Alexander 2022-12-05 11:21:26 +01:00 committed by GitHub
parent 1035e70a4a
commit 17d9a4a116
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 2052 additions and 1849 deletions

View file

@ -0,0 +1,119 @@
package messages
import (
"errors"
"fmt"
"io"
)
type BytesReader interface {
ReadSize() (uint64, error)
ReadByte() (byte, error)
ReadUint() (uint64, error)
ReadInt() (int64, error)
ReadBoolean() (bool, error)
ReadString() (string, error)
Data() []byte
Pointer() int64
SetPointer(p int64)
}
type bytesReaderImpl struct {
data []byte
curr int64
}
func NewBytesReader(data []byte) BytesReader {
return &bytesReaderImpl{
data: data,
}
}
func (m *bytesReaderImpl) ReadSize() (uint64, error) {
if len(m.data)-int(m.curr) < 3 {
return 0, fmt.Errorf("out of range")
}
var size uint64
for i, b := range m.data[m.curr : m.curr+3] {
size += uint64(b) << (8 * i)
}
m.curr += 3
return size, nil
}
func (m *bytesReaderImpl) ReadByte() (byte, error) {
if int(m.curr) >= len(m.data) {
return 0, io.EOF
}
m.curr++
return m.data[m.curr-1], nil
}
func (m *bytesReaderImpl) ReadUint() (uint64, error) {
var x uint64
var s uint
i := 0
for {
b, err := m.ReadByte()
if err != nil {
return x, err
}
if b < 0x80 {
if i > 9 || i == 9 && b > 1 {
return x, errors.New("uint overflow")
}
return x | uint64(b)<<s, nil
}
x |= uint64(b&0x7f) << s
s += 7
i++
}
}
func (m *bytesReaderImpl) ReadInt() (int64, error) {
ux, err := m.ReadUint()
x := int64(ux >> 1)
if err != nil {
return x, err
}
if ux&1 != 0 {
x = ^x
}
return x, err
}
func (m *bytesReaderImpl) ReadBoolean() (bool, error) {
val, err := m.ReadByte()
if err != nil {
return false, err
}
return val == 1, nil
}
func (m *bytesReaderImpl) ReadString() (string, error) {
l, err := m.ReadUint()
if err != nil {
return "", err
}
if l > 10e6 {
return "", errors.New("too long string")
}
if len(m.data)-int(m.curr) < int(l) {
return "", fmt.Errorf("out of range")
}
str := string(m.data[m.curr : int(m.curr)+int(l)])
m.curr += int64(l)
return str, nil
}
func (m *bytesReaderImpl) Data() []byte {
return m.data
}
func (m *bytesReaderImpl) Pointer() int64 {
return m.curr
}
func (m *bytesReaderImpl) SetPointer(p int64) {
m.curr = p
}

View file

@ -1,9 +1,7 @@
package messages
import (
"bytes"
"fmt"
"io"
"log"
)
@ -54,76 +52,32 @@ func (i *messageIteratorImpl) prepareVars(batchInfo *BatchInfo) {
}
func (i *messageIteratorImpl) Iterate(batchData []byte, batchInfo *BatchInfo) {
// Create new message reader
reader := NewMessageReader(batchData)
// Pre-decode batch data
if err := reader.Parse(); err != nil {
log.Printf("pre-decode batch err: %s, info: %s", err, batchInfo.Info())
return
}
// Prepare iterator before processing messages in batch
i.prepareVars(batchInfo)
// Initialize batch reader
reader := bytes.NewReader(batchData)
// Process until end of batch or parsing error
for {
for reader.Next() {
// Increase message index (can be overwritten by batch info message)
i.messageInfo.Index++
if i.broken {
log.Printf("skipping broken batch, info: %s", i.batchInfo.Info())
return
}
if i.canSkip {
if _, err := reader.Seek(int64(i.size), io.SeekCurrent); err != nil {
log.Printf("can't skip message: %s, info: %s", err, i.batchInfo.Info())
return
}
}
i.canSkip = false
// Read message type
msgType, err := ReadUint(reader)
if err != nil {
if err != io.EOF {
log.Printf("can't read message type: %s, info: %s", err, i.batchInfo.Info())
}
return
}
var msg Message
// Read message body (and decode if protocol version less than 1)
if i.version > 0 && messageHasSize(msgType) {
// Read message size if it is a new protocol version
i.size, err = ReadSize(reader)
if err != nil {
log.Printf("can't read message size: %s, info: %s", err, i.batchInfo.Info())
return
}
msg = &RawMessage{
tp: msgType,
size: i.size,
reader: reader,
raw: batchData,
skipped: &i.canSkip,
broken: &i.broken,
meta: i.messageInfo,
}
i.canSkip = true
} else {
msg, err = ReadMessage(msgType, reader)
if err != nil {
if err != io.EOF {
log.Printf("can't read message body: %s, info: %s", err, i.batchInfo.Info())
}
return
}
msg = transformDeprecated(msg)
}
msg := reader.Message()
// Preprocess "system" messages
if _, ok := i.preFilter[msg.TypeID()]; ok {
msg = msg.Decode()
if msg == nil {
log.Printf("decode error, type: %d, info: %s", msgType, i.batchInfo.Info())
log.Printf("decode error, type: %d, info: %s", msg.TypeID(), i.batchInfo.Info())
return
}
msg = transformDeprecated(msg)
if err := i.preprocessing(msg); err != nil {
log.Printf("message preprocessing err: %s", err)
return
@ -140,7 +94,7 @@ func (i *messageIteratorImpl) Iterate(batchData []byte, batchInfo *BatchInfo) {
if i.autoDecode {
msg = msg.Decode()
if msg == nil {
log.Printf("decode error, type: %d, info: %s", msgType, i.batchInfo.Info())
log.Printf("decode error, type: %d, info: %s", msg.TypeID(), i.batchInfo.Info())
return
}
}

View file

@ -1,58 +1,23 @@
package messages
import (
"bytes"
"io"
"log"
)
// RawMessage is a not decoded message
type RawMessage struct {
tp uint64
size uint64
data []byte
reader *bytes.Reader
raw []byte
meta *message
encoded bool
skipped *bool
broken *bool
tp uint64
data []byte
broken *bool
meta *message
}
func (m *RawMessage) Encode() []byte {
if m.encoded {
return m.data
}
// Try to avoid EOF error
if m.reader.Len() < int(m.size) {
return nil
}
// Get current batch position
currPos, err := m.reader.Seek(0, io.SeekCurrent)
if err != nil {
log.Printf("can't get current batch position: %s", err)
return nil
}
// "Move" message type
if currPos == 0 {
log.Printf("can't move message type, curr position = %d", currPos)
return nil
}
// Dirty hack to avoid extra memory allocation
m.raw[currPos-1] = uint8(m.tp)
m.data = m.raw[currPos-1 : currPos+int64(m.size)]
m.encoded = true
return m.data
}
func (m *RawMessage) Decode() Message {
if !m.encoded {
if m.Encode() == nil {
*m.broken = true
return nil
}
}
msg, err := ReadMessage(m.tp, bytes.NewReader(m.data[1:]))
msg, err := ReadMessage(m.tp, NewBytesReader(m.data[1:]))
if err != nil {
log.Printf("decode err: %s", err)
*m.broken = true

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,166 @@
package messages
import (
"fmt"
"io"
)
type MessageReader interface {
Parse() (err error)
Next() bool
Message() Message
}
func NewMessageReader(data []byte) MessageReader {
return &messageReaderImpl{
data: data,
reader: NewBytesReader(data),
list: make([]*MessageMeta, 0, 1024),
}
}
type MessageMeta struct {
msgType uint64
msgSize uint64
msgFrom uint64
}
type messageReaderImpl struct {
data []byte
reader BytesReader
msgType uint64
msgSize uint64
msgBody []byte
version int
broken bool
message Message
err error
list []*MessageMeta
listPtr int
}
func (m *messageReaderImpl) Parse() (err error) {
m.listPtr = 0
m.list = m.list[:0]
m.broken = false
for {
// Try to read and decode message type, message size and check range in
m.msgType, err = m.reader.ReadUint()
if err != nil {
if err != io.EOF {
return fmt.Errorf("read message type err: %s", err)
}
// Reached the end of batch
return nil
}
// Read message body (and decode if protocol version less than 1)
if m.version > 0 && messageHasSize(m.msgType) {
// Read message size if it is a new protocol version
m.msgSize, err = m.reader.ReadSize()
if err != nil {
return fmt.Errorf("read message size err: %s", err)
}
// Try to avoid EOF error
curr := m.reader.Pointer()
if len(m.data)-int(curr) < int(m.msgSize) {
return fmt.Errorf("can't read message body")
}
// Dirty hack to avoid extra memory allocation
m.data[curr-1] = uint8(m.msgType)
// Add message meta to list
m.list = append(m.list, &MessageMeta{
msgType: m.msgType,
msgSize: m.msgSize + 1,
msgFrom: uint64(curr - 1),
})
// Update data pointer
m.reader.SetPointer(curr + int64(m.msgSize))
} else {
from := m.reader.Pointer() - 1
msg, err := ReadMessage(m.msgType, m.reader)
if err != nil {
return fmt.Errorf("read message err: %s", err)
}
if m.msgType == MsgBatchMeta || m.msgType == MsgBatchMetadata {
if len(m.list) > 0 {
return fmt.Errorf("batch meta not at the start of batch")
}
switch message := msg.(type) {
case *BatchMetadata:
m.version = int(message.Version)
case *BatchMeta:
m.version = 0
}
if m.version != 1 {
// Unsupported tracker version, reset reader
m.list = m.list[:0]
m.reader.SetPointer(0)
return nil
}
}
// Add message meta to list
m.list = append(m.list, &MessageMeta{
msgType: m.msgType,
msgSize: uint64(m.reader.Pointer() - from),
msgFrom: uint64(from),
})
}
}
}
func (m *messageReaderImpl) Next() bool {
if m.broken {
return false
}
// For new version of tracker
if len(m.list) > 0 {
if m.listPtr >= len(m.list) {
return false
}
meta := m.list[m.listPtr]
m.listPtr++
m.message = &RawMessage{
tp: meta.msgType,
data: m.data[meta.msgFrom : meta.msgFrom+meta.msgSize],
broken: &m.broken,
meta: &message{},
}
return true
}
// For prev version of tracker
var msg Message
var err error
// Try to read and decode message type, message size and check range in
m.msgType, err = m.reader.ReadUint()
if err != nil {
if err != io.EOF {
m.err = fmt.Errorf("read message type err: %s", err)
}
// Reached the end of batch
return false
}
// Read and decode message
msg, err = ReadMessage(m.msgType, m.reader)
if err != nil {
m.err = fmt.Errorf("read message err: %s", err)
return false
}
m.message = msg
return true
}
func (m *messageReaderImpl) Message() Message {
return m.message
}

View file

@ -6,17 +6,17 @@ import (
"io"
)
<% $messages.each do |msg| %>
func Decode<%= msg.name %>(reader io.Reader) (Message, error) {
func Decode<%= msg.name %>(reader BytesReader) (Message, error) {
var err error = nil
msg := &<%= msg.name %>{}
<%= msg.attributes.map { |attr|
" if msg.#{attr.name}, err = Read#{attr.type.to_s.pascal_case}(reader); err != nil {
" if msg.#{attr.name}, err = reader.Read#{attr.type.to_s.pascal_case}(); err != nil {
return nil, err
}" }.join "\n" %>
return msg, err
}
<% end %>
func ReadMessage(t uint64, reader io.Reader) (Message, error) {
func ReadMessage(t uint64, reader BytesReader) (Message, error) {
switch t {<% $messages.each do |msg| %>
case <%= msg.id %>:
return Decode<%= msg.name %>(reader)<% end %>