main.py 9.57 KB
Newer Older
jvoisin's avatar
jvoisin committed
1
import os
2
import hmac
3
import mimetypes as mtype
4
from uuid import uuid4
5 6 7 8
import jinja2
import base64
import io
import binascii
9
import zipfile
jvoisin's avatar
jvoisin committed
10

11 12
from cerberus import Validator
import utils
jvoisin's avatar
jvoisin committed
13
from libmat2 import parser_factory
14 15 16 17 18 19
from flask import Flask, flash, request, redirect, url_for, render_template, send_from_directory, after_this_request
from flask_restful import Resource, Api, reqparse, abort
from werkzeug.utils import secure_filename
from werkzeug.datastructures import FileStorage
from flask_cors import CORS
from urllib.parse import urljoin
jvoisin's avatar
jvoisin committed
20 21


22 23 24 25 26 27
def create_app(test_config=None):
    app = Flask(__name__)
    app.config['SECRET_KEY'] = os.urandom(32)
    app.config['UPLOAD_FOLDER'] = './uploads/'
    app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024  # 16MB
    app.config['CUSTOM_TEMPLATES_DIR'] = 'custom_templates'
jvoisin's avatar
jvoisin committed
28

29 30 31 32
    app.jinja_loader = jinja2.ChoiceLoader([  # type: ignore
        jinja2.FileSystemLoader(app.config['CUSTOM_TEMPLATES_DIR']),
        app.jinja_loader,
        ])
jvoisin's avatar
jvoisin committed
33

34 35
    api = Api(app)
    CORS(app, resources={r"/api/*": {"origins": utils.get_allow_origin_header_value()}})
jvoisin's avatar
jvoisin committed
36

37
    @app.route('/download/<string:key>/<string:filename>')
38
    def download_file(key: str, filename:str):
39
        if filename != secure_filename(filename):
jvoisin's avatar
jvoisin committed
40
            return redirect(url_for('upload_file'))
jvoisin's avatar
jvoisin committed
41

42
        complete_path, filepath = get_file_paths(filename)
jvoisin's avatar
jvoisin committed
43

44
        if not os.path.exists(complete_path):
jvoisin's avatar
jvoisin committed
45
            return redirect(url_for('upload_file'))
46 47 48 49 50 51 52
        if hmac.compare_digest(utils.hash_file(complete_path), key) is False:
            return redirect(url_for('upload_file'))

        @after_this_request
        def remove_file(response):
            os.remove(complete_path)
            return response
jfriedli's avatar
jfriedli committed
53
        return send_from_directory(app.config['UPLOAD_FOLDER'], filepath, as_attachment=True)
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108

    @app.route('/', methods=['GET', 'POST'])
    def upload_file():
        utils.check_upload_folder(app.config['UPLOAD_FOLDER'])
        mimetypes = get_supported_extensions()

        if request.method == 'POST':
            if 'file' not in request.files:  # check if the post request has the file part
                flash('No file part')
                return redirect(request.url)

            uploaded_file = request.files['file']
            if not uploaded_file.filename:
                flash('No selected file')
                return redirect(request.url)

            filename, filepath = save_file(uploaded_file)
            parser, mime = get_file_parser(filepath)

            if parser is None:
                flash('The type %s is not supported' % mime)
                return redirect(url_for('upload_file'))

            meta = parser.get_meta()

            if parser.remove_all() is not True:
                flash('Unable to clean %s' % mime)
                return redirect(url_for('upload_file'))

            key, meta_after, output_filename = cleanup(parser, filepath)

            return render_template(
                'download.html', mimetypes=mimetypes, meta=meta, filename=output_filename, meta_after=meta_after, key=key
            )

        max_file_size = int(app.config['MAX_CONTENT_LENGTH'] / 1024 / 1024)
        return render_template('index.html', max_file_size=max_file_size, mimetypes=mimetypes)

    def get_supported_extensions():
        extensions = set()
        for parser in parser_factory._get_parsers():
            for m in parser.mimetypes:
                extensions |= set(mtype.guess_all_extensions(m, strict=False))
        # since `guess_extension` might return `None`, we need to filter it out
        return sorted(filter(None, extensions))

    def save_file(file):
        filename = secure_filename(file.filename)
        filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(os.path.join(filepath))
        return filename, filepath

    def get_file_parser(filepath: str):
        parser, mime = parser_factory.get_parser(filepath)
        return parser, mime
jvoisin's avatar
jvoisin committed
109

110 111
    def cleanup(parser, filepath):
        output_filename = os.path.basename(parser.output_filename)
jvoisin's avatar
jvoisin committed
112 113 114
        parser, _ = parser_factory.get_parser(parser.output_filename)
        meta_after = parser.get_meta()
        os.remove(filepath)
jvoisin's avatar
jvoisin committed
115

116 117 118 119 120 121 122 123 124
        key = utils.hash_file(os.path.join(app.config['UPLOAD_FOLDER'], output_filename))
        return key, meta_after, output_filename

    def get_file_paths(filename):
        filepath = secure_filename(filename)

        complete_path = os.path.join(app.config['UPLOAD_FOLDER'], filepath)
        return complete_path, filepath

125 126 127 128 129 130 131 132 133 134 135 136 137
    def is_valid_api_download_file(filename, key):
        if filename != secure_filename(filename):
            abort(400, message='Insecure filename')

        complete_path, filepath = get_file_paths(filename)

        if not os.path.exists(complete_path):
            abort(404, message='File not found')

        if hmac.compare_digest(utils.hash_file(complete_path), key) is False:
            abort(400, message='The file hash does not match')
        return complete_path, filepath

138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
    class APIUpload(Resource):

        def post(self):
            utils.check_upload_folder(app.config['UPLOAD_FOLDER'])
            req_parser = reqparse.RequestParser()
            req_parser.add_argument('file_name', type=str, required=True, help='Post parameter is not specified: file_name')
            req_parser.add_argument('file', type=str, required=True, help='Post parameter is not specified: file')

            args = req_parser.parse_args()
            try:
                file_data = base64.b64decode(args['file'])
            except binascii.Error as err:
                abort(400, message='Failed decoding file: ' + str(err))

            file = FileStorage(stream=io.BytesIO(file_data), filename=args['file_name'])
            filename, filepath = save_file(file)
            parser, mime = get_file_parser(filepath)

            if parser is None:
                abort(415, message='The type %s is not supported' % mime)

            meta = parser.get_meta()
            if not parser.remove_all():
                abort(500, message='Unable to clean %s' % mime)

            key, meta_after, output_filename = cleanup(parser, filepath)
164 165 166 167 168 169 170 171
            return utils.return_file_created_response(
                output_filename,
                mime,
                key,
                meta,
                meta_after,
                urljoin(request.host_url, '%s/%s/%s/%s' % ('api', 'download', key, output_filename))
            )
172 173 174

    class APIDownload(Resource):
        def get(self, key: str, filename: str):
175
            complete_path, filepath = is_valid_api_download_file(filename, key)
176 177 178 179 180 181
            # Make sure the file is NOT deleted on HEAD requests
            if request.method == 'GET':
                @after_this_request
                def remove_file(response):
                    os.remove(complete_path)
                    return response
182

jfriedli's avatar
jfriedli committed
183
            return send_from_directory(app.config['UPLOAD_FOLDER'], filepath, as_attachment=True)
184

185 186 187 188 189 190 191 192 193
    class APIBulkDownloadCreator(Resource):
        schema = {
            'download_list': {
                'type': 'list',
                'minlength': 2,
                'maxlength': int(os.environ.get('MAT2_MAX_FILES_BULK_DOWNLOAD', 10)),
                'schema': {
                    'type': 'dict',
                    'schema': {
194 195
                        'key': {'type': 'string', 'required': True},
                        'file_name': {'type': 'string', 'required': True}
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
                    }
                }
            }
        }
        v = Validator(schema)

        def post(self):
            utils.check_upload_folder(app.config['UPLOAD_FOLDER'])
            data = request.json
            if not self.v.validate(data):
                abort(400, message=self.v.errors)
            # prevent the zip file from being overwritten
            zip_filename = 'files.' + str(uuid4()) + '.zip'
            zip_path = os.path.join(app.config['UPLOAD_FOLDER'], zip_filename)
            cleaned_files_zip = zipfile.ZipFile(zip_path, 'w')
            with cleaned_files_zip:
                for file_candidate in data['download_list']:
                    complete_path, file_path = is_valid_api_download_file(
                        file_candidate['file_name'],
                        file_candidate['key']
                    )
                    try:
                        cleaned_files_zip.write(complete_path)
219
                        os.remove(complete_path)
220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
                    except ValueError:
                        abort(400, message='Creating the archive failed')

                try:
                    cleaned_files_zip.testzip()
                except ValueError as e:
                    abort(400, message=str(e))

            parser, mime = get_file_parser(zip_path)
            if not parser.remove_all():
                abort(500, message='Unable to clean %s' % mime)
            key, meta_after, output_filename = cleanup(parser, zip_path)
            return {
                'output_filename': output_filename,
                'mime': mime,
                'key': key,
                'meta_after': meta_after,
                'download_link': urljoin(request.host_url, '%s/%s/%s/%s' % ('api', 'download', key, output_filename))
            }, 201

    class APISupportedExtensions(Resource):
241 242
        def get(self):
            return get_supported_extensions()
243

244 245
    api.add_resource(APIUpload, '/api/upload')
    api.add_resource(APIDownload, '/api/download/<string:key>/<string:filename>')
246 247
    api.add_resource(APIBulkDownloadCreator, '/api/download/bulk')
    api.add_resource(APISupportedExtensions, '/api/extension')
jvoisin's avatar
jvoisin committed
248

249
    return app
jvoisin's avatar
jvoisin committed
250

251

252
app = create_app()
jvoisin's avatar
jvoisin committed
253

jvoisin's avatar
jvoisin committed
254
if __name__ == '__main__':  # pragma: no cover
255
    app.run()