Skip to content

Commit 5b8e2dd

Browse files
committed
Fix mock
1 parent 4bd9bce commit 5b8e2dd

File tree

2 files changed

+107
-82
lines changed

2 files changed

+107
-82
lines changed

asyncio/unix_events.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -519,10 +519,7 @@ def _write_ready(self):
519519
assert self._buffer, 'Data should not be empty'
520520

521521
try:
522-
# list() is used only to workaround BUG in @mock.patch('os.writev').
523-
# Mock stores internal reference to _buffer, which is then MODIFIED
524-
# during this call. That's why we create separate superfluous copy.
525-
n = os.writev(self._fileno, list(self._buffer))
522+
n = os.writev(self._fileno, self._buffer)
526523
except (BlockingIOError, InterruptedError):
527524
return
528525
except Exception as exc:

tests/test_unix_events.py

Lines changed: 106 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests for unix_events.py."""
22

33
import collections
4+
import contextlib
45
import errno
56
import io
67
import os
@@ -35,6 +36,33 @@ def close_pipe_transport(transport):
3536
transport._pipe = None
3637

3738

39+
@contextlib.contextmanager
40+
def patched_os_writev():
41+
42+
class Wrapper:
43+
def __init__(self, wrapped):
44+
self._wrapped = wrapped
45+
46+
def __call__(self, fd, buffers):
47+
return self._wrapped(fd, list(buffers))
48+
49+
def __getattr__(self, name):
50+
if name.startswith('_'):
51+
return object.__getattr__(self, name)
52+
else:
53+
return getattr(self._wrapped, name)
54+
55+
def __setattr__(self, name, val):
56+
if name.startswith('_'):
57+
object.__setattr__(self, name, val)
58+
else:
59+
setattr(self._wrapped, name, val)
60+
61+
m_writev = Wrapper(mock.Mock())
62+
with mock.patch('os.writev', m_writev):
63+
yield m_writev
64+
65+
3866
@unittest.skipUnless(signal, 'Signals are not supported')
3967
class SelectorEventLoopSignalTests(test_utils.TestCase):
4068

@@ -602,86 +630,86 @@ def test__read_ready(self):
602630
test_utils.run_briefly(self.loop)
603631
self.protocol.connection_lost.assert_called_with(None)
604632

605-
@mock.patch('os.writev')
606-
def test__write_ready(self, m_writev):
607-
tr = self.write_pipe_transport()
608-
self.loop.add_writer(5, tr._write_ready)
609-
tr._buffer.extend([b'da', b'ta'])
610-
m_writev.return_value = 4
611-
tr._write_ready()
612-
m_writev.assert_called_with(5, [b'da', b'ta'])
613-
self.assertFalse(self.loop.writers)
614-
self.assertEqual([], list(tr._buffer))
615-
616-
@mock.patch('os.writev')
617-
def test__write_ready_partial(self, m_writev):
618-
tr = self.write_pipe_transport()
619-
self.loop.add_writer(5, tr._write_ready)
620-
tr._buffer.extend([b'da', b'ta'])
621-
m_writev.return_value = 3
622-
tr._write_ready()
623-
m_writev.assert_called_with(5, [b'da', b'ta'])
624-
self.loop.assert_writer(5, tr._write_ready)
625-
self.assertEqual([b'a'], list(tr._buffer))
626-
627-
@mock.patch('os.writev')
628-
def test__write_ready_again(self, m_writev):
629-
tr = self.write_pipe_transport()
630-
self.loop.add_writer(5, tr._write_ready)
631-
tr._buffer.extend([b'da', b'ta'])
632-
m_writev.side_effect = BlockingIOError()
633-
tr._write_ready()
634-
m_writev.assert_called_with(5, [b'da', b'ta'])
635-
self.loop.assert_writer(5, tr._write_ready)
636-
self.assertEqual([b'da', b'ta'], list(tr._buffer))
637-
638-
@mock.patch('os.writev')
639-
def test__write_ready_empty(self, m_writev):
640-
tr = self.write_pipe_transport()
641-
self.loop.add_writer(5, tr._write_ready)
642-
tr._buffer.extend([b'da', b'ta'])
643-
m_writev.return_value = 0
644-
tr._write_ready()
645-
m_writev.assert_called_with(5, [b'da', b'ta'])
646-
self.loop.assert_writer(5, tr._write_ready)
647-
self.assertEqual([b'da', b'ta'], list(tr._buffer))
633+
def test__write_ready(self):
634+
with patched_os_writev() as m_writev:
635+
tr = self.write_pipe_transport()
636+
self.loop.add_writer(5, tr._write_ready)
637+
tr._buffer.extend([b'da', b'ta'])
638+
m_writev.return_value = 4
639+
tr._write_ready()
640+
m_writev.assert_called_with(5, [b'da', b'ta'])
641+
self.assertFalse(self.loop.writers)
642+
self.assertEqual([], list(tr._buffer))
643+
644+
def test__write_ready_partial(self):
645+
with patched_os_writev() as m_writev:
646+
tr = self.write_pipe_transport()
647+
self.loop.add_writer(5, tr._write_ready)
648+
tr._buffer.extend([b'da', b'ta'])
649+
m_writev.return_value = 3
650+
tr._write_ready()
651+
m_writev.assert_called_with(5, [b'da', b'ta'])
652+
self.loop.assert_writer(5, tr._write_ready)
653+
self.assertEqual([b'a'], list(tr._buffer))
654+
655+
def test__write_ready_again(self):
656+
with patched_os_writev() as m_writev:
657+
tr = self.write_pipe_transport()
658+
self.loop.add_writer(5, tr._write_ready)
659+
tr._buffer.extend([b'da', b'ta'])
660+
m_writev.side_effect = BlockingIOError()
661+
tr._write_ready()
662+
m_writev.assert_called_with(5, [b'da', b'ta'])
663+
self.loop.assert_writer(5, tr._write_ready)
664+
self.assertEqual([b'da', b'ta'], list(tr._buffer))
665+
666+
def test__write_ready_empty(self):
667+
with patched_os_writev() as m_writev:
668+
tr = self.write_pipe_transport()
669+
self.loop.add_writer(5, tr._write_ready)
670+
tr._buffer.extend([b'da', b'ta'])
671+
m_writev.return_value = 0
672+
tr._write_ready()
673+
m_writev.assert_called_with(5, [b'da', b'ta'])
674+
self.loop.assert_writer(5, tr._write_ready)
675+
self.assertEqual([b'da', b'ta'], list(tr._buffer))
648676

649677
@mock.patch('asyncio.log.logger.error')
650-
@mock.patch('os.writev')
651-
def test__write_ready_err(self, m_writev, m_logexc):
652-
tr = self.write_pipe_transport()
653-
self.loop.add_writer(5, tr._write_ready)
654-
tr._buffer.extend([b'da', b'ta'])
655-
m_writev.side_effect = err = OSError()
656-
tr._write_ready()
657-
m_writev.assert_called_with(5, [b'da', b'ta'])
658-
self.assertFalse(self.loop.writers)
659-
self.assertFalse(self.loop.readers)
660-
self.assertEqual([], list(tr._buffer))
661-
self.assertTrue(tr.is_closing())
662-
m_logexc.assert_called_with(
663-
test_utils.MockPattern(
664-
'Fatal write error on pipe transport'
665-
'\nprotocol:.*\ntransport:.*'),
666-
exc_info=(OSError, MOCK_ANY, MOCK_ANY))
667-
self.assertEqual(1, tr._conn_lost)
668-
test_utils.run_briefly(self.loop)
669-
self.protocol.connection_lost.assert_called_with(err)
670-
671-
@mock.patch('os.writev')
672-
def test__write_ready_closing(self, m_writev):
673-
tr = self.write_pipe_transport()
674-
self.loop.add_writer(5, tr._write_ready)
675-
tr._closing = True
676-
tr._buffer.extend([b'da', b'ta'])
677-
m_writev.return_value = 4
678-
tr._write_ready()
679-
m_writev.assert_called_with(5, [b'da', b'ta'])
680-
self.assertFalse(self.loop.writers)
681-
self.assertFalse(self.loop.readers)
682-
self.assertEqual([], list(tr._buffer))
683-
self.protocol.connection_lost.assert_called_with(None)
684-
self.pipe.close.assert_called_with()
678+
def test__write_ready_err(self, m_logexc):
679+
with patched_os_writev() as m_writev:
680+
tr = self.write_pipe_transport()
681+
self.loop.add_writer(5, tr._write_ready)
682+
tr._buffer.extend([b'da', b'ta'])
683+
m_writev.side_effect = err = OSError()
684+
tr._write_ready()
685+
m_writev.assert_called_with(5, [b'da', b'ta'])
686+
self.assertFalse(self.loop.writers)
687+
self.assertFalse(self.loop.readers)
688+
self.assertEqual([], list(tr._buffer))
689+
self.assertTrue(tr.is_closing())
690+
m_logexc.assert_called_with(
691+
test_utils.MockPattern(
692+
'Fatal write error on pipe transport'
693+
'\nprotocol:.*\ntransport:.*'),
694+
exc_info=(OSError, MOCK_ANY, MOCK_ANY))
695+
self.assertEqual(1, tr._conn_lost)
696+
test_utils.run_briefly(self.loop)
697+
self.protocol.connection_lost.assert_called_with(err)
698+
699+
def test__write_ready_closing(self):
700+
with patched_os_writev() as m_writev:
701+
tr = self.write_pipe_transport()
702+
self.loop.add_writer(5, tr._write_ready)
703+
tr._closing = True
704+
tr._buffer.extend([b'da', b'ta'])
705+
m_writev.return_value = 4
706+
tr._write_ready()
707+
m_writev.assert_called_with(5, [b'da', b'ta'])
708+
self.assertFalse(self.loop.writers)
709+
self.assertFalse(self.loop.readers)
710+
self.assertEqual([], list(tr._buffer))
711+
self.protocol.connection_lost.assert_called_with(None)
712+
self.pipe.close.assert_called_with()
685713

686714
@mock.patch('os.write')
687715
def test_abort(self, m_write):

0 commit comments

Comments
 (0)