From f1291672533db8e73724ccb58b86e4cf5eec3e52 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Dan=20Michael=20O=2E=20Hegg=C3=B8?= <danmichaelo@gmail.com>
Date: Sun, 26 Jul 2015 22:06:42 +0200
Subject: [PATCH] Factor out waiting code into a new Sleeper class

- Moving code related to waiting/sleeping/retrying into a new class for a more object oriented approach.
- Removing any reference to wait "tokens" to avoid confusion with edit tokens.
- Note: `max_retries` and `retry_timeout` are no longer available on `Site`, but can still be passed into the constructor as before.
---
 mwclient/client.py  | 63 ++++++++++++---------------------------------
 mwclient/sleep.py   | 50 +++++++++++++++++++++++++++++++++++
 tests/test_sleep.py | 58 +++++++++++++++++++++++++++++++++++++++++
 3 files changed, 124 insertions(+), 47 deletions(-)
 create mode 100644 mwclient/sleep.py
 create mode 100644 tests/test_sleep.py

diff --git a/mwclient/client.py b/mwclient/client.py
index 2cfe2c5..2be421a 100644
--- a/mwclient/client.py
+++ b/mwclient/client.py
@@ -23,6 +23,7 @@ from requests.auth import HTTPBasicAuth, AuthBase
 
 import mwclient.errors as errors
 import mwclient.listing as listing
+from mwclient.sleep import Sleepers
 
 try:
     import gzip
@@ -34,15 +35,6 @@ __ver__ = '0.8.0.dev1'
 log = logging.getLogger(__name__)
 
 
-class WaitToken(object):
-
-    def __init__(self):
-        self.id = '%032x' % random.getrandbits(128)
-
-    def __hash__(self):
-        return hash(self.id)
-
-
 class Site(object):
     api_limit = 500
 
@@ -55,9 +47,6 @@ class Site(object):
         self.ext = ext
         self.credentials = None
         self.compress = compress
-        self.retry_timeout = retry_timeout
-        self.max_retries = max_retries
-        self.wait_callback = wait_callback
         self.max_lag = text_type(max_lag)
         self.force_login = force_login
 
@@ -68,8 +57,7 @@ class Site(object):
         else:
             raise RuntimeError('Authentication is not a tuple or an instance of AuthBase')
 
-        # The token string => token object mapping
-        self.wait_tokens = weakref.WeakKeyDictionary()
+        self.sleepers = Sleepers(max_retries, retry_timeout, wait_callback)
 
         # Site properties
         self.blocked = False    # Whether current user is blocked
@@ -192,18 +180,18 @@ class Site(object):
             else:
                 kwargs['uiprop'] = 'blockinfo|hasmsg'
 
-        token = self.wait_token()
+        sleeper = self.sleepers.make()
 
         while True:
             info = self.raw_api(action, **kwargs)
             if not info:
                 info = {}
-            if self.handle_api_result(info, token=token):
+            if self.handle_api_result(info, sleeper=sleeper):
                 return info
 
-    def handle_api_result(self, info, kwargs=None, token=None):
-        if token is None:
-            token = self.wait_token()
+    def handle_api_result(self, info, kwargs=None, sleeper=None):
+        if sleeper is None:
+            sleeper = self.sleepers.make()
 
         try:
             userinfo = info['query']['userinfo']
@@ -217,7 +205,7 @@ class Site(object):
         self.logged_in = 'anon' not in userinfo
         if 'error' in info:
             if info['error']['code'] in (u'internal_api_error_DBConnectionError', u'internal_api_error_DBQueryError'):
-                self.wait(token)
+                sleeper.sleep()
                 return False
             if '*' in info['error']:
                 raise errors.APIError(info['error']['code'],
@@ -258,7 +246,7 @@ class Site(object):
         headers = {}
         if self.compress and gzip:
             headers['Accept-Encoding'] = 'gzip'
-        token = self.wait_token((script, data))
+        sleeper = self.sleepers.make((script, data))
         while True:
             scheme = 'http'  # Should we move to 'https' as default?
             host = self.host
@@ -272,7 +260,7 @@ class Site(object):
                 if stream.headers.get('x-database-lag'):
                     wait_time = int(stream.headers.get('retry-after'))
                     log.warn('Database lag exceeds max lag. Waiting for %d seconds', wait_time)
-                    self.wait(token, wait_time)
+                    sleeper.sleep(wait_time)
                 elif stream.status_code == 200:
                     return stream.text
                 elif stream.status_code < 500 or stream.status_code > 599:
@@ -281,7 +269,7 @@ class Site(object):
                     if not retry_on_error:
                         stream.raise_for_status()
                     log.warn('Received %s response: %s. Retrying in a moment.', stream.status_code, stream.text)
-                    self.wait(token)
+                    sleeper.sleep()
 
             except requests.exceptions.ConnectionError:
                 # In the event of a network problem (e.g. DNS failure, refused connection, etc),
@@ -289,7 +277,7 @@ class Site(object):
                 if not retry_on_error:
                     raise
                 log.warn('Connection error. Retrying in a moment.')
-                self.wait(token)
+                sleeper.sleep()
 
     def raw_api(self, action, *args, **kwargs):
         """Sends a call to the API."""
@@ -316,25 +304,6 @@ class Site(object):
         data = self._query_string(*args, **kwargs)
         return self.raw_call('index', data)
 
-    def wait_token(self, args=None):
-        token = WaitToken()
-        self.wait_tokens[token] = (0, args)
-        return token
-
-    def wait(self, token, min_wait=0):
-        retry, args = self.wait_tokens[token]
-        self.wait_tokens[token] = (retry + 1, args)
-        if retry > self.max_retries and self.max_retries != -1:
-            raise errors.MaximumRetriesExceeded(self, token, args)
-        self.wait_callback(self, token, retry, args)
-
-        timeout = self.retry_timeout * retry
-        if timeout < min_wait:
-            timeout = min_wait
-        log.debug('Sleeping for %d seconds', timeout)
-        time.sleep(timeout)
-        return self.wait_tokens[token]
-
     def require(self, major, minor, revision=None, raise_error=True):
         if self.version is None:
             if raise_error is None:
@@ -399,7 +368,7 @@ class Site(object):
             self.conn.cookies[self.host].update(cookies)
 
         if self.credentials:
-            wait_token = self.wait_token()
+            sleeper = self.sleepers.make()
             kwargs = {
                 'lgname': self.credentials[0],
                 'lgpassword': self.credentials[1]
@@ -413,7 +382,7 @@ class Site(object):
                 elif login['login']['result'] == 'NeedToken':
                     kwargs['lgtoken'] = login['login']['token']
                 elif login['login']['result'] == 'Throttled':
-                    self.wait(wait_token, login['login'].get('wait', 5))
+                    sleeper.sleep(int(login['login'].get('wait', 5)))
                 else:
                     raise errors.LoginError(self, login['login'])
 
@@ -541,13 +510,13 @@ class Site(object):
 
             files = {'file': file}
 
-        wait_token = self.wait_token()
+        sleeper = self.sleepers.make()
         while True:
             data = self.raw_call('api', postdata, files)
             info = json.loads(data)
             if not info:
                 info = {}
-            if self.handle_api_result(info, kwargs=predata, token=wait_token):
+            if self.handle_api_result(info, kwargs=predata, sleeper=sleeper):
                 return info.get('upload', {})
 
     def parse(self, text=None, title=None, page=None):
diff --git a/mwclient/sleep.py b/mwclient/sleep.py
new file mode 100644
index 0000000..353f558
--- /dev/null
+++ b/mwclient/sleep.py
@@ -0,0 +1,50 @@
+import random
+import time
+import logging
+from mwclient.errors import MaximumRetriesExceeded
+
+log = logging.getLogger(__name__)
+
+
+class Sleepers(object):
+
+    def __init__(self, max_retries, retry_timeout, callback=lambda *x: None):
+        self.max_retries = max_retries
+        self.retry_timeout = retry_timeout
+        self.callback = callback
+
+    def make(self, args=None):
+        return Sleeper(args, self.max_retries, self.retry_timeout, self.callback)
+
+
+class Sleeper(object):
+    """
+    For any given operation, a `Sleeper` object keeps count of the number of
+    retries. For each retry, the sleep time increases until the max number of
+    retries is reached and a `MaximumRetriesExceeded` is raised. The sleeper
+    object should be discarded once the operation is successful.
+    """
+
+    def __init__(self, args, max_retries, retry_timeout, callback):
+        self.args = args
+        self.retries = 0
+        self.max_retries = max_retries
+        self.retry_timeout = retry_timeout
+        self.callback = callback
+
+    def sleep(self, min_time=0):
+        """
+        Sleep a minimum of `min_time` seconds.
+        The actual sleeping time will increase with the number of retries.
+        """
+        self.retries += 1
+        if self.retries > self.max_retries:
+            raise MaximumRetriesExceeded(self, self.args)
+
+        self.callback(self, self.retries, self.args)
+
+        timeout = self.retry_timeout * (self.retries - 1)
+        if timeout < min_time:
+            timeout = min_time
+        log.debug('Sleeping for %d seconds', timeout)
+        time.sleep(timeout)
diff --git a/tests/test_sleep.py b/tests/test_sleep.py
new file mode 100644
index 0000000..d41d831
--- /dev/null
+++ b/tests/test_sleep.py
@@ -0,0 +1,58 @@
+# encoding=utf-8
+from __future__ import print_function
+import unittest
+import time
+import mock
+import pytest
+from mwclient.sleep import Sleepers
+from mwclient.sleep import Sleeper
+from mwclient.errors import MaximumRetriesExceeded
+
+if __name__ == "__main__":
+    print()
+    print("Note: Running in stand-alone mode. Consult the README")
+    print("      (section 'Contributing') for advice on running tests.")
+    print()
+
+
+class TestSleepers(unittest.TestCase):
+
+    def setUp(self):
+        self.sleep = mock.patch('time.sleep').start()
+        self.max_retries = 10
+        self.sleepers = Sleepers(self.max_retries, 30)
+
+    def tearDown(self):
+        mock.patch.stopall()
+
+    def test_make(self):
+        sleeper = self.sleepers.make()
+        assert type(sleeper) == Sleeper
+        assert sleeper.retries == 0
+
+    def test_sleep(self):
+        sleeper = self.sleepers.make()
+        sleeper.sleep()
+        sleeper.sleep()
+        self.sleep.assert_has_calls([mock.call(0), mock.call(30)])
+
+    def test_min_time(self):
+        sleeper = self.sleepers.make()
+        sleeper.sleep(5)
+        self.sleep.assert_has_calls([mock.call(5)])
+
+    def test_retries_count(self):
+        sleeper = self.sleepers.make()
+        sleeper.sleep()
+        sleeper.sleep()
+        assert sleeper.retries == 2
+
+    def test_max_retries(self):
+        sleeper = self.sleepers.make()
+        for x in range(self.max_retries):
+            sleeper.sleep()
+        with pytest.raises(MaximumRetriesExceeded):
+            sleeper.sleep()
+
+if __name__ == '__main__':
+    unittest.main()
-- 
GitLab