@@ -202,3 +202,34 @@ def test_bitwise_shift_error():
202202 assert_raises (
203203 ValueError , lambda : bitwise_right_shift (asarray ([1 , 1 ]), asarray ([1 , - 1 ]))
204204 )
205+
206+
207+ def test_scalars ():
208+ # Test that binary functions accept (array, scalar) and (scalar, array) arguments
209+ # and reject (scalar, scalar) arguments
210+
211+ def _sample_scalar (category ):
212+ if 'boolean' in category :
213+ return True
214+ elif 'floating-point' in category :
215+ return 1.0
216+ elif 'numeric' in category or 'integer' in category or 'all' in category :
217+ return 1
218+ else :
219+ raise ValueError (f'Unknown { category = } ' )
220+
221+ for func_name , types in elementwise_function_input_types .items ():
222+ dtypes = _dtype_categories [types ]
223+ func = getattr (_elementwise_functions , func_name )
224+ if nargs (func ) == 2 :
225+ print (func_name , types , _sample_scalar (types ))
226+ scalar = _sample_scalar (types )
227+ for dt in dtypes :
228+ array = asarray (scalar , dtype = dt )
229+ conv_scalar = asarray (scalar , dtype = array .dtype )
230+ assert func (scalar , array ) == func (conv_scalar , array )
231+ assert func (array , scalar ) == func (array , conv_scalar )
232+
233+ with assert_raises (TypeError ):
234+ func (scalar , scalar )
235+
0 commit comments