Skip to content
Snippets Groups Projects
Commit 6a1b0b31 authored by Julien (jvoisin) Voisin's avatar Julien (jvoisin) Voisin
Browse files

Add more typing and use mypy in the CI

parent 4ebf9754
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,14 @@ pyflakes: ...@@ -18,6 +18,14 @@ pyflakes:
- apt-get -qqy install --no-install-recommends pyflakes3 - apt-get -qqy install --no-install-recommends pyflakes3
- pyflakes3 ./libmat2 - pyflakes3 ./libmat2
mypy:
stage: linting
script:
- apt-get -qqy update
- apt-get -qqy install --no-install-recommends python3-pip
- pip3 install mypy
- mypy mat2 libmat2/*.py --ignore-missing-imports
tests: tests:
stage: test stage: test
script: script:
......
from typing import Dict
from . import abstract from . import abstract
class HarmlessParser(abstract.AbstractParser): class HarmlessParser(abstract.AbstractParser):
""" This is the parser for filetypes that do not contain metadata. """ """ This is the parser for filetypes that do not contain metadata. """
mimetypes = {'application/xml', 'text/plain'} mimetypes = {'application/xml', 'text/plain', 'application/rdf+xml'}
def __init__(self, filename: str) -> None: def __init__(self, filename: str) -> None:
super().__init__(filename) super().__init__(filename)
self.filename = filename self.filename = filename
self.output_filename = filename self.output_filename = filename
def get_meta(self): def get_meta(self) -> Dict[str, str]:
return dict() return dict()
def remove_all(self): def remove_all(self) -> bool:
return True return True
...@@ -4,11 +4,15 @@ import shutil ...@@ -4,11 +4,15 @@ import shutil
import tempfile import tempfile
import datetime import datetime
import zipfile import zipfile
from typing import Dict, Set
from . import abstract, parser_factory from . import abstract, parser_factory
assert Set # make pyflakes happy
class ArchiveBasedAbstractParser(abstract.AbstractParser): class ArchiveBasedAbstractParser(abstract.AbstractParser):
whitelist = set() # type: Set[str]
def _clean_zipinfo(self, zipinfo: zipfile.ZipInfo) -> zipfile.ZipInfo: def _clean_zipinfo(self, zipinfo: zipfile.ZipInfo) -> zipfile.ZipInfo:
zipinfo.compress_type = zipfile.ZIP_DEFLATED zipinfo.compress_type = zipfile.ZIP_DEFLATED
zipinfo.create_system = 3 # Linux zipinfo.create_system = 3 # Linux
...@@ -16,7 +20,7 @@ class ArchiveBasedAbstractParser(abstract.AbstractParser): ...@@ -16,7 +20,7 @@ class ArchiveBasedAbstractParser(abstract.AbstractParser):
zipinfo.date_time = (1980, 1, 1, 0, 0, 0) zipinfo.date_time = (1980, 1, 1, 0, 0, 0)
return zipinfo return zipinfo
def _get_zipinfo_meta(self, zipinfo: zipfile.ZipInfo) -> dict: def _get_zipinfo_meta(self, zipinfo: zipfile.ZipInfo) -> Dict[str, str]:
metadata = {} metadata = {}
if zipinfo.create_system == 3: if zipinfo.create_system == 3:
#metadata['create_system'] = 'Linux' #metadata['create_system'] = 'Linux'
...@@ -27,25 +31,31 @@ class ArchiveBasedAbstractParser(abstract.AbstractParser): ...@@ -27,25 +31,31 @@ class ArchiveBasedAbstractParser(abstract.AbstractParser):
metadata['create_system'] = 'Weird' metadata['create_system'] = 'Weird'
if zipinfo.comment: if zipinfo.comment:
metadata['comment'] = zipinfo.comment metadata['comment'] = zipinfo.comment # type: ignore
if zipinfo.date_time != (1980, 1, 1, 0, 0, 0): if zipinfo.date_time != (1980, 1, 1, 0, 0, 0):
metadata['date_time'] = datetime.datetime(*zipinfo.date_time) metadata['date_time'] =str(datetime.datetime(*zipinfo.date_time))
return metadata return metadata
def _clean_internal_file(self, item: zipfile.ZipInfo, temp_folder: str, def _clean_internal_file(self, item: zipfile.ZipInfo, temp_folder: str,
zin: zipfile.ZipFile, zout: zipfile.ZipFile): zin: zipfile.ZipFile, zout: zipfile.ZipFile):
output = ''
zin.extract(member=item, path=temp_folder) zin.extract(member=item, path=temp_folder)
tmp_parser, mtype = parser_factory.get_parser(os.path.join(temp_folder, item.filename)) if item.filename not in self.whitelist:
if not tmp_parser: full_path = os.path.join(temp_folder, item.filename)
print("%s's format (%s) isn't supported" % (item.filename, mtype)) tmp_parser, mtype = parser_factory.get_parser(full_path) # type: ignore
return if not tmp_parser:
tmp_parser.remove_all() print("%s's format (%s) isn't supported" % (item.filename, mtype))
zinfo = zipfile.ZipInfo(item.filename) return
tmp_parser.remove_all()
output = tmp_parser.output_filename
else:
output = os.path.join(temp_folder, item.filename)
zinfo = zipfile.ZipInfo(item.filename) # type: ignore
clean_zinfo = self._clean_zipinfo(zinfo) clean_zinfo = self._clean_zipinfo(zinfo)
with open(tmp_parser.output_filename, 'rb') as f: with open(output, 'rb') as f:
zout.writestr(clean_zinfo, f.read()) zout.writestr(clean_zinfo, f.read())
...@@ -72,7 +82,8 @@ class MSOfficeParser(ArchiveBasedAbstractParser): ...@@ -72,7 +82,8 @@ class MSOfficeParser(ArchiveBasedAbstractParser):
if not metadata: # better safe than sorry if not metadata: # better safe than sorry
metadata[item] = 'harmful content' metadata[item] = 'harmful content'
metadata = {**metadata, **self._get_zipinfo_meta(item)} for key, value in self._get_zipinfo_meta(item).items():
metadata[key] = value
zipin.close() zipin.close()
return metadata return metadata
...@@ -112,6 +123,8 @@ class LibreOfficeParser(ArchiveBasedAbstractParser): ...@@ -112,6 +123,8 @@ class LibreOfficeParser(ArchiveBasedAbstractParser):
'application/vnd.oasis.opendocument.formula', 'application/vnd.oasis.opendocument.formula',
'application/vnd.oasis.opendocument.image', 'application/vnd.oasis.opendocument.image',
} }
whitelist = {'mimetype', 'manifest.rdf'}
def get_meta(self): def get_meta(self):
""" """
...@@ -127,7 +140,8 @@ class LibreOfficeParser(ArchiveBasedAbstractParser): ...@@ -127,7 +140,8 @@ class LibreOfficeParser(ArchiveBasedAbstractParser):
metadata[key] = value metadata[key] = value
if not metadata: # better safe than sorry if not metadata: # better safe than sorry
metadata[item] = 'harmful content' metadata[item] = 'harmful content'
metadata = {**metadata, **self._get_zipinfo_meta(item)} for key, value in self._get_zipinfo_meta(item).items():
metadata[key] = value
zipin.close() zipin.close()
return metadata return metadata
......
...@@ -2,10 +2,12 @@ import glob ...@@ -2,10 +2,12 @@ import glob
import os import os
import mimetypes import mimetypes
import importlib import importlib
from typing import TypeVar, List from typing import TypeVar, List, Tuple, Optional
from . import abstract, unsupported_extensions from . import abstract, unsupported_extensions
assert Tuple # make pyflakes happy
T = TypeVar('T', bound='abstract.AbstractParser') T = TypeVar('T', bound='abstract.AbstractParser')
def __load_all_parsers(): def __load_all_parsers():
...@@ -28,14 +30,14 @@ def _get_parsers() -> List[T]: ...@@ -28,14 +30,14 @@ def _get_parsers() -> List[T]:
return __get_parsers(abstract.AbstractParser) return __get_parsers(abstract.AbstractParser)
def get_parser(filename: str) -> (T, str): def get_parser(filename: str) -> Tuple[Optional[T], Optional[str]]:
mtype, _ = mimetypes.guess_type(filename) mtype, _ = mimetypes.guess_type(filename)
_, extension = os.path.splitext(filename) _, extension = os.path.splitext(filename)
if extension in unsupported_extensions: if extension in unsupported_extensions:
return None, mtype return None, mtype
for c in _get_parsers(): for c in _get_parsers(): # type: ignore
if mtype in c.mimetypes: if mtype in c.mimetypes:
try: try:
return c(filename), mtype return c(filename), mtype
......
...@@ -131,5 +131,6 @@ class PDFParser(abstract.AbstractParser): ...@@ -131,5 +131,6 @@ class PDFParser(abstract.AbstractParser):
metadata[key] = document.get_property(key) metadata[key] = document.get_property(key)
if 'metadata' in metadata: if 'metadata' in metadata:
parsed_meta = self.__parse_metadata_field(metadata['metadata']) parsed_meta = self.__parse_metadata_field(metadata['metadata'])
return {**metadata, **parsed_meta} for key, value in parsed_meta.items():
metadata[key] = value
return metadata return metadata
from typing import Union, Tuple, Dict
from . import abstract from . import abstract
class TorrentParser(abstract.AbstractParser): class TorrentParser(abstract.AbstractParser):
mimetypes = {b'application/x-bittorrent', } mimetypes = {'application/x-bittorrent', }
whitelist = {b'announce', b'announce-list', b'info'} whitelist = {b'announce', b'announce-list', b'info'}
def get_meta(self) -> dict: def get_meta(self) -> Dict[str, str]:
metadata = {} metadata = {}
with open(self.filename, 'rb') as f: with open(self.filename, 'rb') as f:
d = _BencodeHandler().bdecode(f.read()) d = _BencodeHandler().bdecode(f.read())
...@@ -54,7 +55,7 @@ class _BencodeHandler(object): ...@@ -54,7 +55,7 @@ class _BencodeHandler(object):
} }
@staticmethod @staticmethod
def __decode_int(s: str) -> (int, str): def __decode_int(s: bytes) -> Tuple[int, bytes]:
s = s[1:] s = s[1:]
next_idx = s.index(b'e') next_idx = s.index(b'e')
if s.startswith(b'-0'): if s.startswith(b'-0'):
...@@ -64,7 +65,7 @@ class _BencodeHandler(object): ...@@ -64,7 +65,7 @@ class _BencodeHandler(object):
return int(s[:next_idx]), s[next_idx+1:] return int(s[:next_idx]), s[next_idx+1:]
@staticmethod @staticmethod
def __decode_string(s: str) -> (str, str): def __decode_string(s: bytes) -> Tuple[bytes, bytes]:
sep = s.index(b':') sep = s.index(b':')
str_len = int(s[:sep]) str_len = int(s[:sep])
if str_len < 0: if str_len < 0:
...@@ -74,7 +75,7 @@ class _BencodeHandler(object): ...@@ -74,7 +75,7 @@ class _BencodeHandler(object):
s = s[1:] s = s[1:]
return s[sep:sep+str_len], s[sep+str_len:] return s[sep:sep+str_len], s[sep+str_len:]
def __decode_list(self, s: str) -> (list, str): def __decode_list(self, s: bytes) -> Tuple[list, bytes]:
r = list() r = list()
s = s[1:] # skip leading `l` s = s[1:] # skip leading `l`
while s[0] != ord('e'): while s[0] != ord('e'):
...@@ -82,7 +83,7 @@ class _BencodeHandler(object): ...@@ -82,7 +83,7 @@ class _BencodeHandler(object):
r.append(v) r.append(v)
return r, s[1:] return r, s[1:]
def __decode_dict(self, s: str) -> (dict, str): def __decode_dict(self, s: bytes) -> Tuple[dict, bytes]:
r = dict() r = dict()
s = s[1:] # skip leading `d` s = s[1:] # skip leading `d`
while s[0] != ord(b'e'): while s[0] != ord(b'e'):
...@@ -91,11 +92,11 @@ class _BencodeHandler(object): ...@@ -91,11 +92,11 @@ class _BencodeHandler(object):
return r, s[1:] return r, s[1:]
@staticmethod @staticmethod
def __encode_int(x: str) -> bytes: def __encode_int(x: bytes) -> bytes:
return b'i' + bytes(str(x), 'utf-8') + b'e' return b'i' + bytes(str(x), 'utf-8') + b'e'
@staticmethod @staticmethod
def __encode_string(x: str) -> bytes: def __encode_string(x: bytes) -> bytes:
return bytes((str(len(x))), 'utf-8') + b':' + x return bytes((str(len(x))), 'utf-8') + b':' + x
def __encode_list(self, x: str) -> bytes: def __encode_list(self, x: str) -> bytes:
...@@ -104,17 +105,17 @@ class _BencodeHandler(object): ...@@ -104,17 +105,17 @@ class _BencodeHandler(object):
ret += self.__encode_func[type(i)](i) ret += self.__encode_func[type(i)](i)
return b'l' + ret + b'e' return b'l' + ret + b'e'
def __encode_dict(self, x: str) -> bytes: def __encode_dict(self, x: dict) -> bytes:
ret = b'' ret = b''
for k, v in sorted(x.items()): for k, v in sorted(x.items()):
ret += self.__encode_func[type(k)](k) ret += self.__encode_func[type(k)](k)
ret += self.__encode_func[type(v)](v) ret += self.__encode_func[type(v)](v)
return b'd' + ret + b'e' return b'd' + ret + b'e'
def bencode(self, s: str) -> bytes: def bencode(self, s: Union[dict, list, bytes, int]) -> bytes:
return self.__encode_func[type(s)](s) return self.__encode_func[type(s)](s)
def bdecode(self, s: str): def bdecode(self, s: bytes) -> Union[dict, None]:
try: try:
r, l = self.__decode_func[s[0]](s) r, l = self.__decode_func[s[0]](s)
except (IndexError, KeyError, ValueError) as e: except (IndexError, KeyError, ValueError) as e:
......
...@@ -44,7 +44,7 @@ def show_meta(filename: str): ...@@ -44,7 +44,7 @@ def show_meta(filename: str):
if not __check_file(filename): if not __check_file(filename):
return return
p, mtype = parser_factory.get_parser(filename) p, mtype = parser_factory.get_parser(filename) # type: ignore
if p is None: if p is None:
print("[-] %s's format (%s) is not supported" % (filename, mtype)) print("[-] %s's format (%s) is not supported" % (filename, mtype))
return return
...@@ -61,7 +61,7 @@ def clean_meta(params: Tuple[str, bool]) -> bool: ...@@ -61,7 +61,7 @@ def clean_meta(params: Tuple[str, bool]) -> bool:
if not __check_file(filename, os.R_OK|os.W_OK): if not __check_file(filename, os.R_OK|os.W_OK):
return False return False
p, mtype = parser_factory.get_parser(filename) p, mtype = parser_factory.get_parser(filename) # type: ignore
if p is None: if p is None:
print("[-] %s's format (%s) is not supported" % (filename, mtype)) print("[-] %s's format (%s) is not supported" % (filename, mtype))
return False return False
......
...@@ -67,6 +67,13 @@ class TestCleanMeta(unittest.TestCase): ...@@ -67,6 +67,13 @@ class TestCleanMeta(unittest.TestCase):
os.remove('./tests/data/clean.jpg') os.remove('./tests/data/clean.jpg')
class TestIsSupported(unittest.TestCase):
def test_pdf(self):
proc = subprocess.Popen(['./mat2', '--show', './tests/data/dirty.pdf'],
stdout=subprocess.PIPE)
stdout, _ = proc.communicate()
self.assertNotIn(b"isn't supported", stdout)
class TestGetMeta(unittest.TestCase): class TestGetMeta(unittest.TestCase):
def test_pdf(self): def test_pdf(self):
proc = subprocess.Popen(['./mat2', '--show', './tests/data/dirty.pdf'], proc = subprocess.Popen(['./mat2', '--show', './tests/data/dirty.pdf'],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment