Skip to content

Commit d71b97b

Browse files
authored
Merge pull request #187 from GeoNet/batchMethods
Batch methods
2 parents 73fdbd3 + d2a8033 commit d71b97b

2 files changed

Lines changed: 688 additions & 33 deletions

File tree

aws/sqs/sqs.go

Lines changed: 299 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -127,34 +127,40 @@ func (s *SQS) ReceiveWithContextAttributes(ctx context.Context, queueURL string,
127127
WaitTimeSeconds: 20,
128128
AttributeNames: attrs,
129129
}
130-
return s.receiveMessage(ctx, &input)
130+
msgs, err := s.receiveMessages(ctx, &input)
131+
if err != nil {
132+
return Raw{}, err
133+
}
134+
return msgs[0], err
131135
}
132136

133-
// receiveMessage is the common code used internally to receive an SQS message based
137+
// receiveMessages is the common code used internally to receive an SQS messages based
134138
// on the provided input.
135-
func (s *SQS) receiveMessage(ctx context.Context, input *sqs.ReceiveMessageInput) (Raw, error) {
139+
func (s *SQS) receiveMessages(ctx context.Context, input *sqs.ReceiveMessageInput) ([]Raw, error) {
136140
r, err := s.client.ReceiveMessage(ctx, input)
137141
if err != nil {
138-
return Raw{}, err
142+
return []Raw{}, err
139143
}
140144

141145
switch {
142146
case r == nil || len(r.Messages) == 0:
143147
// no message received
144-
return Raw{}, ErrNoMessages
148+
return []Raw{}, ErrNoMessages
145149

146-
case len(r.Messages) == 1:
147-
raw := r.Messages[0]
150+
case len(r.Messages) >= 1:
148151

149-
m := Raw{
150-
Body: aws.ToString(raw.Body),
151-
ReceiptHandle: aws.ToString(raw.ReceiptHandle),
152-
Attributes: raw.Attributes,
152+
messages := make([]Raw, len(r.Messages))
153+
for i := range r.Messages {
154+
messages[i] = Raw{
155+
Body: aws.ToString(r.Messages[i].Body),
156+
ReceiptHandle: aws.ToString(r.Messages[i].ReceiptHandle),
157+
Attributes: r.Messages[i].Attributes,
158+
}
153159
}
154-
return m, nil
160+
return messages, nil
155161

156162
default:
157-
return Raw{}, fmt.Errorf("received unexpected messages: %d", len(r.Messages))
163+
return []Raw{}, fmt.Errorf("received unexpected number of messages: %d", len(r.Messages)) // Probably an impossible case
158164
}
159165
}
160166

@@ -169,7 +175,28 @@ func (s *SQS) ReceiveWithContext(ctx context.Context, queueURL string, visibilit
169175
VisibilityTimeout: visibilityTimeout,
170176
WaitTimeSeconds: 20,
171177
}
172-
return s.receiveMessage(ctx, &input)
178+
msgs, err := s.receiveMessages(ctx, &input)
179+
if err != nil {
180+
return Raw{}, err
181+
}
182+
return msgs[0], err
183+
}
184+
185+
// ReceiveBatch is similar to Receive, however it can return up to 10 messages.
186+
func (s *SQS) ReceiveBatch(ctx context.Context, queueURL string, visibilityTimeout int32) ([]Raw, error) {
187+
188+
input := sqs.ReceiveMessageInput{
189+
QueueUrl: aws.String(queueURL),
190+
MaxNumberOfMessages: 10,
191+
VisibilityTimeout: visibilityTimeout,
192+
WaitTimeSeconds: 20,
193+
}
194+
195+
msgs, err := s.receiveMessages(ctx, &input)
196+
if err != nil {
197+
return []Raw{}, err
198+
}
199+
return msgs, nil
173200
}
174201

175202
// Delete deletes the message referred to by receiptHandle from the queue.
@@ -231,44 +258,284 @@ func (s *SQS) SendFifoMessage(queue, group, dedupe string, msg []byte) (string,
231258
return "", nil
232259
}
233260

234-
// Leverage the sendbatch api for uploading large numbers of messages
235-
func (s *SQS) SendBatch(ctx context.Context, queueURL string, bodies []string) error {
236-
if len(bodies) > 11 {
237-
return errors.New("too many messages to batch")
261+
type SendBatchError struct {
262+
Err error
263+
Info []SendBatchErrorEntry
264+
}
265+
type SendBatchErrorEntry struct {
266+
Entry types.BatchResultErrorEntry
267+
Index int
268+
}
269+
270+
func (s *SendBatchError) Error() string {
271+
return fmt.Sprintf("%v: %v messages failed to send", s.Err, len(s.Info))
272+
}
273+
func (s *SendBatchError) Unwrap() error {
274+
return s.Err
275+
}
276+
277+
type SendNBatchError struct {
278+
Errors []error
279+
Info []SendBatchErrorEntry
280+
}
281+
282+
func (s *SendNBatchError) Error() string {
283+
var allErrors string
284+
for _, err := range s.Errors {
285+
allErrors += fmt.Sprintf("%s,", err.Error())
238286
}
287+
allErrors = strings.TrimSuffix(allErrors, ",")
288+
return fmt.Sprintf("%v error(s) sending batches: %s", len(s.Errors), allErrors)
289+
}
290+
291+
// SendBatch sends up to 10 messages to a given SQS queue with one API call.
292+
// If an error occurs on any or all messages, a SendBatchError is returned that lets
293+
// the caller know the index of the message/s in bodies that failed.
294+
func (s *SQS) SendBatch(ctx context.Context, queueURL string, bodies []string) error {
295+
239296
var err error
240297
entries := make([]types.SendMessageBatchRequestEntry, len(bodies))
241298
for j, body := range bodies {
242299
entries[j] = types.SendMessageBatchRequestEntry{
243-
Id: aws.String(fmt.Sprintf("gamitjob%d", j)),
300+
Id: aws.String(fmt.Sprintf("message-%d", j)),
244301
MessageBody: aws.String(body),
245302
}
246303
}
247-
_, err = s.client.SendMessageBatch(ctx, &sqs.SendMessageBatchInput{
304+
output, err := s.client.SendMessageBatch(ctx, &sqs.SendMessageBatchInput{
248305
Entries: entries,
249306
QueueUrl: &queueURL,
250307
})
251-
return err
308+
if err != nil {
309+
info := make([]SendBatchErrorEntry, len(entries))
310+
for i := range entries {
311+
info[i] = SendBatchErrorEntry{
312+
Index: i,
313+
}
314+
}
315+
return &SendBatchError{Err: err, Info: info}
316+
}
317+
if len(output.Failed) > 0 {
318+
info := make([]SendBatchErrorEntry, len(output.Failed))
319+
for i, entry := range output.Failed {
320+
for j, msg := range entries {
321+
if aws.ToString(msg.Id) == aws.ToString(entry.Id) {
322+
info[i] = SendBatchErrorEntry{
323+
Entry: entry,
324+
Index: j,
325+
}
326+
break
327+
}
328+
}
329+
}
330+
return &SendBatchError{Err: errors.New("partial message failure"), Info: info}
331+
}
332+
return nil
252333
}
253334

254-
func (s *SQS) SendNBatch(ctx context.Context, queueURL string, bodies []string) error {
335+
// SendNBatch sends any number of messages to a given SQS queue via a series of SendBatch calls.
336+
// If an error occurs on any or all messages, a SendNBatchError is returned that lets
337+
// the caller know the index of the message/s in bodies that failed.
338+
// Returns the number of API calls to SendBatch made.
339+
func (s *SQS) SendNBatch(ctx context.Context, queueURL string, bodies []string) (int, error) {
340+
341+
const (
342+
maxCount = 10
343+
maxSize = 262144 // 256 KiB
344+
)
345+
346+
allErrors := make([]error, 0)
347+
allInfo := make([]SendBatchErrorEntry, 0)
348+
349+
batchesSent := 0
350+
351+
batch := make([]int, 0)
352+
totalSize := 0
353+
354+
sendBatch := func() {
355+
batchBodies := make([]string, len(batch))
356+
357+
for i, batchIndex := range batch {
358+
batchBodies[i] = bodies[batchIndex]
359+
}
360+
361+
err := s.SendBatch(ctx, queueURL, batchBodies)
362+
var sbe *SendBatchError
363+
if errors.As(err, &sbe) {
364+
allErrors = append(allErrors, err)
365+
366+
// Update index so that index refers to the position in given bodies slice.
367+
for i := range sbe.Info {
368+
sbe.Info[i].Index = batch[sbe.Info[i].Index]
369+
}
370+
371+
allInfo = append(allInfo, sbe.Info...)
372+
}
373+
374+
batchesSent++
375+
batch = batch[:0]
376+
totalSize = 0
377+
}
378+
379+
for i, body := range bodies {
380+
381+
// Check if any single message is too big
382+
if len(body) > maxSize {
383+
allErrors = append(allErrors, errors.New("message too big to send"))
384+
allInfo = append(allInfo, SendBatchErrorEntry{
385+
Index: i,
386+
})
387+
continue
388+
}
389+
// If adding the current message would exceed the batch max size or count, send the current batch.
390+
if totalSize+len(body) > maxSize || len(batch) == maxCount {
391+
sendBatch()
392+
}
393+
batch = append(batch, i)
394+
totalSize += len(body)
395+
}
396+
397+
if len(batch) > 0 {
398+
sendBatch()
399+
}
400+
401+
if len(allErrors) > 0 {
402+
return batchesSent, &SendNBatchError{
403+
Errors: allErrors,
404+
Info: allInfo,
405+
}
406+
}
407+
408+
return batchesSent, nil
409+
}
410+
411+
type DeleteBatchError struct {
412+
Err error
413+
Info []DeleteBatchErrorEntry
414+
}
415+
416+
type DeleteBatchErrorEntry struct {
417+
Entry types.BatchResultErrorEntry
418+
Index int
419+
}
420+
421+
func (d *DeleteBatchError) Error() string {
422+
return fmt.Sprintf("%v: %v messages failed to delete", d.Err, len(d.Info))
423+
}
424+
425+
func (d *DeleteBatchError) Unwrap() error {
426+
return d.Err
427+
}
428+
429+
type DeleteNBatchError struct {
430+
Errors []error
431+
Info []DeleteBatchErrorEntry
432+
}
433+
434+
func (s *DeleteNBatchError) Error() string {
435+
var allErrors string
436+
for _, err := range s.Errors {
437+
allErrors += fmt.Sprintf("%s,", err.Error())
438+
}
439+
allErrors = strings.TrimSuffix(allErrors, ",")
440+
return fmt.Sprintf("%v error(s) deleting batches: %s", len(s.Errors), allErrors)
441+
}
442+
443+
// DeleteBatch deletes up to 10 messages from an SQS queue in a single batch.
444+
// If an error occurs on any or all messages, a DeleteBatchError is returned that lets
445+
// the caller know the indice/s in receiptHandles that failed.
446+
func (s *SQS) DeleteBatch(ctx context.Context, queueURL string, receiptHandles []string) error {
447+
entries := make([]types.DeleteMessageBatchRequestEntry, len(receiptHandles))
448+
for i, receipt := range receiptHandles {
449+
entries[i] = types.DeleteMessageBatchRequestEntry{
450+
Id: aws.String(fmt.Sprintf("delete-message-%d", i)),
451+
ReceiptHandle: aws.String(receipt),
452+
}
453+
}
454+
455+
output, err := s.client.DeleteMessageBatch(ctx, &sqs.DeleteMessageBatchInput{
456+
Entries: entries,
457+
QueueUrl: &queueURL,
458+
})
459+
if err != nil {
460+
info := make([]DeleteBatchErrorEntry, len(entries))
461+
for i := range entries {
462+
info[i] = DeleteBatchErrorEntry{
463+
Index: i,
464+
}
465+
}
466+
return &DeleteBatchError{Err: err, Info: info}
467+
}
468+
if len(output.Failed) > 0 {
469+
info := make([]DeleteBatchErrorEntry, len(output.Failed))
470+
for i, errorEntry := range output.Failed {
471+
for j, requestEntry := range entries {
472+
if aws.ToString(requestEntry.Id) == aws.ToString(errorEntry.Id) {
473+
info[i] = DeleteBatchErrorEntry{
474+
Entry: errorEntry,
475+
Index: j,
476+
}
477+
break
478+
}
479+
}
480+
}
481+
return &DeleteBatchError{Info: info}
482+
}
483+
return nil
484+
}
485+
486+
// DeleteNBatch deletes any number of messages from a given SQS queue via a series of DeleteBatch calls.
487+
// If an error occurs on any or all messages, a DeleteNBatchError is returned that lets
488+
// the caller know the receipt handles that failed.
489+
// Returns the number of API calls to DeleteBatch made.
490+
func (s *SQS) DeleteNBatch(ctx context.Context, queueURL string, receiptHandles []string) (int, error) {
491+
255492
var (
256-
bodiesLen = len(bodies)
257-
maxlen = 10
258-
times = int(math.Ceil(float64(bodiesLen) / float64(maxlen)))
493+
receiptCount = len(receiptHandles)
494+
maxlen = 10
495+
times = int(math.Ceil(float64(receiptCount) / float64(maxlen)))
259496
)
497+
498+
allErrors := make([]error, 0)
499+
allInfo := make([]DeleteBatchErrorEntry, 0)
500+
501+
batchesDeleted := 0
502+
260503
for i := 0; i < times; i++ {
261504
batch_end := maxlen * (i + 1)
262-
if maxlen*(i+1) > bodiesLen {
263-
batch_end = bodiesLen
505+
if maxlen*(i+1) > receiptCount {
506+
batch_end = receiptCount
264507
}
265-
var bodies_batch = bodies[maxlen*i : batch_end]
266-
err := s.SendBatch(ctx, queueURL, bodies_batch)
267-
if err != nil {
268-
return err
508+
var receipt_batch = receiptHandles[maxlen*i : batch_end]
509+
510+
indexMap := make(map[int]int, 0)
511+
count := 0
512+
for j := maxlen * i; j < batch_end; j++ {
513+
indexMap[count] = j
514+
count++
269515
}
516+
517+
err := s.DeleteBatch(ctx, queueURL, receipt_batch)
518+
var dbe *DeleteBatchError
519+
if errors.As(err, &dbe) {
520+
allErrors = append(allErrors, err)
521+
522+
// Update index so that index refers to the position in given receiptHandles slice.
523+
for i := range dbe.Info {
524+
dbe.Info[i].Index = indexMap[dbe.Info[i].Index]
525+
}
526+
527+
allInfo = append(allInfo, dbe.Info...)
528+
}
529+
batchesDeleted++
270530
}
271-
return nil
531+
532+
if len(allErrors) > 0 {
533+
return batchesDeleted, &DeleteNBatchError{
534+
Errors: allErrors,
535+
Info: allInfo,
536+
}
537+
}
538+
return batchesDeleted, nil
272539
}
273540

274541
// GetQueueUrl returns an AWS SQS queue URL given its name.

0 commit comments

Comments
 (0)