From 5219cb3e7567143ea704d299ebe6e7135341ebc1 Mon Sep 17 00:00:00 2001
From: shirt-dev <2660574+shirt-dev@users.noreply.github.com>
Date: Mon, 8 Feb 2021 11:46:01 -0500
Subject: [PATCH] #55 Add aria2c support for DASH (mpd) and HLS (m3u8)

Co-authored-by: Dan <2660574+shirtjs@users.noreply.github.com>
Co-authored-by: pukkandan <pukkandan@gmail.com>
---
 youtube_dlc/downloader/__init__.py | 40 +++++++++------
 youtube_dlc/downloader/dash.py     | 37 +++++++++++---
 youtube_dlc/downloader/external.py | 80 +++++++++++++++++++++++++++---
 youtube_dlc/downloader/fragment.py | 21 ++++++++
 youtube_dlc/downloader/hls.py      | 36 ++++++++++++--
 5 files changed, 179 insertions(+), 35 deletions(-)

diff --git a/youtube_dlc/downloader/__init__.py b/youtube_dlc/downloader/__init__.py
index 4ae81f516..0af65890b 100644
--- a/youtube_dlc/downloader/__init__.py
+++ b/youtube_dlc/downloader/__init__.py
@@ -1,11 +1,24 @@
 from __future__ import unicode_literals
 
+from ..utils import (
+    determine_protocol,
+)
+
+
+def _get_real_downloader(info_dict, protocol=None, *args, **kwargs):
+    info_copy = info_dict.copy()
+    if protocol:
+        info_copy['protocol'] = protocol
+    return get_suitable_downloader(info_copy, *args, **kwargs)
+
+
+# Some of these require _get_real_downloader
 from .common import FileDownloader
+from .dash import DashSegmentsFD
 from .f4m import F4mFD
 from .hls import HlsFD
 from .http import HttpFD
 from .rtmp import RtmpFD
-from .dash import DashSegmentsFD
 from .rtsp import RtspFD
 from .ism import IsmFD
 from .youtube_live_chat import YoutubeLiveChatReplayFD
@@ -14,10 +27,6 @@
     FFmpegFD,
 )
 
-from ..utils import (
-    determine_protocol,
-)
-
 PROTOCOL_MAP = {
     'rtmp': RtmpFD,
     'm3u8_native': HlsFD,
@@ -31,7 +40,7 @@
 }
 
 
-def get_suitable_downloader(info_dict, params={}):
+def get_suitable_downloader(info_dict, params={}, default=HttpFD):
     """Get the downloader class that can handle the info dict."""
     protocol = determine_protocol(info_dict)
     info_dict['protocol'] = protocol
@@ -45,16 +54,17 @@ def get_suitable_downloader(info_dict, params={}):
         if ed.can_download(info_dict):
             return ed
 
-    if protocol.startswith('m3u8') and info_dict.get('is_live'):
-        return FFmpegFD
+    if protocol.startswith('m3u8'):
+        if info_dict.get('is_live'):
+            return FFmpegFD
+        elif _get_real_downloader(info_dict, 'frag_urls', params, None):
+            return HlsFD
+        elif params.get('hls_prefer_native') is True:
+            return HlsFD
+        elif params.get('hls_prefer_native') is False:
+            return FFmpegFD
 
-    if protocol == 'm3u8' and params.get('hls_prefer_native') is True:
-        return HlsFD
-
-    if protocol == 'm3u8_native' and params.get('hls_prefer_native') is False:
-        return FFmpegFD
-
-    return PROTOCOL_MAP.get(protocol, HttpFD)
+    return PROTOCOL_MAP.get(protocol, default)
 
 
 __all__ = [
diff --git a/youtube_dlc/downloader/dash.py b/youtube_dlc/downloader/dash.py
index c6d674bc6..d758282c1 100644
--- a/youtube_dlc/downloader/dash.py
+++ b/youtube_dlc/downloader/dash.py
@@ -1,6 +1,8 @@
 from __future__ import unicode_literals
 
+from ..downloader import _get_real_downloader
 from .fragment import FragmentFD
+
 from ..compat import compat_urllib_error
 from ..utils import (
     DownloadError,
@@ -20,31 +22,42 @@ def real_download(self, filename, info_dict):
         fragments = info_dict['fragments'][:1] if self.params.get(
             'test', False) else info_dict['fragments']
 
+        real_downloader = _get_real_downloader(info_dict, 'frag_urls', self.params, None)
+
         ctx = {
             'filename': filename,
             'total_frags': len(fragments),
         }
 
-        self._prepare_and_start_frag_download(ctx)
+        if real_downloader:
+            self._prepare_external_frag_download(ctx)
+        else:
+            self._prepare_and_start_frag_download(ctx)
 
         fragment_retries = self.params.get('fragment_retries', 0)
         skip_unavailable_fragments = self.params.get('skip_unavailable_fragments', True)
 
+        fragment_urls = []
         frag_index = 0
         for i, fragment in enumerate(fragments):
             frag_index += 1
             if frag_index <= ctx['fragment_index']:
                 continue
+            fragment_url = fragment.get('url')
+            if not fragment_url:
+                assert fragment_base_url
+                fragment_url = urljoin(fragment_base_url, fragment['path'])
+
+            if real_downloader:
+                fragment_urls.append(fragment_url)
+                continue
+
             # In DASH, the first segment contains necessary headers to
             # generate a valid MP4 file, so always abort for the first segment
             fatal = i == 0 or not skip_unavailable_fragments
             count = 0
             while count <= fragment_retries:
                 try:
-                    fragment_url = fragment.get('url')
-                    if not fragment_url:
-                        assert fragment_base_url
-                        fragment_url = urljoin(fragment_base_url, fragment['path'])
                     success, frag_content = self._download_fragment(ctx, fragment_url, info_dict)
                     if not success:
                         return False
@@ -75,6 +88,16 @@ def real_download(self, filename, info_dict):
                 self.report_error('giving up after %s fragment retries' % fragment_retries)
                 return False
 
-        self._finish_frag_download(ctx)
-
+        if real_downloader:
+            info_copy = info_dict.copy()
+            info_copy['url_list'] = fragment_urls
+            fd = real_downloader(self.ydl, self.params)
+            # TODO: Make progress updates work without hooking twice
+            # for ph in self._progress_hooks:
+            #     fd.add_progress_hook(ph)
+            success = fd.real_download(filename, info_copy)
+            if not success:
+                return False
+        else:
+            self._finish_frag_download(ctx)
         return True
diff --git a/youtube_dlc/downloader/external.py b/youtube_dlc/downloader/external.py
index 8f82acdf4..67a3b9aea 100644
--- a/youtube_dlc/downloader/external.py
+++ b/youtube_dlc/downloader/external.py
@@ -5,6 +5,13 @@
 import subprocess
 import sys
 import time
+import shutil
+
+try:
+    from Crypto.Cipher import AES
+    can_decrypt_frag = True
+except ImportError:
+    can_decrypt_frag = False
 
 from .common import FileDownloader
 from ..compat import (
@@ -18,15 +25,19 @@
     cli_bool_option,
     cli_configuration_args,
     encodeFilename,
+    error_to_compat_str,
     encodeArgument,
     handle_youtubedl_headers,
     check_executable,
     is_outdated_version,
     process_communicate_or_kill,
+    sanitized_Request,
 )
 
 
 class ExternalFD(FileDownloader):
+    SUPPORTED_PROTOCOLS = ('http', 'https', 'ftp', 'ftps')
+
     def real_download(self, filename, info_dict):
         self.report_destination(filename)
         tmpfilename = self.temp_name(filename)
@@ -79,7 +90,7 @@ def available(cls):
 
     @classmethod
     def supports(cls, info_dict):
-        return info_dict['protocol'] in ('http', 'https', 'ftp', 'ftps')
+        return info_dict['protocol'] in cls.SUPPORTED_PROTOCOLS
 
     @classmethod
     def can_download(cls, info_dict):
@@ -109,8 +120,47 @@ def _call_downloader(self, tmpfilename, info_dict):
         _, stderr = process_communicate_or_kill(p)
         if p.returncode != 0:
             self.to_stderr(stderr.decode('utf-8', 'replace'))
+
+        if 'url_list' in info_dict:
+            file_list = []
+            for [i, url] in enumerate(info_dict['url_list']):
+                tmpsegmentname = '%s_%s.frag' % (tmpfilename, i)
+                file_list.append(tmpsegmentname)
+            with open(tmpfilename, 'wb') as dest:
+                for i in file_list:
+                    if 'decrypt_info' in info_dict:
+                        decrypt_info = info_dict['decrypt_info']
+                        with open(i, 'rb') as src:
+                            if decrypt_info['METHOD'] == 'AES-128':
+                                iv = decrypt_info.get('IV')
+                                decrypt_info['KEY'] = decrypt_info.get('KEY') or self.ydl.urlopen(
+                                    self._prepare_url(info_dict, info_dict.get('_decryption_key_url') or decrypt_info['URI'])).read()
+                                encrypted_data = src.read()
+                                decrypted_data = AES.new(
+                                    decrypt_info['KEY'], AES.MODE_CBC, iv).decrypt(encrypted_data)
+                                dest.write(decrypted_data)
+                            else:
+                                shutil.copyfileobj(open(i, 'rb'), dest)
+                    else:
+                        shutil.copyfileobj(open(i, 'rb'), dest)
+            if not self.params.get('keep_fragments', False):
+                for file_path in file_list:
+                    try:
+                        os.remove(file_path)
+                    except OSError as ose:
+                        self.report_error("Unable to delete file %s; %s" % (file_path, error_to_compat_str(ose)))
+                try:
+                    file_path = '%s.frag.urls' % tmpfilename
+                    os.remove(file_path)
+                except OSError as ose:
+                    self.report_error("Unable to delete file %s; %s" % (file_path, error_to_compat_str(ose)))
+
         return p.returncode
 
+    def _prepare_url(self, info_dict, url):
+        headers = info_dict.get('http_headers')
+        return sanitized_Request(url, None, headers) if headers else url
+
 
 class CurlFD(ExternalFD):
     AVAILABLE_OPT = '-V'
@@ -186,15 +236,17 @@ def _make_cmd(self, tmpfilename, info_dict):
 
 class Aria2cFD(ExternalFD):
     AVAILABLE_OPT = '-v'
+    SUPPORTED_PROTOCOLS = ('http', 'https', 'ftp', 'ftps', 'frag_urls')
 
     def _make_cmd(self, tmpfilename, info_dict):
         cmd = [self.exe, '-c']
-        cmd += self._configuration_args([
-            '--min-split-size', '1M', '--max-connection-per-server', '4'])
         dn = os.path.dirname(tmpfilename)
+        if 'url_list' not in info_dict:
+            cmd += ['--out', os.path.basename(tmpfilename)]
+        verbose_level_args = ['--console-log-level=warn', '--summary-interval=0']
+        cmd += self._configuration_args(['--file-allocation=none', '-x16', '-j16', '-s16'] + verbose_level_args)
         if dn:
             cmd += ['--dir', dn]
-        cmd += ['--out', os.path.basename(tmpfilename)]
         if info_dict.get('http_headers') is not None:
             for key, val in info_dict['http_headers'].items():
                 cmd += ['--header', '%s: %s' % (key, val)]
@@ -202,7 +254,21 @@ def _make_cmd(self, tmpfilename, info_dict):
         cmd += self._option('--all-proxy', 'proxy')
         cmd += self._bool_option('--check-certificate', 'nocheckcertificate', 'false', 'true', '=')
         cmd += self._bool_option('--remote-time', 'updatetime', 'true', 'false', '=')
-        cmd += ['--', info_dict['url']]
+        cmd += ['--auto-file-renaming=false']
+        if 'url_list' in info_dict:
+            cmd += verbose_level_args
+            cmd += ['--uri-selector', 'inorder', '--download-result=hide']
+            url_list_file = '%s.frag.urls' % tmpfilename
+            url_list = []
+            for [i, url] in enumerate(info_dict['url_list']):
+                tmpsegmentname = '%s_%s.frag' % (os.path.basename(tmpfilename), i)
+                url_list.append('%s\n\tout=%s' % (url, tmpsegmentname))
+            with open(url_list_file, 'w') as f:
+                f.write('\n'.join(url_list))
+
+            cmd += ['-i', url_list_file]
+        else:
+            cmd += ['--', info_dict['url']]
         return cmd
 
 
@@ -221,9 +287,7 @@ def _make_cmd(self, tmpfilename, info_dict):
 
 
 class FFmpegFD(ExternalFD):
-    @classmethod
-    def supports(cls, info_dict):
-        return info_dict['protocol'] in ('http', 'https', 'ftp', 'ftps', 'm3u8', 'rtsp', 'rtmp', 'mms')
+    SUPPORTED_PROTOCOLS = ('http', 'https', 'ftp', 'ftps', 'm3u8', 'rtsp', 'rtmp', 'mms')
 
     @classmethod
     def available(cls):
diff --git a/youtube_dlc/downloader/fragment.py b/youtube_dlc/downloader/fragment.py
index cf4fd41da..f4104c713 100644
--- a/youtube_dlc/downloader/fragment.py
+++ b/youtube_dlc/downloader/fragment.py
@@ -277,3 +277,24 @@ def _finish_frag_download(self, ctx):
             'status': 'finished',
             'elapsed': elapsed,
         })
+
+    def _prepare_external_frag_download(self, ctx):
+        if 'live' not in ctx:
+            ctx['live'] = False
+        if not ctx['live']:
+            total_frags_str = '%d' % ctx['total_frags']
+            ad_frags = ctx.get('ad_frags', 0)
+            if ad_frags:
+                total_frags_str += ' (not including %d ad)' % ad_frags
+        else:
+            total_frags_str = 'unknown (live)'
+        self.to_screen(
+            '[%s] Total fragments: %s' % (self.FD_NAME, total_frags_str))
+
+        tmpfilename = self.temp_name(ctx['filename'])
+
+        # Should be initialized before ytdl file check
+        ctx.update({
+            'tmpfilename': tmpfilename,
+            'fragment_index': 0,
+        })
diff --git a/youtube_dlc/downloader/hls.py b/youtube_dlc/downloader/hls.py
index 7aaebc940..c3c862410 100644
--- a/youtube_dlc/downloader/hls.py
+++ b/youtube_dlc/downloader/hls.py
@@ -8,6 +8,7 @@
 except ImportError:
     can_decrypt_frag = False
 
+from ..downloader import _get_real_downloader
 from .fragment import FragmentFD
 from .external import FFmpegFD
 
@@ -73,10 +74,13 @@ def real_download(self, filename, info_dict):
                 'hlsnative has detected features it does not support, '
                 'extraction will be delegated to ffmpeg')
             fd = FFmpegFD(self.ydl, self.params)
-            for ph in self._progress_hooks:
-                fd.add_progress_hook(ph)
+            # TODO: Make progress updates work without hooking twice
+            # for ph in self._progress_hooks:
+            #     fd.add_progress_hook(ph)
             return fd.real_download(filename, info_dict)
 
+        real_downloader = _get_real_downloader(info_dict, 'frag_urls', self.params, None)
+
         def is_ad_fragment_start(s):
             return (s.startswith('#ANVATO-SEGMENT-INFO') and 'type=ad' in s
                     or s.startswith('#UPLYNK-SEGMENT') and s.endswith(',ad'))
@@ -85,6 +89,8 @@ def is_ad_fragment_end(s):
             return (s.startswith('#ANVATO-SEGMENT-INFO') and 'type=master' in s
                     or s.startswith('#UPLYNK-SEGMENT') and s.endswith(',segment'))
 
+        fragment_urls = []
+
         media_frags = 0
         ad_frags = 0
         ad_frag_next = False
@@ -109,7 +115,10 @@ def is_ad_fragment_end(s):
             'ad_frags': ad_frags,
         }
 
-        self._prepare_and_start_frag_download(ctx)
+        if real_downloader:
+            self._prepare_external_frag_download(ctx)
+        else:
+            self._prepare_and_start_frag_download(ctx)
 
         fragment_retries = self.params.get('fragment_retries', 0)
         skip_unavailable_fragments = self.params.get('skip_unavailable_fragments', True)
@@ -140,6 +149,11 @@ def is_ad_fragment_end(s):
                         else compat_urlparse.urljoin(man_url, line))
                     if extra_query:
                         frag_url = update_url_query(frag_url, extra_query)
+
+                    if real_downloader:
+                        fragment_urls.append(frag_url)
+                        continue
+
                     count = 0
                     headers = info_dict.get('http_headers', {})
                     if byte_range:
@@ -168,6 +182,7 @@ def is_ad_fragment_end(s):
                         self.report_error(
                             'giving up after %s fragment retries' % fragment_retries)
                         return False
+
                     if decrypt_info['METHOD'] == 'AES-128':
                         iv = decrypt_info.get('IV') or compat_struct_pack('>8xq', media_sequence)
                         decrypt_info['KEY'] = decrypt_info.get('KEY') or self.ydl.urlopen(
@@ -211,6 +226,17 @@ def is_ad_fragment_end(s):
                 elif is_ad_fragment_end(line):
                     ad_frag_next = False
 
-        self._finish_frag_download(ctx)
-
+        if real_downloader:
+            info_copy = info_dict.copy()
+            info_copy['url_list'] = fragment_urls
+            info_copy['decrypt_info'] = decrypt_info
+            fd = real_downloader(self.ydl, self.params)
+            # TODO: Make progress updates work without hooking twice
+            # for ph in self._progress_hooks:
+            #     fd.add_progress_hook(ph)
+            success = fd.real_download(filename, info_copy)
+            if not success:
+                return False
+        else:
+            self._finish_frag_download(ctx)
         return True