@@ -16,15 +16,33 @@ func CatchMatches(r, expect any) bool {
1616 return false
1717 }
1818
19+ expectType := expect .(reflect.Type )
20+
1921 // if expect is an error type, check if r is an instance of it
2022 if rErr , ok := r .(error ); ok {
21- if expectTyp , ok := expect .(reflect.Type ); ok && expectTyp .Implements (errorType ) {
22- expectVal := reflect .New (expectTyp ).Elem ().Interface ().(error )
23- if errors .Is (rErr , expectVal ) {
23+ if expectType .Implements (errorType ) {
24+ // if expectType is a pointer type, instantiate a new value of that type
25+ // and check if rErr is an instance of it
26+ if expectType .Kind () == reflect .Ptr {
27+ expectVal := reflect .New (expectType .Elem ()).Interface ()
28+ if errors .As (rErr , expectVal ) {
29+ return true
30+ }
31+ }
32+ // if expectType is an interface type, check if rErr implements it
33+ if expectType .Kind () == reflect .Interface {
34+ if reflect .TypeOf (rErr ).Implements (expectType ) {
35+ return true
36+ }
37+ }
38+ // otherwise, create a new value of the expectType and check if
39+ // rErr is an instance of it
40+ expectVal := reflect .New (expectType ).Interface ()
41+ if errors .As (rErr , expectVal ) {
2442 return true
2543 }
2644 }
2745 }
2846
29- return reflect .TypeOf (r ).AssignableTo (expect .(reflect. Type ) )
47+ return reflect .TypeOf (r ).AssignableTo (expectType )
3048}
0 commit comments