@@ -127,34 +127,40 @@ func (s *SQS) ReceiveWithContextAttributes(ctx context.Context, queueURL string,
127127 WaitTimeSeconds : 20 ,
128128 AttributeNames : attrs ,
129129 }
130- return s .receiveMessage (ctx , & input )
130+ msgs , err := s .receiveMessages (ctx , & input )
131+ if err != nil {
132+ return Raw {}, err
133+ }
134+ return msgs [0 ], err
131135}
132136
133- // receiveMessage is the common code used internally to receive an SQS message based
137+ // receiveMessages is the common code used internally to receive an SQS messages based
134138// on the provided input.
135- func (s * SQS ) receiveMessage (ctx context.Context , input * sqs.ReceiveMessageInput ) (Raw , error ) {
139+ func (s * SQS ) receiveMessages (ctx context.Context , input * sqs.ReceiveMessageInput ) ([] Raw , error ) {
136140 r , err := s .client .ReceiveMessage (ctx , input )
137141 if err != nil {
138- return Raw {}, err
142+ return [] Raw {}, err
139143 }
140144
141145 switch {
142146 case r == nil || len (r .Messages ) == 0 :
143147 // no message received
144- return Raw {}, ErrNoMessages
148+ return [] Raw {}, ErrNoMessages
145149
146- case len (r .Messages ) == 1 :
147- raw := r .Messages [0 ]
150+ case len (r .Messages ) >= 1 :
148151
149- m := Raw {
150- Body : aws .ToString (raw .Body ),
151- ReceiptHandle : aws .ToString (raw .ReceiptHandle ),
152- Attributes : raw .Attributes ,
152+ messages := make ([]Raw , len (r .Messages ))
153+ for i := range r .Messages {
154+ messages [i ] = Raw {
155+ Body : aws .ToString (r .Messages [i ].Body ),
156+ ReceiptHandle : aws .ToString (r .Messages [i ].ReceiptHandle ),
157+ Attributes : r .Messages [i ].Attributes ,
158+ }
153159 }
154- return m , nil
160+ return messages , nil
155161
156162 default :
157- return Raw {}, fmt .Errorf ("received unexpected messages: %d" , len (r .Messages ))
163+ return [] Raw {}, fmt .Errorf ("received unexpected number of messages: %d" , len (r .Messages )) // Probably an impossible case
158164 }
159165}
160166
@@ -169,7 +175,28 @@ func (s *SQS) ReceiveWithContext(ctx context.Context, queueURL string, visibilit
169175 VisibilityTimeout : visibilityTimeout ,
170176 WaitTimeSeconds : 20 ,
171177 }
172- return s .receiveMessage (ctx , & input )
178+ msgs , err := s .receiveMessages (ctx , & input )
179+ if err != nil {
180+ return Raw {}, err
181+ }
182+ return msgs [0 ], err
183+ }
184+
185+ // ReceiveBatch is similar to Receive, however it can return up to 10 messages.
186+ func (s * SQS ) ReceiveBatch (ctx context.Context , queueURL string , visibilityTimeout int32 ) ([]Raw , error ) {
187+
188+ input := sqs.ReceiveMessageInput {
189+ QueueUrl : aws .String (queueURL ),
190+ MaxNumberOfMessages : 10 ,
191+ VisibilityTimeout : visibilityTimeout ,
192+ WaitTimeSeconds : 20 ,
193+ }
194+
195+ msgs , err := s .receiveMessages (ctx , & input )
196+ if err != nil {
197+ return []Raw {}, err
198+ }
199+ return msgs , nil
173200}
174201
175202// Delete deletes the message referred to by receiptHandle from the queue.
@@ -231,44 +258,284 @@ func (s *SQS) SendFifoMessage(queue, group, dedupe string, msg []byte) (string,
231258 return "" , nil
232259}
233260
234- // Leverage the sendbatch api for uploading large numbers of messages
235- func (s * SQS ) SendBatch (ctx context.Context , queueURL string , bodies []string ) error {
236- if len (bodies ) > 11 {
237- return errors .New ("too many messages to batch" )
261+ type SendBatchError struct {
262+ Err error
263+ Info []SendBatchErrorEntry
264+ }
265+ type SendBatchErrorEntry struct {
266+ Entry types.BatchResultErrorEntry
267+ Index int
268+ }
269+
270+ func (s * SendBatchError ) Error () string {
271+ return fmt .Sprintf ("%v: %v messages failed to send" , s .Err , len (s .Info ))
272+ }
273+ func (s * SendBatchError ) Unwrap () error {
274+ return s .Err
275+ }
276+
277+ type SendNBatchError struct {
278+ Errors []error
279+ Info []SendBatchErrorEntry
280+ }
281+
282+ func (s * SendNBatchError ) Error () string {
283+ var allErrors string
284+ for _ , err := range s .Errors {
285+ allErrors += fmt .Sprintf ("%s," , err .Error ())
238286 }
287+ allErrors = strings .TrimSuffix (allErrors , "," )
288+ return fmt .Sprintf ("%v error(s) sending batches: %s" , len (s .Errors ), allErrors )
289+ }
290+
291+ // SendBatch sends up to 10 messages to a given SQS queue with one API call.
292+ // If an error occurs on any or all messages, a SendBatchError is returned that lets
293+ // the caller know the index of the message/s in bodies that failed.
294+ func (s * SQS ) SendBatch (ctx context.Context , queueURL string , bodies []string ) error {
295+
239296 var err error
240297 entries := make ([]types.SendMessageBatchRequestEntry , len (bodies ))
241298 for j , body := range bodies {
242299 entries [j ] = types.SendMessageBatchRequestEntry {
243- Id : aws .String (fmt .Sprintf ("gamitjob %d" , j )),
300+ Id : aws .String (fmt .Sprintf ("message- %d" , j )),
244301 MessageBody : aws .String (body ),
245302 }
246303 }
247- _ , err = s .client .SendMessageBatch (ctx , & sqs.SendMessageBatchInput {
304+ output , err : = s .client .SendMessageBatch (ctx , & sqs.SendMessageBatchInput {
248305 Entries : entries ,
249306 QueueUrl : & queueURL ,
250307 })
251- return err
308+ if err != nil {
309+ info := make ([]SendBatchErrorEntry , len (entries ))
310+ for i := range entries {
311+ info [i ] = SendBatchErrorEntry {
312+ Index : i ,
313+ }
314+ }
315+ return & SendBatchError {Err : err , Info : info }
316+ }
317+ if len (output .Failed ) > 0 {
318+ info := make ([]SendBatchErrorEntry , len (output .Failed ))
319+ for i , entry := range output .Failed {
320+ for j , msg := range entries {
321+ if aws .ToString (msg .Id ) == aws .ToString (entry .Id ) {
322+ info [i ] = SendBatchErrorEntry {
323+ Entry : entry ,
324+ Index : j ,
325+ }
326+ break
327+ }
328+ }
329+ }
330+ return & SendBatchError {Err : errors .New ("partial message failure" ), Info : info }
331+ }
332+ return nil
252333}
253334
254- func (s * SQS ) SendNBatch (ctx context.Context , queueURL string , bodies []string ) error {
335+ // SendNBatch sends any number of messages to a given SQS queue via a series of SendBatch calls.
336+ // If an error occurs on any or all messages, a SendNBatchError is returned that lets
337+ // the caller know the index of the message/s in bodies that failed.
338+ // Returns the number of API calls to SendBatch made.
339+ func (s * SQS ) SendNBatch (ctx context.Context , queueURL string , bodies []string ) (int , error ) {
340+
341+ const (
342+ maxCount = 10
343+ maxSize = 262144 // 256 KiB
344+ )
345+
346+ allErrors := make ([]error , 0 )
347+ allInfo := make ([]SendBatchErrorEntry , 0 )
348+
349+ batchesSent := 0
350+
351+ batch := make ([]int , 0 )
352+ totalSize := 0
353+
354+ sendBatch := func () {
355+ batchBodies := make ([]string , len (batch ))
356+
357+ for i , batchIndex := range batch {
358+ batchBodies [i ] = bodies [batchIndex ]
359+ }
360+
361+ err := s .SendBatch (ctx , queueURL , batchBodies )
362+ var sbe * SendBatchError
363+ if errors .As (err , & sbe ) {
364+ allErrors = append (allErrors , err )
365+
366+ // Update index so that index refers to the position in given bodies slice.
367+ for i := range sbe .Info {
368+ sbe .Info [i ].Index = batch [sbe .Info [i ].Index ]
369+ }
370+
371+ allInfo = append (allInfo , sbe .Info ... )
372+ }
373+
374+ batchesSent ++
375+ batch = batch [:0 ]
376+ totalSize = 0
377+ }
378+
379+ for i , body := range bodies {
380+
381+ // Check if any single message is too big
382+ if len (body ) > maxSize {
383+ allErrors = append (allErrors , errors .New ("message too big to send" ))
384+ allInfo = append (allInfo , SendBatchErrorEntry {
385+ Index : i ,
386+ })
387+ continue
388+ }
389+ // If adding the current message would exceed the batch max size or count, send the current batch.
390+ if totalSize + len (body ) > maxSize || len (batch ) == maxCount {
391+ sendBatch ()
392+ }
393+ batch = append (batch , i )
394+ totalSize += len (body )
395+ }
396+
397+ if len (batch ) > 0 {
398+ sendBatch ()
399+ }
400+
401+ if len (allErrors ) > 0 {
402+ return batchesSent , & SendNBatchError {
403+ Errors : allErrors ,
404+ Info : allInfo ,
405+ }
406+ }
407+
408+ return batchesSent , nil
409+ }
410+
411+ type DeleteBatchError struct {
412+ Err error
413+ Info []DeleteBatchErrorEntry
414+ }
415+
416+ type DeleteBatchErrorEntry struct {
417+ Entry types.BatchResultErrorEntry
418+ Index int
419+ }
420+
421+ func (d * DeleteBatchError ) Error () string {
422+ return fmt .Sprintf ("%v: %v messages failed to delete" , d .Err , len (d .Info ))
423+ }
424+
425+ func (d * DeleteBatchError ) Unwrap () error {
426+ return d .Err
427+ }
428+
429+ type DeleteNBatchError struct {
430+ Errors []error
431+ Info []DeleteBatchErrorEntry
432+ }
433+
434+ func (s * DeleteNBatchError ) Error () string {
435+ var allErrors string
436+ for _ , err := range s .Errors {
437+ allErrors += fmt .Sprintf ("%s," , err .Error ())
438+ }
439+ allErrors = strings .TrimSuffix (allErrors , "," )
440+ return fmt .Sprintf ("%v error(s) deleting batches: %s" , len (s .Errors ), allErrors )
441+ }
442+
443+ // DeleteBatch deletes up to 10 messages from an SQS queue in a single batch.
444+ // If an error occurs on any or all messages, a DeleteBatchError is returned that lets
445+ // the caller know the indice/s in receiptHandles that failed.
446+ func (s * SQS ) DeleteBatch (ctx context.Context , queueURL string , receiptHandles []string ) error {
447+ entries := make ([]types.DeleteMessageBatchRequestEntry , len (receiptHandles ))
448+ for i , receipt := range receiptHandles {
449+ entries [i ] = types.DeleteMessageBatchRequestEntry {
450+ Id : aws .String (fmt .Sprintf ("delete-message-%d" , i )),
451+ ReceiptHandle : aws .String (receipt ),
452+ }
453+ }
454+
455+ output , err := s .client .DeleteMessageBatch (ctx , & sqs.DeleteMessageBatchInput {
456+ Entries : entries ,
457+ QueueUrl : & queueURL ,
458+ })
459+ if err != nil {
460+ info := make ([]DeleteBatchErrorEntry , len (entries ))
461+ for i := range entries {
462+ info [i ] = DeleteBatchErrorEntry {
463+ Index : i ,
464+ }
465+ }
466+ return & DeleteBatchError {Err : err , Info : info }
467+ }
468+ if len (output .Failed ) > 0 {
469+ info := make ([]DeleteBatchErrorEntry , len (output .Failed ))
470+ for i , errorEntry := range output .Failed {
471+ for j , requestEntry := range entries {
472+ if aws .ToString (requestEntry .Id ) == aws .ToString (errorEntry .Id ) {
473+ info [i ] = DeleteBatchErrorEntry {
474+ Entry : errorEntry ,
475+ Index : j ,
476+ }
477+ break
478+ }
479+ }
480+ }
481+ return & DeleteBatchError {Info : info }
482+ }
483+ return nil
484+ }
485+
486+ // DeleteNBatch deletes any number of messages from a given SQS queue via a series of DeleteBatch calls.
487+ // If an error occurs on any or all messages, a DeleteNBatchError is returned that lets
488+ // the caller know the receipt handles that failed.
489+ // Returns the number of API calls to DeleteBatch made.
490+ func (s * SQS ) DeleteNBatch (ctx context.Context , queueURL string , receiptHandles []string ) (int , error ) {
491+
255492 var (
256- bodiesLen = len (bodies )
257- maxlen = 10
258- times = int (math .Ceil (float64 (bodiesLen ) / float64 (maxlen )))
493+ receiptCount = len (receiptHandles )
494+ maxlen = 10
495+ times = int (math .Ceil (float64 (receiptCount ) / float64 (maxlen )))
259496 )
497+
498+ allErrors := make ([]error , 0 )
499+ allInfo := make ([]DeleteBatchErrorEntry , 0 )
500+
501+ batchesDeleted := 0
502+
260503 for i := 0 ; i < times ; i ++ {
261504 batch_end := maxlen * (i + 1 )
262- if maxlen * (i + 1 ) > bodiesLen {
263- batch_end = bodiesLen
505+ if maxlen * (i + 1 ) > receiptCount {
506+ batch_end = receiptCount
264507 }
265- var bodies_batch = bodies [maxlen * i : batch_end ]
266- err := s .SendBatch (ctx , queueURL , bodies_batch )
267- if err != nil {
268- return err
508+ var receipt_batch = receiptHandles [maxlen * i : batch_end ]
509+
510+ indexMap := make (map [int ]int , 0 )
511+ count := 0
512+ for j := maxlen * i ; j < batch_end ; j ++ {
513+ indexMap [count ] = j
514+ count ++
269515 }
516+
517+ err := s .DeleteBatch (ctx , queueURL , receipt_batch )
518+ var dbe * DeleteBatchError
519+ if errors .As (err , & dbe ) {
520+ allErrors = append (allErrors , err )
521+
522+ // Update index so that index refers to the position in given receiptHandles slice.
523+ for i := range dbe .Info {
524+ dbe .Info [i ].Index = indexMap [dbe .Info [i ].Index ]
525+ }
526+
527+ allInfo = append (allInfo , dbe .Info ... )
528+ }
529+ batchesDeleted ++
270530 }
271- return nil
531+
532+ if len (allErrors ) > 0 {
533+ return batchesDeleted , & DeleteNBatchError {
534+ Errors : allErrors ,
535+ Info : allInfo ,
536+ }
537+ }
538+ return batchesDeleted , nil
272539}
273540
274541// GetQueueUrl returns an AWS SQS queue URL given its name.
0 commit comments