From f3cae5a7e8ba2d925af341a58375f70a5829d7f8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Dan=20Michael=20O=2E=20Hegg=C3=B8?= <danmichaelo@gmail.com>
Date: Sun, 3 Jul 2016 17:06:38 +0200
Subject: [PATCH] [#125] Send assert=true when force_login=True

- When `Site.force_login` is True we can send `assert=true`
  with `edit` actions to have the api reject our edits if
  we got logged out for some reason.
- Adding a new exception `AssertUserFailedError` as a subclass
  of `LoginError`.
---
 mwclient/errors.py | 14 ++++++++++++
 mwclient/page.py   | 14 ++++++------
 tests/test_page.py | 53 +++++++++++++++++++++++++++++++++++++++++-----
 3 files changed, 68 insertions(+), 13 deletions(-)

diff --git a/mwclient/errors.py b/mwclient/errors.py
index 3dcc83a..4971afb 100644
--- a/mwclient/errors.py
+++ b/mwclient/errors.py
@@ -56,6 +56,20 @@ class OAuthAuthorizationError(LoginError):
         return self.info
 
 
+class AssertUserFailedError(LoginError):
+
+    def __init__(self):
+        self.message = 'By default, mwclient protects you from ' + \
+                       'accidentally editing without being logged in. If you ' + \
+                       'actually want to edit without logging in, you can set ' + \
+                       'force_login on the Site object to False.'
+
+        LoginError.__init__(self)
+
+    def __str__(self):
+        return self.message
+
+
 class EmailError(MwClientError):
     pass
 
diff --git a/mwclient/page.py b/mwclient/page.py
index 54c763c..ae495ba 100644
--- a/mwclient/page.py
+++ b/mwclient/page.py
@@ -173,14 +173,7 @@ class Page(object):
         """Update the text of a section or the whole page by performing an edit operation.
         """
         if not self.site.logged_in and self.site.force_login:
-            # Should we really check for this?
-            raise mwclient.errors.LoginError(
-                self.site,
-                'By default, mwclient protects you from accidentally editing '
-                'without being logged in. '
-                'If you actually want to edit without logging in, '
-                'you can set force_login on the Site object to False.'
-            )
+            raise mwclient.errors.AssertUserFailedError()
         if self.site.blocked:
             raise mwclient.errors.UserBlocked(self.site.blocked)
         if not self.can('edit'):
@@ -205,6 +198,9 @@ class Page(object):
 
         data.update(kwargs)
 
+        if self.site.force_login:
+            data['assert'] = 'user'
+
         def do_edit():
             result = self.site.post('edit', title=self.name, text=text,
                                     summary=summary, token=self.get_token('edit'),
@@ -240,6 +236,8 @@ class Page(object):
                         'noimageredirect-anon', 'noimageredirect', 'noedit-anon',
                         'noedit'}:
             raise mwclient.errors.ProtectedPageError(self, e.code, e.info)
+        elif e.code == 'assertuserfailed':
+            raise mwclient.errors.AssertUserFailedError()
         else:
             raise
 
diff --git a/tests/test_page.py b/tests/test_page.py
index d7babe5..45a4edd 100644
--- a/tests/test_page.py
+++ b/tests/test_page.py
@@ -10,6 +10,7 @@ import mock
 import mwclient
 from mwclient.page import Page
 from mwclient.client import Site
+from mwclient.errors import APIError, AssertUserFailedError
 
 try:
     import json
@@ -215,8 +216,11 @@ class TestPageApiArgs(unittest.TestCase):
             'ns': 0, 'pageid': 2, 'revisions': [{'*': 'Hello world', 'timestamp': '2014-08-29T22:25:15Z'}], 'title': title
         }}}}
 
-    def get_last_api_call_args(self):
-        args, kwargs = self.site.get.call_args
+    def get_last_api_call_args(self, http_method='POST'):
+        if http_method == 'GET':
+            args, kwargs = self.site.get.call_args
+        else:
+            args, kwargs = self.site.post.call_args
         action = args[0]
         args = args[1:]
         kwargs.update(args)
@@ -228,7 +232,7 @@ class TestPageApiArgs(unittest.TestCase):
     def test_get_page_text(self):
         # Check that page.text() works, and that a correct API call is made
         text = self.page.text()
-        args = self.get_last_api_call_args()
+        args = self.get_last_api_call_args(http_method='GET')
 
         assert text == self.page_text
         assert args == {
@@ -253,17 +257,56 @@ class TestPageApiArgs(unittest.TestCase):
     def test_get_section_text(self):
         # Check that the 'rvsection' parameter is sent to the API
         text = self.page.text(section=0)
-        args = self.get_last_api_call_args()
+        args = self.get_last_api_call_args(http_method='GET')
 
         assert args['rvsection'] == '0'
 
     def test_get_text_expanded(self):
         # Check that the 'rvexpandtemplates' parameter is sent to the API
         text = self.page.text(expandtemplates=True)
-        args = self.get_last_api_call_args()
+        args = self.get_last_api_call_args(http_method='GET')
 
         assert args['rvexpandtemplates'] == '1'
 
+    def test_assertuser_true(self):
+        # Check that assert=user is sent when force_login=True
+        self.site.blocked = False
+        self.site.rights = ['read', 'edit']
+        self.site.logged_in = True
+        self.site.force_login = True
+
+        self.site.api.return_value = {
+            'edit': {'result': 'Ok'}
+        }
+        self.page.save('Some text')
+        args = self.get_last_api_call_args()
+
+        assert args['assert'] == 'user'
+
+    def test_assertuser_false(self):
+        # Check that assert=user is not sent when force_login=False
+        self.site.blocked = False
+        self.site.rights = ['read', 'edit']
+        self.site.logged_in = False
+        self.site.force_login = False
+
+        self.site.api.return_value = {
+            'edit': {'result': 'Ok'}
+        }
+        self.page.save('Some text')
+        args = self.get_last_api_call_args()
+
+        assert 'assert' not in args
+
+    def test_handle_edit_error_assertuserfailed(self):
+        # Check that AssertUserFailedError is triggered
+        api_error = APIError('assertuserfailed',
+                             'Assertion that the user is logged in failed',
+                             'See https://en.wikipedia.org/w/api.php for API usage')
+
+        with pytest.raises(AssertUserFailedError):
+            self.page.handle_edit_error(api_error, 'n/a')
+
 
 if __name__ == '__main__':
     unittest.main()
-- 
GitLab