From 5701b57ed14cab2f34af2494faa98c81c95654fa Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Dan=20Michael=20O=2E=20Hegg=C3=B8?= <danmichaelo@gmail.com>
Date: Sat, 29 Apr 2017 00:30:34 +0200
Subject: [PATCH] [#149] Fix login tokens for read protected wikis

On read protected wikis, we must make sure that we don't include extra
parameters when requesting the login token, or we will get
`readapideniederror`.

- Remove extraneous `continue` parameter from non-query calls
- Remove `userinfo` from `meta=tokens` calls
---
 mwclient/client.py   |  27 +++++++---
 tests/test_client.py | 116 +++++++++++++++++++++++--------------------
 2 files changed, 83 insertions(+), 60 deletions(-)

diff --git a/mwclient/client.py b/mwclient/client.py
index e0816b1..9831aae 100644
--- a/mwclient/client.py
+++ b/mwclient/client.py
@@ -261,7 +261,7 @@ class Site(object):
         """
         kwargs.update(args)
 
-        if 'continue' not in kwargs:
+        if action == 'query' and 'continue' not in kwargs:
             kwargs['continue'] = ''
         if action == 'query':
             if 'meta' in kwargs:
@@ -482,12 +482,20 @@ class Site(object):
                 'lgname': self.credentials[0],
                 'lgpassword': self.credentials[1]
             }
-            if self.version[:2] >= (1, 27):
-                kwargs['lgtoken'] = self.get_token('login')
             if self.credentials[2]:
                 kwargs['lgdomain'] = self.credentials[2]
+
+            # Try to login using the scheme for MW 1.27+. If the wiki is read protected,
+            # it is not possible to get the wiki version upfront using the API, so we just
+            # have to try. If the attempt fails, we try the old method.
+            try:
+                kwargs['lgtoken'] = self.get_token('login')
+            except KeyError:
+                log.debug('Failed to get login token, MediaWiki is older than 1.27.')
+
             while True:
                 login = self.post('login', **kwargs)
+
                 if login['login']['result'] == 'Success':
                     break
                 elif login['login']['result'] == 'NeedToken':
@@ -501,7 +509,7 @@ class Site(object):
 
     def get_token(self, type, force=False, title=None):
 
-        if self.version[:2] >= (1, 24):
+        if self.version is None or self.version[:2] >= (1, 24):
             # The 'csrf' (cross-site request forgery) token introduced in 1.24 replaces
             # the majority of older tokens, like edittoken and movetoken.
             if type not in {'watch', 'patrol', 'rollback', 'userrights', 'login'}:
@@ -512,8 +520,15 @@ class Site(object):
 
         if self.tokens.get(type, '0') == '0' or force:
 
-            if self.version[:2] >= (1, 24):
-                info = self.post('query', meta='tokens', type=type)
+            if self.version is None or self.version[:2] >= (1, 24):
+                # We use raw_api() rather than api() because api() is adding "userinfo"
+                # to the query and this raises an readapideniederror if the wiki is read
+                # protected and we're trying to fetch a login token.
+                info = self.raw_api('query', 'GET', meta='tokens', type=type)
+
+                # Note that for read protected wikis, we don't know the version when
+                # fetching the login token. If it's < 1.27, the request below will
+                # raise a KeyError that we should catch.
                 self.tokens[type] = info['query']['tokens']['%stoken' % type]
 
             else:
diff --git a/tests/test_client.py b/tests/test_client.py
index caaa881..c976082 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -301,46 +301,12 @@ class TestClient(TestCase):
         assert repr(site) == '<Site object \'test.wikipedia.org/w/\'>'
 
 
-class TestClientApiMethods(TestCase):
-
-    def setUp(self):
-        self.api = mock.patch('mwclient.client.Site.api').start()
-        self.api.return_value = self.metaResponse()
-        self.site = mwclient.Site('test.wikipedia.org')
-
-    def tearDown(self):
-        mock.patch.stopall()
-
-    def test_revisions(self):
-
-        self.api.return_value = {
-            'query': {'pages': {'1': {
-                'pageid': 1,
-                'title': 'Test page',
-                'revisions': [{
-                    'revid': 689697696,
-                    'timestamp': '2015-11-08T21:52:46Z',
-                    'comment': 'Test comment 1'
-                }, {
-                    'revid': 689816909,
-                    'timestamp': '2015-11-09T16:09:28Z',
-                    'comment': 'Test comment 2'
-                }]
-            }}}}
-
-        revisions = [rev for rev in self.site.revisions([689697696, 689816909], prop='content')]
-
-        args, kwargs = self.api.call_args
-        assert kwargs.get('revids') == '689697696|689816909'
-        assert len(revisions) == 2
-        assert revisions[0]['pageid'] == 1
-        assert revisions[0]['pagetitle'] == 'Test page'
-        assert revisions[0]['revid'] == 689697696
-        assert revisions[0]['timestamp'] == time.strptime('2015-11-08T21:52:46Z', '%Y-%m-%dT%H:%M:%SZ')
-        assert revisions[1]['revid'] == 689816909
-
-    def test_login_flow_1(self):
+class TestLogin(TestCase):
 
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_old_login_flow(self, raw_api, site_init):
+        # The login flow used before MW 1.27 that starts with a action=login POST request
         login_token = 'abc+\\'
 
         def side_effect(*args, **kwargs):
@@ -349,49 +315,91 @@ class TestClientApiMethods(TestCase):
                 return {
                     'login': {'result': 'NeedToken', 'token': login_token}
                 }
-            else:
+            elif 'lgname' in kwargs:
                 assert kwargs['lgtoken'] == login_token
                 return {
                     'login': {'result': 'Success'}
                 }
 
-        self.api.side_effect = side_effect
+        raw_api.side_effect = side_effect
 
-        with mock.patch('mwclient.client.Site.site_init'):
-            self.site.login('myusername', 'mypassword')
+        site = mwclient.Site('test.wikipedia.org')
+        site.login('myusername', 'mypassword')
 
-        call_args = self.api.call_args_list
+        call_args = raw_api.call_args_list
 
         assert len(call_args) == 3
+        assert call_args[0] == mock.call('query', 'GET', meta='tokens', type='login')
         assert call_args[1] == mock.call('login', 'POST', lgname='myusername', lgpassword='mypassword')
         assert call_args[2] == mock.call('login', 'POST', lgname='myusername', lgpassword='mypassword', lgtoken=login_token)
 
-    def test_login_flow_2(self):
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_new_login_flow(self, raw_api, site_init):
+        # The login flow used from MW 1.27 that starts with a meta=tokens GET request
 
         login_token = 'abc+\\'
-        self.site.version = (1, 29, 0, '-wmf', 21)
 
         def side_effect(*args, **kwargs):
             if kwargs.get('meta') == 'tokens':
                 return {
                     'query': {'tokens': {'logintoken': login_token}}
                 }
-            else:
+            elif 'lgname' in kwargs:
                 assert kwargs['lgtoken'] == login_token
                 return {
                     'login': {'result': 'Success'}
                 }
 
-        self.api.side_effect = side_effect
+        raw_api.side_effect = side_effect
 
-        with mock.patch('mwclient.client.Site.site_init'):
-            self.site.login('myusername', 'mypassword')
+        site = mwclient.Site('test.wikipedia.org')
+        site.login('myusername', 'mypassword')
 
-        call_args = self.api.call_args_list
+        call_args = raw_api.call_args_list
 
-        assert len(call_args) == 3
-        assert call_args[1] == mock.call('query', 'POST', meta='tokens', type='login')
-        assert call_args[2] == mock.call('login', 'POST', lgname='myusername', lgpassword='mypassword', lgtoken=login_token)
+        assert len(call_args) == 2
+        assert call_args[0] == mock.call('query', 'GET', meta='tokens', type='login')
+        assert call_args[1] == mock.call('login', 'POST', lgname='myusername', lgpassword='mypassword', lgtoken=login_token)
+
+
+class TestClientApiMethods(TestCase):
+
+    def setUp(self):
+        self.api = mock.patch('mwclient.client.Site.api').start()
+        self.api.return_value = self.metaResponse()
+        self.site = mwclient.Site('test.wikipedia.org')
+
+    def tearDown(self):
+        mock.patch.stopall()
+
+    def test_revisions(self):
+
+        self.api.return_value = {
+            'query': {'pages': {'1': {
+                'pageid': 1,
+                'title': 'Test page',
+                'revisions': [{
+                    'revid': 689697696,
+                    'timestamp': '2015-11-08T21:52:46Z',
+                    'comment': 'Test comment 1'
+                }, {
+                    'revid': 689816909,
+                    'timestamp': '2015-11-09T16:09:28Z',
+                    'comment': 'Test comment 2'
+                }]
+            }}}}
+
+        revisions = [rev for rev in self.site.revisions([689697696, 689816909], prop='content')]
+
+        args, kwargs = self.api.call_args
+        assert kwargs.get('revids') == '689697696|689816909'
+        assert len(revisions) == 2
+        assert revisions[0]['pageid'] == 1
+        assert revisions[0]['pagetitle'] == 'Test page'
+        assert revisions[0]['revid'] == 689697696
+        assert revisions[0]['timestamp'] == time.strptime('2015-11-08T21:52:46Z', '%Y-%m-%dT%H:%M:%SZ')
+        assert revisions[1]['revid'] == 689816909
 
 
 class TestClientUploadArgs(TestCase):
-- 
GitLab