diff --git a/test/test_jsinterp.py b/test/test_jsinterp.py index 665af4668a..863e52458b 100644 --- a/test/test_jsinterp.py +++ b/test/test_jsinterp.py @@ -7,8 +7,10 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import math +import re -from yt_dlp.jsinterp import JSInterpreter +from yt_dlp.jsinterp import JS_Undefined, JSInterpreter class TestJSInterpreter(unittest.TestCase): @@ -66,6 +68,9 @@ def test_operators(self): jsi = JSInterpreter('function f(){return 0 && 1 || 2;}') self.assertEqual(jsi.call_function('f'), 2) + jsi = JSInterpreter('function f(){return 0 ?? 42;}') + self.assertEqual(jsi.call_function('f'), 0) + def test_array_access(self): jsi = JSInterpreter('function f(){var x = [1,2,3]; x[0] = 4; x[0] = 5; x[2.0] = 7; return x;}') self.assertEqual(jsi.call_function('f'), [5, 2, 7]) @@ -229,6 +234,119 @@ def test_return_function(self): ''') self.assertEqual(jsi.call_function('x')([]), 1) + def test_null(self): + jsi = JSInterpreter(''' + function x() { return null; } + ''') + self.assertEqual(jsi.call_function('x'), None) + + jsi = JSInterpreter(''' + function x() { return [null > 0, null < 0, null == 0, null === 0]; } + ''') + self.assertEqual(jsi.call_function('x'), [False, False, False, False]) + + jsi = JSInterpreter(''' + function x() { return [null >= 0, null <= 0]; } + ''') + self.assertEqual(jsi.call_function('x'), [True, True]) + + def test_undefined(self): + jsi = JSInterpreter(''' + function x() { return undefined === undefined; } + ''') + self.assertEqual(jsi.call_function('x'), True) + + jsi = JSInterpreter(''' + function x() { return undefined; } + ''') + self.assertEqual(jsi.call_function('x'), JS_Undefined) + + jsi = JSInterpreter(''' + function x() { let v; return v; } + ''') + self.assertEqual(jsi.call_function('x'), JS_Undefined) + + jsi = JSInterpreter(''' + function x() { return [undefined === undefined, undefined == undefined, undefined < undefined, undefined > undefined]; } + ''') + self.assertEqual(jsi.call_function('x'), [True, True, False, False]) + + jsi = JSInterpreter(''' + function x() { return [undefined === 0, undefined == 0, undefined < 0, undefined > 0]; } + ''') + self.assertEqual(jsi.call_function('x'), [False, False, False, False]) + + jsi = JSInterpreter(''' + function x() { return [undefined >= 0, undefined <= 0]; } + ''') + self.assertEqual(jsi.call_function('x'), [False, False]) + + jsi = JSInterpreter(''' + function x() { return [undefined > null, undefined < null, undefined == null, undefined === null]; } + ''') + self.assertEqual(jsi.call_function('x'), [False, False, True, False]) + + jsi = JSInterpreter(''' + function x() { return [undefined === null, undefined == null, undefined < null, undefined > null]; } + ''') + self.assertEqual(jsi.call_function('x'), [False, True, False, False]) + + jsi = JSInterpreter(''' + function x() { let v; return [42+v, v+42, v**42, 42**v, 0**v]; } + ''') + for y in jsi.call_function('x'): + self.assertTrue(math.isnan(y)) + + jsi = JSInterpreter(''' + function x() { let v; return v**0; } + ''') + self.assertEqual(jsi.call_function('x'), 1) + + jsi = JSInterpreter(''' + function x() { let v; return [v>42, v<=42, v&&42, 42&&v]; } + ''') + self.assertEqual(jsi.call_function('x'), [False, False, JS_Undefined, JS_Undefined]) + + jsi = JSInterpreter('function x(){return undefined ?? 42; }') + self.assertEqual(jsi.call_function('x'), 42) + + def test_object(self): + jsi = JSInterpreter(''' + function x() { return {}; } + ''') + self.assertEqual(jsi.call_function('x'), {}) + + jsi = JSInterpreter(''' + function x() { let a = {m1: 42, m2: 0 }; return [a["m1"], a.m2]; } + ''') + self.assertEqual(jsi.call_function('x'), [42, 0]) + + jsi = JSInterpreter(''' + function x() { let a; return a?.qq; } + ''') + self.assertEqual(jsi.call_function('x'), JS_Undefined) + + jsi = JSInterpreter(''' + function x() { let a = {m1: 42, m2: 0 }; return a?.qq; } + ''') + self.assertEqual(jsi.call_function('x'), JS_Undefined) + + def test_regex(self): + jsi = JSInterpreter(''' + function x() { let a=/,,[/,913,/](,)}/; } + ''') + self.assertEqual(jsi.call_function('x'), None) + + jsi = JSInterpreter(''' + function x() { let a=/,,[/,913,/](,)}/; return a; } + ''') + self.assertIsInstance(jsi.call_function('x'), re.Pattern) + + jsi = JSInterpreter(''' + function x() { let a=/,,[/,913,/](,)}/i; return a; } + ''') + self.assertEqual(jsi.call_function('x').flags & re.I, re.I) + if __name__ == '__main__': unittest.main() diff --git a/yt_dlp/jsinterp.py b/yt_dlp/jsinterp.py index d3994e90c2..2b68f53fae 100644 --- a/yt_dlp/jsinterp.py +++ b/yt_dlp/jsinterp.py @@ -16,50 +16,69 @@ write_string, ) -_NAME_RE = r'[a-zA-Z_$][\w$]*' -# Ref: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/Operator_Precedence -_OPERATORS = { # None => Defined in JSInterpreter._operator - '?': None, +def _js_bit_op(op): + def wrapped(a, b): + def zeroise(x): + return 0 if x in (None, JS_Undefined) else x + return op(zeroise(a), zeroise(b)) - '||': None, - '&&': None, - '&': lambda a, b: (a or 0) & (b or 0), - '|': lambda a, b: (a or 0) | (b or 0), - '^': lambda a, b: (a or 0) ^ (b or 0), - - '===': operator.is_, - '!==': operator.is_not, - '==': operator.eq, - '!=': operator.ne, - - '<=': lambda a, b: (a or 0) <= (b or 0), - '>=': lambda a, b: (a or 0) >= (b or 0), - '<': lambda a, b: (a or 0) < (b or 0), - '>': lambda a, b: (a or 0) > (b or 0), - - '>>': operator.rshift, - '<<': operator.lshift, - - '+': lambda a, b: (a or 0) + (b or 0), - '-': lambda a, b: (a or 0) - (b or 0), - - '*': lambda a, b: (a or 0) * (b or 0), - '/': lambda a, b: (a or 0) / b if b else float('NaN'), - '%': lambda a, b: (a or 0) % b if b else float('NaN'), - - '**': operator.pow, -} - -_COMP_OPERATORS = {'===', '!==', '==', '!=', '<=', '>=', '<', '>'} - -_MATCHING_PARENS = dict(zip('({[', ')}]')) -_QUOTES = '\'"/' + return wrapped -def _ternary(cndn, if_true=True, if_false=False): +def _js_arith_op(op): + + def wrapped(a, b): + if JS_Undefined in (a, b): + return float('nan') + return op(a or 0, b or 0) + + return wrapped + + +def _js_div(a, b): + if JS_Undefined in (a, b) or not (a and b): + return float('nan') + return (a or 0) / b if b else float('inf') + + +def _js_mod(a, b): + if JS_Undefined in (a, b) or not b: + return float('nan') + return (a or 0) % b + + +def _js_exp(a, b): + if not b: + return 1 # even 0 ** 0 !! + elif JS_Undefined in (a, b): + return float('nan') + return (a or 0) ** b + + +def _js_eq_op(op): + + def wrapped(a, b): + if {a, b} <= {None, JS_Undefined}: + return op(a, a) + return op(a, b) + + return wrapped + + +def _js_comp_op(op): + + def wrapped(a, b): + if JS_Undefined in (a, b): + return False + return op(a or 0, b or 0) + + return wrapped + + +def _js_ternary(cndn, if_true=True, if_false=False): """Simulate JS's ternary operator (cndn?if_true:if_false)""" - if cndn in (False, None, 0, ''): + if cndn in (False, None, 0, '', JS_Undefined): return if_false with contextlib.suppress(TypeError): if math.isnan(cndn): # NB: NaN cannot be checked by membership @@ -67,6 +86,50 @@ def _ternary(cndn, if_true=True, if_false=False): return if_true +# Ref: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/Operator_Precedence +_OPERATORS = { # None => Defined in JSInterpreter._operator + '?': None, + '??': None, + '||': None, + '&&': None, + + '|': _js_bit_op(operator.or_), + '^': _js_bit_op(operator.xor), + '&': _js_bit_op(operator.and_), + + '===': operator.is_, + '==': _js_eq_op(operator.eq), + '!==': operator.is_not, + '!=': _js_eq_op(operator.ne), + + '<=': _js_comp_op(operator.le), + '>=': _js_comp_op(operator.ge), + '<': _js_comp_op(operator.lt), + '>': _js_comp_op(operator.gt), + + '>>': _js_bit_op(operator.rshift), + '<<': _js_bit_op(operator.lshift), + + '+': _js_arith_op(operator.add), + '-': _js_arith_op(operator.sub), + + '*': _js_arith_op(operator.mul), + '/': _js_div, + '%': _js_mod, + '**': _js_exp, +} + +_COMP_OPERATORS = {'===', '!==', '==', '!=', '<=', '>=', '<', '>'} + +_NAME_RE = r'[a-zA-Z_$][\w$]*' +_MATCHING_PARENS = dict(zip(*zip('()', '{}', '[]'))) +_QUOTES = '\'"/' + + +class JS_Undefined: + pass + + class JS_Break(ExtractorError): def __init__(self): ExtractorError.__init__(self, 'Invalid break') @@ -119,6 +182,21 @@ def interpret_statement(self, stmt, local_vars, allow_recursion, *args, **kwargs class JSInterpreter: __named_object_counter = 0 + _RE_FLAGS = { + # special knowledge: Python's re flags are bitmask values, current max 128 + # invent new bitmask values well above that for literal parsing + # TODO: new pattern class to execute matches with these flags + 'd': 1024, # Generate indices for substring matches + 'g': 2048, # Global search + 'i': re.I, # Case-insensitive search + 'm': re.M, # Multi-line search + 's': re.S, # Allows . to match newline characters + 'u': re.U, # Treat a pattern as a sequence of unicode code points + 'y': 4096, # Perform a "sticky" search that matches starting at the current position in the target string + } + + _EXC_NAME = '__yt_dlp_exception__' + def __init__(self, code, objects=None): self.code, self._functions = code, {} self._objects = {} if objects is None else objects @@ -135,6 +213,17 @@ def _named_object(self, namespace, obj): namespace[name] = obj return name + @classmethod + def _regex_flags(cls, expr): + flags = 0 + if not expr: + return flags, expr + for idx, ch in enumerate(expr): + if ch not in cls._RE_FLAGS: + break + flags |= cls._RE_FLAGS[ch] + return flags, expr[idx + 1:] + @staticmethod def _separate(expr, delim=',', max_split=None): OP_CHARS = '+-*/%&|^=<>!,;' @@ -178,10 +267,13 @@ def _separate_at_paren(cls, expr, delim): def _operator(self, op, left_val, right_expr, expr, local_vars, allow_recursion): if op in ('||', '&&'): - if (op == '&&') ^ _ternary(left_val): + if (op == '&&') ^ _js_ternary(left_val): return left_val # short circuiting + elif op == '??': + if left_val not in (None, JS_Undefined): + return left_val elif op == '?': - right_expr = _ternary(left_val, *self._separate(right_expr, ':', 1)) + right_expr = _js_ternary(left_val, *self._separate(right_expr, ':', 1)) right_val = self.interpret_expression(right_expr, local_vars, allow_recursion) if not _OPERATORS.get(op): @@ -192,12 +284,14 @@ def _operator(self, op, left_val, right_expr, expr, local_vars, allow_recursion) except Exception as e: raise self.Exception(f'Failed to evaluate {left_val!r} {op} {right_val!r}', expr, cause=e) - def _index(self, obj, idx): + def _index(self, obj, idx, allow_undefined=False): if idx == 'length': return len(obj) try: return obj[int(idx)] if isinstance(obj, list) else obj[idx] except Exception as e: + if allow_undefined: + return JS_Undefined raise self.Exception(f'Cannot get index {idx}', repr(obj), cause=e) def _dump(self, obj, namespace): @@ -233,8 +327,8 @@ def interpret_statement(self, stmt, local_vars, allow_recursion=100): if expr[0] in _QUOTES: inner, outer = self._separate(expr, expr[0], 1) if expr[0] == '/': - inner = inner[1:].replace('"', R'\"') - inner = re.compile(json.loads(js_to_json(f'"{inner}"', strict=True))) + flags, outer = self._regex_flags(outer) + inner = re.compile(inner[1:], flags=flags) else: inner = json.loads(js_to_json(f'{inner}{expr[0]}', strict=True)) if not outer: @@ -259,6 +353,17 @@ def interpret_statement(self, stmt, local_vars, allow_recursion=100): if expr.startswith('{'): inner, outer = self._separate_at_paren(expr, '}') + # Look for Map first + sub_expressions = [list(self._separate(sub_expr.strip(), ':', 1)) for sub_expr in self._separate(inner)] + if all(len(sub_expr) == 2 for sub_expr in sub_expressions): + def dict_item(key, val): + val = self.interpret_expression(val, local_vars, allow_recursion) + if re.match(_NAME_RE, key): + return key, val + return self.interpret_expression(key, local_vars, allow_recursion), val + + return dict(dict_item(k, v) for k, v in sub_expressions), should_return + inner, should_abort = self.interpret_statement(inner, local_vars, allow_recursion) if not outer or should_abort: return inner, should_abort or should_return @@ -295,17 +400,17 @@ def interpret_statement(self, stmt, local_vars, allow_recursion=100): if should_abort: return ret, True except JS_Throw as e: - local_vars['__ytdlp_exception__'] = e.error + local_vars[self._EXC_NAME] = e.error except Exception as e: # XXX: This works for now, but makes debugging future issues very hard - local_vars['__ytdlp_exception__'] = e + local_vars[self._EXC_NAME] = e ret, should_abort = self.interpret_statement(expr, local_vars, allow_recursion) return ret, should_abort or should_return elif m and m.group('catch'): catch_expr, expr = self._separate_at_paren(expr[m.end():], '}') - if '__ytdlp_exception__' in local_vars: - catch_vars = local_vars.new_child({m.group('err'): local_vars.pop('__ytdlp_exception__')}) + if self._EXC_NAME in local_vars: + catch_vars = local_vars.new_child({m.group('err'): local_vars.pop(self._EXC_NAME)}) ret, should_abort = self.interpret_statement(catch_expr, catch_vars, allow_recursion) if should_abort: return ret, True @@ -328,7 +433,7 @@ def interpret_statement(self, stmt, local_vars, allow_recursion=100): start, cndn, increment = self._separate(constructor, ';') self.interpret_expression(start, local_vars, allow_recursion) while True: - if not _ternary(self.interpret_expression(cndn, local_vars, allow_recursion)): + if not _js_ternary(self.interpret_expression(cndn, local_vars, allow_recursion)): break try: ret, should_abort = self.interpret_statement(body, local_vars, allow_recursion) @@ -397,13 +502,13 @@ def interpret_statement(self, stmt, local_vars, allow_recursion=100): (?P (?P{_NAME_RE})(?:\[(?P[^\]]+?)\])?\s* (?P{"|".join(map(re.escape, set(_OPERATORS) - _COMP_OPERATORS))})? - =(?P.*)$ + =(?!=)(?P.*)$ )|(?P (?!if|return|true|false|null|undefined)(?P{_NAME_RE})$ )|(?P (?P{_NAME_RE})\[(?P.+)\]$ )|(?P - (?P{_NAME_RE})(?:\.(?P[^(]+)|\[(?P[^\]]+)\])\s* + (?P{_NAME_RE})(?:(?P\?)?\.(?P[^(]+)|\[(?P[^\]]+)\])\s* )|(?P (?P{_NAME_RE})\((?P.*)\)$ )''', expr) @@ -414,7 +519,7 @@ def interpret_statement(self, stmt, local_vars, allow_recursion=100): local_vars[m.group('out')] = self._operator( m.group('op'), left_val, m.group('expr'), expr, local_vars, allow_recursion) return local_vars[m.group('out')], should_return - elif left_val is None: + elif left_val in (None, JS_Undefined): raise self.Exception(f'Cannot index undefined variable {m.group("out")}', expr) idx = self.interpret_expression(m.group('index'), local_vars, allow_recursion) @@ -432,9 +537,11 @@ def interpret_statement(self, stmt, local_vars, allow_recursion=100): raise JS_Break() elif expr == 'continue': raise JS_Continue() + elif expr == 'undefined': + return JS_Undefined, should_return elif m and m.group('return'): - return local_vars[m.group('name')], should_return + return local_vars.get(m.group('name'), JS_Undefined), should_return with contextlib.suppress(ValueError): return json.loads(js_to_json(expr, strict=True)), should_return @@ -447,8 +554,11 @@ def interpret_statement(self, stmt, local_vars, allow_recursion=100): for op in _OPERATORS: separated = list(self._separate(expr, op)) right_expr = separated.pop() - while op in '<>*-' and len(separated) > 1 and not separated[-1].strip(): - separated.pop() + while True: + if op in '?<>*-' and len(separated) > 1 and not separated[-1].strip(): + separated.pop() + elif not (separated and op == '?' and right_expr.startswith('.')): + break right_expr = f'{op}{right_expr}' if op != '-': right_expr = f'{separated.pop()}{op}{right_expr}' @@ -458,8 +568,7 @@ def interpret_statement(self, stmt, local_vars, allow_recursion=100): return self._operator(op, left_val, right_expr, expr, local_vars, allow_recursion), should_return if m and m.group('attribute'): - variable = m.group('var') - member = m.group('member') + variable, member, nullish = m.group('var', 'member', 'nullish') if not member: member = self.interpret_expression(m.group('member2'), local_vars, allow_recursion) arg_str = expr[m.end():] @@ -486,12 +595,19 @@ def eval_method(): obj = local_vars.get(variable, types.get(variable, NO_DEFAULT)) if obj is NO_DEFAULT: if variable not in self._objects: - self._objects[variable] = self.extract_object(variable) - obj = self._objects[variable] + try: + self._objects[variable] = self.extract_object(variable) + except self.Exception: + if not nullish: + raise + obj = self._objects.get(variable, JS_Undefined) + + if nullish and obj is JS_Undefined: + return JS_Undefined # Member access if arg_str is None: - return self._index(obj, member) + return self._index(obj, member, nullish) # Function call argvals = [