From f1291672533db8e73724ccb58b86e4cf5eec3e52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dan=20Michael=20O=2E=20Hegg=C3=B8?= <danmichaelo@gmail.com> Date: Sun, 26 Jul 2015 22:06:42 +0200 Subject: [PATCH] Factor out waiting code into a new Sleeper class - Moving code related to waiting/sleeping/retrying into a new class for a more object oriented approach. - Removing any reference to wait "tokens" to avoid confusion with edit tokens. - Note: `max_retries` and `retry_timeout` are no longer available on `Site`, but can still be passed into the constructor as before. --- mwclient/client.py | 63 ++++++++++++--------------------------------- mwclient/sleep.py | 50 +++++++++++++++++++++++++++++++++++ tests/test_sleep.py | 58 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 124 insertions(+), 47 deletions(-) create mode 100644 mwclient/sleep.py create mode 100644 tests/test_sleep.py diff --git a/mwclient/client.py b/mwclient/client.py index 2cfe2c5..2be421a 100644 --- a/mwclient/client.py +++ b/mwclient/client.py @@ -23,6 +23,7 @@ from requests.auth import HTTPBasicAuth, AuthBase import mwclient.errors as errors import mwclient.listing as listing +from mwclient.sleep import Sleepers try: import gzip @@ -34,15 +35,6 @@ __ver__ = '0.8.0.dev1' log = logging.getLogger(__name__) -class WaitToken(object): - - def __init__(self): - self.id = '%032x' % random.getrandbits(128) - - def __hash__(self): - return hash(self.id) - - class Site(object): api_limit = 500 @@ -55,9 +47,6 @@ class Site(object): self.ext = ext self.credentials = None self.compress = compress - self.retry_timeout = retry_timeout - self.max_retries = max_retries - self.wait_callback = wait_callback self.max_lag = text_type(max_lag) self.force_login = force_login @@ -68,8 +57,7 @@ class Site(object): else: raise RuntimeError('Authentication is not a tuple or an instance of AuthBase') - # The token string => token object mapping - self.wait_tokens = weakref.WeakKeyDictionary() + self.sleepers = Sleepers(max_retries, retry_timeout, wait_callback) # Site properties self.blocked = False # Whether current user is blocked @@ -192,18 +180,18 @@ class Site(object): else: kwargs['uiprop'] = 'blockinfo|hasmsg' - token = self.wait_token() + sleeper = self.sleepers.make() while True: info = self.raw_api(action, **kwargs) if not info: info = {} - if self.handle_api_result(info, token=token): + if self.handle_api_result(info, sleeper=sleeper): return info - def handle_api_result(self, info, kwargs=None, token=None): - if token is None: - token = self.wait_token() + def handle_api_result(self, info, kwargs=None, sleeper=None): + if sleeper is None: + sleeper = self.sleepers.make() try: userinfo = info['query']['userinfo'] @@ -217,7 +205,7 @@ class Site(object): self.logged_in = 'anon' not in userinfo if 'error' in info: if info['error']['code'] in (u'internal_api_error_DBConnectionError', u'internal_api_error_DBQueryError'): - self.wait(token) + sleeper.sleep() return False if '*' in info['error']: raise errors.APIError(info['error']['code'], @@ -258,7 +246,7 @@ class Site(object): headers = {} if self.compress and gzip: headers['Accept-Encoding'] = 'gzip' - token = self.wait_token((script, data)) + sleeper = self.sleepers.make((script, data)) while True: scheme = 'http' # Should we move to 'https' as default? host = self.host @@ -272,7 +260,7 @@ class Site(object): if stream.headers.get('x-database-lag'): wait_time = int(stream.headers.get('retry-after')) log.warn('Database lag exceeds max lag. Waiting for %d seconds', wait_time) - self.wait(token, wait_time) + sleeper.sleep(wait_time) elif stream.status_code == 200: return stream.text elif stream.status_code < 500 or stream.status_code > 599: @@ -281,7 +269,7 @@ class Site(object): if not retry_on_error: stream.raise_for_status() log.warn('Received %s response: %s. Retrying in a moment.', stream.status_code, stream.text) - self.wait(token) + sleeper.sleep() except requests.exceptions.ConnectionError: # In the event of a network problem (e.g. DNS failure, refused connection, etc), @@ -289,7 +277,7 @@ class Site(object): if not retry_on_error: raise log.warn('Connection error. Retrying in a moment.') - self.wait(token) + sleeper.sleep() def raw_api(self, action, *args, **kwargs): """Sends a call to the API.""" @@ -316,25 +304,6 @@ class Site(object): data = self._query_string(*args, **kwargs) return self.raw_call('index', data) - def wait_token(self, args=None): - token = WaitToken() - self.wait_tokens[token] = (0, args) - return token - - def wait(self, token, min_wait=0): - retry, args = self.wait_tokens[token] - self.wait_tokens[token] = (retry + 1, args) - if retry > self.max_retries and self.max_retries != -1: - raise errors.MaximumRetriesExceeded(self, token, args) - self.wait_callback(self, token, retry, args) - - timeout = self.retry_timeout * retry - if timeout < min_wait: - timeout = min_wait - log.debug('Sleeping for %d seconds', timeout) - time.sleep(timeout) - return self.wait_tokens[token] - def require(self, major, minor, revision=None, raise_error=True): if self.version is None: if raise_error is None: @@ -399,7 +368,7 @@ class Site(object): self.conn.cookies[self.host].update(cookies) if self.credentials: - wait_token = self.wait_token() + sleeper = self.sleepers.make() kwargs = { 'lgname': self.credentials[0], 'lgpassword': self.credentials[1] @@ -413,7 +382,7 @@ class Site(object): elif login['login']['result'] == 'NeedToken': kwargs['lgtoken'] = login['login']['token'] elif login['login']['result'] == 'Throttled': - self.wait(wait_token, login['login'].get('wait', 5)) + sleeper.sleep(int(login['login'].get('wait', 5))) else: raise errors.LoginError(self, login['login']) @@ -541,13 +510,13 @@ class Site(object): files = {'file': file} - wait_token = self.wait_token() + sleeper = self.sleepers.make() while True: data = self.raw_call('api', postdata, files) info = json.loads(data) if not info: info = {} - if self.handle_api_result(info, kwargs=predata, token=wait_token): + if self.handle_api_result(info, kwargs=predata, sleeper=sleeper): return info.get('upload', {}) def parse(self, text=None, title=None, page=None): diff --git a/mwclient/sleep.py b/mwclient/sleep.py new file mode 100644 index 0000000..353f558 --- /dev/null +++ b/mwclient/sleep.py @@ -0,0 +1,50 @@ +import random +import time +import logging +from mwclient.errors import MaximumRetriesExceeded + +log = logging.getLogger(__name__) + + +class Sleepers(object): + + def __init__(self, max_retries, retry_timeout, callback=lambda *x: None): + self.max_retries = max_retries + self.retry_timeout = retry_timeout + self.callback = callback + + def make(self, args=None): + return Sleeper(args, self.max_retries, self.retry_timeout, self.callback) + + +class Sleeper(object): + """ + For any given operation, a `Sleeper` object keeps count of the number of + retries. For each retry, the sleep time increases until the max number of + retries is reached and a `MaximumRetriesExceeded` is raised. The sleeper + object should be discarded once the operation is successful. + """ + + def __init__(self, args, max_retries, retry_timeout, callback): + self.args = args + self.retries = 0 + self.max_retries = max_retries + self.retry_timeout = retry_timeout + self.callback = callback + + def sleep(self, min_time=0): + """ + Sleep a minimum of `min_time` seconds. + The actual sleeping time will increase with the number of retries. + """ + self.retries += 1 + if self.retries > self.max_retries: + raise MaximumRetriesExceeded(self, self.args) + + self.callback(self, self.retries, self.args) + + timeout = self.retry_timeout * (self.retries - 1) + if timeout < min_time: + timeout = min_time + log.debug('Sleeping for %d seconds', timeout) + time.sleep(timeout) diff --git a/tests/test_sleep.py b/tests/test_sleep.py new file mode 100644 index 0000000..d41d831 --- /dev/null +++ b/tests/test_sleep.py @@ -0,0 +1,58 @@ +# encoding=utf-8 +from __future__ import print_function +import unittest +import time +import mock +import pytest +from mwclient.sleep import Sleepers +from mwclient.sleep import Sleeper +from mwclient.errors import MaximumRetriesExceeded + +if __name__ == "__main__": + print() + print("Note: Running in stand-alone mode. Consult the README") + print(" (section 'Contributing') for advice on running tests.") + print() + + +class TestSleepers(unittest.TestCase): + + def setUp(self): + self.sleep = mock.patch('time.sleep').start() + self.max_retries = 10 + self.sleepers = Sleepers(self.max_retries, 30) + + def tearDown(self): + mock.patch.stopall() + + def test_make(self): + sleeper = self.sleepers.make() + assert type(sleeper) == Sleeper + assert sleeper.retries == 0 + + def test_sleep(self): + sleeper = self.sleepers.make() + sleeper.sleep() + sleeper.sleep() + self.sleep.assert_has_calls([mock.call(0), mock.call(30)]) + + def test_min_time(self): + sleeper = self.sleepers.make() + sleeper.sleep(5) + self.sleep.assert_has_calls([mock.call(5)]) + + def test_retries_count(self): + sleeper = self.sleepers.make() + sleeper.sleep() + sleeper.sleep() + assert sleeper.retries == 2 + + def test_max_retries(self): + sleeper = self.sleepers.make() + for x in range(self.max_retries): + sleeper.sleep() + with pytest.raises(MaximumRetriesExceeded): + sleeper.sleep() + +if __name__ == '__main__': + unittest.main() -- GitLab