From 73d2966e8c10eb6c083a2abacc53f3297d16376e Mon Sep 17 00:00:00 2001
From: jvoisin <julien.voisin@dustri.org>
Date: Wed, 27 Feb 2019 23:04:38 +0100
Subject: [PATCH] Improve epub support

---
 libmat2/epub.py               | 46 +++++++++++++++---
 libmat2/web.py                | 87 ++++++++++++++++++++++++++---------
 tests/test_corrupted_files.py |  7 ++-
 tests/test_libmat2.py         |  6 ++-
 4 files changed, 114 insertions(+), 32 deletions(-)

diff --git a/libmat2/epub.py b/libmat2/epub.py
index 09b7937..d385465 100644
--- a/libmat2/epub.py
+++ b/libmat2/epub.py
@@ -1,11 +1,13 @@
 import logging
 import re
+import uuid
 import xml.etree.ElementTree as ET  # type: ignore
 
 from . import archive, office
 
 class EPUBParser(archive.ArchiveBasedAbstractParser):
     mimetypes = {'application/epub+zip', }
+    metadata_namespace = '{http://purl.org/dc/elements/1.1/}'
 
     def __init__(self, filename):
         super().__init__(filename)
@@ -14,6 +16,7 @@ class EPUBParser(archive.ArchiveBasedAbstractParser):
             'mimetype',
             'OEBPS/content.opf',
             }))
+        self.uniqid = uuid.uuid4()
 
     def _specific_get_meta(self, full_path, file_path):
         if file_path != 'OEBPS/content.opf':
@@ -25,23 +28,52 @@ class EPUBParser(archive.ArchiveBasedAbstractParser):
                                      f.read(), re.I|re.M)
                 return {k:v for (k, v) in results}
             except (TypeError, UnicodeDecodeError):
-                # We didn't manage to parse the xml file
                 return {file_path: 'harmful content', }
 
     def _specific_cleanup(self, full_path: str):
-        if not full_path.endswith('OEBPS/content.opf'):
-            return True
+        if full_path.endswith('OEBPS/content.opf'):
+            return self.__handle_contentopf(full_path)
+        elif full_path.endswith('OEBPS/toc.ncx'):
+            return self.__handle_tocncx(full_path)
+        return True
+
+    def __handle_tocncx(self, full_path: str):
+        try:
+            tree, namespace = office._parse_xml(full_path)
+        except ET.ParseError:  # pragma: nocover
+            logging.error("Unable to parse %s in %s.", full_path, self.filename)
+            return False
+
+        for item in tree.iterfind('.//', namespace):  # pragma: nocover
+            if item.tag.strip().lower().endswith('head'):
+                item.clear()
+                ET.SubElement(item, 'meta', attrib={'name': '', 'content': ''})
+                break
+        tree.write(full_path, xml_declaration=True, encoding='utf-8',
+                   short_empty_elements=False)
+        return True
 
+    def __handle_contentopf(self, full_path: str):
         try:
             tree, namespace = office._parse_xml(full_path)
         except ET.ParseError:
             logging.error("Unable to parse %s in %s.", full_path, self.filename)
             return False
-        parent_map = {c:p for p in tree.iter() for c in p}
 
-        for item in tree.iterfind('.//', namespace):
+        for item in tree.iterfind('.//', namespace):  # pragma: nocover
             if item.tag.strip().lower().endswith('metadata'):
-                parent_map[item].remove(item)
+                item.clear()
+
+                # item with mandatory content
+                uniqid = ET.Element(self.metadata_namespace + 'identifier')
+                uniqid.text = str(self.uniqid)
+                uniqid.set('id', 'id')
+                item.append(uniqid)
+
+                # items without mandatory content
+                for name in {'language', 'title'}:
+                    uniqid = ET.Element(self.metadata_namespace + name)
+                    item.append(uniqid)
                 break  # there is only a single <metadata> block
-        tree.write(full_path, xml_declaration=True)
+        tree.write(full_path, xml_declaration=True, encoding='utf-8')
         return True
diff --git a/libmat2/web.py b/libmat2/web.py
index c11b47d..067f5f9 100644
--- a/libmat2/web.py
+++ b/libmat2/web.py
@@ -1,10 +1,13 @@
-from html import parser
-from typing import Dict, Any, List, Tuple
+from html import parser, escape
+from typing import Dict, Any, List, Tuple, Set
 import re
 import string
 
 from . import abstract
 
+assert Set
+
+# pylint: disable=too-many-instance-attributes
 
 class CSSParser(abstract.AbstractParser):
     """There is no such things as metadata in CSS files,
@@ -33,11 +36,16 @@ class CSSParser(abstract.AbstractParser):
         return metadata
 
 
-class HTMLParser(abstract.AbstractParser):
-    mimetypes = {'text/html', 'application/x-dtbncx+xml', }
+class AbstractHTMLParser(abstract.AbstractParser):
+    tags_blacklist = set()  # type: Set[str]
+    # In some html/xml based formats some tags are mandatory,
+    # so we're keeping them, but are discaring their contents
+    tags_required_blacklist = set()  # type: Set[str]
+
     def __init__(self, filename):
         super().__init__(filename)
-        self.__parser = _HTMLParser(self.filename)
+        self.__parser = _HTMLParser(self.filename, self.tags_blacklist,
+                                    self.tags_required_blacklist)
         with open(filename, encoding='utf-8') as f:
             self.__parser.feed(f.read())
         self.__parser.close()
@@ -49,29 +57,50 @@ class HTMLParser(abstract.AbstractParser):
         return self.__parser.remove_all(self.output_filename)
 
 
+class HTMLParser(AbstractHTMLParser):
+    mimetypes = {'text/html', }
+    tags_blacklist = {'meta', }
+    tags_required_blacklist = {'title', }
+
+
+class DTBNCXParser(AbstractHTMLParser):
+    mimetypes = {'application/x-dtbncx+xml', }
+    tags_required_blacklist = {'title', 'doctitle', 'meta'}
+
+
 class _HTMLParser(parser.HTMLParser):
     """Python doesn't have a validating html parser in its stdlib, so
     we're using an internal queue to track all the opening/closing tags,
     and hoping for the best.
     """
-    tag_blacklist = {'doctitle', 'meta', 'title'}  # everything is lowercase
-    def __init__(self, filename):
+    def __init__(self, filename, blacklisted_tags, required_blacklisted_tags):
         super().__init__()
         self.filename = filename
         self.__textrepr = ''
         self.__meta = {}
-        self.__validation_queue = []
-        # We're using a counter instead of a boolean to handle nested tags
+        self.__validation_queue = []  # type: List[str]
+        # We're using counters instead of booleans, to handle nested tags
+        self.__in_dangerous_but_required_tag = 0
         self.__in_dangerous_tag = 0
 
+        if required_blacklisted_tags & blacklisted_tags:  # pragma: nocover
+            raise ValueError("There is an overlap between %s and %s" % (
+                required_blacklisted_tags, blacklisted_tags))
+        self.tag_required_blacklist = required_blacklisted_tags
+        self.tag_blacklist = blacklisted_tags
+
     def handle_starttag(self, tag: str, attrs: List[Tuple[str, str]]):
-        self.__validation_queue.append(tag)
+        original_tag = self.get_starttag_text()
+        self.__validation_queue.append(original_tag)
+
+        if tag in self.tag_required_blacklist:
+            self.__in_dangerous_but_required_tag += 1
         if tag in self.tag_blacklist:
             self.__in_dangerous_tag += 1
-            return
 
         if self.__in_dangerous_tag == 0:
-            self.__textrepr += self.get_starttag_text()
+            if self.__in_dangerous_but_required_tag <= 1:
+                self.__textrepr += original_tag
 
     def handle_endtag(self, tag: str):
         if not self.__validation_queue:
@@ -79,29 +108,43 @@ class _HTMLParser(parser.HTMLParser):
                              "opening one in %s." % (tag, self.filename))
 
         previous_tag = self.__validation_queue.pop()
-        if tag != previous_tag:
+        previous_tag = previous_tag[1:-1]  # remove < and >
+        previous_tag = previous_tag.split(' ')[0]  # remove attributes
+        if tag != previous_tag.lower():
             raise ValueError("The closing tag %s doesn't match the previous "
                              "tag %s in %s" %
                              (tag, previous_tag, self.filename))
-        elif tag in self.tag_blacklist:
-            self.__in_dangerous_tag -= 1
-            return
 
         if self.__in_dangerous_tag == 0:
-            # There is no `get_endtag_text()` method :/
-            self.__textrepr += '</' + tag + '>\n'
+            if self.__in_dangerous_but_required_tag <= 1:
+                # There is no `get_endtag_text()` method :/
+                self.__textrepr += '</' + previous_tag + '>'
+
+        if tag in self.tag_required_blacklist:
+            self.__in_dangerous_but_required_tag -= 1
+        elif tag in self.tag_blacklist:
+            self.__in_dangerous_tag -= 1
 
     def handle_data(self, data: str):
-        if self.__in_dangerous_tag == 0 and data.strip():
-            self.__textrepr += data
+        if self.__in_dangerous_but_required_tag == 0:
+            if self.__in_dangerous_tag == 0:
+                if data.strip():
+                    self.__textrepr += escape(data)
 
     def handle_startendtag(self, tag: str, attrs: List[Tuple[str, str]]):
-        if tag in self.tag_blacklist:
+        if tag in self.tag_required_blacklist | self.tag_blacklist:
             meta = {k:v for k, v in attrs}
             name = meta.get('name', 'harmful metadata')
             content = meta.get('content', 'harmful data')
             self.__meta[name] = content
-        else:
+
+            if self.__in_dangerous_tag != 0:
+                return
+            elif tag in self.tag_required_blacklist:
+                self.__textrepr += '<' + tag + ' />'
+            return
+
+        if self.__in_dangerous_but_required_tag == 0:
             if self.__in_dangerous_tag == 0:
                 self.__textrepr += self.get_starttag_text()
 
diff --git a/tests/test_corrupted_files.py b/tests/test_corrupted_files.py
index 53c856a..b2cec00 100644
--- a/tests/test_corrupted_files.py
+++ b/tests/test_corrupted_files.py
@@ -253,13 +253,13 @@ class TestCorruptedFiles(unittest.TestCase):
         os.remove('./tests/data/clean.cleaned.html')
 
         with open('./tests/data/clean.html', 'w') as f:
-            f.write('</close>')
+            f.write('</meta>')
         with self.assertRaises(ValueError):
             web.HTMLParser('./tests/data/clean.html')
         os.remove('./tests/data/clean.html')
 
         with open('./tests/data/clean.html', 'w') as f:
-            f.write('<notclosed>')
+            f.write('<meta><a>test</a><set/></meta><title></title><meta>')
         p = web.HTMLParser('./tests/data/clean.html')
         with self.assertRaises(ValueError):
             p.get_meta()
@@ -269,6 +269,9 @@ class TestCorruptedFiles(unittest.TestCase):
         os.remove('./tests/data/clean.html')
 
         with open('./tests/data/clean.html', 'w') as f:
+            f.write('<meta><meta/></meta>')
+            f.write('<title><title>pouet</title></title>')
+            f.write('<title><mysupertag/></title>')
             f.write('<doctitle><br/></doctitle><br/><notclosed>')
         p = web.HTMLParser('./tests/data/clean.html')
         with self.assertRaises(ValueError):
diff --git a/tests/test_libmat2.py b/tests/test_libmat2.py
index 249c56d..f4b1890 100644
--- a/tests/test_libmat2.py
+++ b/tests/test_libmat2.py
@@ -3,6 +3,7 @@
 import unittest
 import shutil
 import os
+import re
 import zipfile
 
 from libmat2 import pdf, images, audio, office, parser_factory, torrent, harmless
@@ -644,7 +645,10 @@ class TestCleaning(unittest.TestCase):
         self.assertTrue(ret)
 
         p = epub.EPUBParser('./tests/data/clean.cleaned.epub')
-        self.assertEqual(p.get_meta(), {})
+        meta = p.get_meta()
+        res = re.match(meta['OEBPS/content.opf']['metadata'], '^<dc:identifier>[0-9a-f-]+</dc:identifier><dc:title /><dc:language />$')
+        self.assertNotEqual(res, False)
+
         self.assertTrue(p.remove_all())
 
         os.remove('./tests/data/clean.epub')
-- 
GitLab