22
33import errno
44import socket
5+ import threading
6+ import time
57import unittest
68from unittest import mock
79try :
@@ -337,18 +339,6 @@ def test__sock_sendall_none(self):
337339 (10 , self .loop ._sock_sendall , f , True , sock , b'data' ),
338340 self .loop .add_writer .call_args [0 ])
339341
340- def test_sock_connect (self ):
341- sock = test_utils .mock_nonblocking_socket ()
342- self .loop ._sock_connect = mock .Mock ()
343-
344- f = self .loop .sock_connect (sock , ('127.0.0.1' , 8080 ))
345- self .assertIsInstance (f , asyncio .Future )
346- self .loop ._run_once ()
347- future_in , sock_in , address_in = self .loop ._sock_connect .call_args [0 ]
348- self .assertEqual (future_in , f )
349- self .assertEqual (sock_in , sock )
350- self .assertEqual (address_in , ('127.0.0.1' , 8080 ))
351-
352342 def test_sock_connect_timeout (self ):
353343 # asyncio issue #205: sock_connect() must unregister the socket on
354344 # timeout error
@@ -360,29 +350,34 @@ def test_sock_connect_timeout(self):
360350 sock .connect .side_effect = BlockingIOError
361351
362352 # first call to sock_connect() registers the socket
363- fut = self .loop .sock_connect (sock , ('127.0.0.1' , 80 ))
353+ fut = self .loop .create_task (
354+ self .loop .sock_connect (sock , ('127.0.0.1' , 80 )))
364355 self .loop ._run_once ()
365356 self .assertTrue (sock .connect .called )
366357 self .assertTrue (self .loop .add_writer .called )
367- self .assertEqual (len (fut ._callbacks ), 1 )
368358
369359 # on timeout, the socket must be unregistered
370360 sock .connect .reset_mock ()
371- fut .set_exception ( asyncio . TimeoutError )
372- with self .assertRaises (asyncio .TimeoutError ):
361+ fut .cancel ( )
362+ with self .assertRaises (asyncio .CancelledError ):
373363 self .loop .run_until_complete (fut )
374364 self .assertTrue (self .loop .remove_writer .called )
375365
376- def test_sock_connect_resolve_using_socket_params (self ):
366+ @mock .patch ('socket.getaddrinfo' )
367+ def test_sock_connect_resolve_using_socket_params (self , m_gai ):
377368 addr = ('need-resolution.com' , 8080 )
378369 sock = test_utils .mock_nonblocking_socket ()
379- self .loop .getaddrinfo = mock .Mock ()
380- self .loop .sock_connect (sock , addr )
381- while not self .loop .getaddrinfo .called :
370+ m_gai .side_effect = (None , None , None , None , ('127.0.0.1' , 0 ))
371+ m_gai ._is_coroutine = False
372+ con = self .loop .create_task (self .loop .sock_connect (sock , addr ))
373+ while not m_gai .called :
382374 self .loop ._run_once ()
383- self .loop .getaddrinfo .assert_called_with (
384- * addr , type = sock .type , family = sock .family , proto = sock .proto ,
385- flags = 0 )
375+ m_gai .assert_called_with (
376+ addr [0 ], addr [1 ], sock .family , sock .type , sock .proto , 0 )
377+
378+ con .cancel ()
379+ with self .assertRaises (asyncio .CancelledError ):
380+ self .loop .run_until_complete (con )
386381
387382 def test__sock_connect (self ):
388383 f = asyncio .Future (loop = self .loop )
@@ -1792,5 +1787,88 @@ def test_fatal_error_connected(self, m_exc):
17921787 exc_info = (ConnectionRefusedError , MOCK_ANY , MOCK_ANY ))
17931788
17941789
1790+ class SelectorLoopFunctionalTests (unittest .TestCase ):
1791+
1792+ def setUp (self ):
1793+ self .loop = asyncio .new_event_loop ()
1794+ asyncio .set_event_loop (None )
1795+
1796+ def tearDown (self ):
1797+ self .loop .close ()
1798+
1799+ @asyncio .coroutine
1800+ def recv_all (self , sock , nbytes ):
1801+ buf = b''
1802+ while len (buf ) < nbytes :
1803+ buf += yield from self .loop .sock_recv (sock , nbytes - len (buf ))
1804+ return buf
1805+
1806+ def test_sock_connect_sock_write_race (self ):
1807+ TIMEOUT = 3.0
1808+ PAYLOAD = b'DATA' * 1024 * 1024
1809+
1810+ class Server (threading .Thread ):
1811+ def __init__ (self , * args , srv_sock , ** kwargs ):
1812+ super ().__init__ (* args , ** kwargs )
1813+ self .srv_sock = srv_sock
1814+
1815+ def run (self ):
1816+ with self .srv_sock :
1817+ srv_sock .listen (100 )
1818+
1819+ sock , addr = self .srv_sock .accept ()
1820+ sock .settimeout (TIMEOUT )
1821+
1822+ with sock :
1823+ sock .sendall (b'helo' )
1824+
1825+ buf = bytearray ()
1826+ while len (buf ) < len (PAYLOAD ):
1827+ pack = sock .recv (1024 * 65 )
1828+ if not pack :
1829+ break
1830+ buf .extend (pack )
1831+
1832+ @asyncio .coroutine
1833+ def client (addr ):
1834+ sock = socket .socket ()
1835+ with sock :
1836+ sock .setblocking (False )
1837+
1838+ started = time .monotonic ()
1839+ while True :
1840+ if time .monotonic () - started > TIMEOUT :
1841+ self .fail ('unable to connect to the socket' )
1842+ return
1843+ try :
1844+ yield from self .loop .sock_connect (sock , addr )
1845+ except OSError :
1846+ yield from asyncio .sleep (0.05 , loop = self .loop )
1847+ else :
1848+ break
1849+
1850+ # Give 'Server' thread a chance to accept and send b'helo'
1851+ time .sleep (0.1 )
1852+
1853+ data = yield from self .recv_all (sock , 4 )
1854+ self .assertEqual (data , b'helo' )
1855+ yield from self .loop .sock_sendall (sock , PAYLOAD )
1856+
1857+ srv_sock = socket .socket ()
1858+ srv_sock .settimeout (TIMEOUT )
1859+ srv_sock .bind (('127.0.0.1' , 0 ))
1860+ srv_addr = srv_sock .getsockname ()
1861+
1862+ srv = Server (srv_sock = srv_sock , daemon = True )
1863+ srv .start ()
1864+
1865+ try :
1866+ self .loop .run_until_complete (
1867+ asyncio .wait_for (client (srv_addr ), loop = self .loop ,
1868+ timeout = TIMEOUT ))
1869+ finally :
1870+ srv .join ()
1871+
1872+
17951873if __name__ == '__main__' :
17961874 unittest .main ()
0 commit comments