From f3cef319b90a5a82ca879380c213651d74390a72 Mon Sep 17 00:00:00 2001
From: Daniel Kahn Gillmor <dkg@fifthhorseman.net>
Date: Wed, 5 Sep 2018 18:49:35 -0400
Subject: [PATCH] Unknown Members: make policy use an Enum

Closes #60

Note: this changeset also ensures that clean.cleaned.docx is removed
up after the pytest is over.
---
 libmat2/__init__.py  |  6 ++++++
 libmat2/office.py    | 14 +++++---------
 mat2                 | 17 ++++++++---------
 tests/test_policy.py | 11 ++++++-----
 4 files changed, 25 insertions(+), 23 deletions(-)

diff --git a/libmat2/__init__.py b/libmat2/__init__.py
index bf4e813..8a5b064 100644
--- a/libmat2/__init__.py
+++ b/libmat2/__init__.py
@@ -2,6 +2,7 @@
 
 import os
 import collections
+from enum import Enum
 import importlib
 from typing import Dict, Optional
 
@@ -62,3 +63,8 @@ def check_dependencies() -> dict:
             ret[value] = False  # pragma: no cover
 
     return ret
+
+class UnknownMemberPolicy(Enum):
+    ABORT = 'abort'
+    OMIT = 'omit'
+    KEEP = 'keep'
diff --git a/libmat2/office.py b/libmat2/office.py
index 29100df..60c5478 100644
--- a/libmat2/office.py
+++ b/libmat2/office.py
@@ -9,7 +9,7 @@ from typing import Dict, Set, Pattern
 
 import xml.etree.ElementTree as ET  # type: ignore
 
-from . import abstract, parser_factory
+from . import abstract, parser_factory, UnknownMemberPolicy
 
 # Make pyflakes happy
 assert Set
@@ -37,8 +37,8 @@ class ArchiveBasedAbstractParser(abstract.AbstractParser):
     files_to_omit = set() # type: Set[Pattern]
 
     # what should the parser do if it encounters an unknown file in
-    # the archive?  valid policies are 'abort', 'omit', 'keep'
-    unknown_member_policy = 'abort' # type: str
+    # the archive?
+    unknown_member_policy = UnknownMemberPolicy.ABORT # type: UnknownMemberPolicy
 
     def __init__(self, filename):
         super().__init__(filename)
@@ -81,10 +81,6 @@ class ArchiveBasedAbstractParser(abstract.AbstractParser):
     def remove_all(self) -> bool:
         # pylint: disable=too-many-branches
 
-        if self.unknown_member_policy not in ['omit', 'keep', 'abort']:
-            logging.error("The policy %s is invalid.", self.unknown_member_policy)
-            raise ValueError
-
         with zipfile.ZipFile(self.filename) as zin,\
              zipfile.ZipFile(self.output_filename, 'w') as zout:
 
@@ -113,11 +109,11 @@ class ArchiveBasedAbstractParser(abstract.AbstractParser):
                     # supported files that we want to clean then add
                     tmp_parser, mtype = parser_factory.get_parser(full_path)  # type: ignore
                     if not tmp_parser:
-                        if self.unknown_member_policy == 'omit':
+                        if self.unknown_member_policy == UnknownMemberPolicy.OMIT:
                             logging.warning("In file %s, omitting unknown element %s (format: %s)",
                                             self.filename, item.filename, mtype)
                             continue
-                        elif self.unknown_member_policy == 'keep':
+                        elif self.unknown_member_policy == UnknownMemberPolicy.KEEP:
                             logging.warning("In file %s, keeping unknown element %s (format: %s)",
                                             self.filename, item.filename, mtype)
                         else:
diff --git a/mat2 b/mat2
index 2a8ef46..0aba8d1 100755
--- a/mat2
+++ b/mat2
@@ -10,7 +10,8 @@ import multiprocessing
 import logging
 
 try:
-    from libmat2 import parser_factory, UNSUPPORTED_EXTENSIONS, check_dependencies
+    from libmat2 import (parser_factory, UNSUPPORTED_EXTENSIONS, check_dependencies,
+                         UnknownMemberPolicy)
 except ValueError as e:
     print(e)
     sys.exit(1)
@@ -42,8 +43,8 @@ def create_arg_parser():
     parser.add_argument('-V', '--verbose', action='store_true',
                         help='show more verbose status information')
     parser.add_argument('--unknown-members', metavar='policy', default='abort',
-                        help='how to handle unknown members of archive-style files ' +
-                        '(policy should be abort, omit, or keep)')
+                        help='how to handle unknown members of archive-style files (policy should' +
+                        ' be one of: ' + ', '.join([x.value for x in UnknownMemberPolicy]) + ')')
 
 
     info = parser.add_mutually_exclusive_group()
@@ -70,7 +71,7 @@ def show_meta(filename: str):
         except UnicodeEncodeError:
             print("  %s: harmful content" % k)
 
-def clean_meta(params: Tuple[str, bool, str]) -> bool:
+def clean_meta(params: Tuple[str, bool, UnknownMemberPolicy]) -> bool:
     filename, is_lightweight, unknown_member_policy = params
     if not __check_file(filename, os.R_OK|os.W_OK):
         return False
@@ -137,15 +138,13 @@ def main():
         return 0
 
     else:
-        if args.unknown_members == 'keep':
+        unknown_member_policy = UnknownMemberPolicy(args.unknown_members)
+        if unknown_member_policy == UnknownMemberPolicy.KEEP:
             logging.warning('Keeping unknown member files may leak metadata in the resulting file!')
-        elif args.unknown_members not in ['omit', 'abort']:
-            logging.warning('Undefined policy for handling unknown member files: "%s"',
-                            args.unknown_members)
         p = multiprocessing.Pool()
         mode = (args.lightweight is True)
         l = zip(__get_files_recursively(args.files), itertools.repeat(mode),
-                itertools.repeat(args.unknown_members))
+                itertools.repeat(unknown_member_policy))
 
         ret = list(p.imap_unordered(clean_meta, list(l)))
         return 0 if all(ret) else -1
diff --git a/tests/test_policy.py b/tests/test_policy.py
index 39282b1..5a8447b 100644
--- a/tests/test_policy.py
+++ b/tests/test_policy.py
@@ -4,28 +4,29 @@ import unittest
 import shutil
 import os
 
-from libmat2 import office
+from libmat2 import office, UnknownMemberPolicy
 
 class TestPolicy(unittest.TestCase):
     def test_policy_omit(self):
         shutil.copy('./tests/data/embedded.docx', './tests/data/clean.docx')
         p = office.MSOfficeParser('./tests/data/clean.docx')
-        p.unknown_member_policy = 'omit'
+        p.unknown_member_policy = UnknownMemberPolicy.OMIT
         self.assertTrue(p.remove_all())
         os.remove('./tests/data/clean.docx')
+        os.remove('./tests/data/clean.cleaned.docx')
 
     def test_policy_keep(self):
         shutil.copy('./tests/data/embedded.docx', './tests/data/clean.docx')
         p = office.MSOfficeParser('./tests/data/clean.docx')
-        p.unknown_member_policy = 'keep'
+        p.unknown_member_policy = UnknownMemberPolicy.KEEP
         self.assertTrue(p.remove_all())
         os.remove('./tests/data/clean.docx')
+        os.remove('./tests/data/clean.cleaned.docx')
 
     def test_policy_unknown(self):
         shutil.copy('./tests/data/embedded.docx', './tests/data/clean.docx')
         p = office.MSOfficeParser('./tests/data/clean.docx')
-        p.unknown_member_policy = 'unknown_policy_name_totally_invalid'
         with self.assertRaises(ValueError):
-            p.remove_all()
+            p.unknown_member_policy = UnknownMemberPolicy('unknown_policy_name_totally_invalid')
         os.remove('./tests/data/clean.docx')
 
-- 
GitLab