diff --git a/test/helper.py b/test/helper.py index 5b7e3dfe2..d51ed5634 100644 --- a/test/helper.py +++ b/test/helper.py @@ -20,6 +20,7 @@ from youtube_dl.compat import ( from youtube_dl.utils import ( IDENTITY, preferredencoding, + variadic, write_string, ) @@ -66,7 +67,7 @@ def report_warning(message): class FakeYDL(YoutubeDL): def __init__(self, override=None): # Different instances of the downloader can't share the same dictionary - # some test set the "sublang" parameter, which would break the md5 checks. + # some tests set the "sublang" parameter, which would break the md5 checks. params = get_params(override=override) super(FakeYDL, self).__init__(params, auto_init=False) self.result = [] @@ -83,13 +84,7 @@ class FakeYDL(YoutubeDL): def expect_warning(self, regex): # Silence an expected warning matching a regex - old_report_warning = self.report_warning - - def report_warning(self, message): - if re.match(regex, message): - return - old_report_warning(message) - self.report_warning = types.MethodType(report_warning, self) + expect_warnings(self, regex) class FakeLogger(object): @@ -285,12 +280,14 @@ def assertEqual(self, got, expected, msg=None): def expect_warnings(ydl, warnings_re): real_warning = ydl.report_warning + # to facilitate matching, don't prettify messages + ydl.params['no_color'] = True - def _report_warning(w): - if not any(re.search(w_re, w) for w_re in warnings_re): - real_warning(w) + def _report_warning(self, w, *args, **kwargs): + if not any(re.search(w_re, w) for w_re in variadic(warnings_re)): + real_warning(w, *args, **kwargs) - ydl.report_warning = _report_warning + ydl.report_warning = types.MethodType(_report_warning, ydl) def http_server_port(httpd): diff --git a/test/test_download.py b/test/test_download.py index e0bc8cb95..4e728e980 100644 --- a/test/test_download.py +++ b/test/test_download.py @@ -20,13 +20,17 @@ from test.helper import ( import hashlib +import itertools import json import socket +import re import youtube_dl.YoutubeDL from youtube_dl.compat import ( + compat_filter as filter, compat_http_client, compat_HTTPError, + compat_map as map, compat_open as open, compat_urllib_error, ) @@ -35,9 +39,11 @@ from youtube_dl.utils import ( ExtractorError, error_to_compat_str, format_bytes, + std_headers, UnavailableVideoError, ) from youtube_dl.extractor import get_info_extractor +from youtube_dl.downloader.common import FileDownloader RETRIES = 3 @@ -48,7 +54,7 @@ class YoutubeDL(youtube_dl.YoutubeDL): self.processed_info_dicts = [] super(YoutubeDL, self).__init__(*args, **kwargs) - def report_warning(self, message): + def report_warning(self, message, *args, **kwargs): # Don't accept warnings during tests raise ExtractorError(message) @@ -57,9 +63,10 @@ class YoutubeDL(youtube_dl.YoutubeDL): return super(YoutubeDL, self).process_info(info_dict) -def _file_md5(fn): +def _file_md5(fn, length=None): with open(fn, 'rb') as f: - return hashlib.md5(f.read()).hexdigest() + return hashlib.md5( + f.read() if length is None else f.read(length)).hexdigest() defs = gettestcases() @@ -84,6 +91,13 @@ class TestDownload(unittest.TestCase): strclass(self.__class__), ' [%s]' % add_ie if add_ie else '') + @classmethod + def addTest(cls, test_method, test_method_name, add_ie): + test_method.__name__ = str(test_method_name) + test_method.add_ie = add_ie + setattr(cls, test_method.__name__, test_method) + del test_method + def setUp(self): self.defs = defs @@ -125,6 +139,17 @@ def generator(test_case, tname): params.setdefault('playlistend', test_case.get('playlist_mincount')) params.setdefault('skip_download', True) + if 'user_agent' in params: + std_headers['User-Agent'] = params['user_agent'] + + if 'referer' in params: + std_headers['Referer'] = params['referer'] + + for h in params.get('headers', []): + h = h.split(':', 1) + if len(h) > 1: + std_headers[h[0]] = h[1] + ydl = YoutubeDL(params, auto_init=False) ydl.add_default_info_extractors() finished_hook_called = set() @@ -151,8 +176,7 @@ def generator(test_case, tname): try_rm_tcs_files() try: - try_num = 1 - while True: + for try_num in itertools.count(1): try: # We're not using .download here since that is just a shim # for outside error handling, and returns the exit code @@ -161,7 +185,7 @@ def generator(test_case, tname): test_case['url'], force_generic_extractor=params.get('force_generic_extractor', False)) except (DownloadError, ExtractorError) as err: - # Check if the exception is not a network related one + # Retry, or raise if the exception is not network-related if not err.exc_info[0] in (compat_urllib_error.URLError, socket.timeout, UnavailableVideoError, compat_http_client.BadStatusLine) or (err.exc_info[0] == compat_HTTPError and err.exc_info[1].code == 503): msg = getattr(err, 'msg', error_to_compat_str(err)) err.msg = '%s (%s)' % (msg, tname, ) @@ -172,8 +196,6 @@ def generator(test_case, tname): return print('Retrying: {0} failed tries\n\n##########\n\n'.format(try_num)) - - try_num += 1 else: break @@ -237,7 +259,7 @@ def generator(test_case, tname): (tc_filename, format_bytes(expected_minsize), format_bytes(got_fsize))) if 'md5' in tc: - md5_for_file = _file_md5(tc_filename) + md5_for_file = _file_md5(tc_filename) if not params.get('test') else _file_md5(tc_filename, FileDownloader._TEST_FILE_SIZE) self.assertEqual(tc['md5'], md5_for_file) # Finally, check test cases' data again but this time against # extracted data from info JSON file written during processing @@ -267,12 +289,42 @@ for n, test_case in enumerate(defs): tname = 'test_%s_%d' % (test_case['name'], i) i += 1 test_method = generator(test_case, tname) - test_method.__name__ = str(tname) - ie_list = test_case.get('add_ie') - test_method.add_ie = ie_list and ','.join(ie_list) - setattr(TestDownload, test_method.__name__, test_method) - del test_method - + ie_list = ','.join(test_case.get('add_ie', [])) + TestDownload.addTest(test_method, tname, ie_list) + + +def tests_for_ie(ie_key): + return filter( + lambda a: callable(getattr(TestDownload, a, None)), + filter(lambda a: re.match(r'test_%s(?:_\d+)?$' % ie_key, a), + dir(TestDownload))) + + +def gen_test_suite(ie_key): + def test_all(self): + print(self) + suite = unittest.TestSuite( + map(TestDownload, tests_for_ie(ie_key))) + result = self.defaultTestResult() + suite.run(result) + print('Errors: %d\t Failures: %d\tSkipped: %d' % + tuple(map(len, (result.errors, result.failures, result.skipped)))) + print('Expected failures: %d\tUnexpected successes: %d' % + tuple(map(len, (result.expectedFailures, result.unexpectedSuccesses)))) + return result + + return test_all + + +for ie_key in set( + map(lambda a: a[5:], + filter( + lambda x: callable(getattr(TestDownload, x, None)), + filter( + lambda t: re.match(r"test_.+(?