mirror of
				https://github.com/yt-dlp/yt-dlp.git
				synced 2025-10-31 14:45:14 +00:00 
			
		
		
		
	[utils] traverse_obj: Support xml.etree.ElementTree.Element (#8911)
				
					
				
			Authored by: Grub4K
This commit is contained in:
		| @@ -2340,6 +2340,58 @@ Line 1 | |||||||
|         self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'], |         self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'], | ||||||
|                          msg='function on a `re.Match` should give group name as well') |                          msg='function on a `re.Match` should give group name as well') | ||||||
| 
 | 
 | ||||||
|  |         # Test xml.etree.ElementTree.Element as input obj | ||||||
|  |         etree = xml.etree.ElementTree.fromstring('''<?xml version="1.0"?> | ||||||
|  |         <data> | ||||||
|  |             <country name="Liechtenstein"> | ||||||
|  |                 <rank>1</rank> | ||||||
|  |                 <year>2008</year> | ||||||
|  |                 <gdppc>141100</gdppc> | ||||||
|  |                 <neighbor name="Austria" direction="E"/> | ||||||
|  |                 <neighbor name="Switzerland" direction="W"/> | ||||||
|  |             </country> | ||||||
|  |             <country name="Singapore"> | ||||||
|  |                 <rank>4</rank> | ||||||
|  |                 <year>2011</year> | ||||||
|  |                 <gdppc>59900</gdppc> | ||||||
|  |                 <neighbor name="Malaysia" direction="N"/> | ||||||
|  |             </country> | ||||||
|  |             <country name="Panama"> | ||||||
|  |                 <rank>68</rank> | ||||||
|  |                 <year>2011</year> | ||||||
|  |                 <gdppc>13600</gdppc> | ||||||
|  |                 <neighbor name="Costa Rica" direction="W"/> | ||||||
|  |                 <neighbor name="Colombia" direction="E"/> | ||||||
|  |             </country> | ||||||
|  |         </data>''') | ||||||
|  |         self.assertEqual(traverse_obj(etree, ''), etree, | ||||||
|  |                          msg='empty str key should return the element itself') | ||||||
|  |         self.assertEqual(traverse_obj(etree, 'country'), list(etree), | ||||||
|  |                          msg='str key should lead all children with that tag name') | ||||||
|  |         self.assertEqual(traverse_obj(etree, ...), list(etree), | ||||||
|  |                          msg='`...` as key should return all children') | ||||||
|  |         self.assertEqual(traverse_obj(etree, lambda _, x: x[0].text == '4'), [etree[1]], | ||||||
|  |                          msg='function as key should get element as value') | ||||||
|  |         self.assertEqual(traverse_obj(etree, lambda i, _: i == 1), [etree[1]], | ||||||
|  |                          msg='function as key should get index as key') | ||||||
|  |         self.assertEqual(traverse_obj(etree, 0), etree[0], | ||||||
|  |                          msg='int key should return the nth child') | ||||||
|  |         self.assertEqual(traverse_obj(etree, './/neighbor/@name'), | ||||||
|  |                          ['Austria', 'Switzerland', 'Malaysia', 'Costa Rica', 'Colombia'], | ||||||
|  |                          msg='`@<attribute>` at end of path should give that attribute') | ||||||
|  |         self.assertEqual(traverse_obj(etree, '//neighbor/@fail'), [None, None, None, None, None], | ||||||
|  |                          msg='`@<nonexistant>` at end of path should give `None`') | ||||||
|  |         self.assertEqual(traverse_obj(etree, ('//neighbor/@', 2)), {'name': 'Malaysia', 'direction': 'N'}, | ||||||
|  |                          msg='`@` should give the full attribute dict') | ||||||
|  |         self.assertEqual(traverse_obj(etree, '//year/text()'), ['2008', '2011', '2011'], | ||||||
|  |                          msg='`text()` at end of path should give the inner text') | ||||||
|  |         self.assertEqual(traverse_obj(etree, '//*[@direction]/@direction'), ['E', 'W', 'N', 'W', 'E'], | ||||||
|  |                          msg='full python xpath features should be supported') | ||||||
|  |         self.assertEqual(traverse_obj(etree, (0, '@name')), 'Liechtenstein', | ||||||
|  |                          msg='special transformations should act on current element') | ||||||
|  |         self.assertEqual(traverse_obj(etree, ('country', 0, ..., 'text()', {int_or_none})), [1, 2008, 141100], | ||||||
|  |                          msg='special transformations should act on current element') | ||||||
|  | 
 | ||||||
|     def test_http_header_dict(self): |     def test_http_header_dict(self): | ||||||
|         headers = HTTPHeaderDict() |         headers = HTTPHeaderDict() | ||||||
|         headers['ytdl-test'] = b'0' |         headers['ytdl-test'] = b'0' | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ import contextlib | |||||||
| import inspect | import inspect | ||||||
| import itertools | import itertools | ||||||
| import re | import re | ||||||
|  | import xml.etree.ElementTree | ||||||
| 
 | 
 | ||||||
| from ._utils import ( | from ._utils import ( | ||||||
|     IDENTITY, |     IDENTITY, | ||||||
| @@ -118,7 +119,7 @@ def traverse_obj( | |||||||
|             branching = True |             branching = True | ||||||
|             if isinstance(obj, collections.abc.Mapping): |             if isinstance(obj, collections.abc.Mapping): | ||||||
|                 result = obj.values() |                 result = obj.values() | ||||||
|             elif is_iterable_like(obj): |             elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element): | ||||||
|                 result = obj |                 result = obj | ||||||
|             elif isinstance(obj, re.Match): |             elif isinstance(obj, re.Match): | ||||||
|                 result = obj.groups() |                 result = obj.groups() | ||||||
| @@ -132,7 +133,7 @@ def traverse_obj( | |||||||
|             branching = True |             branching = True | ||||||
|             if isinstance(obj, collections.abc.Mapping): |             if isinstance(obj, collections.abc.Mapping): | ||||||
|                 iter_obj = obj.items() |                 iter_obj = obj.items() | ||||||
|             elif is_iterable_like(obj): |             elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element): | ||||||
|                 iter_obj = enumerate(obj) |                 iter_obj = enumerate(obj) | ||||||
|             elif isinstance(obj, re.Match): |             elif isinstance(obj, re.Match): | ||||||
|                 iter_obj = itertools.chain( |                 iter_obj = itertools.chain( | ||||||
| @@ -168,7 +169,7 @@ def traverse_obj( | |||||||
|                 result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) |                 result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) | ||||||
| 
 | 
 | ||||||
|         elif isinstance(key, (int, slice)): |         elif isinstance(key, (int, slice)): | ||||||
|             if is_iterable_like(obj, collections.abc.Sequence): |             if is_iterable_like(obj, (collections.abc.Sequence, xml.etree.ElementTree.Element)): | ||||||
|                 branching = isinstance(key, slice) |                 branching = isinstance(key, slice) | ||||||
|                 with contextlib.suppress(IndexError): |                 with contextlib.suppress(IndexError): | ||||||
|                     result = obj[key] |                     result = obj[key] | ||||||
| @@ -176,6 +177,34 @@ def traverse_obj( | |||||||
|                 with contextlib.suppress(IndexError): |                 with contextlib.suppress(IndexError): | ||||||
|                     result = str(obj)[key] |                     result = str(obj)[key] | ||||||
| 
 | 
 | ||||||
|  |         elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str): | ||||||
|  |             xpath, _, special = key.rpartition('/') | ||||||
|  |             if not special.startswith('@') and special != 'text()': | ||||||
|  |                 xpath = key | ||||||
|  |                 special = None | ||||||
|  | 
 | ||||||
|  |             # Allow abbreviations of relative paths, absolute paths error | ||||||
|  |             if xpath.startswith('/'): | ||||||
|  |                 xpath = f'.{xpath}' | ||||||
|  |             elif xpath and not xpath.startswith('./'): | ||||||
|  |                 xpath = f'./{xpath}' | ||||||
|  | 
 | ||||||
|  |             def apply_specials(element): | ||||||
|  |                 if special is None: | ||||||
|  |                     return element | ||||||
|  |                 if special == '@': | ||||||
|  |                     return element.attrib | ||||||
|  |                 if special.startswith('@'): | ||||||
|  |                     return try_call(element.attrib.get, args=(special[1:],)) | ||||||
|  |                 if special == 'text()': | ||||||
|  |                     return element.text | ||||||
|  |                 assert False, f'apply_specials is missing case for {special!r}' | ||||||
|  | 
 | ||||||
|  |             if xpath: | ||||||
|  |                 result = list(map(apply_specials, obj.iterfind(xpath))) | ||||||
|  |             else: | ||||||
|  |                 result = apply_specials(obj) | ||||||
|  | 
 | ||||||
|         return branching, result if branching else (result,) |         return branching, result if branching else (result,) | ||||||
| 
 | 
 | ||||||
|     def lazy_last(iterable): |     def lazy_last(iterable): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Simon Sawicki
					Simon Sawicki