mirror of
				https://github.com/yt-dlp/yt-dlp.git
				synced 2025-10-31 14:45:14 +00:00 
			
		
		
		
	[utils] traverse_obj:  Various improvements
				
					
				
			- Add `set` key for transformations/filters - Add `re.Match` group names - Fix behavior for `expected_type` with `dict` key - Raise for filter function signature mismatch in debug Authored by: Grub4K
This commit is contained in:
		| @@ -105,6 +105,7 @@ from yt_dlp.utils import ( | |||||||
|     sanitized_Request, |     sanitized_Request, | ||||||
|     shell_quote, |     shell_quote, | ||||||
|     smuggle_url, |     smuggle_url, | ||||||
|  |     str_or_none, | ||||||
|     str_to_int, |     str_to_int, | ||||||
|     strip_jsonp, |     strip_jsonp, | ||||||
|     strip_or_none, |     strip_or_none, | ||||||
| @@ -2015,6 +2016,29 @@ Line 1 | |||||||
|                          msg='function as query key should perform a filter based on (key, value)') |                          msg='function as query key should perform a filter based on (key, value)') | ||||||
|         self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'}, |         self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'}, | ||||||
|                               msg='exceptions in the query function should be catched') |                               msg='exceptions in the query function should be catched') | ||||||
|  |         if __debug__: | ||||||
|  |             with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'): | ||||||
|  |                 traverse_obj(_TEST_DATA, lambda a: ...) | ||||||
|  |             with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'): | ||||||
|  |                 traverse_obj(_TEST_DATA, lambda a, b, c: ...) | ||||||
|  | 
 | ||||||
|  |         # Test set as key (transformation/type, like `expected_type`) | ||||||
|  |         self.assertEqual(traverse_obj(_TEST_DATA, (..., {str.upper}, )), ['STR'], | ||||||
|  |                          msg='Function in set should be a transformation') | ||||||
|  |         self.assertEqual(traverse_obj(_TEST_DATA, (..., {str})), ['str'], | ||||||
|  |                          msg='Type in set should be a type filter') | ||||||
|  |         self.assertEqual(traverse_obj(_TEST_DATA, {dict}), _TEST_DATA, | ||||||
|  |                          msg='A single set should be wrapped into a path') | ||||||
|  |         self.assertEqual(traverse_obj(_TEST_DATA, (..., {str.upper})), ['STR'], | ||||||
|  |                          msg='Transformation function should not raise') | ||||||
|  |         self.assertEqual(traverse_obj(_TEST_DATA, (..., {str_or_none})), | ||||||
|  |                          [item for item in map(str_or_none, _TEST_DATA.values()) if item is not None], | ||||||
|  |                          msg='Function in set should be a transformation') | ||||||
|  |         if __debug__: | ||||||
|  |             with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'): | ||||||
|  |                 traverse_obj(_TEST_DATA, set()) | ||||||
|  |             with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'): | ||||||
|  |                 traverse_obj(_TEST_DATA, {str.upper, str}) | ||||||
| 
 | 
 | ||||||
|         # Test alternative paths |         # Test alternative paths | ||||||
|         self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str', |         self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str', | ||||||
| @@ -2106,6 +2130,20 @@ Line 1 | |||||||
|                          msg='wrap expected_type fuction in try_call') |                          msg='wrap expected_type fuction in try_call') | ||||||
|         self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str), ['str'], |         self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str), ['str'], | ||||||
|                          msg='eliminate items that expected_type fails on') |                          msg='eliminate items that expected_type fails on') | ||||||
|  |         self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int), {0: 100}, | ||||||
|  |                          msg='type as expected_type should filter dict values') | ||||||
|  |         self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none), {0: '100', 1: '1.2'}, | ||||||
|  |                          msg='function as expected_type should transform dict values') | ||||||
|  |         self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, {int_or_none}), expected_type=int), 1, | ||||||
|  |                          msg='expected_type should not filter non final dict values') | ||||||
|  |         self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int), {0: {0: 100}}, | ||||||
|  |                          msg='expected_type should transform deep dict values') | ||||||
|  |         self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(...)), [{0: ...}, {0: ...}], | ||||||
|  |                          msg='expected_type should transform branched dict values') | ||||||
|  |         self.assertEqual(traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int), [4], | ||||||
|  |                          msg='expected_type regression for type matching in tuple branching') | ||||||
|  |         self.assertEqual(traverse_obj(_TEST_DATA, ['data', ...], expected_type=int), [], | ||||||
|  |                          msg='expected_type regression for type matching in dict result') | ||||||
| 
 | 
 | ||||||
|         # Test get_all behavior |         # Test get_all behavior | ||||||
|         _GET_ALL_DATA = {'key': [0, 1, 2]} |         _GET_ALL_DATA = {'key': [0, 1, 2]} | ||||||
| @@ -2189,6 +2227,8 @@ Line 1 | |||||||
|                          msg='failing str key on a `re.Match` should return `default`') |                          msg='failing str key on a `re.Match` should return `default`') | ||||||
|         self.assertEqual(traverse_obj(mobj, 8), None, |         self.assertEqual(traverse_obj(mobj, 8), None, | ||||||
|                          msg='failing int key on a `re.Match` should return `default`') |                          msg='failing int key on a `re.Match` should return `default`') | ||||||
|  |         self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'], | ||||||
|  |                          msg='function on a `re.Match` should give group name as well') | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   | |||||||
| @@ -5424,6 +5424,9 @@ def traverse_obj( | |||||||
| 
 | 
 | ||||||
|     The keys in the path can be one of: |     The keys in the path can be one of: | ||||||
|         - `None`:           Return the current object. |         - `None`:           Return the current object. | ||||||
|  |         - `set`:            Requires the only item in the set to be a type or function, | ||||||
|  |                             like `{type}`/`{func}`. If a `type`, returns only values | ||||||
|  |                             of this type. If a function, returns `func(obj)`. | ||||||
|         - `str`/`int`:      Return `obj[key]`. For `re.Match`, return `obj.group(key)`. |         - `str`/`int`:      Return `obj[key]`. For `re.Match`, return `obj.group(key)`. | ||||||
|         - `slice`:          Branch out and return all values in `obj[key]`. |         - `slice`:          Branch out and return all values in `obj[key]`. | ||||||
|         - `Ellipsis`:       Branch out and return a list of all values. |         - `Ellipsis`:       Branch out and return a list of all values. | ||||||
| @@ -5432,6 +5435,8 @@ def traverse_obj( | |||||||
|         - `function`:       Branch out and return values filtered by the function. |         - `function`:       Branch out and return values filtered by the function. | ||||||
|                             Read as: `[value for key, value in obj if function(key, value)]`. |                             Read as: `[value for key, value in obj if function(key, value)]`. | ||||||
|                             For `Sequence`s, `key` is the index of the value. |                             For `Sequence`s, `key` is the index of the value. | ||||||
|  |                             For `re.Match`es, `key` is the group number (0 = full match) | ||||||
|  |                             as well as additionally any group names, if given. | ||||||
|         - `dict`            Transform the current object and return a matching dict. |         - `dict`            Transform the current object and return a matching dict. | ||||||
|                             Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. |                             Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. | ||||||
| 
 | 
 | ||||||
| @@ -5441,6 +5446,8 @@ def traverse_obj( | |||||||
|     @param default          Value to return if the paths do not match. |     @param default          Value to return if the paths do not match. | ||||||
|     @param expected_type    If a `type`, only accept final values of this type. |     @param expected_type    If a `type`, only accept final values of this type. | ||||||
|                             If any other callable, try to call the function on each result. |                             If any other callable, try to call the function on each result. | ||||||
|  |                             If the last key in the path is a `dict`, it will apply to each value inside | ||||||
|  |                             the dict instead, recursively. This does respect branching paths. | ||||||
|     @param get_all          If `False`, return the first matching result, otherwise all matching ones. |     @param get_all          If `False`, return the first matching result, otherwise all matching ones. | ||||||
|     @param casesense        If `False`, consider string dictionary keys as case insensitive. |     @param casesense        If `False`, consider string dictionary keys as case insensitive. | ||||||
| 
 | 
 | ||||||
| @@ -5466,16 +5473,25 @@ def traverse_obj( | |||||||
|     else: |     else: | ||||||
|         type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) |         type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) | ||||||
| 
 | 
 | ||||||
|     def apply_key(key, obj): |     def apply_key(key, test_type, obj): | ||||||
|         if obj is None: |         if obj is None: | ||||||
|             return |             return | ||||||
| 
 | 
 | ||||||
|         elif key is None: |         elif key is None: | ||||||
|             yield obj |             yield obj | ||||||
| 
 | 
 | ||||||
|  |         elif isinstance(key, set): | ||||||
|  |             assert len(key) == 1, 'Set should only be used to wrap a single item' | ||||||
|  |             item = next(iter(key)) | ||||||
|  |             if isinstance(item, type): | ||||||
|  |                 if isinstance(obj, item): | ||||||
|  |                     yield obj | ||||||
|  |             else: | ||||||
|  |                 yield try_call(item, args=(obj,)) | ||||||
|  | 
 | ||||||
|         elif isinstance(key, (list, tuple)): |         elif isinstance(key, (list, tuple)): | ||||||
|             for branch in key: |             for branch in key: | ||||||
|                 _, result = apply_path(obj, branch) |                 _, result = apply_path(obj, branch, test_type) | ||||||
|                 yield from result |                 yield from result | ||||||
| 
 | 
 | ||||||
|         elif key is ...: |         elif key is ...: | ||||||
| @@ -5494,7 +5510,9 @@ def traverse_obj( | |||||||
|             elif isinstance(obj, collections.abc.Mapping): |             elif isinstance(obj, collections.abc.Mapping): | ||||||
|                 iter_obj = obj.items() |                 iter_obj = obj.items() | ||||||
|             elif isinstance(obj, re.Match): |             elif isinstance(obj, re.Match): | ||||||
|                 iter_obj = enumerate((obj.group(), *obj.groups())) |                 iter_obj = itertools.chain( | ||||||
|  |                     enumerate((obj.group(), *obj.groups())), | ||||||
|  |                     obj.groupdict().items()) | ||||||
|             elif traverse_string: |             elif traverse_string: | ||||||
|                 iter_obj = enumerate(str(obj)) |                 iter_obj = enumerate(str(obj)) | ||||||
|             else: |             else: | ||||||
| @@ -5502,7 +5520,7 @@ def traverse_obj( | |||||||
|             yield from (v for k, v in iter_obj if try_call(key, args=(k, v))) |             yield from (v for k, v in iter_obj if try_call(key, args=(k, v))) | ||||||
| 
 | 
 | ||||||
|         elif isinstance(key, dict): |         elif isinstance(key, dict): | ||||||
|             iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items()) |             iter_obj = ((k, _traverse_obj(obj, v, test_type=test_type)) for k, v in key.items()) | ||||||
|             yield {k: v if v is not None else default for k, v in iter_obj |             yield {k: v if v is not None else default for k, v in iter_obj | ||||||
|                    if v is not None or default is not NO_DEFAULT} |                    if v is not None or default is not NO_DEFAULT} | ||||||
| 
 | 
 | ||||||
| @@ -5537,11 +5555,24 @@ def traverse_obj( | |||||||
|             with contextlib.suppress(IndexError): |             with contextlib.suppress(IndexError): | ||||||
|                 yield obj[key] |                 yield obj[key] | ||||||
| 
 | 
 | ||||||
|     def apply_path(start_obj, path): |     def lazy_last(iterable): | ||||||
|  |         iterator = iter(iterable) | ||||||
|  |         prev = next(iterator, NO_DEFAULT) | ||||||
|  |         if prev is NO_DEFAULT: | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         for item in iterator: | ||||||
|  |             yield False, prev | ||||||
|  |             prev = item | ||||||
|  | 
 | ||||||
|  |         yield True, prev | ||||||
|  | 
 | ||||||
|  |     def apply_path(start_obj, path, test_type=False): | ||||||
|         objs = (start_obj,) |         objs = (start_obj,) | ||||||
|         has_branched = False |         has_branched = False | ||||||
| 
 | 
 | ||||||
|         for key in variadic(path): |         key = None | ||||||
|  |         for last, key in lazy_last(variadic(path, (str, bytes, dict, set))): | ||||||
|             if is_user_input and key == ':': |             if is_user_input and key == ':': | ||||||
|                 key = ... |                 key = ... | ||||||
| 
 | 
 | ||||||
| @@ -5551,14 +5582,21 @@ def traverse_obj( | |||||||
|             if key is ... or isinstance(key, (list, tuple)) or callable(key): |             if key is ... or isinstance(key, (list, tuple)) or callable(key): | ||||||
|                 has_branched = True |                 has_branched = True | ||||||
| 
 | 
 | ||||||
|             key_func = functools.partial(apply_key, key) |             if __debug__ and callable(key): | ||||||
|  |                 # Verify function signature | ||||||
|  |                 inspect.signature(key).bind(None, None) | ||||||
|  | 
 | ||||||
|  |             key_func = functools.partial(apply_key, key, last) | ||||||
|             objs = itertools.chain.from_iterable(map(key_func, objs)) |             objs = itertools.chain.from_iterable(map(key_func, objs)) | ||||||
| 
 | 
 | ||||||
|  |         if test_type and not isinstance(key, (dict, list, tuple)): | ||||||
|  |             objs = map(type_test, objs) | ||||||
|  | 
 | ||||||
|         return has_branched, objs |         return has_branched, objs | ||||||
| 
 | 
 | ||||||
|     def _traverse_obj(obj, path, use_list=True): |     def _traverse_obj(obj, path, use_list=True, test_type=True): | ||||||
|         has_branched, results = apply_path(obj, path) |         has_branched, results = apply_path(obj, path, test_type) | ||||||
|         results = LazyList(x for x in map(type_test, results) if x is not None) |         results = LazyList(x for x in results if x is not None) | ||||||
| 
 | 
 | ||||||
|         if get_all and has_branched: |         if get_all and has_branched: | ||||||
|             return results.exhaust() if results or use_list else None |             return results.exhaust() if results or use_list else None | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Simon Sawicki
					Simon Sawicki